Class - TensorDB#
- class openfl.databases.tensor_db.TensorDB[source]#
Bases:
objectThe TensorDB stores a tensor key and the data that it corresponds to.
It is built on top of a pandas dataframe for it’s easy insertion, retrieval and aggregation capabilities. Each collaborator and aggregator has its own TensorDB.
- tensor_db#
A pandas DataFrame that stores the tensor key and the data that it corresponds to.
- mutex#
A threading Lock object used to ensure thread-safe operations on the tensor_db Dataframe.
Methods
__init__()Initializes a new instance of the TensorDB class.
cache_tensor(tensor_key_dict)Insert a tensor into TensorDB (dataframe).
clean_up([remove_older_than])Removes old entries from the database to prevent it from becoming too large and slow.
get_aggregated_tensor(tensor_key, ...)Determine whether all of the collaborator tensors are present for a given tensor key
get_tensor_from_cache(tensor_key)Perform a lookup of the tensor_key in the TensorDB.
get_tensors_by_filter([custom_filter])Retrieve all tensors that match the specified round and tags.
get_tensors_by_round_and_tags(fl_round, tags)Retrieve all tensors that match the specified round and tags.
- cache_tensor(tensor_key_dict)[source]#
Insert a tensor into TensorDB (dataframe).
- Parameters:
tensor_key_dict (Dict[TensorKey, np.ndarray]) – A dictionary where the key is a TensorKey and the value is a numpy array.
- Returns:
None
- Return type:
None
- clean_up(remove_older_than=1)[source]#
Removes old entries from the database to prevent it from becoming too large and slow.
- Parameters:
remove_older_than (int, optional) – Entries older than this number of rounds are removed. Defaults to 1.
- Return type:
None
- get_aggregated_tensor(tensor_key, collaborator_weight_dict, aggregation_function)[source]#
Determine whether all of the collaborator tensors are present for a given tensor key
Returns their weighted average.
- Parameters:
tensor_key (TensorKey) – The tensor key to be resolved. If origin ‘agg_uuid’ is present, can be returned directly. Otherwise must compute weighted average of all collaborators.
collaborator_weight_dict (dict) – A dictionary where the keys are collaborator names and the values are their respective weights.
aggregation_function (AggregationFunction) – Call the underlying numpy aggregation function to use to compute the weighted average. Default is just the weighted average.
- Returns:
- weighted_nparray The weighted
average if all collaborator values are present. Otherwise, returns None.
None: if not all values are present.
- Return type:
agg_nparray Optional[np.ndarray]
- get_tensor_from_cache(tensor_key)[source]#
Perform a lookup of the tensor_key in the TensorDB.
- Parameters:
tensor_key (TensorKey) – The key of the tensor to look up.
- Returns:
- The numpy array if it is available.
Otherwise, returns None.
- Return type:
Optional[np.ndarray]
- get_tensors_by_filter(custom_filter=None)[source]#
Retrieve all tensors that match the specified round and tags.
- Parameters:
custom_filter (callable) – A function that takes a DataFrame and returns a boolean mask to filter the DataFrame. If None, no filtering is applied.
- Returns:
A dictionary where the keys are TensorKey objects and the values are numpy arrays.
- Return type:
dict
- get_tensors_by_round_and_tags(fl_round, tags)[source]#
Retrieve all tensors that match the specified round and tags.
- Parameters:
fl_round (int) – The round number to filter tensors.
tags (tuple) – The tags to filter tensors.
- Returns:
A dictionary where the keys are TensorKey objects and the values are numpy arrays.
- Return type:
dict