openfl.databases
- class openfl.databases.TensorDB
The TensorDB stores a tensor key and the data that it corresponds to.
It is built on top of a pandas dataframe for it’s easy insertion, retreival and aggregation capabilities. Each collaborator and aggregator has its own TensorDB.
- Class Attributes:
tensor_db – A pandas DataFrame that stores the tensor key and the data that it corresponds to.
mutex – A threading Lock object used to ensure thread-safe operations on the tensor_db Dataframe.
- cache_tensor(tensor_key_dict: Dict[TensorKey, ndarray]) None
Insert a tensor into TensorDB (dataframe).
- Parameters:
tensor_key_dict (Dict[TensorKey, np.ndarray]) – A dictionary where the key is a TensorKey and the value is a numpy array.
- Returns:
None
- clean_up(remove_older_than: int = 1) None
Removes old entries from the database to prevent it from becoming too large and slow.
- Parameters:
remove_older_than (int, optional) – Entries older than this number of rounds are removed. Defaults to 1.
- get_aggregated_tensor(tensor_key: TensorKey, collaborator_weight_dict: dict, aggregation_function: AggregationFunction) ndarray | None
Determine whether all of the collaborator tensors are present for a given tensor key
Returns their weighted average.
- Parameters:
tensor_key (TensorKey) – The tensor key to be resolved. If origin ‘agg_uuid’ is present, can be returned directly. Otherwise must compute weighted average of all collaborators.
collaborator_weight_dict (dict) – A dictionary where the keys are collaborator names and the values are their respective weights.
aggregation_function (AggregationFunction) – Call the underlying numpy aggregation function to use to compute the weighted average. Default is just the weighted average.
- Returns:
agg_nparray Optional[np.ndarray] – weighted_nparray The weighted average if all collaborator values are present. Otherwise, returns None.
None – if not all values are present.
|
TensorDB Module. |