Source code for archai.supergraph.nas.arch_trainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
from typing import Callable, Optional, Type
import torch
from overrides import EnforceOverrides, overrides
from torch import Tensor
from archai.common import utils
from archai.common.config import Config
from archai.supergraph.datasets import data
from archai.supergraph.nas.model import Model
from archai.supergraph.nas.vis_model_desc import draw_model_desc
from archai.supergraph.utils.checkpoint import CheckPoint
from archai.supergraph.utils.trainer import Trainer
TArchTrainer = Optional[Type['ArchTrainer']]
[docs]class ArchTrainer(Trainer, EnforceOverrides):
def __init__(self, conf_train: Config, model: Model,
checkpoint:Optional[CheckPoint]) -> None:
super().__init__(conf_train, model, checkpoint)
self._l1_alphas = conf_train['l1_alphas']
self._plotsdir = conf_train['plotsdir']
# if l1 regularization is needed then cache alphas
if self._l1_alphas > 0.0:
self._alphas = list(self.model.all_owned().param_by_kind('alphas'))
[docs] @overrides
def compute_loss(self, lossfn: Callable,
y: Tensor, logits: Tensor,
aux_weight: float, aux_logits: Optional[Tensor]) -> Tensor:
loss = super().compute_loss(lossfn, y, logits,
aux_weight, aux_logits)
# add L1 alpha regularization
if self._l1_alphas > 0.0:
l_extra = sum(torch.sum(a.abs()) for a in self._alphas)
loss += self._l1_alphas * l_extra
return loss
[docs] @overrides
def post_epoch(self, data_loaders:data.DataLoaders)->None:
super().post_epoch(data_loaders)
self._draw_model()
# TODO: move this outside as utility
def _draw_model(self) -> None:
if not self._plotsdir:
return
train_metrics = self.get_metrics()
if train_metrics:
best_train, best_val, best_test = train_metrics.run_metrics.best_epoch()
# if test is available and is best for this epoch then mark it as best
is_best = best_test and best_test.index==train_metrics.cur_epoch().index
# if val is available and is best for this epoch then mark it as best
is_best = is_best or best_val and best_val.index==train_metrics.cur_epoch().index
# if neither val or test availavle then use train metrics
is_best = is_best or best_train.index==train_metrics.cur_epoch().index
if is_best:
# log model_desc as a image
plot_filepath = utils.full_path(os.path.join(
self._plotsdir,
f"EP{train_metrics.cur_epoch().index:03d}"),
create=True)
draw_model_desc(self.model.finalize(), filepath=plot_filepath,
caption=f"Epoch {train_metrics.cur_epoch().index}")