Source code for archai.trainers.cv.pl_trainer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Optional, Union

from overrides import overrides
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities.types import (
    _EVALUATE_OUTPUT,
    _PREDICT_OUTPUT,
    EVAL_DATALOADERS,
    TRAIN_DATALOADERS,
)

from archai.api.trainer_base import TrainerBase


[docs]class PlTrainer(Trainer, TrainerBase): """PyTorch-Lightning trainer."""
[docs] @overrides def train( self, model: LightningModule, train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, ckpt_path: Optional[str] = None, ) -> None: return self.fit( model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule, ckpt_path=ckpt_path, )
[docs] @overrides def evaluate( self, model: Optional[LightningModule] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[str] = None, verbose: Optional[bool] = True, datamodule: Optional[LightningDataModule] = None, ) -> _EVALUATE_OUTPUT: return self.test( model=model, dataloaders=dataloaders, ckpt_path=ckpt_path, verbose=verbose, datamodule=datamodule )
[docs] @overrides def predict( self, model: Optional[LightningModule] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, ckpt_path: Optional[str] = None, ) -> Optional[_PREDICT_OUTPUT]: # Needs to call method directly from class to avoid infinite recurssion # due to same method name in TrainerBase return Trainer.predict( self, model=model, dataloaders=dataloaders, datamodule=datamodule, return_predictions=return_predictions, ckpt_path=ckpt_path, )