Source code for archai.supergraph.utils.multi_optim

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Iterator, List, Optional

from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer

from archai.common.utils import zip_eq


[docs]class OptimSched: """Holds the optimizer and scheduler""" def __init__(self, optim:Optimizer, sched:Optional[_LRScheduler], sched_on_epoch:Optional[bool])->None: self.optim = optim self.sched = sched self.sched_on_epoch = sched_on_epoch
[docs]class MultiOptim: def __init__(self) -> None: self._optim_scheds:List[OptimSched] = []
[docs] def append(self, optim_sched:OptimSched)->None: self._optim_scheds.append(optim_sched)
[docs] def zero_grad(self)->None: for optim_sched in self._optim_scheds: optim_sched.optim.zero_grad()
[docs] def step(self)->None: for optim_sched in self._optim_scheds: optim_sched.optim.step() if optim_sched.sched and not optim_sched.sched_on_epoch: optim_sched.sched.step(epoch=None)
[docs] def epoch(self, epoch:Optional[int]=None)->None: for optim_sched in self._optim_scheds: if optim_sched.sched and optim_sched.sched_on_epoch: optim_sched.sched.step(epoch=epoch)
[docs] def get_lr(self, optim_index:int, param_index:int)->float: return self._optim_scheds[optim_index].optim.param_groups[param_index]['lr']
[docs] def state_dict(self)->dict: optim_states = [optim_sched.optim.state_dict() for optim_sched in self] sched_states = [optim_sched.sched.state_dict() if optim_sched.sched else None \ for optim_sched in self] return {'optim_states': optim_states, 'sched_states':sched_states}
[docs] def load_state_dict(self, state_dict:dict)->None: optim_states = state_dict['optim_states'] sched_states = state_dict['sched_states'] for optim_sched, optim_state, sched_state in zip_eq(self, optim_states, sched_states): optim_sched.optim.load_state_dict(optim_state) if optim_sched.sched: assert sched_state is not None optim_sched.sched.load_state_dict(sched_state) else: assert sched_state is None
def __getitem__(self, index)->OptimSched: return self._optim_scheds[index] def __len__(self)->int: return len(self._optim_scheds) def __iter__(self)->Iterator[OptimSched]: return iter(self._optim_scheds)