Coverage for mlos_bench/mlos_bench/environments/script_env.py: 100%
32 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-05 00:36 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-05 00:36 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Base scriptable benchmark environment.
7"""
9import abc
10import logging
11import re
12from typing import Dict, Iterable, Optional
14from mlos_bench.environments.base_environment import Environment
15from mlos_bench.services.base_service import Service
16from mlos_bench.tunables.tunable import TunableValue
17from mlos_bench.tunables.tunable_groups import TunableGroups
19from mlos_bench.util import try_parse_val
21_LOG = logging.getLogger(__name__)
24class ScriptEnv(Environment, metaclass=abc.ABCMeta):
25 """
26 Base Environment that runs scripts for setup/run/teardown.
27 """
29 _RE_INVALID = re.compile(r"[^a-zA-Z0-9_]")
31 def __init__(self,
32 *,
33 name: str,
34 config: dict,
35 global_config: Optional[dict] = None,
36 tunables: Optional[TunableGroups] = None,
37 service: Optional[Service] = None):
38 """
39 Create a new environment for script execution.
41 Parameters
42 ----------
43 name: str
44 Human-readable name of the environment.
45 config : dict
46 Free-format dictionary that contains the benchmark environment
47 configuration. Each config must have at least the `tunable_params`
48 and the `const_args` sections. It must also have at least one of
49 the following parameters: {`setup`, `run`, `teardown`}.
50 Additional parameters:
51 * `shell_env_params` - an array of parameters to pass to the script
52 as shell environment variables, and
53 * `shell_env_params_rename` - a dictionary of {to: from} mappings
54 of the script parameters. If not specified, replace all
55 non-alphanumeric characters with underscores.
56 If neither `shell_env_params` nor `shell_env_params_rename` are specified,
57 *no* additional shell parameters will be passed to the script.
58 global_config : dict
59 Free-format dictionary of global parameters (e.g., security credentials)
60 to be mixed in into the "const_args" section of the local config.
61 tunables : TunableGroups
62 A collection of tunable parameters for *all* environments.
63 service: Service
64 An optional service object (e.g., providing methods to
65 deploy or reboot a VM, etc.).
66 """
67 super().__init__(name=name, config=config, global_config=global_config,
68 tunables=tunables, service=service)
70 self._script_setup = self.config.get("setup")
71 self._script_run = self.config.get("run")
72 self._script_teardown = self.config.get("teardown")
74 self._shell_env_params: Iterable[str] = self.config.get("shell_env_params", [])
75 self._shell_env_params_rename: Dict[str, str] = self.config.get("shell_env_params_rename", {})
77 results_stdout_pattern = self.config.get("results_stdout_pattern")
78 self._results_stdout_pattern: Optional[re.Pattern[str]] = \
79 re.compile(results_stdout_pattern, flags=re.MULTILINE) if results_stdout_pattern else None
81 def _get_env_params(self, restrict: bool = True) -> Dict[str, str]:
82 """
83 Get the *shell* environment parameters to be passed to the script.
85 Parameters
86 ----------
87 restrict : bool
88 If True, only return the parameters that are in the `_shell_env_params`
89 list. If False, return all parameters in `_params` with some possible
90 conversions.
92 Returns
93 -------
94 env_params : Dict[str, str]
95 Parameters to pass as *shell* environment variables into the script.
96 This is usually a subset of `_params` with some possible conversions.
97 """
98 input_params = self._shell_env_params if restrict else self._params.keys()
99 rename = {self._RE_INVALID.sub("_", key): key for key in input_params}
100 rename.update(self._shell_env_params_rename)
101 return {key_sub: str(self._params[key]) for (key_sub, key) in rename.items()}
103 def _extract_stdout_results(self, stdout: str) -> Dict[str, TunableValue]:
104 """
105 Extract the results from the stdout of the script.
107 Parameters
108 ----------
109 stdout : str
110 The stdout of the script.
112 Returns
113 -------
114 results : Dict[str, TunableValue]
115 A dictionary of results extracted from the stdout.
116 """
117 if not self._results_stdout_pattern:
118 return {}
119 _LOG.debug("Extract regex: '%s' from: '%s'", self._results_stdout_pattern, stdout)
120 return {key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout)}