# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Callable, Optional, Tuple
import torch
from overrides import EnforceOverrides
from torch import Tensor, nn
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from archai.common import ml_utils
from archai.common.apex_utils import ApexUtils
from archai.common.config import Config
from archai.common.ordered_dict_logger import get_global_logger
from archai.supergraph.datasets import data
from archai.supergraph.utils.checkpoint import CheckPoint
from archai.supergraph.utils.metrics import Metrics
from archai.supergraph.utils.multi_optim import MultiOptim, OptimSched
from archai.supergraph.utils.tester import Tester
logger = get_global_logger()
[docs]class Trainer(EnforceOverrides):
def __init__(self, conf_train:Config, model:nn.Module,
checkpoint:Optional[CheckPoint]=None)->None:
# region config vars
self.conf_train = conf_train
conf_lossfn = conf_train['lossfn']
self._aux_weight = conf_train['aux_weight']
self._grad_clip = conf_train['grad_clip']
self._drop_path_prob = conf_train['drop_path_prob']
self._logger_freq = conf_train['logger_freq']
self._title = conf_train['title']
self._epochs = conf_train['epochs']
self.conf_optim = conf_train['optimizer']
self.conf_sched = conf_train['lr_schedule']
self.batch_chunks = conf_train['batch_chunks']
conf_validation = conf_train['validation']
conf_apex = conf_train['apex']
self._validation_freq = 0 if conf_validation is None else conf_validation['freq']
# endregion
logger.pushd(self._title + '__init__')
self._apex = ApexUtils(conf_apex)
self._checkpoint = checkpoint
self.model = model
self._lossfn = ml_utils.get_lossfn(conf_lossfn)
# using separate apex for Tester is not possible because we must use
# same distributed model as Trainer and hence they must share apex
self._tester = Tester(conf_validation, model, self._apex) \
if conf_validation else None
self._metrics:Optional[Metrics] = None
self._droppath_module = self._get_droppath_module()
if self._droppath_module is None and self._drop_path_prob > 0.0:
logger.warn({'droppath_module': None})
self._start_epoch = -1 # nothing is started yet
logger.popd()
[docs] def fit(self, data_loaders:data.DataLoaders)->Metrics:
logger.pushd(self._title)
assert data_loaders.train_dl is not None
self._metrics = Metrics(self._title, self._apex, logger_freq=self._logger_freq)
# create optimizers and schedulers
self._multi_optim = self.create_multi_optim(len(data_loaders.train_dl))
# before checkpoint restore, convert to amp
self.model = self._apex.to_amp(self.model, self._multi_optim,
batch_size=data_loaders.train_dl.batch_size) # pyright: ignore[reportGeneralTypeIssues]
self._lossfn = self._lossfn.to(self.get_device())
self.pre_fit(data_loaders)
# we need to restore checkpoint after all objects are created because
# restoring checkpoint requires load_state_dict calls on these objects
self._start_epoch = 0
# do we have a checkpoint
checkpoint_avail = self._checkpoint is not None
checkpoint_val = checkpoint_avail and 'trainer' in self._checkpoint # pyright: ignore[reportGeneralTypeIssues]
resumed = False
if checkpoint_val:
# restore checkpoint
resumed = True
self.restore_checkpoint()
elif checkpoint_avail: # TODO: bad checkpoint?
self._checkpoint.clear()
logger.warn({'resumed': resumed, 'checkpoint_avail': checkpoint_avail,
'checkpoint_val': checkpoint_val,
'start_epoch': self._start_epoch,
'total_epochs': self._epochs})
logger.info({'aux_weight': self._aux_weight,
'grad_clip': self._grad_clip,
'drop_path_prob': self._drop_path_prob,
'validation_freq': self._validation_freq,
'batch_chunks': self.batch_chunks})
if self._start_epoch >= self._epochs:
logger.warn(f'fit done because start_epoch {self._start_epoch}>={self._epochs}')
return self.get_metrics() # we already finished the run, we might be checkpointed
logger.pushd('epochs')
for epoch in range(self._start_epoch, self._epochs):
logger.pushd(epoch)
self._set_epoch(epoch, data_loaders)
self.pre_epoch(data_loaders)
self._train_epoch(data_loaders.train_dl)
self.post_epoch(data_loaders)
logger.popd()
logger.popd()
self.post_fit(data_loaders)
# make sure we don't keep references to the graph
del self._multi_optim
logger.popd()
return self.get_metrics()
[docs] def create_multi_optim(self, train_len:int)->MultiOptim:
logger.info({'steps_per_epoch': train_len,
'conf_sched': self.conf_sched.to_dict()})
logger.info({'conf_optim': self.conf_optim.to_dict()})
# optimizers, schedulers needs to be recreated for each fit call
# as they have state specific to each run
optim = self.create_optimizer(self.conf_optim, self.model.parameters())
# create scheduler for optim before applying amp
sched, sched_on_epoch = self.create_scheduler(self.conf_sched, optim, train_len)
multi_optim = MultiOptim()
multi_optim.append(OptimSched(optim, sched, sched_on_epoch))
logger.info({'multi_optim_len': len(multi_optim)})
return multi_optim
[docs] def create_optimizer(self, conf_optim:Config, params)->Optimizer:
optim = ml_utils.create_optimizer(conf_optim, params)
return optim
[docs] def create_scheduler(self, conf_sched:Config, optim:Optimizer, steps_per_epoch:int) \
->Tuple[Optional[_LRScheduler],bool]:
return ml_utils.create_lr_scheduler(conf_sched, self._epochs,
optim, steps_per_epoch)
[docs] def get_optimizer(self, index=0)->Optimizer:
return self._multi_optim[index].optim
[docs] def get_scheduler(self, index=0)->Optional[_LRScheduler]:
return self._multi_optim[index].sched
[docs] def get_metrics(self)->Metrics:
return self._metrics # pyright: ignore[reportGeneralTypeIssues]
def _set_epoch(self, epoch:int, data_loaders:data.DataLoaders)->None:
# optimizers such as bi-level may use val set for its own use
# which causes reshuffling due to automatic epoch counting
# here we make sure that val_dl has same epoch as train_dl
if hasattr(data_loaders.train_dl.sampler, 'set_epoch'):
data_loaders.train_dl.sampler.set_epoch(epoch) # pyright: ignore[reportGeneralTypeIssues,reportOptionalMemberAccess]
if data_loaders.val_dl is not None and hasattr(data_loaders.val_dl.sampler, 'set_epoch'):
data_loaders.val_dl.sampler.set_epoch(epoch) # pyright: ignore[reportGeneralTypeIssues]
# apply droppath
self._set_drop_path(epoch, self._epochs)
assert self._metrics.epochs() == epoch
######################### hooks #########################
[docs] def pre_fit(self, data_loaders:data.DataLoaders)->None:
self._metrics.pre_run()
[docs] def post_fit(self, data_loaders:data.DataLoaders)->None:
test_metrics = None
# first run test before checkpointing, otherwise we won't have val metrics
if data_loaders.test_dl and self._tester:
test_metrics = self._tester.test(data_loaders.test_dl)
self._metrics.post_run(test_metrics=test_metrics)
[docs] def pre_epoch(self, data_loaders:data.DataLoaders)->None:
self._metrics.pre_epoch(lr=self._multi_optim.get_lr(0, 0))
[docs] def post_epoch(self, data_loaders:data.DataLoaders)->None:
val_metrics = None
# first run test before checkpointing, otherwise we won't have val metrics
if data_loaders.val_dl and self._tester and self._validation_freq > 0:
if self._metrics.epochs() % self._validation_freq == 0 or \
self._metrics.epochs() >= self._epochs: # last epoch
# these asserts makes sure train and val are not ovrlapiing
# assert train_dl.sampler.epoch == val_dl.sampler.epoch
# tidx = list(train_dl.sampler)
# vidx = list(val_dl.sampler)
# assert all(ti not in vidx for ti in tidx)
val_metrics = self._tester.test(data_loaders.val_dl)
# update val metrics
self._metrics.post_epoch(lr=self._multi_optim.get_lr(0, 0), val_metrics=val_metrics)
# checkpoint if enabled with given freq or if this is the last epoch
if self._checkpoint is not None and self._apex.is_master() and \
self._checkpoint.freq > 0 and (self._metrics.epochs() % self._checkpoint.freq == 0 or \
self._metrics.epochs() >= self._epochs):
self._checkpoint.new()
self.update_checkpoint(self._checkpoint)
self._checkpoint.commit()
[docs] def pre_step(self, x:Tensor, y:Tensor)->None:
self._metrics.pre_step(x, y)
[docs] def post_step(self, x:Tensor, y:Tensor, logits:Tensor, loss:Tensor,
steps:int)->None:
self._metrics.post_step(x, y, logits, loss, steps)
######################### hooks #########################
[docs] def get_device(self):
return self._apex.device
[docs] def restore_checkpoint(self)->None:
state = self._checkpoint['trainer']
last_epoch = state['last_epoch']
assert last_epoch >= 0 and last_epoch < self._epochs
self._metrics.load_state_dict(state['metrics'])
assert self._metrics.epochs() == last_epoch+1
self._apex.load_state_dict(state['amp'])
self.model.load_state_dict(state['model'])
self._multi_optim.load_state_dict(state['multi_optim'])
self._start_epoch = last_epoch + 1
[docs] def epoch(self)->int:
return self._metrics.epochs()
[docs] def update_checkpoint(self, checkpoint:CheckPoint)->None:
# TODO: Don't need to pass checkpoint
# save all necessory state
state = {
'last_epoch': self._metrics.epochs()-1,
'metrics': self._metrics.state_dict(),
'model': self.model.state_dict(),
'multi_optim': self._multi_optim.state_dict(),
'amp': self._apex.state_dict()
}
self._checkpoint['trainer'] = state
def _train_epoch(self, train_dl: DataLoader)->None:
steps = len(train_dl)
self.model.train()
logger.pushd('steps')
for step, (x, y) in enumerate(train_dl):
logger.pushd(step)
assert self.model.training # derived class might alter the mode
# TODO: please check that no algorithm is invalidated by swapping prestep with zero grad
self._multi_optim.zero_grad()
self.pre_step(x, y)
# divide batch in to chunks if needed so it fits in GPU RAM
if self.batch_chunks > 1:
x_chunks, y_chunks = torch.chunk(x, self.batch_chunks), torch.chunk(y, self.batch_chunks)
else:
x_chunks, y_chunks = (x,), (y,)
logits_chunks = []
loss_sum, loss_count = 0.0, 0
for xc, yc in zip(x_chunks, y_chunks):
xc, yc = xc.to(self.get_device(), non_blocking=True), yc.to(self.get_device(), non_blocking=True)
with self._apex.autocast():
logits_c, aux_logits = self.model(xc), None
tupled_out = isinstance(logits_c, Tuple) and len(logits_c) >=2
# if self._aux_weight: # TODO: some other way to validate?
# assert tupled_out, "aux_logits cannot be None unless aux tower is disabled"
if tupled_out: # then we are using model created by desc
logits_c, aux_logits = logits_c[0], logits_c[1]
loss_c = self.compute_loss(self._lossfn, yc, logits_c, # pyright: ignore[reportGeneralTypeIssues]
self._aux_weight, aux_logits)
self._apex.backward(loss_c)
loss_sum += loss_c.item() * len(logits_c)
loss_count += len(logits_c)
# TODO: cannot place on CPU if it was half precision but should we somehow?
logits_chunks.append(logits_c.detach()) # pyright: ignore[reportGeneralTypeIssues]
# TODO: original darts clips alphas as well but pt.darts doesn't
self._apex.clip_grad(self._grad_clip, self.model, self._multi_optim)
self._apex.step(self._multi_optim)
# TODO: we possibly need to sync so all replicas are upto date
self._apex.sync_devices()
# TODO: we need to put y on GPU because logits are on GPU. Is this good idea from GPU mem perspective?
self.post_step(x, y.to(self.get_device(), non_blocking=True),
ml_utils.join_chunks(logits_chunks),
torch.tensor(loss_sum/loss_count),
steps)
logger.popd()
# end of step
self._multi_optim.epoch()
logger.popd()
[docs] def compute_loss(self, lossfn:Callable, y:Tensor, logits:Tensor,
aux_weight:float, aux_logits:Optional[Tensor])->Tensor:
loss = lossfn(logits, y)
if aux_weight > 0.0 and aux_logits is not None:
loss += aux_weight * lossfn(aux_logits, y)
return loss
def _get_droppath_module(self)->Optional[nn.Module]:
m = self.model
if hasattr(self.model, 'module'): # for data parallel model
m = self.model.module
if hasattr(m, 'drop_path_prob'):
return m # pyright: ignore[reportGeneralTypeIssues]
return None
def _set_drop_path(self, epoch:int, epochs:int)->None:
if self._drop_path_prob and self._droppath_module is not None:
drop_prob = self._drop_path_prob * epoch / epochs
# set value as property in model (it will be used by forward())
# this is necessory when using DataParallel(model)
# https://github.com/pytorch/pytorch/issues/16885
m = self.model
if hasattr(self.model, 'module'): # for data parallel model
m = self.model.module
if hasattr(m, 'drop_path_prob'):
m.drop_path_prob(drop_prob) # pyright: ignore[reportGeneralTypeIssues]
else:
raise RuntimeError('Drop path value {} was specified but model'
' does not have drop_path_prob() method'\
.format(self._drop_path_prob))