Coverage for mlos_bench/mlos_bench/dict_templater.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 01:18 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Simple class to help with nested dictionary ``$var`` templating in configuration file 

6expansions. 

7""" 

8 

9from copy import deepcopy 

10from string import Template 

11from typing import Any, Dict, Optional 

12 

13from mlos_bench.os_environ import environ 

14 

15 

16class DictTemplater: # pylint: disable=too-few-public-methods 

17 """Simple class to help with nested dictionary ``$var`` templating.""" 

18 

19 def __init__(self, source_dict: Dict[str, Any]): 

20 """ 

21 Initialize the templater. 

22 

23 Parameters 

24 ---------- 

25 source_dict : Dict[str, Any] 

26 The template dict to use for source variables. 

27 """ 

28 # A copy of the initial data structure we were given with templates intact. 

29 self._template_dict = deepcopy(source_dict) 

30 # The source/target dictionary to expand. 

31 self._dict: Dict[str, Any] = {} 

32 

33 def expand_vars( 

34 self, 

35 *, 

36 extra_source_dict: Optional[Dict[str, Any]] = None, 

37 use_os_env: bool = False, 

38 ) -> Dict[str, Any]: 

39 """ 

40 Expand the template variables in the destination dictionary. 

41 

42 Parameters 

43 ---------- 

44 extra_source_dict : Dict[str, Any] 

45 An optional extra source dictionary to use for expansion. 

46 use_os_env : bool 

47 Whether to use the os environment variables a final fallback for expansion. 

48 

49 Returns 

50 ------- 

51 Dict[str, Any] 

52 The expanded dictionary. 

53 

54 Raises 

55 ------ 

56 ValueError 

57 On unsupported nested types. 

58 """ 

59 self._dict = deepcopy(self._template_dict) 

60 self._dict = self._expand_vars(self._dict, extra_source_dict, use_os_env) 

61 assert isinstance(self._dict, dict) 

62 return self._dict 

63 

64 def _expand_vars( 

65 self, 

66 value: Any, 

67 extra_source_dict: Optional[Dict[str, Any]], 

68 use_os_env: bool, 

69 ) -> Any: 

70 """Recursively expand ``$var`` strings in the currently operating dictionary.""" 

71 if isinstance(value, str): 

72 # First try to expand all $vars internally. 

73 value = Template(value).safe_substitute(self._dict) 

74 # Next, if there are any left, try to expand them from the extra source dict. 

75 if extra_source_dict: 

76 value = Template(value).safe_substitute(extra_source_dict) 

77 # Finally, fallback to the os environment. 

78 if use_os_env: 

79 value = Template(value).safe_substitute(dict(environ)) 

80 elif isinstance(value, dict): 

81 # Note: we use a loop instead of dict comprehension in order to 

82 # allow secondary expansion of subsequent values immediately. 

83 for key, val in value.items(): 

84 value[key] = self._expand_vars(val, extra_source_dict, use_os_env) 

85 elif isinstance(value, list): 

86 value = [self._expand_vars(val, extra_source_dict, use_os_env) for val in value] 

87 elif isinstance(value, (int, float, bool)) or value is None: 

88 return value 

89 else: 

90 raise ValueError(f"Unexpected type {type(value)} for value {value}") 

91 return value