Training a CV-based Model#
Training a CV-based model with PyTorch-Lightning is a simplified process, where the model architecture, loss function, and training process are defined using the LightningModule
. Archai offers a set of dataset providers to load and pre-process the data. Additionally, Archai provides a PlTrainer
which wraps the TrainerBase
abstraction and renames methods so they fit in the search interface.
Loading the Data#
When using a dataset provider, the data loading process is simplified, as the provider takes care of downloading and pre-processing the required dataset.
This step is accomplished in the same way as the previous notebook:
[1]:
from archai.datasets.cv.mnist_dataset_provider import MnistDatasetProvider
dataset_provider = MnistDatasetProvider()
train_dataset = dataset_provider.get_train_dataset()
val_dataset = dataset_provider.get_val_dataset()
Defining the Model#
Once the data is loaded, we can define any CV-based model. In this example, we will create a simple linear model using PyTorch and wrapping it with LightningModule
from PyTorch-Lightning.
Additionally, PyTorch-Lightning requires that some methods are implemented, such as:
forward
: Defines the forward pass of the model.training_step
: Defines the training step (loop) of the model.test_step
: If usingevaluate
, it defines the evaluation step (loop) of the model.configure_optimizers
: Defines the optimizer and attaches the model’s parameters.
[2]:
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch import nn
class Model(pl.LightningModule):
def __init__(self):
super().__init__()
self.linear = nn.Linear(28 * 28, 10)
def forward(self, x):
return self.linear(x)
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
x_hat = self.linear(x)
loss = F.cross_entropy(x_hat, y)
self.log("train_loss", loss)
return loss
def test_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
x_hat = self.linear(x)
loss = F.cross_entropy(x_hat, y)
self.log("val_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
Running the Trainer#
The final step is to use the PyTorch-Lightning trainer abstraction (PlTrainer
) to conduct the training process, which involves optimizing the model’s parameters using a pre-defined optimization algorithm and loss function, and updating the model’s parameters based on the training data. This process is repeated until the model converges to a satisfactory accuracy or performance level.
[3]:
from torch.utils.data import DataLoader
from archai.trainers.cv.pl_trainer import PlTrainer
model = Model()
trainer = PlTrainer(max_steps=1, limit_train_batches=1, limit_test_batches=1, limit_predict_batches=1)
trainer.train(model, DataLoader(train_dataset))
trainer.evaluate(model, DataLoader(val_dataset))
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_predict_batches=1)` was configured so 1 batch will be used.
Missing logger folder: c:\Users\gderosa\Projects\archai\docs\getting_started\notebooks\cv\lightning_logs
| Name | Type | Params
----------------------------------
0 | linear | Linear | 7.9 K
----------------------------------
7.9 K Trainable params
0 Non-trainable params
7.9 K Total params
0.031 Total estimated model params size (MB)
c:\Users\gderosa\Anaconda3\envs\archai\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:229: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
category=PossibleUserWarning,
c:\Users\gderosa\Anaconda3\envs\archai\lib\site-packages\pytorch_lightning\trainer\trainer.py:1604: PossibleUserWarning: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
category=PossibleUserWarning,
`Trainer.fit` stopped: `max_steps=1` reached.
c:\Users\gderosa\Anaconda3\envs\archai\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:229: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
category=PossibleUserWarning,
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ val_loss │ 2.2003676891326904 │ └───────────────────────────┴───────────────────────────┘
[3]:
[{'val_loss': 2.2003676891326904}]