openfl.interface.interactive_api.experiment.TaskKeeper

class openfl.interface.interactive_api.experiment.TaskKeeper

Bases: object

Task keeper class.

This class is responsible for managing tasks in a federated learning experiment. It keeps track of registered tasks, their settings, and aggregation functions.

Task should accept the following entities that exist on collaborator nodes: 1. model - will be rebuilt with relevant weights for every task by TaskRunner. 2. data_loader - data loader equipped with repository adapter that provides local data. 3. device - a device to be used on collaborator machines. 4. optimizer (optional).

Task returns a dictionary {metric name: metric value for this task}

Class Attributes:
  • task_registry (dict) – A dictionary mapping task names to callable functions.

  • task_contract (dict) – A dictionary mapping task names to their contract.

  • task_settings (dict) – A dictionary mapping task names to their settings.

  • aggregation_functions (dict) – A dictionary mapping task names to their aggregation functions.

  • _tasks (dict) – A dictionary mapping task aliases to Task objects.

Methods

add_kwargs

Register tasks settings.

get_registered_tasks

Return registered tasks.

register_fl_task

Register FL tasks.

set_aggregation_function

Set aggregation function for the task.

add_kwargs(**task_kwargs)

Register tasks settings.

Warning! We do not actually need to register additional kwargs, we ust serialize them. This one is a decorator because we need task name and to be consistent with the main registering method

Parameters:

**task_kwargs – Keyword arguments for the task settings.

get_registered_tasks() Dict[str, Task]

Return registered tasks.

Returns:

A dictionary mapping task names to Task objects.

register_fl_task(model, data_loader, device, optimizer=None, round_num=None)

Register FL tasks.

The task contract should be set up by providing variable names: [model, data_loader, device] - necessarily and optimizer - optionally

All tasks should accept contract entities to be run on collaborator node. Moreover we ask users return dict{‘metric’:value} in every task ` TI = TaskInterface()

task_settings = {

‘batch_size’: 32, ‘some_arg’: 228,

} @TI.add_kwargs(**task_settings) @TI.register_fl_task(model=’my_model’, data_loader=’train_loader’,

device=’device’, optimizer=’my_Adam_opt’)

def foo_task(my_model, train_loader, my_Adam_opt, device, batch_size, some_arg=356)

… return {‘metric_name’: metric, ‘metric_name_2’: metric_2,}

`

Parameters:
  • model – The model to be used in the task.

  • data_loader – The data loader to be used in the task.

  • device – The device to be used in the task.

  • optimizer (optional) – The optimizer to be used in the task. Defaults to None.

  • round_num (optional) – The round number for the task. Defaults to None.

set_aggregation_function(aggregation_function: AggregationFunction)

Set aggregation function for the task.

To be serialized and sent to aggregator node.

There is no support for aggregation functions containing logic from workspace-related libraries that are not present on director yet.

You might need to override default FedAvg aggregation with built-in aggregation types:

  • openfl.interface.aggregation_functions.GeometricMedian

  • openfl.interface.aggregation_functions.Median

or define your own AggregationFunction subclass. See more details on `Overriding the aggregation function`_ documentation page. .. _Overriding the aggregation function:

Parameters:

aggregation_function – The aggregation function to be used for the task.