Coverage for mlos_bench/mlos_bench/dict_templater.py: 97%

30 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Simple class to help with nested dictionary $var templating. 

7""" 

8 

9from copy import deepcopy 

10from string import Template 

11from typing import Any, Dict, Optional 

12 

13from mlos_bench.os_environ import environ 

14 

15 

16class DictTemplater: # pylint: disable=too-few-public-methods 

17 """ 

18 Simple class to help with nested dictionary $var templating. 

19 """ 

20 

21 def __init__(self, source_dict: Dict[str, Any]): 

22 """ 

23 Initialize the templater. 

24 

25 Parameters 

26 ---------- 

27 source_dict : Dict[str, Any] 

28 The template dict to use for source variables. 

29 """ 

30 # A copy of the initial data structure we were given with templates intact. 

31 self._template_dict = deepcopy(source_dict) 

32 # The source/target dictionary to expand. 

33 self._dict: Dict[str, Any] = {} 

34 

35 def expand_vars(self, *, 

36 extra_source_dict: Optional[Dict[str, Any]] = None, 

37 use_os_env: bool = False) -> Dict[str, Any]: 

38 """ 

39 Expand the template variables in the destination dictionary. 

40 

41 Parameters 

42 ---------- 

43 extra_source_dict : Dict[str, Any] 

44 An optional extra source dictionary to use for expansion. 

45 use_os_env : bool 

46 Whether to use the os environment variables a final fallback for expansion. 

47 

48 Returns 

49 ------- 

50 Dict[str, Any] 

51 The expanded dictionary. 

52 """ 

53 self._dict = deepcopy(self._template_dict) 

54 self._dict = self._expand_vars(self._dict, extra_source_dict, use_os_env) 

55 assert isinstance(self._dict, dict) 

56 return self._dict 

57 

58 def _expand_vars(self, value: Any, extra_source_dict: Optional[Dict[str, Any]], use_os_env: bool) -> Any: 

59 """ 

60 Recursively expand $var strings in the currently operating dictionary. 

61 """ 

62 if isinstance(value, str): 

63 # First try to expand all $vars internally. 

64 value = Template(value).safe_substitute(self._dict) 

65 # Next, if there are any left, try to expand them from the extra source dict. 

66 if extra_source_dict: 

67 value = Template(value).safe_substitute(extra_source_dict) 

68 # Finally, fallback to the os environment. 

69 if use_os_env: 

70 value = Template(value).safe_substitute(dict(environ)) 

71 elif isinstance(value, dict): 

72 # Note: we use a loop instead of dict comprehension in order to 

73 # allow secondary expansion of subsequent values immediately. 

74 for (key, val) in value.items(): 

75 value[key] = self._expand_vars(val, extra_source_dict, use_os_env) 

76 elif isinstance(value, list): 

77 value = [self._expand_vars(val, extra_source_dict, use_os_env) for val in value] 

78 elif isinstance(value, (int, float, bool)) or value is None: 

79 return value 

80 else: 

81 raise ValueError(f"Unexpected type {type(value)} for value {value}") 

82 return value