openfl.federated.task.runner.TaskRunner#

class openfl.federated.task.runner.TaskRunner(data_loader, tensor_dict_split_fn_kwargs=None, **kwargs)[source]#

Bases: object

Federated Learning Task Runner Class.

Parameters:

tensor_dict_split_fn_kwargs (dict)

data_loader#

The data_loader object.

tensor_dict_split_fn_kwargs#

Key word arguments for determining which parameters to hold out from aggregation.

Type:

dict

logger#

Logger object for logging events.

Type:

logging.Logger

opt_treatment#

Treatment of current instance optimizer.

Type:

str

__init__(data_loader, tensor_dict_split_fn_kwargs=None, **kwargs)[source]#

Intializes the TaskRunner object.

Parameters:
  • data_loader – The data_loader object

  • tensor_dict_split_fn_kwargs (dict, optional) – Key word arguments for determining which parameters to hold out from aggregation. Default is None.

  • **kwargs – Additional parameters to pass to the function.

Methods

__init__(data_loader[, ...])

Intializes the TaskRunner object.

get_data_loader()

Get the data_loader object.

get_required_tensorkeys_for_function(...)

When running a task, a map of named tensorkeys must be provided to the function as dependencies.

get_tensor_dict(with_opt_vars)

Get the weights.

get_train_data_size()

Get the number of training examples.

get_valid_data_size()

Get the number of examples.

initialize_globals()

Initialize all global variables.

load_native(filepath, **kwargs)

Load model state from a filepath in ML-framework "native" format, e.g. PyTorch pickled models.

reset_opt_vars()

Reinitialize the optimizer variables.

save_native(filepath, **kwargs)

Save model state in ML-framework "native" format, e.g. PyTorch pickled models.

set_data_loader(data_loader)

Set data_loader object.

set_logger()

Set up the log object.

set_optimizer_treatment(opt_treatment)

Change the treatment of current instance optimizer.

set_tensor_dict(tensor_dict, with_opt_vars)

Set the model weights with a tensor dictionary: {<tensor_name>: <value>}.

train_batches([num_batches, use_tqdm])

Perform the training for a specified number of batches.

validate()

Run validation.

get_data_loader()[source]#

Get the data_loader object.

Serves up batches and provides info regarding data_loader.

Returns:

data_loader object

get_required_tensorkeys_for_function(func_name, **kwargs)[source]#

When running a task, a map of named tensorkeys must be provided to the function as dependencies.

Parameters:
  • func_name (str) – The function name.

  • **kwargs – Additional parameters to pass to the function.

Returns:

List of required TensorKey. (TensorKey(tensor_name, origin,

round_number))

Return type:

list

get_tensor_dict(with_opt_vars)[source]#

Get the weights.

Parameters:

with_opt_vars (bool) – Specify if we also want to get the variables of the optimizer.

Returns:

The weight dictionary {<tensor_name>: <value>}.

Return type:

dict

get_train_data_size()[source]#

Get the number of training examples.

It will be used for weighted averaging in aggregation.

Returns:

The number of training examples.

Return type:

int

get_valid_data_size()[source]#

Get the number of examples.

It will be used for weighted averaging in aggregation.

Returns:

The number of validation examples.

Return type:

int

initialize_globals()[source]#

Initialize all global variables.

Returns:

None

load_native(filepath, **kwargs)[source]#

Load model state from a filepath in ML-framework “native” format, e.g. PyTorch pickled models.

May load from multiple files. Other filepaths may be derived from the passed filepath, or they may be in the kwargs.

Parameters:
  • filepath (str) –

    Path to frame-work specific file to load. For frameworks that use multiple files, this string must be

    used to derive the other filepaths.

  • **kwargs – Additional parameters to pass to the function. For future-proofing.

Returns:

None

reset_opt_vars()[source]#

Reinitialize the optimizer variables.

Returns:

None

save_native(filepath, **kwargs)[source]#

Save model state in ML-framework “native” format, e.g. PyTorch pickled models.

May save one file or multiple files, depending on the framework.

Parameters:
  • filepath (str) – If framework stores a single file, this should be a single file path. Frameworks that store multiple files may need to derive the other paths from this path.

  • **kwargs – Additional parameters to pass to the function. For future-proofing.

Returns:

None

set_data_loader(data_loader)[source]#

Set data_loader object.

Parameters:

data_loader – data_loader object to set.

Returns:

None

set_logger()[source]#

Set up the log object.

Returns:

None

set_optimizer_treatment(opt_treatment)[source]#

Change the treatment of current instance optimizer.

Parameters:

opt_treatment (str) – The optimizer treatment.

Returns:

None

set_tensor_dict(tensor_dict, with_opt_vars)[source]#

Set the model weights with a tensor dictionary: {<tensor_name>: <value>}.

Parameters:
  • tensor_dict (dict) – The model weights dictionary.

  • with_opt_vars (bool) – Specify if we also want to set the variables of the optimizer.

Returns:

None

train_batches(num_batches=None, use_tqdm=False)[source]#

Perform the training for a specified number of batches.

Is expected to perform draws randomly, without replacement until data is exausted. Then data is replaced and shuffled and draws continue.

Parameters:
  • num_batches (int, optional) – Number of batches to train. Default is None.

  • use_tqdm (bool, optional) – If True, use tqdm to print a progress bar. Default is False.

Returns:

{<metric>: <value>}.

Return type:

dict

validate()[source]#

Run validation.

Returns:

{<metric>: <value>}.

Return type:

dict