Implementing a Custom Trainer#
Abstract base classes (ABCs) define a blueprint for a class, specifying its methods and attributes, but not its implementation. They are important in implementing a consistent interface, as they enforce a set of requirements on implementing classes and make it easier to write code that can work with multiple implementations.
First, we define a boilerplate for the TrainerBase
class, which is the same implemented in archai.api.trainer_base
module.
[1]:
from abc import abstractmethod
from overrides import EnforceOverrides
class TrainerBase(EnforceOverrides):
def __init__(self) -> None:
pass
@abstractmethod
def train(self) -> None:
pass
@abstractmethod
def evaluate(self) -> None:
pass
@abstractmethod
def predict(self) -> None:
pass
PyTorch-based Trainer#
In the context of a custom trainer, using ABCs can help ensure that the provider implements the required methods and provides a consistent interface for training, evaluating and predicting. In this example, we will implement a PyTorch-based trainer, as follows:
[2]:
from typing import Optional
import torch
from overrides import overrides
from torch.utils.data import Dataset
class PyTorchTrainer(TrainerBase):
def __init__(
self,
model: torch.nn.Module,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
) -> None:
super().__init__()
self.model = model
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
# Setup the trainer
self._setup()
def _setup(self) -> None:
self.loss_fn = torch.nn.CrossEntropyLoss()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
def _train_step(self, inputs: torch.Tensor, labels: torch.Tensor) -> None:
self.optimizer.zero_grad()
outputs = self.model(inputs)
loss = self.loss_fn(outputs, labels)
loss.backward()
self.optimizer.step()
return loss.item()
@overrides
def train(self) -> None:
total_loss = 0.0
train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=64, shuffle=True)
self.model.train()
for idx, (inputs, labels) in enumerate(train_loader):
inputs = inputs.view(inputs.size(0), -1)
total_loss += self._train_step(inputs, labels)
if idx % 10 == 0:
print(f"Batch {idx} loss: {total_loss / (idx + 1)}")
def _eval_step(self, inputs: torch.Tensor, labels: torch.Tensor) -> None:
with torch.no_grad():
outputs = self.model(inputs)
loss = self.loss_fn(outputs, labels)
return loss.item()
@overrides
def evaluate(self, eval_dataset: Optional[Dataset] = None) -> None:
eval_dataset = eval_dataset if eval_dataset else self.eval_dataset
assert eval_dataset is not None, "`eval_dataset` has not been provided."
eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=64, shuffle=False)
eval_loss = 0.0
self.model.eval()
for idx, (inputs, labels) in enumerate(eval_loader):
inputs = inputs.view(inputs.size(0), -1)
loss = self._eval_step(inputs, labels)
eval_loss += loss
self.model.train()
eval_loss /= idx
return eval_loss
@overrides
def predict(self, inputs: torch.Tensor) -> None:
self.model.eval()
preds = self.model(inputs)
self.model.train()
return preds
Defining the Model#
Once the data is loaded, we can define any CV-based model. In this example, we will create a simple linear model using PyTorch:
[3]:
from torch import nn
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(28 * 28, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
model = Model()
Creating and Training with the Trainer#
After loading the data and creating the data, we need to plug these instances into the PyTorchTrainer
and start the training, as follows:
[4]:
from archai.datasets.cv.mnist_dataset_provider import MnistDatasetProvider
dataset_provider = MnistDatasetProvider()
train_dataset = dataset_provider.get_train_dataset()
trainer = PyTorchTrainer(model, train_dataset=train_dataset)
trainer.train()
Batch 0 loss: 2.3435773849487305
Batch 10 loss: 2.164100560274991
Batch 20 loss: 2.0146874416442144
Batch 30 loss: 1.875573092891324
Batch 40 loss: 1.755056075933503
Batch 50 loss: 1.65761978486005
Batch 60 loss: 1.5680492149024714
Batch 70 loss: 1.482287242378987
Batch 80 loss: 1.4176807028275948
Batch 90 loss: 1.3575652700204115
Batch 100 loss: 1.3116845883945427
Batch 110 loss: 1.264954976133398
Batch 120 loss: 1.2235281644773877
Batch 130 loss: 1.1893346013913628
Batch 140 loss: 1.1595103922465169
Batch 150 loss: 1.1271054373671676
Batch 160 loss: 1.098986664173766
Batch 170 loss: 1.0724144109159883
Batch 180 loss: 1.0449848247496463
Batch 190 loss: 1.0206239084610764
Batch 200 loss: 1.0005531422237852
Batch 210 loss: 0.9785312015863391
Batch 220 loss: 0.9595723239814534
Batch 230 loss: 0.9406399880394791
Batch 240 loss: 0.9242911396926864
Batch 250 loss: 0.9074264486947382
Batch 260 loss: 0.8933870223746903
Batch 270 loss: 0.8793117023482094
Batch 280 loss: 0.8656814331685945
Batch 290 loss: 0.8519475182511962
Batch 300 loss: 0.8420679773207123
Batch 310 loss: 0.8304505413368201
Batch 320 loss: 0.8197678343343586
Batch 330 loss: 0.8108695821999783
Batch 340 loss: 0.7994857667303504
Batch 350 loss: 0.7905768925308162
Batch 360 loss: 0.7818403058270008
Batch 370 loss: 0.7724894605717569
Batch 380 loss: 0.7644593056262009
Batch 390 loss: 0.7569716618494
Batch 400 loss: 0.7508149850844148
Batch 410 loss: 0.7446284586350703
Batch 420 loss: 0.737757377662738
Batch 430 loss: 0.731570034354856
Batch 440 loss: 0.7251622536623018
Batch 450 loss: 0.7184810209234643
Batch 460 loss: 0.7108740333685906
Batch 470 loss: 0.705083177854048
Batch 480 loss: 0.7006571214808743
Batch 490 loss: 0.6952635809748818
Batch 500 loss: 0.6899459136579327
Batch 510 loss: 0.6855154860392942
Batch 520 loss: 0.680475171004742
Batch 530 loss: 0.6750251736961964
Batch 540 loss: 0.6709054712612396
Batch 550 loss: 0.6657751858667107
Batch 560 loss: 0.6602577523804816
Batch 570 loss: 0.6550697249257628
Batch 580 loss: 0.6502643423780107
Batch 590 loss: 0.6459637832671857
Batch 600 loss: 0.6414088696092615
Batch 610 loss: 0.6377830508785435
Batch 620 loss: 0.6334477859322768
Batch 630 loss: 0.6295056212722971
Batch 640 loss: 0.625435800731833
Batch 650 loss: 0.6220091110046742
Batch 660 loss: 0.6180242504501847
Batch 670 loss: 0.6146804381900324
Batch 680 loss: 0.6113072523680377
Batch 690 loss: 0.6073371331098283
Batch 700 loss: 0.6044996650344805
Batch 710 loss: 0.6016613611356786
Batch 720 loss: 0.5985554359574259
Batch 730 loss: 0.5963995929618867
Batch 740 loss: 0.5928040172970086
Batch 750 loss: 0.5907960743227907
Batch 760 loss: 0.5879131768011389
Batch 770 loss: 0.5846300883322529
Batch 780 loss: 0.5818878866150193
Batch 790 loss: 0.5786483433055215
Batch 800 loss: 0.5762474060467864
Batch 810 loss: 0.5737036844364077
Batch 820 loss: 0.5711945340527397
Batch 830 loss: 0.568284649585702
Batch 840 loss: 0.5661225801203112
Batch 850 loss: 0.5635929554175546
Batch 860 loss: 0.561713579778195
Batch 870 loss: 0.5594659192074865
Batch 880 loss: 0.5575740954898949
Batch 890 loss: 0.555696272622589
Batch 900 loss: 0.5535918302197302
Batch 910 loss: 0.5511906604179305
Batch 920 loss: 0.5482265714001837
Batch 930 loss: 0.5461382602454512
Evaluating and Predicting with the Trainer#
Finally, we evaluate our pre-trained model with the validation set and create a set of random-based inputs to calculate the model’s predictions:
[5]:
val_dataset = dataset_provider.get_val_dataset()
eval_loss = trainer.evaluate(eval_dataset=val_dataset)
print(f"Eval loss: {eval_loss}")
inputs = torch.zeros(1, 28 * 28)
preds = trainer.predict(inputs)
print(f"Predictions: {preds}")
Eval loss: 0.3360353711610421
Predictions: tensor([[-0.1244, 0.2467, -0.0254, -0.0535, 0.0533, 0.1786, -0.0015, 0.1122,
-0.2270, -0.0415]], grad_fn=<AddmmBackward0>)