Source code for archai.api.trainer_base
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import abstractmethod
from overrides import EnforceOverrides
[docs]class TrainerBase(EnforceOverrides):
"""Abstract class for trainers.
The `TrainerBase` class provides an abstract interface for training a model. The user
is required to implement the `train`, `evaluate`, and `predict` methods. The `train` method
should contain the logic for training the model, the `evaluate` method should contain
the logic for evaluating the model, and the `predict` method should contain the logic
for making predictions with the model.
Note:
This class is inherited from `EnforceOverrides` and any overridden methods in the
subclass should be decorated with `@overrides` to ensure they are properly overridden.
Examples:
>>> class MyTrainer(TrainerBase):
>>> def __init__(self) -> None:
>>> super().__init__()
>>>
>>> @overrides
>>> def train(self) -> None:
>>> return pytorch_lightining.trainer.Trainer().fit(model, train_dataloaders=train_dataloader)
>>>
>>> @overrides
>>> def evaluate(self) -> None:
>>> return pytorch_lightining.trainer.Trainer().test(model, dataloaders=val_dataloader)
>>>
>>> @overrides
>>> def predict(self) -> None:
>>> return pytorch_lightining.trainer.Trainer().predict(model, dataloaders=predict_dataloader)
"""
def __init__(self) -> None:
"""Initialize the trainer."""
pass
[docs] @abstractmethod
def train(self) -> None:
"""Train a model.
This method should contain the logic for training the model.
"""
pass
[docs] @abstractmethod
def evaluate(self) -> None:
"""Evaluate a model.
This method should contain the logic for evaluating the model.
"""
pass
[docs] @abstractmethod
def predict(self) -> None:
"""Predict with a model.
This method should contain the logic for making predictions with the model.
"""
pass