Source code for archai.trainers.gradual_warmup_scheduler
# Copyright (c) 2020 abhuse.
# Licensed under the MIT license.
# https://github.com/ildoonet/pytorch-gradual-warmup-lr
from typing import Any, Dict, List, Optional
from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
from torch.optim.optimizer import Optimizer
[docs]class GradualWarmupScheduler(_LRScheduler):
"""Gradually warm-up (increasing) learning rate in optimizer.
It has been proposed in `Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour`.
"""
def __init__(
self, optimizer: Optimizer, multiplier: float, total_epoch: int, after_scheduler: Optional[_LRScheduler] = None
) -> None:
"""Initialize the scheduler.
Args:
optimizer: Wrapped optimizer.
multiplier: Target learning rate = base lr * multiplier if multiplier > 1.0.
If multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
total_epoch: Target learning rate is reached gradually at total_epoch.
after_scheduler: After target_epoch, use this scheduler.
"""
self.multiplier = multiplier
if self.multiplier < 1.0:
raise ValueError("Multiplier should be >= 1.")
self.total_epoch = total_epoch
self.after_scheduler = after_scheduler
self.finished = False
super(GradualWarmupScheduler, self).__init__(optimizer)
[docs] def get_lr(self) -> List[float]:
if self.last_epoch > self.total_epoch:
if self.after_scheduler:
if not self.finished:
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
self.finished = True
return self.after_scheduler.get_lr()
return [base_lr * self.multiplier for base_lr in self.base_lrs]
if self.multiplier == 1.0:
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
else:
return [
base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
for base_lr in self.base_lrs
]
def _step_reduce_lr(self, epoch: int, metrics: Dict[str, Any]) -> None:
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch if epoch != 0 else 1
if self.last_epoch <= self.total_epoch:
warmup_lr = [
base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
for base_lr in self.base_lrs
]
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
param_group["lr"] = lr
else:
if epoch is None:
self.after_scheduler.step(None, metrics)
else:
self.after_scheduler.step(epoch - self.total_epoch, metrics)
[docs] def step(self, epoch: Optional[int] = None, metrics: Optional[Dict[str, Any]] = None) -> None:
if type(self.after_scheduler) != ReduceLROnPlateau:
if self.finished and self.after_scheduler:
if epoch is None:
self.after_scheduler.step(None)
else:
self.after_scheduler.step(epoch - self.total_epoch)
self._last_lr = self.after_scheduler.get_last_lr()
else:
return super(GradualWarmupScheduler, self).step(epoch)
else:
self._step_reduce_lr(epoch, metrics)