Source code for archai.supergraph.nas.searcher

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import copy
from typing import Dict, Optional, Tuple

from overrides import EnforceOverrides

from archai.common.config import Config
from archai.common.ordered_dict_logger import get_global_logger
from archai.supergraph.datasets import data
from archai.supergraph.nas.arch_trainer import TArchTrainer
from archai.supergraph.nas.finalizers import Finalizers
from archai.supergraph.nas.model import Model
from archai.supergraph.nas.model_desc import ModelDesc
from archai.supergraph.nas.model_desc_builder import ModelDescBuilder
from archai.supergraph.utils.metrics import Metrics
from archai.supergraph.utils.trainer import Trainer

logger = get_global_logger()


[docs]class ModelMetrics: def __init__(self, model:Model, metrics:Metrics) -> None: self.model = model self.metrics = metrics
[docs]class SearchResult: def __init__(self, model_desc:Optional[ModelDesc], search_metrics:Optional[Metrics], train_metrics:Optional[Metrics]) -> None: self.model_desc = model_desc self.search_metrics = search_metrics self.train_metrics = train_metrics
[docs]class Searcher(EnforceOverrides):
[docs] def search(self, conf_search:Config, model_desc_builder:Optional[ModelDescBuilder], trainer_class:TArchTrainer, finalizers:Finalizers)->SearchResult: # region config vars conf_model_desc = conf_search['model_desc'] conf_post_train = conf_search['post_train'] cells = conf_model_desc['n_cells'] reductions = conf_model_desc['n_reductions'] nodes = conf_model_desc['cell']['n_nodes'] # endregion assert model_desc_builder is not None, 'Default search implementation requires model_desc_builder' # build model description that we will search on model_desc = self.build_model_desc(model_desc_builder, conf_model_desc, reductions, cells, nodes) # perform search on model description model_desc, search_metrics = self.search_model_desc(conf_search, model_desc, trainer_class, finalizers) # train searched model for few epochs to get some perf metrics model_metrics = self.train_model_desc(model_desc, conf_post_train) search_result = SearchResult(model_desc, search_metrics, model_metrics.metrics if model_metrics is not None else None) self.clean_log_result(conf_search, search_result) return search_result
[docs] def clean_log_result(self, conf_search:Config, search_result:SearchResult)->None: final_desc_filename = conf_search['final_desc_filename'] # remove weights info deom model_desc so its more readable search_result.model_desc.clear_trainables() # if file name was specified then save the model desc if final_desc_filename: search_result.model_desc.save(final_desc_filename) if search_result.search_metrics is not None: logger.info({'search_top1_val': search_result.search_metrics.best_val_top1()}) if search_result.train_metrics is not None: logger.info({'train_top1_val': search_result.train_metrics.best_val_top1()})
[docs] def build_model_desc(self, model_desc_builder:ModelDescBuilder, conf_model_desc:Config, reductions:int, cells:int, nodes:int)->ModelDesc: # reset macro params in copy of config conf_model_desc = copy.deepcopy(conf_model_desc) conf_model_desc['n_reductions'] = reductions conf_model_desc['n_cells'] = cells # create model desc for search using model config # we will build model without call to model_desc_builder for pre-training model_desc = model_desc_builder.build(conf_model_desc, template=None) return model_desc
[docs] def get_data(self, conf_loader:Config)->data.DataLoaders: # this dict caches the dataset objects per dataset config so we don't have to reload # the reason we do dynamic attribute is so that any dependent methods # can do ray.remote if not hasattr(self, '_data_cache'): self._data_cache:Dict[int, data.DataLoaders] = {} # first get from cache if id(conf_loader) in self._data_cache: data_loaders = self._data_cache[id(conf_loader)] else: data_loaders = data.get_data(conf_loader) self._data_cache[id(conf_loader)] = data_loaders return data_loaders
[docs] def finalize_model(self, model:Model, finalizers:Finalizers)->ModelDesc: return finalizers.finalize_model(model, restore_device=False)
[docs] def search_model_desc(self, conf_search:Config, model_desc:ModelDesc, trainer_class:TArchTrainer, finalizers:Finalizers)\ ->Tuple[ModelDesc, Optional[Metrics]]: # if trainer is not specified for algos like random search we return same desc if trainer_class is None: return model_desc, None logger.pushd('arch_search') conf_trainer = conf_search['trainer'] conf_loader = conf_search['loader'] model = Model(model_desc, droppath=False, affine=False) # get data data_loaders = self.get_data(conf_loader) # search arch arch_trainer = trainer_class(conf_trainer, model, checkpoint=None) search_metrics = arch_trainer.fit(data_loaders) # finalize found_desc = self.finalize_model(model, finalizers) logger.popd() return found_desc, search_metrics
[docs] def train_model_desc(self, model_desc:ModelDesc, conf_train:Config)\ ->Optional[ModelMetrics]: """Train given description""" # region conf vars conf_trainer = conf_train['trainer'] conf_loader = conf_train['loader'] trainer_title = conf_trainer['title'] epochs = conf_trainer['epochs'] drop_path_prob = conf_trainer['drop_path_prob'] # endregion # if epochs ==0 then nothing to train, so save time if epochs <= 0: return None logger.pushd(trainer_title) model = Model(model_desc, droppath=drop_path_prob>0.0, affine=True) # get data data_loaders= self.get_data(conf_loader) trainer = Trainer(conf_trainer, model, checkpoint=None) train_metrics = trainer.fit(data_loaders) logger.popd() return ModelMetrics(model, train_metrics)