openfl.federated.task.fl_model.FederatedModel#
- class openfl.federated.task.fl_model.FederatedModel(build_model, optimizer=None, loss_fn=None, **kwargs)[source]#
Bases:
TaskRunnerA wrapper that adapts to Tensorflow and Pytorch models to a federated context.
This class provides methods to manage and manipulate federated models.
- build_model#
tensorflow/keras (function) , pytorch (class). For keras/tensorflow model, expects a function that returns the model definition. For pytorch models, expects a class (not an instance) containing the model definition and forward function.
- Type:
function or class
- lambda_opt#
Lambda function for the optimizer (only required for pytorch). The optimizer should be definied within a lambda function. This allows the optimizer to be attached to the federated models spawned for each collaborator.
- Type:
function
- model#
The built model.
- Type:
Model
- runner#
Task runner for the model.
- Type:
- loss_fn#
PyTorch Loss function for the model (only required for pytorch).
- Type:
Loss
- tensor_dict_split_fn_kwargs#
Keyword arguments for the tensor dict split function.
- Type:
dict
- data_loader#
A dataset to distribute among the collaborators, see TaskRunner for more details
- Type:
- __init__(build_model, optimizer=None, loss_fn=None, **kwargs)[source]#
Initializes the FederatedModel object.
Sets up the initial state of the FederatedModel object, initializing various components needed for the federated model.
- Parameters:
build_model (function or class) – Function that returns the model definition or Class containing the model definition and forward function.
optimizer (function, optional) – Lambda function defining the optimizer. Defaults to None.
loss_fn (function, optional) – PyTorch loss function. Defaults to None.
**kwargs – Additional parameters to pass to the function.
Methods
__init__(build_model[, optimizer, loss_fn])Initializes the FederatedModel object.
Get the data_loader object.
When running a task, a map of named tensorkeys must be provided to the function as dependencies.
get_tensor_dict(with_opt_vars)Get the weights.
Get the number of training examples.
Get the number of examples.
Initialize all global variables.
load_native(filepath, **kwargs)Load model state from a filepath in ML-framework "native" format, e.g. PyTorch pickled models.
Reinitialize the optimizer variables.
save_native(filepath, **kwargs)Save model state in ML-framework "native" format, e.g. PyTorch pickled models.
set_data_loader(data_loader)Set data_loader object.
Set up the log object.
set_optimizer_treatment(opt_treatment)Change the treatment of current instance optimizer.
set_tensor_dict(tensor_dict, with_opt_vars)Set the model weights with a tensor dictionary: {<tensor_name>: <value>}.
setup(num_collaborators, **kwargs)Create new models for all of the collaborators in the experiment.
train_batches([num_batches, use_tqdm])Perform the training for a specified number of batches.
validate()Run validation.
- get_data_loader()[source]#
Get the data_loader object.
Serves up batches and provides info regarding data_loader.
- Returns:
data_loader object
- get_required_tensorkeys_for_function(func_name, **kwargs)[source]#
When running a task, a map of named tensorkeys must be provided to the function as dependencies.
- Parameters:
func_name (str) – The function name.
**kwargs – Additional parameters to pass to the function.
- Returns:
- List of required TensorKey. (TensorKey(tensor_name, origin,
round_number))
- Return type:
list
- get_tensor_dict(with_opt_vars)[source]#
Get the weights.
- Parameters:
with_opt_vars (bool) – Specify if we also want to get the variables of the optimizer.
- Returns:
The weight dictionary {<tensor_name>: <value>}.
- Return type:
dict
- get_train_data_size()[source]#
Get the number of training examples.
It will be used for weighted averaging in aggregation.
- Returns:
The number of training examples.
- Return type:
int
- get_valid_data_size()[source]#
Get the number of examples.
It will be used for weighted averaging in aggregation.
- Returns:
The number of validation examples.
- Return type:
int
- load_native(filepath, **kwargs)[source]#
Load model state from a filepath in ML-framework “native” format, e.g. PyTorch pickled models.
May load from multiple files. Other filepaths may be derived from the passed filepath, or they may be in the kwargs.
- Parameters:
filepath (str) –
Path to frame-work specific file to load. For frameworks that use multiple files, this string must be
used to derive the other filepaths.
**kwargs – Additional parameters to pass to the function. For future-proofing.
- Returns:
None
- save_native(filepath, **kwargs)[source]#
Save model state in ML-framework “native” format, e.g. PyTorch pickled models.
May save one file or multiple files, depending on the framework.
- Parameters:
filepath (str) – If framework stores a single file, this should be a single file path. Frameworks that store multiple files may need to derive the other paths from this path.
**kwargs – Additional parameters to pass to the function. For future-proofing.
- Returns:
None
- set_data_loader(data_loader)[source]#
Set data_loader object.
- Parameters:
data_loader – data_loader object to set.
- Returns:
None
- set_optimizer_treatment(opt_treatment)[source]#
Change the treatment of current instance optimizer.
- Parameters:
opt_treatment (str) – The optimizer treatment.
- Returns:
None
- set_tensor_dict(tensor_dict, with_opt_vars)[source]#
Set the model weights with a tensor dictionary: {<tensor_name>: <value>}.
- Parameters:
tensor_dict (dict) – The model weights dictionary.
with_opt_vars (bool) – Specify if we also want to set the variables of the optimizer.
- Returns:
None
- setup(num_collaborators, **kwargs)[source]#
Create new models for all of the collaborators in the experiment.
- Parameters:
num_collaborators (int) – Number of experiment collaborators.
**kwargs – Additional parameters to pass to the function.
- Returns:
List of models for each collaborator.
- Return type:
List[FederatedModel]
- train_batches(num_batches=None, use_tqdm=False)[source]#
Perform the training for a specified number of batches.
Is expected to perform draws randomly, without replacement until data is exausted. Then data is replaced and shuffled and draws continue.
- Parameters:
num_batches (int, optional) – Number of batches to train. Default is None.
use_tqdm (bool, optional) – If True, use tqdm to print a progress bar. Default is False.
- Returns:
{<metric>: <value>}.
- Return type:
dict