# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
from typing import List
# latest verion of ray works on Windows as well
import ray
from overrides import overrides
from archai.common import common
from archai.common.common import CommonState
from archai.common.config import Config
from archai.common.ordered_dict_logger import get_global_logger
from archai.supergraph.algos.petridish.petridish_utils import (
ConvexHullPoint,
JobStage,
plot_frontier,
plot_pool,
plot_seed_model_stats,
sample_from_hull,
save_hull,
save_hull_frontier,
)
from archai.supergraph.nas import nas_utils
from archai.supergraph.nas.arch_trainer import TArchTrainer
from archai.supergraph.nas.finalizers import Finalizers
from archai.supergraph.nas.model import Model
from archai.supergraph.nas.model_desc import (
CellType,
EdgeDesc,
ModelDesc,
NodeDesc,
OpDesc,
)
from archai.supergraph.nas.model_desc_builder import ModelDescBuilder
from archai.supergraph.nas.search_combinations import SearchCombinations
from archai.supergraph.nas.searcher import SearchResult
logger = get_global_logger()
[docs]class SearcherPetridish(SearchCombinations):
[docs] @overrides
def search(self, conf_search:Config, model_desc_builder:ModelDescBuilder,
trainer_class:TArchTrainer, finalizers:Finalizers)->SearchResult:
logger.pushd('search')
# region config vars
self.conf_search = conf_search
conf_checkpoint = conf_search['checkpoint']
resume = conf_search['resume']
conf_post_train = conf_search['post_train']
final_desc_foldername = conf_search['final_desc_foldername']
conf_petridish = conf_search['petridish']
# petridish distributed search related parameters
self._convex_hull_eps = conf_petridish['convex_hull_eps']
self._max_madd = conf_petridish['max_madd']
self._max_hull_points = conf_petridish['max_hull_points']
self._checkpoints_foldername = conf_petridish['checkpoints_foldername']
# endregion
self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)
# parent models list
self._hull_points: List[ConvexHullPoint] = []
self._ensure_dataset_download(conf_search)
# checkpoint will restore the hull we had
is_restored = self._restore_checkpoint()
# seed the pool with many models of different
# macro parameters like number of cells, reductions etc if parent pool
# could not be restored and/or this is the first time this job has been run.
future_ids = [] if is_restored else self._create_seed_jobs(conf_search,
model_desc_builder)
while not self._is_search_done():
logger.info(f'Ray jobs running: {len(future_ids)}')
if future_ids:
# get first completed job
job_id_done, future_ids = ray.wait(future_ids)
hull_point = ray.get(job_id_done[0])
logger.info(f'Hull point id {hull_point.id} with stage {hull_point.job_stage.name} completed')
if hull_point.is_trained_stage():
self._update_convex_hull(hull_point)
# sample a point and search
sampled_point = sample_from_hull(self._hull_points,
self._convex_hull_eps)
future_id = SearcherPetridish.search_model_desc_dist.remote(self,
conf_search, sampled_point, model_desc_builder, trainer_class,
finalizers, common.get_state())
future_ids.append(future_id)
logger.info(f'Added sampled point {sampled_point.id} for search')
elif hull_point.job_stage==JobStage.SEARCH:
# create the job to train the searched model
future_id = SearcherPetridish.train_model_desc_dist.remote(self,
conf_post_train, hull_point, common.get_state())
future_ids.append(future_id)
logger.info(f'Added sampled point {hull_point.id} for post-search training')
else:
raise RuntimeError(f'Job stage "{hull_point.job_stage}" is not expected in search loop')
# cancel any remaining jobs to free up gpus for the eval phase
for future_id in future_ids:
ray.cancel(future_id, force=True) # without force, main process stops
ray.wait([future_id])
# plot and save the hull
expdir = common.get_expdir()
assert expdir
plot_frontier(self._hull_points, self._convex_hull_eps, expdir)
best_point = save_hull_frontier(self._hull_points, self._convex_hull_eps,
final_desc_foldername, expdir)
save_hull(self._hull_points, expdir)
plot_pool(self._hull_points,expdir )
# return best point as search result
search_result = SearchResult(best_point.model_desc, search_metrics=None,
train_metrics=best_point.metrics)
self.clean_log_result(conf_search, search_result)
logger.popd()
return search_result
@staticmethod
@ray.remote(num_gpus=1)
def search_model_desc_dist(searcher:'SearcherPetridish', conf_search:Config,
hull_point:ConvexHullPoint, model_desc_builder:ModelDescBuilder,
trainer_class:TArchTrainer, finalizers:Finalizers, common_state:CommonState)\
->ConvexHullPoint:
# as this runs in different process, initialize globals
common.init_from(common_state)
#register ops as we are in different process now
conf_model_desc = conf_search['model_desc']
model_desc_builder.pre_build(conf_model_desc)
assert hull_point.is_trained_stage()
# cloning is strictly not needed but just in case if we run this
# function in same process, it would be good to avoid surprise
model_desc = hull_point.model_desc.clone()
searcher._add_node(model_desc, model_desc_builder)
model_desc, search_metrics = searcher.search_model_desc(conf_search,
model_desc, trainer_class, finalizers)
cells, reductions, nodes = hull_point.cells_reductions_nodes
new_point = ConvexHullPoint(JobStage.SEARCH, hull_point.id,
hull_point.sampling_count, model_desc,
(cells, reductions, nodes+1), # we added a node
metrics=search_metrics)
return new_point
@staticmethod
@ray.remote(num_gpus=1)
def train_model_desc_dist(searcher:'SearcherPetridish', conf_train:Config,
hull_point:ConvexHullPoint, common_state:CommonState)\
->ConvexHullPoint:
# as this runs in different process, initialize globals
common.init_from(common_state)
assert not hull_point.is_trained_stage()
model_metrics = searcher.train_model_desc(hull_point.model_desc, conf_train)
model_stats = nas_utils.get_model_stats(model_metrics.model)
new_point = ConvexHullPoint(hull_point.next_stage(), hull_point.id, hull_point.
sampling_count, hull_point.model_desc,
hull_point.cells_reductions_nodes,
model_metrics.metrics,
model_stats)
return new_point
def _add_node(self, model_desc:ModelDesc, model_desc_builder:ModelDescBuilder)->None:
for ci, cell_desc in enumerate(model_desc.cell_descs()):
reduction = (cell_desc.cell_type==CellType.Reduction)
nodes = cell_desc.nodes()
# petridish must seed with one node
assert len(nodes) > 0
# input/output channels for all nodes are same
conv_params = nodes[0].conv_params
# assign input IDs to nodes, s0 and s1 have IDs 0 and 1
# however as we will be inserting new node before last one
input_ids = list(range(len(nodes) + 1))
assert len(input_ids) >= 2 # 2 stem inputs
op_desc = OpDesc('petridish_reduction_op' if reduction else 'petridish_normal_op',
params={
'conv': conv_params,
# specify strides for each input, later we will
# give this to each primitive
'_strides':[2 if reduction and j < 2 else 1 \
for j in input_ids],
}, in_len=len(input_ids), trainables=None, children=None)
edge = EdgeDesc(op_desc, input_ids=input_ids)
new_node = NodeDesc(edges=[edge], conv_params=conv_params)
nodes.insert(len(nodes)-1, new_node)
# output shape of all nodes are same
node_shapes = cell_desc.node_shapes
new_node_shape = copy.deepcopy(node_shapes[-1])
node_shapes.insert(len(node_shapes)-1, new_node_shape)
# post op needs rebuilding because number of inputs to it has changed so input/output channels may be different
post_op_shape, post_op_desc = model_desc_builder.build_cell_post_op(cell_desc.stem_shapes,
node_shapes, cell_desc.conf_cell, ci)
cell_desc.reset_nodes(nodes, node_shapes,
post_op_desc, post_op_shape)
def _ensure_dataset_download(self, conf_search:Config)->None:
conf_loader = conf_search['loader']
self.get_data(conf_loader)
def _is_search_done(self)->bool:
'''Terminate search if max MAdd or number of points exceeded'''
if not self._hull_points:
return False
max_madd_parent = max(self._hull_points, key=lambda p:p.model_stats.MAdd)
return max_madd_parent.model_stats.MAdd > self._max_madd or \
len(self._hull_points) > self._max_hull_points
def _create_seed_jobs(self, conf_search:Config, model_desc_builder:ModelDescBuilder)->list:
conf_model_desc = conf_search['model_desc']
conf_seed_train = conf_search['seed_train']
future_ids = [] # ray job IDs
seed_model_stats = [] # seed model stats for visualization and debugging
macro_combinations = list(self.get_combinations(conf_search))
for reductions, cells, nodes in macro_combinations:
# if N R N R N R cannot be satisfied, ignore combination
if cells < reductions * 2 + 1:
continue
# create seed model
model_desc = self.build_model_desc(model_desc_builder,
conf_model_desc,
reductions, cells, nodes)
hull_point = ConvexHullPoint(JobStage.SEED, 0, 0, model_desc,
(cells, reductions, nodes))
# pre-train the seed model
future_id = SearcherPetridish.train_model_desc_dist.remote(self,
conf_seed_train, hull_point, common.get_state())
future_ids.append(future_id)
# build a model so we can get its model stats
temp_model = Model(model_desc, droppath=True, affine=True)
seed_model_stats.append(nas_utils.get_model_stats(temp_model))
# save the model stats in a plot and tsv file so we can
# visualize the spread on the x-axis
expdir = common.get_expdir()
assert expdir
plot_seed_model_stats(seed_model_stats, expdir)
return future_ids
def _update_convex_hull(self, new_point:ConvexHullPoint)->None:
assert new_point.is_trained_stage() # only add models for which we have metrics and stats
self._hull_points.append(new_point)
if self._checkpoint is not None:
self._checkpoint.new()
self._checkpoint['convex_hull_points'] = self._hull_points
self._checkpoint.commit()
logger.info(f'Added to convex hull points: MAdd {new_point.model_stats.MAdd}, '
f'num cells {len(new_point.model_desc.cell_descs())}, '
f'num nodes in cell {len(new_point.model_desc.cell_descs()[0].nodes())}')
def _restore_checkpoint(self)->bool:
can_restore = self._checkpoint is not None \
and 'convex_hull_points' in self._checkpoint
if can_restore:
self._hull_points = self._checkpoint['convex_hull_points']
logger.warn({'Hull restored': True})
return can_restore
[docs] @overrides
def build_model_desc(self, model_desc_builder:ModelDescBuilder,
conf_model_desc:Config,
reductions:int, cells:int, nodes:int)->ModelDesc:
# reset macro params in copy of config
conf_model_desc = copy.deepcopy(conf_model_desc)
conf_model_desc['n_reductions'] = reductions
conf_model_desc['n_cells'] = cells
conf_model_desc['cell']['n_nodes'] = nodes
# create model desc for search using model config
# we will build model without call to model_desc_builder for pre-training
model_desc = model_desc_builder.build(conf_model_desc, template=None)
return model_desc