Source code for archai.common.config

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import argparse
from typing import Callable, List, Optional, Any
from collections import UserDict
from typing import Sequence
from collections.abc import Mapping, MutableMapping
import os
from distutils.util import strtobool
import copy

import yaml

from archai.common import yaml_utils

# global config instance
_config: Optional['Config'] = None


# TODO: remove this duplicate code which is also in utils.py without circular deps
[docs]def deep_update(d: MutableMapping, u: Mapping, create_map: Callable[[], MutableMapping])\ -> MutableMapping: for k, v in u.items(): if isinstance(v, Mapping): d[k] = deep_update(d.get(k, create_map()), v, create_map) else: d[k] = v return d
[docs]class Config(UserDict): def __init__(self, config_filepath: Optional[str] = None, app_desc: Optional[str] = None, use_args=False, param_args: Sequence = [], resolve_redirects=True, resolve_env_vars=False) -> None: """Create config from specified files and args Config is simply a dictionary of key, value map. The value can itself be a dictionary so config can be hierarchical. This class allows to load config from yaml. A special key '__include__' can specify another yaml relative file path (or list of file paths) which will be loaded first and the key-value pairs in the main file will override the ones in include file. You can think of included file as defaults provider. This allows to create one base config and then several environment/experiment specific configs. On the top of that you can use param_args to perform final overrides for a given run. You can also have values that reference environment variables using ${ENV_VAR_NAME} syntax. Keyword Arguments: config_filepath {[str]} -- [Yaml file to load config from, could be names of files separated by semicolon which will be loaded in sequence overriding previous config] (default: {None}) app_desc {[str]} -- [app description that will show up in --help] (default: {None}) use_args {bool} -- [if true then command line parameters will override parameters from config files] (default: {False}) param_args {Sequence} -- [parameters specified as ['--key1',val1,'--key2',val2,...] which will override parameters from config file.] (default: {[]}) resolve_redirects -- [if True then _copy commands in yaml are executed] """ super(Config, self).__init__() self.args, self.extra_args = None, [] if use_args: # let command line args specify/override config file parser = argparse.ArgumentParser(description=app_desc) parser.add_argument('--config', type=str, default=None, help='config filepath in yaml format, can be list separated by ;') self.args, self.extra_args = parser.parse_known_args() config_filepath = self.args.config or config_filepath if config_filepath: for filepath in config_filepath.strip().split(';'): self._load_from_file(filepath.strip()) # Create a copy of ourselves and do the resolution over it. # This resolved_conf then can be used to search for overrides that # wouldn't have existed before resolution. resolved_conf = copy.deepcopy(self) if resolve_redirects: yaml_utils.resolve_all(resolved_conf) # Let's do final overrides from args self._update_from_args(param_args, resolved_conf) # merge from params self._update_from_args(self.extra_args, resolved_conf) # merge from command line if resolve_env_vars: self._process_envvars(self) if resolve_redirects: yaml_utils.resolve_all(self) self.config_filepath = config_filepath def _load_from_file(self, filepath: Optional[str]) -> None: if filepath: filepath = os.path.expanduser(os.path.expandvars(filepath)) filepath = os.path.abspath(filepath) with open(filepath, 'r') as f: config_yaml = yaml.load(f, Loader=yaml.Loader) self._process_includes(config_yaml, filepath) deep_update(self, config_yaml, lambda: Config(resolve_redirects=False)) print('config loaded from: ', filepath) def _process_includes(self, config_yaml, filepath: str): if '__include__' in config_yaml: # include could be file name or array of file names to apply in sequence includes = config_yaml['__include__'] if isinstance(includes, str): includes = [includes] assert isinstance(includes, List), "'__include__' value must be string or list" for include in includes: include_filepath = os.path.join(os.path.dirname(filepath), include) self._load_from_file(include_filepath) def _process_envvars(self, config_yaml): for key in config_yaml: value = config_yaml[key] if isinstance(value, Config): self._process_envvars(value) elif isinstance(value, str) and '$' in value: config_yaml[key] = os.path.expandvars(value) def _update_from_args(self, args: Sequence, resolved_section: 'Config') -> None: i = 0 while i < len(args)-1: arg = args[i] if arg.startswith(("--")): path = arg[len("--"):].split('.') i += Config._update_section(self, path, args[i+1], resolved_section) else: # some other arg i += 1
[docs] def save(self, filename: str) -> None: with open(filename, 'w') as f: yaml.dump(self.to_dict(), f)
[docs] def to_dict(self) -> dict: return deep_update({}, self, lambda: dict()) # type: ignore
@staticmethod def _update_section(section: 'Config', path: List[str], val: Any, resolved_section: 'Config') -> int: for p in range(len(path)-1): sub_path = path[p] if sub_path in resolved_section: resolved_section = resolved_section[sub_path] if sub_path not in section: section[sub_path] = Config(resolve_redirects=False) section = section[sub_path] else: return 1 # path not found, ignore this key = path[-1] # final leaf node value if key in resolved_section: original_val, original_type = None, None try: original_val = resolved_section[key] original_type = type(original_val) if original_type == bool: # bool('False') is True :( original_type = lambda x: strtobool(x) == 1 section[key] = original_type(val) except Exception as e: raise KeyError( f'The yaml key or command line argument "{key}" is likely not named correctly or value is of wrong data type. Error was occurred when setting it to value "{val}".' f'Originally it is set to {original_val} which is of type {original_type}.' f'Original exception: {e}') return 2 # path was found, increment arg pointer by 2 as we use up val else: return 1 # path not found, ignore this
[docs] def get_val(self, key, default_val): return super().get(key, default_val)
[docs] @staticmethod def set_inst(instance: 'Config') -> None: global _config _config = instance
[docs] @staticmethod def get_inst() -> 'Config': global _config return _config