openfl.federated.task

Task package.

class openfl.federated.task.TaskRunner(data_loader, tensor_dict_split_fn_kwargs: dict | None = None, **kwargs)

Federated Learning Task Runner Class.

Class Attributes:
  • data_loader – The data_loader object.

  • tensor_dict_split_fn_kwargs (dict) – Key word arguments for determining which parameters to hold out from aggregation.

  • logger (logging.Logger) – Logger object for logging events.

  • opt_treatment (str) – Treatment of current instance optimizer.

get_data_loader()

Get the data_loader object.

Serves up batches and provides info regarding data_loader.

Returns:

data_loader object

get_required_tensorkeys_for_function(func_name, **kwargs)

When running a task, a map of named tensorkeys must be provided to the function as dependencies.

Parameters:
  • func_name (str) – The function name.

  • **kwargs – Additional parameters to pass to the function.

Returns:

list – List of required TensorKey. (TensorKey(tensor_name, origin, round_number))

get_tensor_dict(with_opt_vars)

Get the weights.

Parameters:

with_opt_vars (bool) – Specify if we also want to get the variables of the optimizer.

Returns:

dict – The weight dictionary {<tensor_name>: <value>}.

get_train_data_size()

Get the number of training examples.

It will be used for weighted averaging in aggregation.

Returns:

int – The number of training examples.

get_valid_data_size()

Get the number of examples.

It will be used for weighted averaging in aggregation.

Returns:

int – The number of validation examples.

initialize_globals()

Initialize all global variables.

Returns:

None

load_native(filepath, **kwargs)

Load model state from a filepath in ML-framework “native” format, e.g. PyTorch pickled models.

May load from multiple files. Other filepaths may be derived from the passed filepath, or they may be in the kwargs.

Parameters:
  • filepath (str) –

    Path to frame-work specific file to load. For frameworks that use multiple files, this string must be

    used to derive the other filepaths.

  • **kwargs – Additional parameters to pass to the function. For future-proofing.

Returns:

None

reset_opt_vars()

Reinitialize the optimizer variables.

Returns:

None

save_native(filepath, **kwargs)

Save model state in ML-framework “native” format, e.g. PyTorch pickled models.

May save one file or multiple files, depending on the framework.

Parameters:
  • filepath (str) – If framework stores a single file, this should be a single file path. Frameworks that store multiple files may need to derive the other paths from this path.

  • **kwargs – Additional parameters to pass to the function. For future-proofing.

Returns:

None

set_data_loader(data_loader)

Set data_loader object.

Parameters:

data_loader – data_loader object to set.

Returns:

None

set_logger()

Set up the log object.

Returns:

None

set_optimizer_treatment(opt_treatment)

Change the treatment of current instance optimizer.

Parameters:

opt_treatment (str) – The optimizer treatment.

Returns:

None

set_tensor_dict(tensor_dict, with_opt_vars)

Set the model weights with a tensor dictionary: {<tensor_name>: <value>}.

Parameters:
  • tensor_dict (dict) – The model weights dictionary.

  • with_opt_vars (bool) – Specify if we also want to set the variables of the optimizer.

Returns:

None

train_batches(num_batches=None, use_tqdm=False)

Perform the training for a specified number of batches.

Is expected to perform draws randomly, without replacement until data is exausted. Then data is replaced and shuffled and draws continue.

Parameters:
  • num_batches (int, optional) – Number of batches to train. Default is None.

  • use_tqdm (bool, optional) – If True, use tqdm to print a progress bar. Default is False.

Returns:

dict – {<metric>: <value>}.

validate()

Run validation.

Returns:

dict – {<metric>: <value>}.

class openfl.federated.task.catch_warnings(*, record=False, module=None)

A context manager that copies and restores the warnings filter upon exiting the context.

The ‘record’ argument specifies whether warnings should be captured by a custom implementation of warnings.showwarning() and be appended to a list returned by the context manager. Otherwise None is returned by the context manager. The objects appended to the list are arguments whose attributes mirror the arguments to showwarning().

The ‘module’ argument is to specify an alternative module to the module named ‘warnings’ and imported under that name. This argument is only useful when testing the warnings module itself.

openfl.federated.task.simplefilter(action, category=<class 'Warning'>, lineno=0, append=False)

Insert a simple entry into the list of warnings filters (at the front).

A simple filter matches all modules and messages. ‘action’ – one of “error”, “ignore”, “always”, “default”, “module”,

or “once”

‘category’ – a class that the warning must be a subclass of ‘lineno’ – an integer line number, 0 matches all warnings ‘append’ – if true, append to the list of filters

fl_model

FederatedModel module.

runner

Mixin class for FL models.

task_runner

Interactive API package.