openfl.federated.task.runner.TaskRunner#
- class openfl.federated.task.runner.TaskRunner(data_loader, tensor_dict_split_fn_kwargs=None, **kwargs)[source]#
Bases:
objectFederated Learning Task Runner Class.
- Parameters:
tensor_dict_split_fn_kwargs (dict)
- data_loader#
The data_loader object.
- tensor_dict_split_fn_kwargs#
Key word arguments for determining which parameters to hold out from aggregation.
- Type:
dict
- logger#
Logger object for logging events.
- Type:
logging.Logger
- opt_treatment#
Treatment of current instance optimizer.
- Type:
str
- __init__(data_loader, tensor_dict_split_fn_kwargs=None, **kwargs)[source]#
Intializes the TaskRunner object.
- Parameters:
data_loader – The data_loader object
tensor_dict_split_fn_kwargs (dict, optional) – Key word arguments for determining which parameters to hold out from aggregation. Default is None.
**kwargs – Additional parameters to pass to the function.
Methods
__init__(data_loader[, ...])Intializes the TaskRunner object.
Get the data_loader object.
When running a task, a map of named tensorkeys must be provided to the function as dependencies.
get_tensor_dict(with_opt_vars)Get the weights.
Get the number of training examples.
Get the number of examples.
Initialize all global variables.
load_native(filepath, **kwargs)Load model state from a filepath in ML-framework "native" format, e.g. PyTorch pickled models.
Reinitialize the optimizer variables.
save_native(filepath, **kwargs)Save model state in ML-framework "native" format, e.g. PyTorch pickled models.
set_data_loader(data_loader)Set data_loader object.
Set up the log object.
set_optimizer_treatment(opt_treatment)Change the treatment of current instance optimizer.
set_tensor_dict(tensor_dict, with_opt_vars)Set the model weights with a tensor dictionary: {<tensor_name>: <value>}.
train_batches([num_batches, use_tqdm])Perform the training for a specified number of batches.
validate()Run validation.
- get_data_loader()[source]#
Get the data_loader object.
Serves up batches and provides info regarding data_loader.
- Returns:
data_loader object
- get_required_tensorkeys_for_function(func_name, **kwargs)[source]#
When running a task, a map of named tensorkeys must be provided to the function as dependencies.
- Parameters:
func_name (str) – The function name.
**kwargs – Additional parameters to pass to the function.
- Returns:
- List of required TensorKey. (TensorKey(tensor_name, origin,
round_number))
- Return type:
list
- get_tensor_dict(with_opt_vars)[source]#
Get the weights.
- Parameters:
with_opt_vars (bool) – Specify if we also want to get the variables of the optimizer.
- Returns:
The weight dictionary {<tensor_name>: <value>}.
- Return type:
dict
- get_train_data_size()[source]#
Get the number of training examples.
It will be used for weighted averaging in aggregation.
- Returns:
The number of training examples.
- Return type:
int
- get_valid_data_size()[source]#
Get the number of examples.
It will be used for weighted averaging in aggregation.
- Returns:
The number of validation examples.
- Return type:
int
- load_native(filepath, **kwargs)[source]#
Load model state from a filepath in ML-framework “native” format, e.g. PyTorch pickled models.
May load from multiple files. Other filepaths may be derived from the passed filepath, or they may be in the kwargs.
- Parameters:
filepath (str) –
Path to frame-work specific file to load. For frameworks that use multiple files, this string must be
used to derive the other filepaths.
**kwargs – Additional parameters to pass to the function. For future-proofing.
- Returns:
None
- save_native(filepath, **kwargs)[source]#
Save model state in ML-framework “native” format, e.g. PyTorch pickled models.
May save one file or multiple files, depending on the framework.
- Parameters:
filepath (str) – If framework stores a single file, this should be a single file path. Frameworks that store multiple files may need to derive the other paths from this path.
**kwargs – Additional parameters to pass to the function. For future-proofing.
- Returns:
None
- set_data_loader(data_loader)[source]#
Set data_loader object.
- Parameters:
data_loader – data_loader object to set.
- Returns:
None
- set_optimizer_treatment(opt_treatment)[source]#
Change the treatment of current instance optimizer.
- Parameters:
opt_treatment (str) – The optimizer treatment.
- Returns:
None
- set_tensor_dict(tensor_dict, with_opt_vars)[source]#
Set the model weights with a tensor dictionary: {<tensor_name>: <value>}.
- Parameters:
tensor_dict (dict) – The model weights dictionary.
with_opt_vars (bool) – Specify if we also want to set the variables of the optimizer.
- Returns:
None
- train_batches(num_batches=None, use_tqdm=False)[source]#
Perform the training for a specified number of batches.
Is expected to perform draws randomly, without replacement until data is exausted. Then data is replaced and shuffled and draws continue.
- Parameters:
num_batches (int, optional) – Number of batches to train. Default is None.
use_tqdm (bool, optional) – If True, use tqdm to print a progress bar. Default is False.
- Returns:
{<metric>: <value>}.
- Return type:
dict