Source code for archai.supergraph.nas.search_combinations

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import math
import os
from typing import Iterator, Optional, Tuple

import yaml
from overrides import overrides

from archai.common import utils
from archai.common.config import Config
from archai.common.ordered_dict_logger import get_global_logger
from archai.supergraph.nas import nas_utils
from archai.supergraph.nas.arch_trainer import TArchTrainer
from archai.supergraph.nas.finalizers import Finalizers
from archai.supergraph.nas.model_desc_builder import ModelDescBuilder
from archai.supergraph.nas.searcher import ModelMetrics, Searcher, SearchResult
from archai.supergraph.utils.metrics import Metrics

logger = get_global_logger()


[docs]class SearchCombinations(Searcher):
[docs] @overrides def search(self, conf_search:Config, model_desc_builder:ModelDescBuilder, trainer_class:TArchTrainer, finalizers:Finalizers)->SearchResult: # region config vars conf_model_desc = conf_search['model_desc'] conf_post_train = conf_search['post_train'] conf_checkpoint = conf_search['checkpoint'] resume = conf_search['resume'] # endregion self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume) macro_combinations = list(self.get_combinations(conf_search)) start_macro_i, best_search_result = self.restore_checkpoint(conf_search, macro_combinations) best_macro_comb = -1,-1,-1 # reductions, cells, nodes for macro_comb_i in range(start_macro_i, len(macro_combinations)): reductions, cells, nodes = macro_combinations[macro_comb_i] logger.pushd(f'r{reductions}.c{cells}.n{nodes}') # build model description that we will search on model_desc = self.build_model_desc(model_desc_builder, conf_model_desc, reductions, cells, nodes) # perform search on model description model_desc, search_metrics = self.search_model_desc(conf_search, model_desc, trainer_class, finalizers) # train searched model for few epochs to get some perf metrics model_metrics = self.train_model_desc(model_desc, conf_post_train) assert model_metrics is not None, "'post_train' section in yaml should have non-zero epochs if running combinations search" # save result self.save_trained(conf_search, reductions, cells, nodes, model_metrics) # update the best result so far if self.is_better_metrics(best_search_result.search_metrics, model_metrics.metrics): best_search_result = SearchResult(model_desc, search_metrics, model_metrics.metrics) best_macro_comb = reductions, cells, nodes # checkpoint assert best_search_result is not None self.record_checkpoint(macro_comb_i, best_search_result) logger.popd() # reductions, cells, nodes assert best_search_result is not None self.clean_log_result(conf_search, best_search_result) logger.info({'best_macro_comb':best_macro_comb}) return best_search_result
[docs] def is_better_metrics(self, metrics1:Optional[Metrics], metrics2:Optional[Metrics])->bool: if metrics1 is None or metrics2 is None: return True return metrics2.best_val_top1() >= metrics1.best_val_top1()
[docs] def restore_checkpoint(self, conf_search:Config, macro_combinations)\ ->Tuple[int, Optional[SearchResult]]: conf_pareto = conf_search['pareto'] pareto_summary_filename = conf_pareto['summary_filename'] summary_filepath = utils.full_path(pareto_summary_filename) # if checkpoint is available then restart from last combination we were running checkpoint_avail = self._checkpoint is not None resumed, state = False, None start_macro_i, best_result = 0, None if checkpoint_avail: state = self._checkpoint.get('search', None) if state is not None: start_macro_i = state['start_macro_i'] assert start_macro_i >= 0 and start_macro_i < len(macro_combinations) best_result = yaml.load(state['best_result'], Loader=yaml.Loader) start_macro_i += 1 # resume after the last checkpoint resumed = True if not resumed: # erase previous file left over from run utils.zero_file(summary_filepath) logger.warn({'resumed': resumed, 'checkpoint_avail': checkpoint_avail, 'checkpoint_val': state is not None, 'start_macro_i': start_macro_i, 'total_macro_combinations': len(macro_combinations)}) return start_macro_i, best_result
[docs] def record_checkpoint(self, macro_comb_i:int, best_result:SearchResult)->None: if self._checkpoint is not None: state = {'start_macro_i': macro_comb_i, 'best_result': yaml.dump(best_result)} self._checkpoint.new() self._checkpoint['search'] = state self._checkpoint.commit()
[docs] def get_combinations(self, conf_search:Config)->Iterator[Tuple[int, int, int]]: conf_pareto = conf_search['pareto'] conf_model_desc = conf_search['model_desc'] min_cells = conf_model_desc['n_cells'] min_reductions = conf_model_desc['n_reductions'] min_nodes = conf_model_desc['cell']['n_nodes'] max_cells = conf_pareto['max_cells'] max_reductions = conf_pareto['max_reductions'] max_nodes = conf_pareto['max_nodes'] logger.info({'min_reductions': min_reductions, 'min_cells': min_cells, 'min_nodes': min_nodes, 'max_reductions': max_reductions, 'max_cells': max_cells, 'max_nodes': max_nodes }) # TODO: what happens when reductions is 3 but cells is 2? have to step # through code and check for reductions in range(min_reductions, max_reductions+1): for cells in range(min_cells, max_cells+1): for nodes in range(min_nodes, max_nodes+1): yield reductions, cells, nodes
[docs] def save_trained(self, conf_search:Config, reductions:int, cells:int, nodes:int, model_metrics:ModelMetrics)->None: """Save the model and metric info into a log file""" metrics_dir = conf_search['metrics_dir'] # construct path where we will save subdir = utils.full_path(metrics_dir.format(**vars()), create=True) model_stats = nas_utils.get_model_stats(model_metrics.model) # save model_stats in its own file model_stats_filepath = os.path.join(subdir, 'model_stats.yaml') if model_stats_filepath: with open(model_stats_filepath, 'w') as f: yaml.dump(model_stats, f) # save just metrics separately for convinience metrics_filepath = os.path.join(subdir, 'metrics.yaml') if metrics_filepath: with open(metrics_filepath, 'w') as f: yaml.dump(model_stats.metrics, f) logger.info({'model_stats_filepath': model_stats_filepath, 'metrics_filepath': metrics_filepath}) # append key info in root pareto data if self._summary_filepath: train_top1 = val_top1 = train_epoch = val_epoch = math.nan # extract metrics if model_metrics.metrics: best_metrics = model_metrics.metrics.run_metrics.best_epoch() train_top1 = best_metrics[0].top1.avg train_epoch = best_metrics[0].index if best_metrics[1]: val_top1 = best_metrics[1].top1.avg if len(best_metrics)>1 else math.nan val_epoch = best_metrics[1].index if len(best_metrics)>1 else math.nan # extract model stats flops = model_stats.Flops parameters = model_stats.parameters inference_memory = model_stats.inference_memory inference_duration = model_stats.duration utils.append_csv_file(self._summary_filepath, [ ('reductions', reductions), ('cells', cells), ('nodes', nodes), ('train_top1', train_top1), ('train_epoch', train_epoch), ('val_top1', val_top1), ('val_epoch', val_epoch), ('flops', flops), ('params', parameters), ('inference_memory', inference_memory), ('inference_duration', inference_duration) ])