Source code for archai.discrete_search.search_spaces.config.utils

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from collections import OrderedDict
from typing import Any, Callable, Dict, Union

from archai.discrete_search.search_spaces.config.discrete_choice import DiscreteChoice


[docs]def flatten_dict(odict: Dict[str, Any]) -> dict: """Flatten a nested dictionary into a single level dictionary. Args: odict: Nested dictionary. Returns: Flattened dictionary. """ fdict = dict() def _flatten(prefix: str, d: Dict[str, Any]) -> Dict[str, Any]: prefix = prefix + "." if prefix else prefix if isinstance(d, dict): for k, v in d.items(): flat_v = _flatten(prefix + k, v) if flat_v is not None: fdict[prefix + k] = flat_v else: return d _flatten("", odict) return fdict
[docs]def order_dict_keys(base_dict: OrderedDict, target_dict: Dict[str, Any]) -> OrderedDict: """Order the keys of a target dictionary based on a base dictionary. Args: base_dict (OrderedDict[str, Any]): Dictionary with the desired key order. target_dict (Dict[str, Any]): Dictionary to be ordered. Returns: OrderedDict[str, Any]: Ordered version of `target_dict` dictionary. """ ordered_dict = OrderedDict() for k in base_dict: if k in target_dict: ordered_dict[k] = target_dict[k] return ordered_dict
[docs]def replace_ptree_choices( config_tree: Union[Dict, DiscreteChoice], repl_fn: Callable[[DiscreteChoice], Any] ) -> OrderedDict: """Replace all DiscreteChoice nodes in a tree with the output of a function. Args: config_tree: Tree with DiscreteChoice nodes. repl_fn: Function to replace DiscreteChoice nodes. Returns: Replaced tree. """ def _replace_tree_nodes(node, repl_fn, ref_map): if isinstance(node, dict): output_tree = OrderedDict() for param_name, param in node.items(): output_tree[param_name] = _replace_tree_nodes(param, repl_fn, ref_map) elif isinstance(node, DiscreteChoice): if id(node) not in ref_map: ref_map[id(node)] = repl_fn(node) return ref_map[id(node)] else: return node return output_tree return _replace_tree_nodes(config_tree, repl_fn, {})
[docs]def replace_ptree_pair_choices( query_tree: Union[Dict, DiscreteChoice], aux_tree: Union[Dict, Any], repl_fn: Callable[[DiscreteChoice, Any], Any] ) -> OrderedDict: """Replace all DiscreteChoice nodes in a tree with the output of a function and an auxilary tree. Args: query_tree: Tree with DiscreteChoice nodes. aux_tree: Auxiliary tree with DiscreteChoice nodes. repl_fn: Function that takes a `query_node` and an `aux_node` and returns a replacement for `query_node`. Returns: Replaced tree. """ def _replace_tree_nodes(query_node, aux_node, repl_fn, ref_map): if isinstance(query_node, dict): output_tree = OrderedDict() for param_name, param in query_node.items(): assert param_name in aux_node, "`aux_tree` must be identical to `query_tree` apart from terminal nodes" output_tree[param_name] = _replace_tree_nodes(param, aux_node[param_name], repl_fn, ref_map) elif isinstance(query_node, DiscreteChoice): if id(query_node) not in ref_map: ref_map[id(query_node)] = repl_fn(query_node, aux_node) return ref_map[id(query_node)] else: return query_node return output_tree return _replace_tree_nodes(query_tree, aux_tree, repl_fn, {})