openfl.federated.data.loader_pt.PyTorchDataLoader
- class openfl.federated.data.loader_pt.PyTorchDataLoader(batch_size, random_seed=None, **kwargs)
Bases:
DataLoaderA class used to represent a Federation Data Loader for PyTorch models.
- Class Attributes:
batch_size (int) – Size of batches used for all data loaders.
X_train (np.array) – Training features.
y_train (np.array) – Training labels.
X_valid (np.array) – Validation features.
y_valid (np.array) – Validation labels.
random_seed (int, optional) – Random seed for data shuffling.
Methods
Returns the shape of an example feature array.
Returns the data loader for inferencing data.
Returns the total number of training samples.
Returns the data loader for the training data.
Returns the total number of validation samples.
Returns the data loader for the validation data.
- get_feature_shape()
Returns the shape of an example feature array.
- Returns:
tuple – The shape of an example feature array.
- get_infer_loader()
Returns the data loader for inferencing data.
- Raises:
NotImplementedError – This method must be implemented by a child class.
- get_train_data_size()
Returns the total number of training samples.
- Returns:
int – The total number of training samples.
- get_train_loader(batch_size=None, num_batches=None)
Returns the data loader for the training data.
- Parameters:
batch_size (int, optional) – The batch size for the data loader (default is None).
num_batches (int, optional) – The number of batches for the data loader (default is None).
- Returns:
DataLoader – The DataLoader object for the training data.
- get_valid_data_size()
Returns the total number of validation samples.
- Returns:
int – The total number of validation samples.
- get_valid_loader(batch_size=None)
Returns the data loader for the validation data.
- Parameters:
batch_size (int, optional) – The batch size for the data loader (default is None).
- Returns:
DataLoader – The DataLoader object for the validation data.