Source code for archai.supergraph.utils.augmented_trainer

import itertools
import math
import os
from collections import OrderedDict

import torch
from torch import nn
from torch.nn.parallel.data_parallel import DataParallel
from tqdm import tqdm

from archai.common import ml_utils, utils
from archai.common.common import get_tb_writer
from archai.common.ordered_dict_logger import get_global_logger
from archai.supergraph.datasets.data import get_dataloaders
from archai.supergraph.models import get_model, num_class
from archai.supergraph.utils.metrics import Accumulator

logger = get_global_logger()


# TODO: remove scheduler parameter?
[docs]def run_epoch( conf, logger, model: nn.Module, loader, loss_fn, optimizer, split_type: str, epoch=0, verbose=1, scheduler=None ): """Runs epoch for given dataloader and model. If optimizer is supplied then backprop and model update is done as well. This can be called from test to train modes. """ writer = get_tb_writer() # region conf vars conf_loader = conf["autoaug"]["loader"] epochs = conf_loader["epochs"] conf_opt = conf["autoaug"]["optimizer"] grad_clip = conf_opt["clip"] # endregion tqdm_disable = bool(os.environ.get("TASK_NAME", "")) # TODO: remove? if verbose: loader = tqdm(loader, disable=tqdm_disable) loader.set_description("[%s %04d/%04d]" % (split_type, epoch, epochs)) metrics = Accumulator() cnt = 0 steps = 0 for data, label in loader: steps += 1 data, label = data.cuda(), label.cuda() if optimizer: optimizer.zero_grad() preds = model(data) loss = loss_fn(preds, label) if optimizer: loss.backward() if getattr(optimizer, "synchronize", None): optimizer.synchronize() # for horovod # grad clipping defaults to 5 (same as Darts) if grad_clip > 0.0: nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() top1, top5 = ml_utils.accuracy(preds, label, (1, 5)) metrics.add_dict( { "loss": loss.item() * len(data), "top1": top1.item() * len(data), "top5": top5.item() * len(data), } ) cnt += len(data) if verbose: postfix = metrics / cnt if optimizer: if "lr" in optimizer.param_groups[0]: postfix["lr"] = optimizer.param_groups[0]["lr"] loader.set_postfix(postfix) # below changes LR for every batch in epoch # TODO: should we do LR step at epoch start only? # if scheduler is not None: # scheduler.step(epoch - 1 + float(steps) / total_steps) del preds, loss, top1, top5, data, label if tqdm_disable: if optimizer: logger.info( "[%s %03d/%03d] %s lr=%.6f", split_type, epoch, epochs, metrics / cnt, optimizer.param_groups[0]["lr"] ) else: logger.info("[%s %03d/%03d] %s", split_type, epoch, epochs, metrics / cnt) metrics /= cnt if optimizer: if "lr" in optimizer.param_groups[0]: metrics.metrics["lr"] = optimizer.param_groups[0]["lr"] if verbose: for key, value in metrics.items(): writer.add_scalar("{}/{}".format(key, split_type), value, epoch) return metrics
# NOTE that 'eval' is overloaded in this code base. 'eval' here means # taking a trained model and running it on val or test sets. In NAS 'eval' # often means taking a found model and training it fully (often termed 'final training'). # metric could be 'last', 'test', 'val', 'train'.
[docs]def train_and_eval(conf, val_ratio, val_fold, save_path, only_eval, reporter=None, metric="test"): writer = get_tb_writer() # region conf vars conf_dataset = conf["dataset"] dataroot = utils.full_path(conf_dataset["dataroot"]) horovod = conf["common"]["horovod"] checkpoint_freq = conf["common"]["checkpoint"]["freq"] conf_loader = conf["autoaug"]["loader"] conf_model = conf["autoaug"]["model"] ds_name = conf_dataset["name"] aug = conf_loader["aug"] cutout = conf_loader["cutout"] batch_size = conf_loader["batch"] max_batches = conf_dataset["max_batches"] epochs = conf_loader["epochs"] conf_model = conf["autoaug"]["model"] conf_opt = conf["autoaug"]["optimizer"] conf_lr_sched = conf["autoaug"]["lr_schedule"] n_workers = conf_loader["n_workers"] # endregion # initialize horovod # TODO: move to common init if horovod: import horovod.torch as hvd hvd.init() device = torch.device("cuda", hvd.local_rank()) torch.cuda.set_device(device) if not reporter: def reporter(**kwargs): return 0 # get dataloaders with transformations and splits applied train_dl, valid_dl, test_dl = get_dataloaders( ds_name, batch_size, dataroot, aug, cutout, load_train=True, load_test=True, val_ratio=val_ratio, val_fold=val_fold, horovod=horovod, n_workers=n_workers, max_batches=max_batches, ) # create a model & an optimizer model = get_model(conf_model, num_class(ds_name), data_parallel=(not horovod)) # select loss function and optimizer lossfn = nn.CrossEntropyLoss() optimizer = ml_utils.create_optimizer(conf_opt, model.parameters()) # distributed optimizer if horovod is used is_master = True if horovod: optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) # issue : https://github.com/horovod/horovod/issues/1099 optimizer._requires_update = set() hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) if hvd.rank() != 0: is_master = False logger.debug("is_master=%s" % is_master) # select LR schedule scheduler = ml_utils.create_lr_scheduler(conf_lr_sched, epochs, optimizer, len(train_dl)) result = OrderedDict() epoch_start = 1 # if model available from previous checkpount then load it if save_path and os.path.exists(save_path): logger.info("%s checkpoint found. loading..." % save_path) data = torch.load(save_path) # when checkpointing we do add 'model' key so other cases are special cases if "model" in data or "state_dict" in data: key = "model" if "model" in data else "state_dict" logger.info("checkpoint epoch@%d" % data["epoch"]) # TODO: do we need change here? if not isinstance(model, DataParallel): # for non-dataparallel models, remove default 'module.' prefix model.load_state_dict({k.replace("module.", ""): v for k, v in data[key].items()}) else: # for dataparallel models, make sure 'module.' prefix exist model.load_state_dict({k if "module." in k else "module." + k: v for k, v in data[key].items()}) # load optimizer optimizer.load_state_dict(data["optimizer"]) # restore epoch count if data["epoch"] < epochs: epoch_start = data["epoch"] else: # epochs finished, switch to eval mode only_eval = False else: model.load_state_dict({k: v for k, v in data.items()}) del data else: logger.info( 'model checkpoint does not exist at "%s". skip \ to pretrain weights...' % save_path ) only_eval = False # we made attempt to load checkpt but as it does not exist, switch to train mode # if eval only then run model on train, test and val sets if only_eval: logger.info("evaluation only+") model.eval() rs = dict() # stores metrics for each set rs["train"] = run_epoch(conf, logger, model, train_dl, lossfn, None, split_type="train", epoch=0) if valid_dl: rs["valid"] = run_epoch(conf, logger, model, valid_dl, lossfn, None, split_type="valid", epoch=0) rs["test"] = run_epoch(conf, logger, model, test_dl, lossfn, None, split_type="test", epoch=0) for key, setname in itertools.product(["loss", "top1", "top5"], ["train", "valid", "test"]): result["%s_%s" % (key, setname)] = rs[setname][key] result["epoch"] = 0 return result # train loop best_top1, best_valid_loss = 0, 10.0e10 max_epoch = epochs for epoch in range(epoch_start, max_epoch + 1): # if horovod: # trainsampler.set_epoch(epoch) # run train epoch and update the model model.train() rs = dict() rs["train"] = run_epoch( conf, logger, model, train_dl, lossfn, optimizer, split_type="train", epoch=epoch, verbose=is_master, scheduler=scheduler, ) if scheduler[0]: scheduler[0].step() model.eval() # check for nan loss if math.isnan(rs["train"]["loss"]): raise Exception("train loss is NaN.") # collect metrics on val and test set, checkpoint if epoch % checkpoint_freq == 0 or epoch == max_epoch: if valid_dl: rs["valid"] = run_epoch( conf, logger, model, valid_dl, lossfn, None, split_type="valid", epoch=epoch, verbose=is_master ) rs["test"] = run_epoch( conf, logger, model, test_dl, lossfn, None, split_type="test", epoch=epoch, verbose=is_master ) # TODO: is this good enough condition? if rs[metric]["loss"] < best_valid_loss or rs[metric]["top1"] > best_top1: best_top1 = rs[metric]["top1"] best_valid_loss = rs[metric]["loss"] for key, setname in itertools.product(["loss", "top1", "top5"], ["train", "valid", "test"]): result["%s_%s" % (key, setname)] = rs[setname][key] result["epoch"] = epoch writer.add_scalar("best_top1/valid", rs["valid"]["top1"], epoch) writer.add_scalar("best_top1/test", rs["test"]["top1"], epoch) reporter( loss_valid=rs["valid"]["loss"], top1_valid=rs["valid"]["top1"], loss_test=rs["test"]["loss"], top1_test=rs["test"]["top1"], ) # save checkpoint if is_master and save_path: logger.info("save model@%d to %s" % (epoch, save_path)) torch.save( { "epoch": epoch, "log": { "train": rs["train"].get_dict(), "valid": rs["valid"].get_dict(), "test": rs["test"].get_dict(), }, "optimizer": optimizer.state_dict(), "model": model.state_dict(), }, save_path, ) del model result["top1_test"] = best_top1 return result