Implementing a Custom Trainer#

Abstract base classes (ABCs) define a blueprint for a class, specifying its methods and attributes, but not its implementation. They are important in implementing a consistent interface, as they enforce a set of requirements on implementing classes and make it easier to write code that can work with multiple implementations.

First, we define a boilerplate for the TrainerBase class, which is the same implemented in archai.api.trainer_base module.

[1]:
from abc import abstractmethod

from overrides import EnforceOverrides


class TrainerBase(EnforceOverrides):
    def __init__(self) -> None:
        pass

    @abstractmethod
    def train(self) -> None:
        pass

    @abstractmethod
    def evaluate(self) -> None:
        pass

    @abstractmethod
    def predict(self) -> None:
        pass

PyTorch-based Trainer#

In the context of a custom trainer, using ABCs can help ensure that the provider implements the required methods and provides a consistent interface for training, evaluating and predicting. In this example, we will implement a PyTorch-based trainer, as follows:

[2]:
from typing import Optional

import torch
from overrides import overrides
from torch.utils.data import Dataset


class PyTorchTrainer(TrainerBase):
    def __init__(
        self,
        model: torch.nn.Module,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
    ) -> None:
        super().__init__()

        self.model = model
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset

        # Setup the trainer
        self._setup()

    def _setup(self) -> None:
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)

    def _train_step(self, inputs: torch.Tensor, labels: torch.Tensor) -> None:
        self.optimizer.zero_grad()

        outputs = self.model(inputs)

        loss = self.loss_fn(outputs, labels)
        loss.backward()

        self.optimizer.step()

        return loss.item()

    @overrides
    def train(self) -> None:
        total_loss = 0.0

        train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=64, shuffle=True)

        self.model.train()
        for idx, (inputs, labels) in enumerate(train_loader):
            inputs = inputs.view(inputs.size(0), -1)

            total_loss += self._train_step(inputs, labels)

            if idx % 10 == 0:
                print(f"Batch {idx} loss: {total_loss / (idx + 1)}")

    def _eval_step(self, inputs: torch.Tensor, labels: torch.Tensor) -> None:
        with torch.no_grad():
            outputs = self.model(inputs)
            loss = self.loss_fn(outputs, labels)

        return loss.item()

    @overrides
    def evaluate(self, eval_dataset: Optional[Dataset] = None) -> None:
        eval_dataset = eval_dataset if eval_dataset else self.eval_dataset
        assert eval_dataset is not None, "`eval_dataset` has not been provided."

        eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=64, shuffle=False)

        eval_loss = 0.0

        self.model.eval()
        for idx, (inputs, labels) in enumerate(eval_loader):
            inputs = inputs.view(inputs.size(0), -1)

            loss = self._eval_step(inputs, labels)

            eval_loss += loss

        self.model.train()

        eval_loss /= idx

        return eval_loss

    @overrides
    def predict(self, inputs: torch.Tensor) -> None:
        self.model.eval()
        preds = self.model(inputs)
        self.model.train()

        return preds

Defining the Model#

Once the data is loaded, we can define any CV-based model. In this example, we will create a simple linear model using PyTorch:

[3]:
from torch import nn

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(28 * 28, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)

model = Model()

Creating and Training with the Trainer#

After loading the data and creating the data, we need to plug these instances into the PyTorchTrainer and start the training, as follows:

[4]:

from archai.datasets.cv.mnist_dataset_provider import MnistDatasetProvider dataset_provider = MnistDatasetProvider() train_dataset = dataset_provider.get_train_dataset() trainer = PyTorchTrainer(model, train_dataset=train_dataset) trainer.train()
Batch 0 loss: 2.3435773849487305
Batch 10 loss: 2.164100560274991
Batch 20 loss: 2.0146874416442144
Batch 30 loss: 1.875573092891324
Batch 40 loss: 1.755056075933503
Batch 50 loss: 1.65761978486005
Batch 60 loss: 1.5680492149024714
Batch 70 loss: 1.482287242378987
Batch 80 loss: 1.4176807028275948
Batch 90 loss: 1.3575652700204115
Batch 100 loss: 1.3116845883945427
Batch 110 loss: 1.264954976133398
Batch 120 loss: 1.2235281644773877
Batch 130 loss: 1.1893346013913628
Batch 140 loss: 1.1595103922465169
Batch 150 loss: 1.1271054373671676
Batch 160 loss: 1.098986664173766
Batch 170 loss: 1.0724144109159883
Batch 180 loss: 1.0449848247496463
Batch 190 loss: 1.0206239084610764
Batch 200 loss: 1.0005531422237852
Batch 210 loss: 0.9785312015863391
Batch 220 loss: 0.9595723239814534
Batch 230 loss: 0.9406399880394791
Batch 240 loss: 0.9242911396926864
Batch 250 loss: 0.9074264486947382
Batch 260 loss: 0.8933870223746903
Batch 270 loss: 0.8793117023482094
Batch 280 loss: 0.8656814331685945
Batch 290 loss: 0.8519475182511962
Batch 300 loss: 0.8420679773207123
Batch 310 loss: 0.8304505413368201
Batch 320 loss: 0.8197678343343586
Batch 330 loss: 0.8108695821999783
Batch 340 loss: 0.7994857667303504
Batch 350 loss: 0.7905768925308162
Batch 360 loss: 0.7818403058270008
Batch 370 loss: 0.7724894605717569
Batch 380 loss: 0.7644593056262009
Batch 390 loss: 0.7569716618494
Batch 400 loss: 0.7508149850844148
Batch 410 loss: 0.7446284586350703
Batch 420 loss: 0.737757377662738
Batch 430 loss: 0.731570034354856
Batch 440 loss: 0.7251622536623018
Batch 450 loss: 0.7184810209234643
Batch 460 loss: 0.7108740333685906
Batch 470 loss: 0.705083177854048
Batch 480 loss: 0.7006571214808743
Batch 490 loss: 0.6952635809748818
Batch 500 loss: 0.6899459136579327
Batch 510 loss: 0.6855154860392942
Batch 520 loss: 0.680475171004742
Batch 530 loss: 0.6750251736961964
Batch 540 loss: 0.6709054712612396
Batch 550 loss: 0.6657751858667107
Batch 560 loss: 0.6602577523804816
Batch 570 loss: 0.6550697249257628
Batch 580 loss: 0.6502643423780107
Batch 590 loss: 0.6459637832671857
Batch 600 loss: 0.6414088696092615
Batch 610 loss: 0.6377830508785435
Batch 620 loss: 0.6334477859322768
Batch 630 loss: 0.6295056212722971
Batch 640 loss: 0.625435800731833
Batch 650 loss: 0.6220091110046742
Batch 660 loss: 0.6180242504501847
Batch 670 loss: 0.6146804381900324
Batch 680 loss: 0.6113072523680377
Batch 690 loss: 0.6073371331098283
Batch 700 loss: 0.6044996650344805
Batch 710 loss: 0.6016613611356786
Batch 720 loss: 0.5985554359574259
Batch 730 loss: 0.5963995929618867
Batch 740 loss: 0.5928040172970086
Batch 750 loss: 0.5907960743227907
Batch 760 loss: 0.5879131768011389
Batch 770 loss: 0.5846300883322529
Batch 780 loss: 0.5818878866150193
Batch 790 loss: 0.5786483433055215
Batch 800 loss: 0.5762474060467864
Batch 810 loss: 0.5737036844364077
Batch 820 loss: 0.5711945340527397
Batch 830 loss: 0.568284649585702
Batch 840 loss: 0.5661225801203112
Batch 850 loss: 0.5635929554175546
Batch 860 loss: 0.561713579778195
Batch 870 loss: 0.5594659192074865
Batch 880 loss: 0.5575740954898949
Batch 890 loss: 0.555696272622589
Batch 900 loss: 0.5535918302197302
Batch 910 loss: 0.5511906604179305
Batch 920 loss: 0.5482265714001837
Batch 930 loss: 0.5461382602454512

Evaluating and Predicting with the Trainer#

Finally, we evaluate our pre-trained model with the validation set and create a set of random-based inputs to calculate the model’s predictions:

[5]:
val_dataset = dataset_provider.get_val_dataset()

eval_loss  = trainer.evaluate(eval_dataset=val_dataset)
print(f"Eval loss: {eval_loss}")

inputs = torch.zeros(1, 28 * 28)
preds = trainer.predict(inputs)
print(f"Predictions: {preds}")
Eval loss: 0.3360353711610421
Predictions: tensor([[-0.1244,  0.2467, -0.0254, -0.0535,  0.0533,  0.1786, -0.0015,  0.1122,
         -0.2270, -0.0415]], grad_fn=<AddmmBackward0>)