openfl.databases.tensor_db.TensorDB#

class openfl.databases.tensor_db.TensorDB[source]#

Bases: object

The TensorDB stores a tensor key and the data that it corresponds to.

It is built on top of a pandas dataframe for it’s easy insertion, retreival and aggregation capabilities. Each collaborator and aggregator has its own TensorDB.

tensor_db#

A pandas DataFrame that stores the tensor key and the data that it corresponds to.

mutex#

A threading Lock object used to ensure thread-safe operations on the tensor_db Dataframe.

__init__()[source]#

Initializes a new instance of the TensorDB class.

Return type:

None

Methods

__init__()

Initializes a new instance of the TensorDB class.

cache_tensor(tensor_key_dict)

Insert a tensor into TensorDB (dataframe).

clean_up([remove_older_than])

Removes old entries from the database to prevent it from becoming too large and slow.

get_aggregated_tensor(tensor_key, ...)

Determine whether all of the collaborator tensors are present for a given tensor key

get_tensor_from_cache(tensor_key)

Perform a lookup of the tensor_key in the TensorDB.

cache_tensor(tensor_key_dict)[source]#

Insert a tensor into TensorDB (dataframe).

Parameters:

tensor_key_dict (Dict[TensorKey, np.ndarray]) – A dictionary where the key is a TensorKey and the value is a numpy array.

Returns:

None

Return type:

None

clean_up(remove_older_than=1)[source]#

Removes old entries from the database to prevent it from becoming too large and slow.

Parameters:

remove_older_than (int, optional) – Entries older than this number of rounds are removed. Defaults to 1.

Return type:

None

get_aggregated_tensor(tensor_key, collaborator_weight_dict, aggregation_function)[source]#

Determine whether all of the collaborator tensors are present for a given tensor key

Returns their weighted average.

Parameters:
  • tensor_key (TensorKey) – The tensor key to be resolved. If origin ‘agg_uuid’ is present, can be returned directly. Otherwise must compute weighted average of all collaborators.

  • collaborator_weight_dict (dict) – A dictionary where the keys are collaborator names and the values are their respective weights.

  • aggregation_function (AggregationFunction) – Call the underlying numpy aggregation function to use to compute the weighted average. Default is just the weighted average.

Returns:

weighted_nparray The weighted

average if all collaborator values are present. Otherwise, returns None.

None: if not all values are present.

Return type:

agg_nparray Optional[np.ndarray]

get_tensor_from_cache(tensor_key)[source]#

Perform a lookup of the tensor_key in the TensorDB.

Parameters:

tensor_key (TensorKey) – The key of the tensor to look up.

Returns:

The numpy array if it is available.

Otherwise, returns None.

Return type:

Optional[np.ndarray]