Coverage for mlos_bench/mlos_bench/environments/local/local_env.py: 96%
133 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-30 00:51 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-30 00:51 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Scheduler-side benchmark environment to run scripts locally.
8TODO: Reference the script_env.py file for the base class.
9"""
11import json
12import logging
13import sys
14from collections.abc import Iterable, Mapping
15from contextlib import nullcontext
16from datetime import datetime
17from tempfile import TemporaryDirectory
18from types import TracebackType
19from typing import Any, Literal
21import pandas
23from mlos_bench.environments.base_environment import Environment
24from mlos_bench.environments.script_env import ScriptEnv
25from mlos_bench.environments.status import Status
26from mlos_bench.services.base_service import Service
27from mlos_bench.services.types.local_exec_type import SupportsLocalExec
28from mlos_bench.tunables.tunable_groups import TunableGroups
29from mlos_bench.tunables.tunable_types import TunableValue
30from mlos_bench.util import datetime_parser, path_join
32_LOG = logging.getLogger(__name__)
35class LocalEnv(ScriptEnv):
36 # pylint: disable=too-many-instance-attributes
37 """Scheduler-side Environment that runs scripts locally."""
39 def __init__( # pylint: disable=too-many-arguments
40 self,
41 *,
42 name: str,
43 config: dict,
44 global_config: dict | None = None,
45 tunables: TunableGroups | None = None,
46 service: Service | None = None,
47 ):
48 """
49 Create a new environment for local execution.
51 Parameters
52 ----------
53 name: str
54 Human-readable name of the environment.
55 config : dict
56 Free-format dictionary that contains the benchmark environment
57 configuration. Each config must have at least the "tunable_params"
58 and the "const_args" sections.
59 `LocalEnv` must also have at least some of the following parameters:
60 {setup, run, teardown, dump_params_file, read_results_file}
61 global_config : dict
62 Free-format dictionary of global parameters (e.g., security credentials)
63 to be mixed in into the "const_args" section of the local config.
64 tunables : TunableGroups
65 A collection of tunable parameters for *all* environments.
66 service: Service
67 An optional service object (e.g., providing methods to
68 deploy or reboot a VM, etc.).
69 """
70 super().__init__(
71 name=name,
72 config=config,
73 global_config=global_config,
74 tunables=tunables,
75 service=service,
76 )
78 assert self._service is not None and isinstance(
79 self._service, SupportsLocalExec
80 ), "LocalEnv requires a service that supports local execution"
81 self._local_exec_service: SupportsLocalExec = self._service
83 self._temp_dir: str | None = None
84 self._temp_dir_context: TemporaryDirectory | nullcontext | None = None
86 self._dump_params_file: str | None = self.config.get("dump_params_file")
87 self._dump_meta_file: str | None = self.config.get("dump_meta_file")
89 self._read_results_file: str | None = self.config.get("read_results_file")
90 self._read_telemetry_file: str | None = self.config.get("read_telemetry_file")
92 def __enter__(self) -> Environment:
93 assert self._temp_dir is None and self._temp_dir_context is None
94 self._temp_dir_context = self._local_exec_service.temp_dir_context(
95 self.config.get("temp_dir"),
96 )
97 self._temp_dir = self._temp_dir_context.__enter__()
98 return super().__enter__()
100 def __exit__(
101 self,
102 ex_type: type[BaseException] | None,
103 ex_val: BaseException | None,
104 ex_tb: TracebackType | None,
105 ) -> Literal[False]:
106 """Exit the context of the benchmarking environment."""
107 assert not (self._temp_dir is None or self._temp_dir_context is None)
108 self._temp_dir_context.__exit__(ex_type, ex_val, ex_tb)
109 self._temp_dir = None
110 self._temp_dir_context = None
111 return super().__exit__(ex_type, ex_val, ex_tb)
113 def setup(self, tunables: TunableGroups, global_config: dict | None = None) -> bool:
114 """
115 Check if the environment is ready and set up the application and benchmarks, if
116 necessary.
118 Parameters
119 ----------
120 tunables : TunableGroups
121 A collection of tunable OS and application parameters along with their
122 values. In a local environment these could be used to prepare a config
123 file on the scheduler prior to transferring it to the remote environment,
124 for instance.
125 global_config : dict
126 Free-format dictionary of global parameters of the environment
127 that are not used in the optimization process.
129 Returns
130 -------
131 is_success : bool
132 True if operation is successful, false otherwise.
133 """
134 if not super().setup(tunables, global_config):
135 return False
137 _LOG.info("Set up the environment locally: '%s' at %s", self, self._temp_dir)
138 assert self._temp_dir is not None
140 if self._dump_params_file:
141 fname = path_join(self._temp_dir, self._dump_params_file)
142 _LOG.debug("Dump tunables to file: %s", fname)
143 with open(fname, "w", encoding="utf-8") as fh_tunables:
144 # json.dump(self._params, fh_tunables) # Tunables *and* const_args
145 json.dump(self._tunable_params.get_param_values(), fh_tunables)
147 if self._dump_meta_file:
148 fname = path_join(self._temp_dir, self._dump_meta_file)
149 _LOG.debug("Dump tunables metadata to file: %s", fname)
150 with open(fname, "w", encoding="utf-8") as fh_meta:
151 json.dump(
152 {
153 tunable.name: tunable.meta
154 for (tunable, _group) in self._tunable_params
155 if tunable.meta
156 },
157 fh_meta,
158 )
160 if self._script_setup:
161 (return_code, _output) = self._local_exec(self._script_setup, self._temp_dir)
162 self._is_ready = bool(return_code == 0)
163 else:
164 self._is_ready = True
166 return self._is_ready
168 def run(self) -> tuple[Status, datetime, dict[str, TunableValue] | None]:
169 """
170 Run a script in the local scheduler environment.
172 Returns
173 -------
174 (status, timestamp, output) : (Status, datetime.datetime, dict)
175 3-tuple of (Status, timestamp, output) values, where `output` is a dict
176 with the results or None if the status is not COMPLETED.
177 If run script is a benchmark, then the score is usually expected to
178 be in the `score` field.
179 """
180 (status, timestamp, _) = result = super().run()
181 if not status.is_ready():
182 return result
184 assert self._temp_dir is not None
186 stdout_data: dict[str, TunableValue] = {}
187 if self._script_run:
188 (return_code, output) = self._local_exec(self._script_run, self._temp_dir)
189 if return_code != 0:
190 return (Status.FAILED, timestamp, None)
191 stdout_data = self._extract_stdout_results(output.get("stdout", ""))
193 # FIXME: We should not be assuming that the only output file type is a CSV.
194 if not self._read_results_file:
195 _LOG.debug("Not reading the data at: %s", self)
196 return (Status.SUCCEEDED, timestamp, stdout_data)
198 try:
199 data = self._normalize_columns(
200 pandas.read_csv(
201 self._config_loader_service.resolve_path(
202 self._read_results_file,
203 extra_paths=[self._temp_dir],
204 ),
205 index_col=False,
206 )
207 )
208 except pandas.errors.EmptyDataError:
209 _LOG.warning("Empty metrics file - fail the run")
210 return (Status.FAILED, timestamp, None)
212 _LOG.debug("Read data:\n%s", data)
213 if len(data) == 0:
214 _LOG.warning("No data in the metrics file - fail the run")
215 return (Status.FAILED, timestamp, None)
216 elif list(data.columns) == ["metric", "value"]:
217 _LOG.info(
218 "Local results have (metric,value) header and %d rows: assume long format",
219 len(data),
220 )
221 data = pandas.DataFrame([data.value.to_list()], columns=data.metric.to_list())
222 # Try to convert string metrics to numbers.
223 data = data.apply(
224 pandas.to_numeric,
225 errors="coerce",
226 ).fillna(data)
227 elif len(data) == 1:
228 _LOG.info("Local results have 1 row: assume wide format")
229 else:
230 raise ValueError(f"Invalid data format: {data}")
232 stdout_data.update(data.iloc[-1].to_dict())
233 _LOG.info("Local run complete: %s ::\n%s", self, stdout_data)
234 return (Status.SUCCEEDED, timestamp, stdout_data)
236 @staticmethod
237 def _normalize_columns(data: pandas.DataFrame) -> pandas.DataFrame:
238 """Strip trailing spaces from column names (Windows only)."""
239 # Windows cmd interpretation of > redirect symbols can leave trailing spaces in
240 # the final column, which leads to misnamed columns.
241 # For now, we simply strip trailing spaces from column names to account for that.
242 if sys.platform == "win32":
243 data.rename(str.rstrip, axis="columns", inplace=True)
244 return data
246 def status(self) -> tuple[Status, datetime, list[tuple[datetime, str, Any]]]:
248 (status, timestamp, _) = super().status()
249 if not (self._is_ready and self._read_telemetry_file):
250 return (status, timestamp, [])
252 assert self._temp_dir is not None
253 try:
254 fname = self._config_loader_service.resolve_path(
255 self._read_telemetry_file,
256 extra_paths=[self._temp_dir],
257 )
259 # TODO: Use the timestamp of the CSV file as our status timestamp?
261 # FIXME: We should not be assuming that the only output file type is a CSV.
263 data = self._normalize_columns(pandas.read_csv(fname, index_col=False))
264 data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local")
266 expected_col_names = ["timestamp", "metric", "value"]
267 if len(data.columns) != len(expected_col_names):
268 raise ValueError(f"Telemetry data must have columns {expected_col_names}")
270 if list(data.columns) != expected_col_names:
271 # Assume no header - this is ok for telemetry data.
272 data = pandas.read_csv(fname, index_col=False, names=expected_col_names)
273 data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local")
275 except FileNotFoundError as ex:
276 _LOG.warning("Telemetry CSV file not found: %s :: %s", self._read_telemetry_file, ex)
277 return (status, timestamp, [])
279 _LOG.debug("Read telemetry data:\n%s", data)
280 col_dtypes: Mapping[int, type] = {0: datetime}
281 return (
282 status,
283 timestamp,
284 [
285 (pandas.Timestamp(ts).to_pydatetime(), metric, value)
286 for (ts, metric, value) in data.to_records(index=False, column_dtypes=col_dtypes)
287 ],
288 )
290 def teardown(self) -> None:
291 """Clean up the local environment."""
292 if self._script_teardown:
293 _LOG.info("Local teardown: %s", self)
294 (return_code, _output) = self._local_exec(self._script_teardown)
295 _LOG.info("Local teardown complete: %s :: %s", self, return_code)
296 super().teardown()
298 def _local_exec(self, script: Iterable[str], cwd: str | None = None) -> tuple[int, dict]:
299 """
300 Execute a script locally in the scheduler environment.
302 Parameters
303 ----------
304 script : Iterable[str]
305 Lines of the script to run locally.
306 Treat every line as a separate command to run.
307 cwd : str | None
308 Work directory to run the script at.
310 Returns
311 -------
312 (return_code, output) : (int, dict)
313 Return code of the script and a dict with stdout/stderr. Return code = 0 if successful.
314 """
315 env_params = self._get_env_params()
316 _LOG.info("Run script locally on: %s at %s with env %s", self, cwd, env_params)
317 (return_code, stdout, stderr) = self._local_exec_service.local_exec(
318 script,
319 env=env_params,
320 cwd=cwd,
321 )
322 if return_code != 0:
323 _LOG.warning("ERROR: Local script returns code %d stderr:\n%s", return_code, stderr)
324 return (return_code, {"stdout": stdout, "stderr": stderr})