Class - FlowerTaskRunner#
- class openfl.federated.task.runner_flower.FlowerTaskRunner(**kwargs)[source]#
Bases:
TaskRunnerFlowerTaskRunner is a task runner that executes the Flower SuperNode to initialize and manage experiments from the client side.
This class is responsible for starting a local gRPC server and a Flower SuperNode in a subprocess. It provides options for both manual and automatic shutdown based on subprocess activity.
- __init__(**kwargs)[source]#
Initialize the FlowerTaskRunner.
- Parameters:
**kwargs – Additional parameters to pass to the functions.
Methods
__init__(**kwargs)Initialize the FlowerTaskRunner.
Get the data_loader object.
Get tensor keys for functions.
get_tensor_dict(with_opt_vars)Get the weights.
Get the number of training examples.
Get the number of examples.
Initialize all global variables.
Initialize tensor keys for functions.
load_native(filepath, **kwargs)Load model state from a filepath in ML-framework "native" format, e.g. PyTorch pickled models.
Reinitialize the optimizer variables.
save_native(filepath, **kwargs)Save model weights to a .npz file specified by the filepath.
set_data_loader(data_loader)Set data_loader object.
set_optimizer_treatment(opt_treatment)Change the treatment of current instance optimizer.
set_tensor_dict(tensor_dict[, with_opt_vars])Set the tensor dictionary for the task runner.
start_client_adapter([col_name, round_num, ...])Start the FlowerInteropServer and the Flower SuperNode.
train_batches([num_batches, use_tqdm])Perform the training for a specified number of batches.
validate()Run validation.
- get_data_loader()[source]#
Get the data_loader object.
Serves up batches and provides info regarding data_loader.
- Returns:
data_loader object
- get_required_tensorkeys_for_function(func_name, **kwargs)[source]#
Get tensor keys for functions. Return empty dict.
- get_tensor_dict(with_opt_vars)[source]#
Get the weights.
- Parameters:
with_opt_vars (bool) – Specify if we also want to get the variables of the optimizer.
- Returns:
The weight dictionary {<tensor_name>: <value>}.
- Return type:
dict
- get_train_data_size()[source]#
Get the number of training examples.
It will be used for weighted averaging in aggregation.
- Returns:
The number of training examples.
- Return type:
int
- get_valid_data_size()[source]#
Get the number of examples.
It will be used for weighted averaging in aggregation.
- Returns:
The number of validation examples.
- Return type:
int
- initialize_tensorkeys_for_functions(with_opt_vars=False)[source]#
Initialize tensor keys for functions. Currently not implemented.
- load_native(filepath, **kwargs)[source]#
Load model state from a filepath in ML-framework “native” format, e.g. PyTorch pickled models.
May load from multiple files. Other filepaths may be derived from the passed filepath, or they may be in the kwargs.
- Parameters:
filepath (str) –
Path to frame-work specific file to load. For frameworks that use multiple files, this string must be
used to derive the other filepaths.
**kwargs – Additional parameters to pass to the function. For future-proofing.
- Returns:
None
- save_native(filepath, **kwargs)[source]#
Save model weights to a .npz file specified by the filepath.
The model weights are stored as a dictionary of np.ndarray.
- Parameters:
filepath (str) – Path to the .npz file to be created by np.savez().
**kwargs – Additional parameters (currently not used).
- Returns:
None
- Raises:
AssertionError – If the file extension is not ‘.npz’.
- set_data_loader(data_loader)[source]#
Set data_loader object.
- Parameters:
data_loader – data_loader object to set.
- Returns:
None
- set_optimizer_treatment(opt_treatment)[source]#
Change the treatment of current instance optimizer.
- Parameters:
opt_treatment (str) – The optimizer treatment.
- Returns:
None
- set_tensor_dict(tensor_dict, with_opt_vars=False)[source]#
Set the tensor dictionary for the task runner.
This method is framework agnostic and does not attempt to load the weights into the model or save out the native format. Instead, it directly loads and saves the dictionary.
- Parameters:
tensor_dict (dict) – The tensor dictionary.
with_opt_vars (bool) – This argument is inherited from the parent class but is not used in the FlowerTaskRunner.
- start_client_adapter(col_name=None, round_num=None, input_tensor_dict=None, **kwargs)[source]#
Start the FlowerInteropServer and the Flower SuperNode.
- Parameters:
col_name (str, optional) – The collaborator name. Defaults to None.
round_num (int, optional) – The current round number. Defaults to None.
input_tensor_dict (dict, optional) – The input tensor dictionary. Defaults to None.
**kwargs –
Additional parameters for configuration. includes:
interop_server (object): The FlowerInteropServer instance. interop_server_host (str): The address of the interop server. clientappio_api_port (int): The port for the clientappio API. local_simulation (bool): Flag for local simulation to dynamically adjust ports. interop_server_port (int): The port for the interop server.
- train_batches(num_batches=None, use_tqdm=False)[source]#
Perform the training for a specified number of batches.
Is expected to perform draws randomly, without replacement until data is exhausted. Then data is replaced and shuffled and draws continue.
- Parameters:
num_batches (int, optional) – Number of batches to train. Default is None.
use_tqdm (bool, optional) – If True, use tqdm to print a progress bar. Default is False.
- Returns:
{<metric>: <value>}.
- Return type:
dict