Class - PyTorchDataLoader#
- class openfl.federated.data.loader_pt.PyTorchDataLoader(batch_size, random_seed=None, **kwargs)[source]#
Bases:
DataLoaderA class used to represent a Federation Data Loader for PyTorch models.
- batch_size#
Size of batches used for all data loaders.
- Type:
int
- X_train#
Training features.
- Type:
np.array
- y_train#
Training labels.
- Type:
np.array
- X_valid#
Validation features.
- Type:
np.array
- y_valid#
Validation labels.
- Type:
np.array
- random_seed#
Random seed for data shuffling.
- Type:
int, optional
- __init__(batch_size, random_seed=None, **kwargs)[source]#
Initializes the PyTorchDataLoader object with the batch size, random seed, and any additional arguments.
- Parameters:
batch_size (int) – The size of batches used for all data loaders.
random_seed (int, optional) – Random seed for data shuffling.
kwargs – Additional arguments to pass to the function.
Methods
__init__(batch_size[, random_seed])Initializes the PyTorchDataLoader object with the batch size, random seed, and any additional arguments.
Returns the shape of an example feature array.
Returns the data loader for inferencing data.
Returns the number of classes for classification tasks.
Returns the total number of training samples.
get_train_loader([batch_size, num_batches])Returns the data loader for the training data.
Returns the total number of validation samples.
get_valid_loader([batch_size])Returns the data loader for the validation data.
- get_feature_shape()[source]#
Returns the shape of an example feature array.
Child classes must implement this method and return the feature shape.
- Returns:
The shape of an example feature array.
- Return type:
list
- Raises:
NotImplementedError – This method must be implemented by all derived classes.
- get_infer_loader()[source]#
Returns the data loader for inferencing data.
- Raises:
NotImplementedError – This method must be implemented by a child class.
- get_num_classes()[source]#
Returns the number of classes for classification tasks.
Child classes must implement this method and return the number of classes.
- Returns:
The number of classes.
- Return type:
int
- Raises:
NotImplementedError – This method must be implemented by all derived classes.
- get_train_data_size()[source]#
Returns the total number of training samples.
- Returns:
The total number of training samples.
- Return type:
int
- get_train_loader(batch_size=None, num_batches=None)[source]#
Returns the data loader for the training data.
- Parameters:
batch_size (int, optional) – The batch size for the data loader (default is None).
num_batches (int, optional) – The number of batches for the data loader (default is None).
- Returns:
The DataLoader object for the training data.
- Return type:
- get_valid_data_size()[source]#
Returns the total number of validation samples.
- Returns:
The total number of validation samples.
- Return type:
int