openfl.federated.task.fl_model.FederatedModel

class openfl.federated.task.fl_model.FederatedModel(build_model, optimizer=None, loss_fn=None, **kwargs)

Bases: TaskRunner

A wrapper that adapts to Tensorflow and Pytorch models to a federated context.

This class provides methods to manage and manipulate federated models.

Class Attributes:
  • build_model (function or class) – tensorflow/keras (function) , pytorch (class). For keras/tensorflow model, expects a function that returns the model definition. For pytorch models, expects a class (not an instance) containing the model definition and forward function.

  • lambda_opt (function) – Lambda function for the optimizer (only required for pytorch). The optimizer should be definied within a lambda function. This allows the optimizer to be attached to the federated models spawned for each collaborator.

  • model (Model) – The built model.

  • optimizer (Optimizer) – Optimizer for the model.

  • runner (TaskRunner) – Task runner for the model.

  • loss_fn (Loss) – PyTorch Loss function for the model (only required for pytorch).

  • tensor_dict_split_fn_kwargs (dict) – Keyword arguments for the tensor dict split function.

Methods

get_data_loader

Get the data_loader object.

get_required_tensorkeys_for_function

When running a task, a map of named tensorkeys must be provided to the function as dependencies.

get_tensor_dict

Get the weights.

get_train_data_size

Get the number of training examples.

get_valid_data_size

Get the number of examples.

initialize_globals

Initialize all global variables.

load_native

Load model state from a filepath in ML-framework "native" format, e.g. PyTorch pickled models.

reset_opt_vars

Reinitialize the optimizer variables.

save_native

Save model state in ML-framework "native" format, e.g. PyTorch pickled models.

set_data_loader

Set data_loader object.

set_logger

Set up the log object.

set_optimizer_treatment

Change the treatment of current instance optimizer.

set_tensor_dict

Set the model weights with a tensor dictionary: {<tensor_name>: <value>}.

setup

Create new models for all of the collaborators in the experiment.

train_batches

Perform the training for a specified number of batches.

validate

Run validation.

get_data_loader()

Get the data_loader object.

Serves up batches and provides info regarding data_loader.

Returns:

data_loader object

get_required_tensorkeys_for_function(func_name, **kwargs)

When running a task, a map of named tensorkeys must be provided to the function as dependencies.

Parameters:
  • func_name (str) – The function name.

  • **kwargs – Additional parameters to pass to the function.

Returns:

list – List of required TensorKey. (TensorKey(tensor_name, origin, round_number))

get_tensor_dict(with_opt_vars)

Get the weights.

Parameters:

with_opt_vars (bool) – Specify if we also want to get the variables of the optimizer.

Returns:

dict – The weight dictionary {<tensor_name>: <value>}.

get_train_data_size()

Get the number of training examples.

It will be used for weighted averaging in aggregation.

Returns:

int – The number of training examples.

get_valid_data_size()

Get the number of examples.

It will be used for weighted averaging in aggregation.

Returns:

int – The number of validation examples.

initialize_globals()

Initialize all global variables.

Returns:

None

load_native(filepath, **kwargs)

Load model state from a filepath in ML-framework “native” format, e.g. PyTorch pickled models.

May load from multiple files. Other filepaths may be derived from the passed filepath, or they may be in the kwargs.

Parameters:
  • filepath (str) –

    Path to frame-work specific file to load. For frameworks that use multiple files, this string must be

    used to derive the other filepaths.

  • **kwargs – Additional parameters to pass to the function. For future-proofing.

Returns:

None

reset_opt_vars()

Reinitialize the optimizer variables.

Returns:

None

save_native(filepath, **kwargs)

Save model state in ML-framework “native” format, e.g. PyTorch pickled models.

May save one file or multiple files, depending on the framework.

Parameters:
  • filepath (str) – If framework stores a single file, this should be a single file path. Frameworks that store multiple files may need to derive the other paths from this path.

  • **kwargs – Additional parameters to pass to the function. For future-proofing.

Returns:

None

set_data_loader(data_loader)

Set data_loader object.

Parameters:

data_loader – data_loader object to set.

Returns:

None

set_logger()

Set up the log object.

Returns:

None

set_optimizer_treatment(opt_treatment)

Change the treatment of current instance optimizer.

Parameters:

opt_treatment (str) – The optimizer treatment.

Returns:

None

set_tensor_dict(tensor_dict, with_opt_vars)

Set the model weights with a tensor dictionary: {<tensor_name>: <value>}.

Parameters:
  • tensor_dict (dict) – The model weights dictionary.

  • with_opt_vars (bool) – Specify if we also want to set the variables of the optimizer.

Returns:

None

setup(num_collaborators, **kwargs)

Create new models for all of the collaborators in the experiment.

Parameters:
  • num_collaborators (int) – Number of experiment collaborators.

  • **kwargs – Additional parameters to pass to the function.

Returns:

List[FederatedModel] – List of models for each collaborator.

train_batches(num_batches=None, use_tqdm=False)

Perform the training for a specified number of batches.

Is expected to perform draws randomly, without replacement until data is exausted. Then data is replaced and shuffled and draws continue.

Parameters:
  • num_batches (int, optional) – Number of batches to train. Default is None.

  • use_tqdm (bool, optional) – If True, use tqdm to print a progress bar. Default is False.

Returns:

dict – {<metric>: <value>}.

validate()

Run validation.

Returns:

dict – {<metric>: <value>}.