openfl.federated.task.runner.TaskRunner
- class openfl.federated.task.runner.TaskRunner(data_loader, tensor_dict_split_fn_kwargs: dict | None = None, **kwargs)
Bases:
objectFederated Learning Task Runner Class.
- Class Attributes:
data_loader – The data_loader object.
tensor_dict_split_fn_kwargs (dict) – Key word arguments for determining which parameters to hold out from aggregation.
logger (logging.Logger) – Logger object for logging events.
opt_treatment (str) – Treatment of current instance optimizer.
Methods
Get the data_loader object.
When running a task, a map of named tensorkeys must be provided to the function as dependencies.
Get the weights.
Get the number of training examples.
Get the number of examples.
Initialize all global variables.
Load model state from a filepath in ML-framework "native" format, e.g. PyTorch pickled models.
Reinitialize the optimizer variables.
Save model state in ML-framework "native" format, e.g. PyTorch pickled models.
Set data_loader object.
Set up the log object.
Change the treatment of current instance optimizer.
Set the model weights with a tensor dictionary: {<tensor_name>: <value>}.
Perform the training for a specified number of batches.
Run validation.
- get_data_loader()
Get the data_loader object.
Serves up batches and provides info regarding data_loader.
- Returns:
data_loader object
- get_required_tensorkeys_for_function(func_name, **kwargs)
When running a task, a map of named tensorkeys must be provided to the function as dependencies.
- Parameters:
func_name (str) – The function name.
**kwargs – Additional parameters to pass to the function.
- Returns:
list – List of required TensorKey. (TensorKey(tensor_name, origin, round_number))
- get_tensor_dict(with_opt_vars)
Get the weights.
- Parameters:
with_opt_vars (bool) – Specify if we also want to get the variables of the optimizer.
- Returns:
dict – The weight dictionary {<tensor_name>: <value>}.
- get_train_data_size()
Get the number of training examples.
It will be used for weighted averaging in aggregation.
- Returns:
int – The number of training examples.
- get_valid_data_size()
Get the number of examples.
It will be used for weighted averaging in aggregation.
- Returns:
int – The number of validation examples.
- initialize_globals()
Initialize all global variables.
- Returns:
None
- load_native(filepath, **kwargs)
Load model state from a filepath in ML-framework “native” format, e.g. PyTorch pickled models.
May load from multiple files. Other filepaths may be derived from the passed filepath, or they may be in the kwargs.
- Parameters:
filepath (str) –
Path to frame-work specific file to load. For frameworks that use multiple files, this string must be
used to derive the other filepaths.
**kwargs – Additional parameters to pass to the function. For future-proofing.
- Returns:
None
- reset_opt_vars()
Reinitialize the optimizer variables.
- Returns:
None
- save_native(filepath, **kwargs)
Save model state in ML-framework “native” format, e.g. PyTorch pickled models.
May save one file or multiple files, depending on the framework.
- Parameters:
filepath (str) – If framework stores a single file, this should be a single file path. Frameworks that store multiple files may need to derive the other paths from this path.
**kwargs – Additional parameters to pass to the function. For future-proofing.
- Returns:
None
- set_data_loader(data_loader)
Set data_loader object.
- Parameters:
data_loader – data_loader object to set.
- Returns:
None
- set_logger()
Set up the log object.
- Returns:
None
- set_optimizer_treatment(opt_treatment)
Change the treatment of current instance optimizer.
- Parameters:
opt_treatment (str) – The optimizer treatment.
- Returns:
None
- set_tensor_dict(tensor_dict, with_opt_vars)
Set the model weights with a tensor dictionary: {<tensor_name>: <value>}.
- Parameters:
tensor_dict (dict) – The model weights dictionary.
with_opt_vars (bool) – Specify if we also want to set the variables of the optimizer.
- Returns:
None
- train_batches(num_batches=None, use_tqdm=False)
Perform the training for a specified number of batches.
Is expected to perform draws randomly, without replacement until data is exausted. Then data is replaced and shuffled and draws continue.
- Parameters:
num_batches (int, optional) – Number of batches to train. Default is None.
use_tqdm (bool, optional) – If True, use tqdm to print a progress bar. Default is False.
- Returns:
dict – {<metric>: <value>}.
- validate()
Run validation.
- Returns:
dict – {<metric>: <value>}.