Source code for archai.supergraph.nas.evaluater

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Dict, Optional

from overrides import EnforceOverrides
from torch import nn

from archai.common import ml_utils, utils
from archai.common.config import Config
from archai.common.ordered_dict_logger import get_global_logger
from archai.supergraph.datasets import data
from archai.supergraph.nas import nas_utils
from archai.supergraph.nas.model import Model
from archai.supergraph.nas.model_desc import ModelDesc
from archai.supergraph.nas.model_desc_builder import ModelDescBuilder
from archai.supergraph.utils.checkpoint import CheckPoint
from archai.supergraph.utils.metrics import Metrics
from archai.supergraph.utils.trainer import Trainer

logger = get_global_logger()


[docs]class EvalResult: def __init__(self, train_metrics:Metrics) -> None: self.train_metrics = train_metrics
[docs]class Evaluater(EnforceOverrides):
[docs] def evaluate(self, conf_eval:Config, model_desc_builder:ModelDescBuilder)->EvalResult: logger.pushd('eval_arch') # region conf vars conf_checkpoint = conf_eval['checkpoint'] resume = conf_eval['resume'] model_filename = conf_eval['model_filename'] metric_filename = conf_eval['metric_filename'] # endregion model = self.create_model(conf_eval, model_desc_builder) checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume) train_metrics = self.train_model(conf_eval, model, checkpoint) train_metrics.save(metric_filename) # save model if model_filename: model_filename = utils.full_path(model_filename) ml_utils.save_model(model, model_filename) logger.info({'model_save_path': model_filename}) logger.popd() return EvalResult(train_metrics)
[docs] def train_model(self, conf_train:Config, model:nn.Module, checkpoint:Optional[CheckPoint])->Metrics: conf_loader = conf_train['loader'] conf_train = conf_train['trainer'] # get data data_loaders = self.get_data(conf_loader) trainer = Trainer(conf_train, model, checkpoint) train_metrics = trainer.fit(data_loaders) return train_metrics
[docs] def get_data(self, conf_loader:Config)->data.DataLoaders: # this dict caches the dataset objects per dataset config so we don't have to reload # the reason we do dynamic attribute is so that any dependent methods # can do ray.remote if not hasattr(self, '_data_cache'): self._data_cache:Dict[int, data.DataLoaders] = {} # first get from cache if id(conf_loader) in self._data_cache: data_loaders = self._data_cache[id(conf_loader)] else: data_loaders = data.get_data(conf_loader) self._data_cache[id(conf_loader)] = data_loaders return data_loaders
def _default_module_name(self, dataset_name:str, function_name:str)->str: """Select PyTorch pre-defined network to support manual mode""" module_name = '' # TODO: below detection code is too week, need to improve, possibly encode image size in yaml and use that instead if dataset_name.startswith('cifar'): if function_name.startswith('res'): # support resnext as well module_name = 'archai.supergraph.models.resnet' elif function_name.startswith('dense'): module_name = 'archai.supergraph.models.densenet' elif dataset_name.startswith('imagenet') or dataset_name.startswith('sport8'): module_name = 'torchvision.models' if not module_name: raise NotImplementedError(f'Cannot get default module for {function_name} and dataset {dataset_name} because it is not supported yet') return module_name
[docs] def create_model(self, conf_eval:Config, model_desc_builder:ModelDescBuilder, final_desc_filename=None, full_desc_filename=None)->nn.Module: assert model_desc_builder is not None, 'Default evaluater requires model_desc_builder' # region conf vars # if explicitly passed in then don't get from conf if not final_desc_filename: final_desc_filename = conf_eval['final_desc_filename'] full_desc_filename = conf_eval['full_desc_filename'] conf_model_desc = conf_eval['model_desc'] # endregion # load model desc file to get template model template_model_desc = ModelDesc.load(final_desc_filename) model_desc = model_desc_builder.build(conf_model_desc, template=template_model_desc) # save desc for reference model_desc.save(full_desc_filename) model = self.model_from_desc(model_desc) logger.info({'model_factory':False, 'cells_len':len(model.desc.cell_descs()), 'init_node_ch': conf_model_desc['model_stems']['init_node_ch'], 'n_cells': conf_model_desc['n_cells'], 'n_reductions': conf_model_desc['n_reductions'], 'n_nodes': conf_model_desc['cell']['n_nodes']}) return model
[docs] def model_from_desc(self, model_desc)->Model: return Model(model_desc, droppath=True, affine=True)