Coverage for mlos_bench/mlos_bench/dict_templater.py: 100%
29 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Simple class to help with nested dictionary ``$var`` templating in configuration file
6expansions.
7"""
9from copy import deepcopy
10from string import Template
11from typing import Any, Dict, Optional
13from mlos_bench.os_environ import environ
16class DictTemplater: # pylint: disable=too-few-public-methods
17 """Simple class to help with nested dictionary ``$var`` templating."""
19 def __init__(self, source_dict: Dict[str, Any]):
20 """
21 Initialize the templater.
23 Parameters
24 ----------
25 source_dict : Dict[str, Any]
26 The template dict to use for source variables.
27 """
28 # A copy of the initial data structure we were given with templates intact.
29 self._template_dict = deepcopy(source_dict)
30 # The source/target dictionary to expand.
31 self._dict: Dict[str, Any] = {}
33 def expand_vars(
34 self,
35 *,
36 extra_source_dict: Optional[Dict[str, Any]] = None,
37 use_os_env: bool = False,
38 ) -> Dict[str, Any]:
39 """
40 Expand the template variables in the destination dictionary.
42 Parameters
43 ----------
44 extra_source_dict : Dict[str, Any]
45 An optional extra source dictionary to use for expansion.
46 use_os_env : bool
47 Whether to use the os environment variables a final fallback for expansion.
49 Returns
50 -------
51 Dict[str, Any]
52 The expanded dictionary.
54 Raises
55 ------
56 ValueError
57 On unsupported nested types.
58 """
59 self._dict = deepcopy(self._template_dict)
60 self._dict = self._expand_vars(self._dict, extra_source_dict, use_os_env)
61 assert isinstance(self._dict, dict)
62 return self._dict
64 def _expand_vars(
65 self,
66 value: Any,
67 extra_source_dict: Optional[Dict[str, Any]],
68 use_os_env: bool,
69 ) -> Any:
70 """Recursively expand ``$var`` strings in the currently operating dictionary."""
71 if isinstance(value, str):
72 # First try to expand all $vars internally.
73 value = Template(value).safe_substitute(self._dict)
74 # Next, if there are any left, try to expand them from the extra source dict.
75 if extra_source_dict:
76 value = Template(value).safe_substitute(extra_source_dict)
77 # Finally, fallback to the os environment.
78 if use_os_env:
79 value = Template(value).safe_substitute(dict(environ))
80 elif isinstance(value, dict):
81 # Note: we use a loop instead of dict comprehension in order to
82 # allow secondary expansion of subsequent values immediately.
83 for key, val in value.items():
84 value[key] = self._expand_vars(val, extra_source_dict, use_os_env)
85 elif isinstance(value, list):
86 value = [self._expand_vars(val, extra_source_dict, use_os_env) for val in value]
87 elif isinstance(value, (int, float, bool)) or value is None:
88 return value
89 else:
90 raise ValueError(f"Unexpected type {type(value)} for value {value}")
91 return value