openfl.federated.task.fl_model.FederatedModel
- class openfl.federated.task.fl_model.FederatedModel(build_model, optimizer=None, loss_fn=None, **kwargs)
Bases:
TaskRunnerA wrapper that adapts to Tensorflow and Pytorch models to a federated context.
This class provides methods to manage and manipulate federated models.
- Class Attributes:
build_model (function or class) – tensorflow/keras (function) , pytorch (class). For keras/tensorflow model, expects a function that returns the model definition. For pytorch models, expects a class (not an instance) containing the model definition and forward function.
lambda_opt (function) – Lambda function for the optimizer (only required for pytorch). The optimizer should be definied within a lambda function. This allows the optimizer to be attached to the federated models spawned for each collaborator.
model (Model) – The built model.
optimizer (Optimizer) – Optimizer for the model.
runner (TaskRunner) – Task runner for the model.
loss_fn (Loss) – PyTorch Loss function for the model (only required for pytorch).
tensor_dict_split_fn_kwargs (dict) – Keyword arguments for the tensor dict split function.
Methods
Get the data_loader object.
When running a task, a map of named tensorkeys must be provided to the function as dependencies.
Get the weights.
Get the number of training examples.
Get the number of examples.
Initialize all global variables.
Load model state from a filepath in ML-framework "native" format, e.g. PyTorch pickled models.
Reinitialize the optimizer variables.
Save model state in ML-framework "native" format, e.g. PyTorch pickled models.
Set data_loader object.
Set up the log object.
Change the treatment of current instance optimizer.
Set the model weights with a tensor dictionary: {<tensor_name>: <value>}.
Create new models for all of the collaborators in the experiment.
Perform the training for a specified number of batches.
Run validation.
- get_data_loader()
Get the data_loader object.
Serves up batches and provides info regarding data_loader.
- Returns:
data_loader object
- get_required_tensorkeys_for_function(func_name, **kwargs)
When running a task, a map of named tensorkeys must be provided to the function as dependencies.
- Parameters:
func_name (str) – The function name.
**kwargs – Additional parameters to pass to the function.
- Returns:
list – List of required TensorKey. (TensorKey(tensor_name, origin, round_number))
- get_tensor_dict(with_opt_vars)
Get the weights.
- Parameters:
with_opt_vars (bool) – Specify if we also want to get the variables of the optimizer.
- Returns:
dict – The weight dictionary {<tensor_name>: <value>}.
- get_train_data_size()
Get the number of training examples.
It will be used for weighted averaging in aggregation.
- Returns:
int – The number of training examples.
- get_valid_data_size()
Get the number of examples.
It will be used for weighted averaging in aggregation.
- Returns:
int – The number of validation examples.
- initialize_globals()
Initialize all global variables.
- Returns:
None
- load_native(filepath, **kwargs)
Load model state from a filepath in ML-framework “native” format, e.g. PyTorch pickled models.
May load from multiple files. Other filepaths may be derived from the passed filepath, or they may be in the kwargs.
- Parameters:
filepath (str) –
Path to frame-work specific file to load. For frameworks that use multiple files, this string must be
used to derive the other filepaths.
**kwargs – Additional parameters to pass to the function. For future-proofing.
- Returns:
None
- reset_opt_vars()
Reinitialize the optimizer variables.
- Returns:
None
- save_native(filepath, **kwargs)
Save model state in ML-framework “native” format, e.g. PyTorch pickled models.
May save one file or multiple files, depending on the framework.
- Parameters:
filepath (str) – If framework stores a single file, this should be a single file path. Frameworks that store multiple files may need to derive the other paths from this path.
**kwargs – Additional parameters to pass to the function. For future-proofing.
- Returns:
None
- set_data_loader(data_loader)
Set data_loader object.
- Parameters:
data_loader – data_loader object to set.
- Returns:
None
- set_logger()
Set up the log object.
- Returns:
None
- set_optimizer_treatment(opt_treatment)
Change the treatment of current instance optimizer.
- Parameters:
opt_treatment (str) – The optimizer treatment.
- Returns:
None
- set_tensor_dict(tensor_dict, with_opt_vars)
Set the model weights with a tensor dictionary: {<tensor_name>: <value>}.
- Parameters:
tensor_dict (dict) – The model weights dictionary.
with_opt_vars (bool) – Specify if we also want to set the variables of the optimizer.
- Returns:
None
- setup(num_collaborators, **kwargs)
Create new models for all of the collaborators in the experiment.
- Parameters:
num_collaborators (int) – Number of experiment collaborators.
**kwargs – Additional parameters to pass to the function.
- Returns:
List[FederatedModel] – List of models for each collaborator.
- train_batches(num_batches=None, use_tqdm=False)
Perform the training for a specified number of batches.
Is expected to perform draws randomly, without replacement until data is exausted. Then data is replaced and shuffled and draws continue.
- Parameters:
num_batches (int, optional) – Number of batches to train. Default is None.
use_tqdm (bool, optional) – If True, use tqdm to print a progress bar. Default is False.
- Returns:
dict – {<metric>: <value>}.
- validate()
Run validation.
- Returns:
dict – {<metric>: <value>}.