openfl.interface.interactive_api.experiment.TaskKeeper
- class openfl.interface.interactive_api.experiment.TaskKeeper
Bases:
objectTask keeper class.
This class is responsible for managing tasks in a federated learning experiment. It keeps track of registered tasks, their settings, and aggregation functions.
Task should accept the following entities that exist on collaborator nodes: 1. model - will be rebuilt with relevant weights for every task by TaskRunner. 2. data_loader - data loader equipped with repository adapter that provides local data. 3. device - a device to be used on collaborator machines. 4. optimizer (optional).
Task returns a dictionary {metric name: metric value for this task}
- Class Attributes:
task_registry (dict) – A dictionary mapping task names to callable functions.
task_contract (dict) – A dictionary mapping task names to their contract.
task_settings (dict) – A dictionary mapping task names to their settings.
aggregation_functions (dict) – A dictionary mapping task names to their aggregation functions.
_tasks (dict) – A dictionary mapping task aliases to Task objects.
Methods
Register tasks settings.
Return registered tasks.
Register FL tasks.
Set aggregation function for the task.
- add_kwargs(**task_kwargs)
Register tasks settings.
Warning! We do not actually need to register additional kwargs, we ust serialize them. This one is a decorator because we need task name and to be consistent with the main registering method
- Parameters:
**task_kwargs – Keyword arguments for the task settings.
- get_registered_tasks() Dict[str, Task]
Return registered tasks.
- Returns:
A dictionary mapping task names to Task objects.
- register_fl_task(model, data_loader, device, optimizer=None, round_num=None)
Register FL tasks.
The task contract should be set up by providing variable names: [model, data_loader, device] - necessarily and optimizer - optionally
All tasks should accept contract entities to be run on collaborator node. Moreover we ask users return dict{‘metric’:value} in every task ` TI = TaskInterface()
- task_settings = {
‘batch_size’: 32, ‘some_arg’: 228,
} @TI.add_kwargs(**task_settings) @TI.register_fl_task(model=’my_model’, data_loader=’train_loader’,
device=’device’, optimizer=’my_Adam_opt’)
def foo_task(my_model, train_loader, my_Adam_opt, device, batch_size, some_arg=356)
… return {‘metric_name’: metric, ‘metric_name_2’: metric_2,}
- Parameters:
model – The model to be used in the task.
data_loader – The data loader to be used in the task.
device – The device to be used in the task.
optimizer (optional) – The optimizer to be used in the task. Defaults to None.
round_num (optional) – The round number for the task. Defaults to None.
- set_aggregation_function(aggregation_function: AggregationFunction)
Set aggregation function for the task.
To be serialized and sent to aggregator node.
There is no support for aggregation functions containing logic from workspace-related libraries that are not present on director yet.
You might need to override default FedAvg aggregation with built-in aggregation types:
openfl.interface.aggregation_functions.GeometricMedian
openfl.interface.aggregation_functions.Median
or define your own AggregationFunction subclass. See more details on `Overriding the aggregation function`_ documentation page. .. _Overriding the aggregation function:
- Parameters:
aggregation_function – The aggregation function to be used for the task.