Coverage for mlos_bench/mlos_bench/environments/script_env.py: 100%
31 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Base scriptable benchmark environment.
8TODO: Document how variable propogation works in the script environments using
9shell_env_params, required_args, const_args, etc.
10"""
12import abc
13import logging
14import re
15from typing import Dict, Iterable, Optional
17from mlos_bench.environments.base_environment import Environment
18from mlos_bench.services.base_service import Service
19from mlos_bench.tunables.tunable import TunableValue
20from mlos_bench.tunables.tunable_groups import TunableGroups
21from mlos_bench.util import try_parse_val
23_LOG = logging.getLogger(__name__)
26class ScriptEnv(Environment, metaclass=abc.ABCMeta):
27 """Base Environment that runs scripts for the different phases (e.g.,
28 :py:meth:`.Environment.setup`, :py:meth:`.Environment.run`,
29 :py:meth:`.Environment.teardown`, etc.)
30 """
32 _RE_INVALID = re.compile(r"[^a-zA-Z0-9_]")
34 def __init__( # pylint: disable=too-many-arguments
35 self,
36 *,
37 name: str,
38 config: dict,
39 global_config: Optional[dict] = None,
40 tunables: Optional[TunableGroups] = None,
41 service: Optional[Service] = None,
42 ):
43 """
44 Create a new environment for script execution.
46 Parameters
47 ----------
48 name : str
49 Human-readable name of the environment.
50 config : dict
51 Free-format dictionary that contains the benchmark environment
52 configuration. Each config must have at least the `tunable_params`
53 and the `const_args` sections. It must also have at least one of
54 the following parameters: {`setup`, `run`, `teardown`}.
55 Additional parameters:
57 - `shell_env_params` - an array of parameters to pass to the script
58 as shell environment variables, and
59 - `shell_env_params_rename` - a dictionary of {to: from} mappings
60 of the script parameters. If not specified, replace all
61 non-alphanumeric characters with underscores.
63 If neither `shell_env_params` nor `shell_env_params_rename` are specified,
64 *no* additional shell parameters will be passed to the script.
65 global_config : dict
66 Free-format dictionary of global parameters (e.g., security credentials)
67 to be mixed in into the "const_args" section of the local config.
68 tunables : TunableGroups
69 A collection of tunable parameters for *all* environments.
70 service : Service
71 An optional service object (e.g., providing methods to
72 deploy or reboot a VM, etc.).
73 """
74 super().__init__(
75 name=name,
76 config=config,
77 global_config=global_config,
78 tunables=tunables,
79 service=service,
80 )
82 self._script_setup = self.config.get("setup")
83 self._script_run = self.config.get("run")
84 self._script_teardown = self.config.get("teardown")
86 self._shell_env_params: Iterable[str] = self.config.get("shell_env_params", [])
87 self._shell_env_params_rename: Dict[str, str] = self.config.get(
88 "shell_env_params_rename", {}
89 )
91 results_stdout_pattern = self.config.get("results_stdout_pattern")
92 self._results_stdout_pattern: Optional[re.Pattern[str]] = (
93 re.compile(results_stdout_pattern, flags=re.MULTILINE)
94 if results_stdout_pattern
95 else None
96 )
98 def _get_env_params(self, restrict: bool = True) -> Dict[str, str]:
99 """
100 Get the *shell* environment parameters to be passed to the script.
102 Parameters
103 ----------
104 restrict : bool
105 If True, only return the parameters that are in the `_shell_env_params`
106 list. If False, return all parameters in `_params` with some possible
107 conversions.
109 Returns
110 -------
111 env_params : Dict[str, str]
112 Parameters to pass as *shell* environment variables into the script.
113 This is usually a subset of `_params` with some possible conversions.
114 """
115 input_params = self._shell_env_params if restrict else self._params.keys()
116 rename = {self._RE_INVALID.sub("_", key): key for key in input_params}
117 rename.update(self._shell_env_params_rename)
118 return {key_sub: str(self._params[key]) for (key_sub, key) in rename.items()}
120 def _extract_stdout_results(self, stdout: str) -> Dict[str, TunableValue]:
121 """
122 Extract the results from the stdout of the script.
124 Parameters
125 ----------
126 stdout : str
127 The stdout of the script.
129 Returns
130 -------
131 results : Dict[str, TunableValue]
132 A dictionary of results extracted from the stdout.
133 """
134 if not self._results_stdout_pattern:
135 return {}
136 _LOG.debug("Extract regex: '%s' from: '%s'", self._results_stdout_pattern, stdout)
137 return {
138 key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout)
139 }