openfl.federated.data.federated_data.FederatedDataSet
- class openfl.federated.data.federated_data.FederatedDataSet(X_train, y_train, X_valid, y_valid, batch_size=1, num_classes=None, train_splitter=None, valid_splitter=None)
Bases:
PyTorchDataLoaderA Data Loader class used to represent a federated dataset for in-memory Numpy data.
- Class Attributes:
train_splitter (NumPyDataSplitter) – An object that splits the training data.
valid_splitter (NumPyDataSplitter) – An object that splits the validation data.
Methods
Returns the shape of an example feature array.
Returns the data loader for inferencing data.
Returns the total number of training samples.
Returns the data loader for the training data.
Returns the total number of validation samples.
Returns the data loader for the validation data.
Splits the dataset into equal parts for each collaborator and returns a list of FederatedDataSet objects.
Attributes
train_splittervalid_splitter- get_feature_shape()
Returns the shape of an example feature array.
- Returns:
tuple – The shape of an example feature array.
- get_infer_loader()
Returns the data loader for inferencing data.
- Raises:
NotImplementedError – This method must be implemented by a child class.
- get_train_data_size()
Returns the total number of training samples.
- Returns:
int – The total number of training samples.
- get_train_loader(batch_size=None, num_batches=None)
Returns the data loader for the training data.
- Parameters:
batch_size (int, optional) – The batch size for the data loader (default is None).
num_batches (int, optional) – The number of batches for the data loader (default is None).
- Returns:
DataLoader – The DataLoader object for the training data.
- get_valid_data_size()
Returns the total number of validation samples.
- Returns:
int – The total number of validation samples.
- get_valid_loader(batch_size=None)
Returns the data loader for the validation data.
- Parameters:
batch_size (int, optional) – The batch size for the data loader (default is None).
- Returns:
DataLoader – The DataLoader object for the validation data.
- split(num_collaborators)
Splits the dataset into equal parts for each collaborator and returns a list of FederatedDataSet objects.
- Parameters:
num_collaborators (int) – The number of collaborators to split the dataset between.
- Returns:
FederatedDataSets (list) – A list of FederatedDataSet objects, each representing a slice of the dataset for a collaborator.