Coverage for mlos_bench/mlos_bench/environments/script_env.py: 100%

32 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-05 00:36 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Base scriptable benchmark environment. 

7""" 

8 

9import abc 

10import logging 

11import re 

12from typing import Dict, Iterable, Optional 

13 

14from mlos_bench.environments.base_environment import Environment 

15from mlos_bench.services.base_service import Service 

16from mlos_bench.tunables.tunable import TunableValue 

17from mlos_bench.tunables.tunable_groups import TunableGroups 

18 

19from mlos_bench.util import try_parse_val 

20 

21_LOG = logging.getLogger(__name__) 

22 

23 

24class ScriptEnv(Environment, metaclass=abc.ABCMeta): 

25 """ 

26 Base Environment that runs scripts for setup/run/teardown. 

27 """ 

28 

29 _RE_INVALID = re.compile(r"[^a-zA-Z0-9_]") 

30 

31 def __init__(self, 

32 *, 

33 name: str, 

34 config: dict, 

35 global_config: Optional[dict] = None, 

36 tunables: Optional[TunableGroups] = None, 

37 service: Optional[Service] = None): 

38 """ 

39 Create a new environment for script execution. 

40 

41 Parameters 

42 ---------- 

43 name: str 

44 Human-readable name of the environment. 

45 config : dict 

46 Free-format dictionary that contains the benchmark environment 

47 configuration. Each config must have at least the `tunable_params` 

48 and the `const_args` sections. It must also have at least one of 

49 the following parameters: {`setup`, `run`, `teardown`}. 

50 Additional parameters: 

51 * `shell_env_params` - an array of parameters to pass to the script 

52 as shell environment variables, and 

53 * `shell_env_params_rename` - a dictionary of {to: from} mappings 

54 of the script parameters. If not specified, replace all 

55 non-alphanumeric characters with underscores. 

56 If neither `shell_env_params` nor `shell_env_params_rename` are specified, 

57 *no* additional shell parameters will be passed to the script. 

58 global_config : dict 

59 Free-format dictionary of global parameters (e.g., security credentials) 

60 to be mixed in into the "const_args" section of the local config. 

61 tunables : TunableGroups 

62 A collection of tunable parameters for *all* environments. 

63 service: Service 

64 An optional service object (e.g., providing methods to 

65 deploy or reboot a VM, etc.). 

66 """ 

67 super().__init__(name=name, config=config, global_config=global_config, 

68 tunables=tunables, service=service) 

69 

70 self._script_setup = self.config.get("setup") 

71 self._script_run = self.config.get("run") 

72 self._script_teardown = self.config.get("teardown") 

73 

74 self._shell_env_params: Iterable[str] = self.config.get("shell_env_params", []) 

75 self._shell_env_params_rename: Dict[str, str] = self.config.get("shell_env_params_rename", {}) 

76 

77 results_stdout_pattern = self.config.get("results_stdout_pattern") 

78 self._results_stdout_pattern: Optional[re.Pattern[str]] = \ 

79 re.compile(results_stdout_pattern, flags=re.MULTILINE) if results_stdout_pattern else None 

80 

81 def _get_env_params(self, restrict: bool = True) -> Dict[str, str]: 

82 """ 

83 Get the *shell* environment parameters to be passed to the script. 

84 

85 Parameters 

86 ---------- 

87 restrict : bool 

88 If True, only return the parameters that are in the `_shell_env_params` 

89 list. If False, return all parameters in `_params` with some possible 

90 conversions. 

91 

92 Returns 

93 ------- 

94 env_params : Dict[str, str] 

95 Parameters to pass as *shell* environment variables into the script. 

96 This is usually a subset of `_params` with some possible conversions. 

97 """ 

98 input_params = self._shell_env_params if restrict else self._params.keys() 

99 rename = {self._RE_INVALID.sub("_", key): key for key in input_params} 

100 rename.update(self._shell_env_params_rename) 

101 return {key_sub: str(self._params[key]) for (key_sub, key) in rename.items()} 

102 

103 def _extract_stdout_results(self, stdout: str) -> Dict[str, TunableValue]: 

104 """ 

105 Extract the results from the stdout of the script. 

106 

107 Parameters 

108 ---------- 

109 stdout : str 

110 The stdout of the script. 

111 

112 Returns 

113 ------- 

114 results : Dict[str, TunableValue] 

115 A dictionary of results extracted from the stdout. 

116 """ 

117 if not self._results_stdout_pattern: 

118 return {} 

119 _LOG.debug("Extract regex: '%s' from: '%s'", self._results_stdout_pattern, stdout) 

120 return {key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout)}