Source code for archai.common.common

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import atexit
import os
import subprocess
from typing import Optional, Tuple, Union

import yaml
from send2trash import send2trash
from torch.utils.tensorboard.writer import SummaryWriter

from archai.common import utils
from archai.common.apex_utils import ApexUtils
from archai.common.config import Config
from archai.common.ordered_dict_logger import get_global_logger

logger = get_global_logger()


[docs]class SummaryWriterDummy: def __init__(self, log_dir): pass
[docs] def add_scalar(self, *args, **kwargs): pass
[docs] def flush(self): pass
SummaryWriterAny = Union[SummaryWriterDummy, SummaryWriter] _tb_writer: Optional[SummaryWriterAny] = None _atexit_reg = False # is hook for atexit registered?
[docs]def get_conf(conf:Optional[Config]=None)->Config: if conf is not None: return conf return Config.get_inst()
[docs]def get_conf_common(conf:Optional[Config]=None)->Config: return get_conf(conf)['common']
[docs]def get_conf_dataset(conf:Optional[Config]=None)->Config: return get_conf(conf)['dataset']
[docs]def get_experiment_name(conf:Optional[Config]=None)->str: return get_conf_common(conf)['experiment_name']
[docs]def get_expdir(conf:Optional[Config]=None)->Optional[str]: return get_conf_common(conf)['expdir']
[docs]def get_datadir(conf:Optional[Config]=None)->Optional[str]: return get_conf(conf)['dataset']['dataroot']
[docs]def get_tb_writer() -> SummaryWriterAny: global _tb_writer assert _tb_writer return _tb_writer
[docs]class CommonState: def __init__(self) -> None: global _conf, _tb_writer self.conf = get_conf() self.tb_writer = _tb_writer
[docs]def on_app_exit(): print('Process exit:', os.getpid(), flush=True) writer = get_tb_writer() writer.flush()
[docs]def pt_dirs()->Tuple[str, str]: # dirs for pt infrastructure are supplied in env vars pt_data_dir = os.environ.get('PT_DATA_DIR', '') # currently yaml should be copying dataset folder to local dir # so below is not needed. The hope is that less reads from cloud # storage will reduce overall latency. # if pt_data_dir: # param_args = ['--nas.eval.loader.dataset.dataroot', pt_data_dir, # '--nas.search.loader.dataset.dataroot', pt_data_dir, # '--nas.search.seed_train.loader.dataset.dataroot', pt_data_dir, # '--nas.search.post_train.loader.dataset.dataroot', pt_data_dir, # '--autoaug.loader.dataset.dataroot', pt_data_dir] + param_args pt_output_dir = os.environ.get('PT_OUTPUT_DIR', '') return pt_data_dir, pt_output_dir
def _pt_params(param_args: list)->list: pt_data_dir, pt_output_dir = pt_dirs() if pt_output_dir: # prepend so if supplied from outside it takes back seat param_args = ['--common.logdir', pt_output_dir] + param_args return param_args
[docs]def get_state()->CommonState: return CommonState()
[docs]def init_from(state:CommonState)->None: global _tb_writer Config.set_inst(state.conf) _tb_writer = state.tb_writer
[docs]def create_conf(config_filepath: Optional[str]=None, param_args: list = [], use_args=True)->Config: # modify passed args for pt infrastructure # if pt infrastructure doesn't exit then param_overrides == param_args param_overrides = _pt_params(param_args) # create env vars that might be used in paths in config if 'default_dataroot' not in os.environ: os.environ['default_dataroot'] = default_dataroot() conf = Config(config_filepath=config_filepath, param_args=param_overrides, use_args=use_args) _update_conf(conf) return conf
# TODO: rename this simply as init # initializes random number gen, debugging etc
[docs]def common_init(config_filepath: Optional[str]=None, param_args: list = [], use_args=True, clean_expdir=False)->Config: # TODO: multiple child processes will create issues with shared state so we need to # detect multiple child processes but allow if there is only one child process. # if not utils.is_main_process(): # raise RuntimeError('common_init should not be called from child process. Please use Common.init_from()') # setup global instance conf = create_conf(config_filepath, param_args, use_args) Config.set_inst(conf) # setup env vars which might be used in paths update_envvars(conf) # create experiment dir create_dirs(conf, clean_expdir) _create_sysinfo(conf) # create apex to know distributed processing paramters conf_apex = get_conf_common(conf)['apex'] apex = ApexUtils(conf_apex) # setup tensorboard global _tb_writer _tb_writer = create_tb_writer(conf, apex.is_master()) # create hooks to execute code when script exits global _atexit_reg if not _atexit_reg: atexit.register(on_app_exit) _atexit_reg = True return conf
def _create_sysinfo(conf:Config)->None: expdir = get_expdir(conf) if expdir and not utils.is_debugging(): # copy net config to experiment folder for reference with open(expdir_abspath('config_used.yaml'), 'w') as f: yaml.dump(conf.to_dict(), f) if not utils.is_debugging(): sysinfo_filepath = expdir_abspath('sysinfo.txt') subprocess.Popen([f'./scripts/sysinfo.sh "{expdir}" > "{sysinfo_filepath}"'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
[docs]def expdir_abspath(path:str, create=False)->str: """Returns full path for given relative path within experiment directory.""" return utils.full_path(os.path.join('$expdir',path), create=create)
[docs]def create_tb_writer(conf:Config, is_master=True)-> SummaryWriterAny: conf_common = get_conf_common(conf) tb_dir, conf_enable_tb = utils.full_path(conf_common['tb_dir']), conf_common['tb_enable'] tb_enable = conf_enable_tb and is_master and tb_dir is not None and len(tb_dir) > 0 logger.info({'conf_enable_tb': conf_enable_tb, 'tb_enable': tb_enable, 'tb_dir': tb_dir}) WriterClass = SummaryWriter if tb_enable else SummaryWriterDummy return WriterClass(log_dir=tb_dir)
[docs]def is_pt()->bool: """Is this code running in pt infrastrucuture""" return os.environ.get('PT_OUTPUT_DIR', '') != ''
[docs]def default_dataroot()->str: # the home folder on ITP VMs is super slow so use local temp directory instead return '/var/tmp/dataroot' if is_pt() else '~/dataroot'
def _update_conf(conf:Config)->None: """Updates conf with full paths resolving enviromental vars""" conf_common = get_conf_common(conf) conf_dataset = get_conf_dataset(conf) experiment_name = conf_common['experiment_name'] # make sure dataroot exists dataroot = conf_dataset['dataroot'] dataroot = utils.full_path(dataroot) # make sure logdir and expdir exists logdir = conf_common['logdir'] if logdir: logdir = utils.full_path(logdir) expdir = os.path.join(logdir, experiment_name) # directory for non-master replica logs distdir = os.path.join(expdir, 'dist') else: expdir = distdir = logdir # update conf so everyone gets expanded full paths from here on # set environment variable so it can be referenced in paths used in config conf_common['logdir'] = logdir conf_dataset['dataroot'] = dataroot conf_common['expdir'] = expdir conf_common['distdir'] = distdir
[docs]def update_envvars(conf)->None: """Get values from config and put it into env vars""" conf_common = get_conf_common(conf) logdir = conf_common['logdir'] expdir = conf_common['expdir'] distdir = conf_common['distdir'] conf_dataset = get_conf_dataset(conf) dataroot = conf_dataset['dataroot'] # update conf so everyone gets expanded full paths from here on # set environment variable so it can be referenced in paths used in config os.environ['logdir'] = logdir os.environ['dataroot'] = dataroot os.environ['expdir'] = expdir os.environ['distdir'] = distdir
[docs]def clean_ensure_expdir(conf:Optional[Config], clean_dir:bool, ensure_dir:bool)->None: expdir = get_expdir(conf) assert expdir if clean_dir and os.path.exists(expdir): send2trash(expdir) if ensure_dir: os.makedirs(expdir, exist_ok=True)
[docs]def create_dirs(conf:Config, clean_expdir:bool)->Optional[str]: conf_common = get_conf_common(conf) logdir = conf_common['logdir'] expdir = conf_common['expdir'] distdir = conf_common['distdir'] conf_dataset = get_conf_dataset(conf) dataroot = utils.full_path(conf_dataset['dataroot']) # make sure dataroot exists os.makedirs(dataroot, exist_ok=True) # make sure logdir and expdir exists if logdir: clean_ensure_expdir(conf, clean_dir=clean_expdir, ensure_dir=True) os.makedirs(distdir, exist_ok=True) else: raise RuntimeError('The logdir setting must be specified for the output directory in yaml') # get cloud dirs if any pt_data_dir, pt_output_dir = pt_dirs() # validate dirs assert not pt_output_dir or not expdir.startswith(utils.full_path('~/logdir')) logger.info({'expdir': expdir, # create info file for current system 'PT_DATA_DIR': pt_data_dir, 'PT_OUTPUT_DIR': pt_output_dir})