Source code for archai.supergraph.algos.petridish.petridish_exp_runner
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import shutil
from overrides import overrides
from archai.common import utils
from archai.supergraph.algos.petridish.evaluater_petridish import EvaluaterPetridish
from archai.supergraph.algos.petridish.petridish_model_desc_builder import (
PetridishModelBuilder,
)
from archai.supergraph.algos.petridish.searcher_petridish import SearcherPetridish
from archai.supergraph.nas.arch_trainer import ArchTrainer, TArchTrainer
from archai.supergraph.nas.exp_runner import ExperimentRunner
[docs]def copy_dir(src_dir:str, dest_dir:str, use_shutil:bool=True)->None:
if os.path.isdir(src_dir):
if use_shutil:
shutil.copytree(src_dir, dest_dir)
else:
if not os.path.isdir(dest_dir):
os.makedirs(dest_dir)
files = os.listdir(src_dir)
for f in files:
copy_dir(os.path.join(src_dir, f),
os.path.join(dest_dir, f), use_shutil=use_shutil)
else:
utils.copy_file(src_dir, dest_dir, use_shutil=use_shutil)
[docs]class PetridishExperimentRunner(ExperimentRunner):
[docs] @overrides
def model_desc_builder(self)->PetridishModelBuilder:
return PetridishModelBuilder()
[docs] @overrides
def trainer_class(self)->TArchTrainer:
return ArchTrainer
[docs] @overrides
def searcher(self)->SearcherPetridish:
return SearcherPetridish()
[docs] @overrides
def evaluater(self)->EvaluaterPetridish:
return EvaluaterPetridish()
[docs] @overrides
def copy_search_to_eval(self)->None:
# get folder of model gallery that search has produced
conf_search = self.get_conf(True)['nas']['search']
search_desc_foldername = conf_search['final_desc_foldername']
search_desc_folderpath = utils.full_path(search_desc_foldername)
assert search_desc_foldername and os.path.exists(search_desc_folderpath)
# get folder path that eval would need
conf_eval = self.get_conf(False)['nas']['eval']
eval_desc_foldername = conf_eval['final_desc_foldername']
eval_desc_folderpath = utils.full_path(eval_desc_foldername)
assert eval_desc_folderpath
# only later version of shutil copytree has dirs_exists_ok option
# so being robust to pre-existing directory
if os.path.exists(eval_desc_folderpath):
shutil.rmtree(eval_desc_folderpath)
utils.copy_dir(search_desc_folderpath, eval_desc_folderpath, use_shutil=True)