Source code for archai.trainers.cv.pl_trainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional, Union
from overrides import overrides
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities.types import (
_EVALUATE_OUTPUT,
_PREDICT_OUTPUT,
EVAL_DATALOADERS,
TRAIN_DATALOADERS,
)
from archai.api.trainer_base import TrainerBase
[docs]class PlTrainer(Trainer, TrainerBase):
"""PyTorch-Lightning trainer."""
[docs] @overrides
def train(
self,
model: LightningModule,
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
ckpt_path: Optional[str] = None,
) -> None:
return self.fit(
model,
train_dataloaders=train_dataloaders,
val_dataloaders=val_dataloaders,
datamodule=datamodule,
ckpt_path=ckpt_path,
)
[docs] @overrides
def evaluate(
self,
model: Optional[LightningModule] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
ckpt_path: Optional[str] = None,
verbose: Optional[bool] = True,
datamodule: Optional[LightningDataModule] = None,
) -> _EVALUATE_OUTPUT:
return self.test(
model=model, dataloaders=dataloaders, ckpt_path=ckpt_path, verbose=verbose, datamodule=datamodule
)
[docs] @overrides
def predict(
self,
model: Optional[LightningModule] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
datamodule: Optional[LightningDataModule] = None,
return_predictions: Optional[bool] = None,
ckpt_path: Optional[str] = None,
) -> Optional[_PREDICT_OUTPUT]:
# Needs to call method directly from class to avoid infinite recurssion
# due to same method name in TrainerBase
return Trainer.predict(
self,
model=model,
dataloaders=dataloaders,
datamodule=datamodule,
return_predictions=return_predictions,
ckpt_path=ckpt_path,
)