A DataModule
is a collection of at most 3
dataset to provide data for both training
, validation
and testing
in a flexible way. For more information about a DataModule
, see the original implementation in the pytorch-lightning
repository. This folder extends the original DataModule
with some simple utilities to:
- Do training only then a
train_dataloader
method is defined. - Same for test.
Moreover, the SuperDataModule
in this folder is written to accept Adapters
in input and to use two different type of Dataset
to offer the maximum flexibility.
The general schema is the following:
Adapters
load data from disk / remote storages / invent them / ... and return an iterator along with some funtion to do pre-processing.Datasets
read the data from theAdapter
and manage indexing, distributed parallel access and so on.DataModules
collect someDatasets
together and provide them to the training algorithm to do training, validation and test. Moreover,DataModule
addDataLoaders
to the datasets and offers some easy primitives to check if it is the case to do training and testing.
A simple PersonalDataModule
can be defined in either way:
train_adapter = SomeAdapter(hyperparameters, ...)
test_adapter = SomeAdapter(hyperparameters, ...)
PersonalDataModule = SuperDataModule
datamodule = PersonalDataModule(
train_adapter=train_adapter,
test_adapter=test_adapter
)
or by defining Adapters
internally
class PersonalDataModule(SuperDataModule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.train_adapter = SomeAdapter(self.hyperparameters, ...)
self.test_adapter = SomeAdapter(self.hyperparameters, ...)
datamodule = PersonalDataModule()
Some useful method comprehend:
datamodule.do_train():
trainer.fit(model, datamodule=datamodule)
datamodule.do_test():
trainer.test(model, datamodule=datamodule)
If your Adapters
return some strange data structure that is not a simple dictionary
, you should define the appropriate collate function to merge entries together:
train_adapter = SomeAdapterWithCustomOutput(hyperparameters, ...)
test_adapter = SomeAdapterWithCustomOutput(hyperparameters, ...)
PersonalDataModule = SuperDataModule
datamodule = PersonalDataModule(
train_adapter=train_adapter,
test_adapter=test_adapter,
collate_fn=my_collate_fn
)
For more information about collate_fn
, see the original definition in PyTorch DataLoaders
.