openfl.utilities.split.split_tensor_dict_for_holdouts

openfl.utilities.split.split_tensor_dict_for_holdouts#

openfl.utilities.split.split_tensor_dict_for_holdouts(logger, tensor_dict, keep_types=(<class 'numpy.floating'>, <class 'numpy.integer'>), holdout_tensor_names=())[source]#

Split a tensor according to tensor types.

This function splits a tensor dictionary into two dictionaries: one containing the tensors to send and the other containing the holdout tensors.

Parameters:
  • logger (Logger) – The logger to use for reporting warnings.

  • tensor_dict (dict) – A dictionary of tensors.

  • keep_types (Tuple[type, ...], optional) – A tuple of types to keep in the dictionary of tensors. Defaults to (np.floating, np.integer).

  • holdout_tensor_names (Iterable[str], optional) – An iterable of tensor names to extract from the dictionary of tensors. Defaults to ().

Returns:

The first dictionary is the original tensor

dictionary minus the holdout tensors and the second dictionary is a tensor dictionary with only the holdout tensors.

Return type:

Tuple[dict, dict]