API#

Dataset Provider#

class archai.api.dataset_provider.DatasetProvider[source]#

Abstract class for dataset providers.

This class serves as a base for implementing dataset providers that can return training, validation and testing datasets. The class enforces implementation of three methods: get_train_dataset, get_val_dataset and get_test_dataset. These methods should return an instance of the respective dataset, regardless of its structure.

Note

This class is inherited from EnforceOverrides and any overridden methods in the subclass should be decorated with @overrides to ensure they are properly overridden.

Examples

>>> class MyDatasetProvider(DatasetProvider):
>>>     def __init__(self) -> None:
>>>         super().__init__()
>>>
>>>     @overrides
>>>     def get_train_dataset(self) -> Any:
>>>         return torchvision.datasets.MNIST(train=True)
>>>
>>>     @overrides
>>>     def get_val_dataset(self) -> Any:
>>>         return torchvision.datasets.MNIST(train=False)
>>>
>>>     @overrides
>>>     def get_test_dataset(self) -> Any:
>>>         return torchvision.datasets.MNIST(train=False)
abstract get_train_dataset() Any[source]#

Get a training dataset.

Returns:

An instance of a training dataset.

abstract get_val_dataset() Any[source]#

Get a validation dataset.

Returns:

An instance of a validation dataset, or the training dataset if validation dataset is not available.

abstract get_test_dataset() Any[source]#

Get a testing dataset.

Returns:

An instance of a testing dataset, or the training/validation dataset if testing dataset is not available.

Trainer (Base Class)#

class archai.api.trainer_base.TrainerBase[source]#

Abstract class for trainers.

The TrainerBase class provides an abstract interface for training a model. The user is required to implement the train, evaluate, and predict methods. The train method should contain the logic for training the model, the evaluate method should contain the logic for evaluating the model, and the predict method should contain the logic for making predictions with the model.

Note

This class is inherited from EnforceOverrides and any overridden methods in the subclass should be decorated with @overrides to ensure they are properly overridden.

Examples

>>> class MyTrainer(TrainerBase):
>>>     def __init__(self) -> None:
>>>         super().__init__()
>>>
>>>     @overrides
>>>     def train(self) -> None:
>>>         return pytorch_lightining.trainer.Trainer().fit(model, train_dataloaders=train_dataloader)
>>>
>>>     @overrides
>>>     def evaluate(self) -> None:
>>>         return pytorch_lightining.trainer.Trainer().test(model, dataloaders=val_dataloader)
>>>
>>>     @overrides
>>>     def predict(self) -> None:
>>>         return pytorch_lightining.trainer.Trainer().predict(model, dataloaders=predict_dataloader)
abstract train() None[source]#

Train a model.

This method should contain the logic for training the model.

abstract evaluate() None[source]#

Evaluate a model.

This method should contain the logic for evaluating the model.

abstract predict() None[source]#

Predict with a model.

This method should contain the logic for making predictions with the model.