Source code for archai.supergraph.utils.augmented_searcher

import copy
import json
import os
import time
from collections import OrderedDict
from typing import Optional

import gorilla
import numpy as np
import ray
import torch
from hyperopt import hp
from ray.tune import register_trainable, run_experiments
from ray.tune.suggest import HyperOptSearch
from ray.tune.trial import Trial
from tqdm import tqdm

from archai.common.common import expdir_abspath
from archai.common.config import Config
from archai.common.ordered_dict_logger import get_global_logger
from archai.common.stopwatch import StopWatch
from archai.supergraph.datasets.augmentation import (
    augment_list,
    policy_decoder,
    remove_deplicates,
)
from archai.supergraph.datasets.data import get_dataloaders
from archai.supergraph.models import get_model, num_class
from archai.supergraph.utils.augmented_trainer import train_and_eval
from archai.supergraph.utils.metrics import Accumulator

logger = get_global_logger()


# this method is overriden version of ray.tune.trial_runner.TrialRunner.step using monkey patching
def _step_w_log(self):
    original = gorilla.get_original_attribute(ray.tune.trial_runner.TrialRunner, "step")

    # collect counts by status for all trials
    cnts = OrderedDict()
    for status in [Trial.RUNNING, Trial.TERMINATED, Trial.PENDING, Trial.PAUSED, Trial.ERROR]:
        cnt = len(list(filter(lambda x: x.status == status, self._trials)))
        cnts[status] = cnt

    # get the best top1 accuracy from all finished trials so far
    best_top1_acc = 0.0
    for trial in filter(lambda x: x.status == Trial.TERMINATED, self._trials):
        if not trial.last_result:  # TODO: why would this happen?
            continue
        best_top1_acc = max(best_top1_acc, trial.last_result["top1_valid"])

    # display best accuracy from all finished trial
    logger.info("iter", self._iteration, "top1_acc=%.3f" % best_top1_acc, cnts, end="\r")

    # call original step method
    return original(self)


# override ray.tune.trial_runner.TrialRunner.step method so we can print best accuracy at each step
patch = gorilla.Patch(ray.tune.trial_runner.TrialRunner, "step", _step_w_log, settings=gorilla.Settings(allow_hit=True))
gorilla.apply(patch)


@ray.remote(num_gpus=torch.cuda.device_count(), max_calls=1)
def _train_model(conf, dataroot, augment, val_ratio, val_fold, save_path=None, only_eval=False):
    Config.set_inst(conf)
    conf["autoaug"]["loader"]["aug"] = augment
    model_type = conf["autoaug"]["model"]["type"]

    result = train_and_eval(conf, val_ratio=val_ratio, val_fold=val_fold, save_path=save_path, only_eval=only_eval)
    return model_type, val_fold, result


def _get_model_filepath(dataset, model, tag) -> Optional[str]:
    filename = "%s_%s_%s.model" % (dataset, model, tag)
    return expdir_abspath(filename)


def _train_no_aug(conf):
    sw = StopWatch.get()

    # region conf vars
    conf_dataset = conf["dataset"]
    dataroot = conf["dataroot"]
    conf_loader = conf["autoaug"]["loader"]
    conf_model = conf["autoaug"]["model"]
    model_type = conf_model["type"]
    ds_name = conf_dataset["name"]
    aug = conf_loader["aug"]
    val_ratio = conf_loader["val_ratio"]
    epochs = conf_loader["epochs"]
    cv_num = conf_loader["cv_num"]
    # endregion

    logger.info("----- Train without Augmentations cv=%d ratio(test)=%.1f -----" % (cv_num, val_ratio))
    sw.start(tag="train_no_aug")

    # for each fold, we will save model
    save_paths = [_get_model_filepath(ds_name, model_type, "ratio%.1f_fold%d" % (val_ratio, i)) for i in range(cv_num)]

    # Train model for each fold, save model in specified path, put result
    # in reqs list. These models are trained with aug specified in config.
    # TODO: configuration will be changed ('aug' key),
    #   but do we really need deepcopy everywhere?
    reqs = [
        # TODO: eliminate need for deep copy as only aug key is changed
        _train_model.remote(
            copy.deepcopy(copy.deepcopy(conf)), dataroot, aug, val_ratio, i, save_path=save_paths[i], only_eval=True
        )
        for i in range(cv_num)
    ]

    # we now probe saved models for each fold to check the epoch number
    # they are on. When every fold crosses an epoch number, we update
    # the progress.
    tqdm_epoch = tqdm(range(epochs))
    is_done = False
    for epoch in tqdm_epoch:
        while True:
            epochs_per_cv = OrderedDict()
            for cv_idx in range(cv_num):
                try:
                    if os.path.exists(save_paths[cv_idx]):
                        latest_ckpt = torch.load(save_paths[cv_idx])
                        if "epoch" not in latest_ckpt:
                            epochs_per_cv["cv%d" % (cv_idx + 1)] = epochs
                            continue
                    else:
                        continue
                    epochs_per_cv["cv%d" % (cv_idx + 1)] = latest_ckpt["epoch"]
                except Exception:
                    continue
            tqdm_epoch.set_postfix(epochs_per_cv)
            if len(epochs_per_cv) == cv_num and min(epochs_per_cv.values()) >= epochs:
                is_done = True
            if len(epochs_per_cv) == cv_num and min(epochs_per_cv.values()) >= epoch:
                break
            time.sleep(10)
        if is_done:
            break

    logger.info("getting results...")
    pretrain_results = ray.get(reqs)
    for r_model, r_cv, r_dict in pretrain_results:
        logger.info(
            "model=%s cv=%d top1_train=%.4f top1_valid=%.4f"
            % (r_model, r_cv + 1, r_dict["top1_train"], r_dict["top1_valid"])
        )
    logger.info("processed in %.4f secs" % sw.pause("train_no_aug"))





def _eval_tta(conf, augment, reporter):
    Config.set_inst(conf)

    # region conf vars
    conf_dataset = conf["dataset"]
    conf_loader = conf["autoaug"]["loader"]
    conf_model = conf["autoaug"]["model"]
    ds_name = conf_dataset["name"]
    cutout = conf_loader["cutout"]
    n_workers = conf_loader["n_workers"]
    # endregion

    val_ratio, val_fold, save_path = augment["val_ratio"], augment["val_fold"], augment["save_path"]

    # setup - provided augmentation rules
    aug = policy_decoder(augment, augment["num_policy"], augment["num_op"])

    # eval
    model = get_model(conf_model, num_class(ds_name))
    ckpt = torch.load(save_path)
    if "model" in ckpt:
        model.load_state_dict(ckpt["model"])
    else:
        model.load_state_dict(ckpt)
    model.eval()

    loaders = []
    for _ in range(augment["num_policy"]):
        tl, validloader, tl2 = get_dataloaders(
            augment["dataroot"],
            ds_name,
            aug,
            cutout,
            load_train=True,
            load_test=True,
            val_ratio=val_ratio,
            val_fold=val_fold,
            n_workers=n_workers,
        )
        loaders.append(iter(validloader))
        del tl, tl2  # TODO: why exclude validloader?

    start_t = time.time()
    metrics = Accumulator()
    loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
    try:
        while True:
            losses = []
            corrects = []
            for loader in loaders:
                data, label = next(loader)
                data, label = data.cuda(), label.cuda()

                pred = model(data)

                loss = loss_fn(pred, label)
                losses.append(loss.detach().cpu().numpy())

                _, pred = pred.topk(1, 1, True, True)
                pred = pred.t()
                correct = pred.eq(label.view(1, -1).expand_as(pred)).detach().cpu().numpy()
                corrects.append(correct)
                del loss, correct, pred, data, label

            losses = np.concatenate(losses)
            losses_min = np.min(losses, axis=0).squeeze()

            corrects = np.concatenate(corrects)
            corrects_max = np.max(corrects, axis=0).squeeze()
            metrics.add_dict(
                {"minus_loss": -1 * np.sum(losses_min), "correct": np.sum(corrects_max), "cnt": len(corrects_max)}
            )
            del corrects, corrects_max
    except StopIteration:
        pass

    del model
    metrics = metrics / "cnt"
    gpu_secs = (time.time() - start_t) * torch.cuda.device_count()
    reporter(minus_loss=metrics["minus_loss"], top1_valid=metrics["correct"], elapsed_time=gpu_secs, done=True)
    return metrics["correct"]