Source code for archai.supergraph.algos.darts.bilevel_arch_trainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional
from overrides import overrides
from torch import Tensor
from archai.common import ml_utils
from archai.common.config import Config
from archai.supergraph.algos.darts.bilevel_optimizer import BilevelOptimizer
from archai.supergraph.datasets import data
from archai.supergraph.nas.arch_trainer import ArchTrainer
from archai.supergraph.nas.model import Model
from archai.supergraph.utils.checkpoint import CheckPoint
[docs]class BilevelArchTrainer(ArchTrainer):
def __init__(self, conf_train: Config, model: Model,
checkpoint:Optional[CheckPoint]) -> None:
super().__init__(conf_train, model, checkpoint)
self._conf_w_optim = conf_train['optimizer']
self._conf_w_lossfn = conf_train['lossfn']
self._conf_alpha_optim = conf_train['alpha_optimizer']
[docs] @overrides
def pre_fit(self, data_loaders:data.DataLoaders)->None:
super().pre_fit(data_loaders)
# optimizers, schedulers needs to be recreated for each fit call
# as they have state
assert data_loaders.val_dl is not None
w_momentum = self._conf_w_optim['momentum']
w_decay = self._conf_w_optim['decay']
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(self.get_device())
self._bilevel_optim = BilevelOptimizer(self._conf_alpha_optim, w_momentum,
w_decay, self.model, lossfn,
self.get_device(), self.batch_chunks)
[docs] @overrides
def post_fit(self, data_loaders:data.DataLoaders)->None:
# delete state we created in pre_fit
del self._bilevel_optim
return super().post_fit(data_loaders)
[docs] @overrides
def pre_epoch(self, data_loaders:data.DataLoaders)->None:
super().pre_epoch(data_loaders)
# prep val set to train alphas
assert data_loaders.val_dl is not None
self._val_dl = data_loaders.val_dl
self._valid_iter = iter(data_loaders.val_dl) # type: ignore
[docs] @overrides
def post_epoch(self, data_loaders:data.DataLoaders)->None:
del self._val_dl
del self._valid_iter # clean up
super().post_epoch(data_loaders)
[docs] @overrides
def pre_step(self, x: Tensor, y: Tensor) -> None:
super().pre_step(x, y)
# reset val loader if we exausted it
try:
x_val, y_val = next(self._valid_iter)
except StopIteration:
# reinit iterator
self._valid_iter = iter(self._val_dl)
x_val, y_val = next(self._valid_iter)
# update alphas
self._bilevel_optim.step(x, y, x_val, y_val, super().get_optimizer())
[docs] @overrides
def update_checkpoint(self, check_point:CheckPoint)->None:
super().update_checkpoint(check_point)
check_point['bilevel_optim'] = self._bilevel_optim.state_dict()
[docs] @overrides
def restore_checkpoint(self)->None:
super().restore_checkpoint()
self._bilevel_optim.load_state_dict(self.check_point['bilevel_optim'])