Training a CV-based Model#

Training a CV-based model with PyTorch-Lightning is a simplified process, where the model architecture, loss function, and training process are defined using the LightningModule. Archai offers a set of dataset providers to load and pre-process the data. Additionally, Archai provides a PlTrainer which wraps the TrainerBase abstraction and renames methods so they fit in the search interface.

Loading the Data#

When using a dataset provider, the data loading process is simplified, as the provider takes care of downloading and pre-processing the required dataset.

This step is accomplished in the same way as the previous notebook:

[1]:
from archai.datasets.cv.mnist_dataset_provider import MnistDatasetProvider

dataset_provider = MnistDatasetProvider()

train_dataset = dataset_provider.get_train_dataset()
val_dataset = dataset_provider.get_val_dataset()

Defining the Model#

Once the data is loaded, we can define any CV-based model. In this example, we will create a simple linear model using PyTorch and wrapping it with LightningModule from PyTorch-Lightning.

Additionally, PyTorch-Lightning requires that some methods are implemented, such as:

  • forward: Defines the forward pass of the model.

  • training_step: Defines the training step (loop) of the model.

  • test_step: If using evaluate, it defines the evaluation step (loop) of the model.

  • configure_optimizers: Defines the optimizer and attaches the model’s parameters.

[2]:
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch import nn


class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.linear = nn.Linear(28 * 28, 10)

    def forward(self, x):
        return self.linear(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)

        x_hat = self.linear(x)
        loss = F.cross_entropy(x_hat, y)

        self.log("train_loss", loss)

        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)

        x_hat = self.linear(x)
        loss = F.cross_entropy(x_hat, y)

        self.log("val_loss", loss)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

Running the Trainer#

The final step is to use the PyTorch-Lightning trainer abstraction (PlTrainer) to conduct the training process, which involves optimizing the model’s parameters using a pre-defined optimization algorithm and loss function, and updating the model’s parameters based on the training data. This process is repeated until the model converges to a satisfactory accuracy or performance level.

[3]:
from torch.utils.data import DataLoader
from archai.trainers.cv.pl_trainer import PlTrainer

model = Model()
trainer = PlTrainer(max_steps=1, limit_train_batches=1, limit_test_batches=1, limit_predict_batches=1)

trainer.train(model, DataLoader(train_dataset))
trainer.evaluate(model, DataLoader(val_dataset))
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_predict_batches=1)` was configured so 1 batch will be used.
Missing logger folder: c:\Users\gderosa\Projects\archai\docs\getting_started\notebooks\cv\lightning_logs

  | Name   | Type   | Params
----------------------------------
0 | linear | Linear | 7.9 K
----------------------------------
7.9 K     Trainable params
0         Non-trainable params
7.9 K     Total params
0.031     Total estimated model params size (MB)
c:\Users\gderosa\Anaconda3\envs\archai\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:229: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  category=PossibleUserWarning,
c:\Users\gderosa\Anaconda3\envs\archai\lib\site-packages\pytorch_lightning\trainer\trainer.py:1604: PossibleUserWarning: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  category=PossibleUserWarning,
`Trainer.fit` stopped: `max_steps=1` reached.
c:\Users\gderosa\Anaconda3\envs\archai\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:229: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  category=PossibleUserWarning,
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         val_loss              2.2003676891326904     │
└───────────────────────────┴───────────────────────────┘
[3]:
[{'val_loss': 2.2003676891326904}]