openfl.utilities.split.split_tensor_dict_for_holdouts#
- openfl.utilities.split.split_tensor_dict_for_holdouts(logger, tensor_dict, keep_types=(<class 'numpy.floating'>, <class 'numpy.integer'>), holdout_tensor_names=())[source]#
Split a tensor according to tensor types.
This function splits a tensor dictionary into two dictionaries: one containing the tensors to send and the other containing the holdout tensors.
- Parameters:
logger (Logger) – The logger to use for reporting warnings.
tensor_dict (dict) – A dictionary of tensors.
keep_types (Tuple[type, ...], optional) – A tuple of types to keep in the dictionary of tensors. Defaults to (np.floating, np.integer).
holdout_tensor_names (Iterable[str], optional) – An iterable of tensor names to extract from the dictionary of tensors. Defaults to ().
- Returns:
- The first dictionary is the original tensor
dictionary minus the holdout tensors and the second dictionary is a tensor dictionary with only the holdout tensors.
- Return type:
Tuple[dict, dict]