Coverage for mlos_bench/mlos_bench/dict_templater.py: 97%
30 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Simple class to help with nested dictionary $var templating.
7"""
9from copy import deepcopy
10from string import Template
11from typing import Any, Dict, Optional
13from mlos_bench.os_environ import environ
16class DictTemplater: # pylint: disable=too-few-public-methods
17 """
18 Simple class to help with nested dictionary $var templating.
19 """
21 def __init__(self, source_dict: Dict[str, Any]):
22 """
23 Initialize the templater.
25 Parameters
26 ----------
27 source_dict : Dict[str, Any]
28 The template dict to use for source variables.
29 """
30 # A copy of the initial data structure we were given with templates intact.
31 self._template_dict = deepcopy(source_dict)
32 # The source/target dictionary to expand.
33 self._dict: Dict[str, Any] = {}
35 def expand_vars(self, *,
36 extra_source_dict: Optional[Dict[str, Any]] = None,
37 use_os_env: bool = False) -> Dict[str, Any]:
38 """
39 Expand the template variables in the destination dictionary.
41 Parameters
42 ----------
43 extra_source_dict : Dict[str, Any]
44 An optional extra source dictionary to use for expansion.
45 use_os_env : bool
46 Whether to use the os environment variables a final fallback for expansion.
48 Returns
49 -------
50 Dict[str, Any]
51 The expanded dictionary.
52 """
53 self._dict = deepcopy(self._template_dict)
54 self._dict = self._expand_vars(self._dict, extra_source_dict, use_os_env)
55 assert isinstance(self._dict, dict)
56 return self._dict
58 def _expand_vars(self, value: Any, extra_source_dict: Optional[Dict[str, Any]], use_os_env: bool) -> Any:
59 """
60 Recursively expand $var strings in the currently operating dictionary.
61 """
62 if isinstance(value, str):
63 # First try to expand all $vars internally.
64 value = Template(value).safe_substitute(self._dict)
65 # Next, if there are any left, try to expand them from the extra source dict.
66 if extra_source_dict:
67 value = Template(value).safe_substitute(extra_source_dict)
68 # Finally, fallback to the os environment.
69 if use_os_env:
70 value = Template(value).safe_substitute(dict(environ))
71 elif isinstance(value, dict):
72 # Note: we use a loop instead of dict comprehension in order to
73 # allow secondary expansion of subsequent values immediately.
74 for (key, val) in value.items():
75 value[key] = self._expand_vars(val, extra_source_dict, use_os_env)
76 elif isinstance(value, list):
77 value = [self._expand_vars(val, extra_source_dict, use_os_env) for val in value]
78 elif isinstance(value, (int, float, bool)) or value is None:
79 return value
80 else:
81 raise ValueError(f"Unexpected type {type(value)} for value {value}")
82 return value