Coverage for mlos_bench/mlos_bench/environments/base_environment.py: 93%
137 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""A hierarchy of benchmark environments."""
7import abc
8import json
9import logging
10from datetime import datetime
11from types import TracebackType
12from typing import (
13 TYPE_CHECKING,
14 Any,
15 Dict,
16 Iterable,
17 List,
18 Literal,
19 Optional,
20 Sequence,
21 Tuple,
22 Type,
23 Union,
24)
26from pytz import UTC
28from mlos_bench.config.schemas import ConfigSchema
29from mlos_bench.dict_templater import DictTemplater
30from mlos_bench.environments.status import Status
31from mlos_bench.services.base_service import Service
32from mlos_bench.tunables.tunable import TunableValue
33from mlos_bench.tunables.tunable_groups import TunableGroups
34from mlos_bench.util import instantiate_from_config, merge_parameters
36if TYPE_CHECKING:
37 from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
39_LOG = logging.getLogger(__name__)
42class Environment(metaclass=abc.ABCMeta):
43 # pylint: disable=too-many-instance-attributes
44 """An abstract base of all benchmark environments."""
46 @classmethod
47 def new( # pylint: disable=too-many-arguments
48 cls,
49 *,
50 env_name: str,
51 class_name: str,
52 config: dict,
53 global_config: Optional[dict] = None,
54 tunables: Optional[TunableGroups] = None,
55 service: Optional[Service] = None,
56 ) -> "Environment":
57 """
58 Factory method for a new environment with a given config.
60 Parameters
61 ----------
62 env_name: str
63 Human-readable name of the environment.
64 class_name: str
65 FQN of a Python class to instantiate, e.g.,
66 "mlos_bench.environments.remote.HostEnv".
67 Must be derived from the `Environment` class.
68 config : dict
69 Free-format dictionary that contains the benchmark environment
70 configuration. It will be passed as a constructor parameter of
71 the class specified by `name`.
72 global_config : dict
73 Free-format dictionary of global parameters (e.g., security credentials)
74 to be mixed in into the "const_args" section of the local config.
75 tunables : TunableGroups
76 A collection of groups of tunable parameters for all environments.
77 service: Service
78 An optional service object (e.g., providing methods to
79 deploy or reboot a VM/Host, etc.).
81 Returns
82 -------
83 env : Environment
84 An instance of the `Environment` class initialized with `config`.
85 """
86 assert issubclass(cls, Environment)
87 return instantiate_from_config(
88 cls,
89 class_name,
90 name=env_name,
91 config=config,
92 global_config=global_config,
93 tunables=tunables,
94 service=service,
95 )
97 def __init__( # pylint: disable=too-many-arguments
98 self,
99 *,
100 name: str,
101 config: dict,
102 global_config: Optional[dict] = None,
103 tunables: Optional[TunableGroups] = None,
104 service: Optional[Service] = None,
105 ):
106 """
107 Create a new environment with a given config.
109 Parameters
110 ----------
111 name: str
112 Human-readable name of the environment.
113 config : dict
114 Free-format dictionary that contains the benchmark environment
115 configuration. Each config must have at least the "tunable_params"
116 and the "const_args" sections.
117 global_config : dict
118 Free-format dictionary of global parameters (e.g., security credentials)
119 to be mixed in into the "const_args" section of the local config.
120 tunables : TunableGroups
121 A collection of groups of tunable parameters for all environments.
122 service: Service
123 An optional service object (e.g., providing methods to
124 deploy or reboot a VM/Host, etc.).
125 """
126 self._validate_json_config(config, name)
127 self.name = name
128 self.config = config
129 self._service = service
130 self._service_context: Optional[Service] = None
131 self._is_ready = False
132 self._in_context = False
133 self._const_args: Dict[str, TunableValue] = config.get("const_args", {})
135 if _LOG.isEnabledFor(logging.DEBUG):
136 _LOG.debug(
137 "Environment: '%s' Service: %s",
138 name,
139 self._service.pprint() if self._service else None,
140 )
142 if tunables is None:
143 _LOG.warning(
144 (
145 "No tunables provided for %s. "
146 "Tunable inheritance across composite environments may be broken."
147 ),
148 name,
149 )
150 tunables = TunableGroups()
152 groups = self._expand_groups(
153 config.get("tunable_params", []),
154 (global_config or {}).get("tunable_params_map", {}),
155 )
156 _LOG.debug("Tunable groups for: '%s' :: %s", name, groups)
158 self._tunable_params = tunables.subgroup(groups)
160 # If a parameter comes from the tunables, do not require it in the const_args or globals
161 req_args = set(config.get("required_args", [])) - set(
162 self._tunable_params.get_param_values().keys()
163 )
164 merge_parameters(dest=self._const_args, source=global_config, required_keys=req_args)
165 self._const_args = self._expand_vars(self._const_args, global_config or {})
167 self._params = self._combine_tunables(self._tunable_params)
168 _LOG.debug("Parameters for '%s' :: %s", name, self._params)
170 if _LOG.isEnabledFor(logging.DEBUG):
171 _LOG.debug("Config for: '%s'\n%s", name, json.dumps(self.config, indent=2))
173 def _validate_json_config(self, config: dict, name: str) -> None:
174 """Reconstructs a basic json config that this class might have been instantiated
175 from in order to validate configs provided outside the file loading
176 mechanism.
177 """
178 json_config: dict = {
179 "class": self.__class__.__module__ + "." + self.__class__.__name__,
180 }
181 if name:
182 json_config["name"] = name
183 if config:
184 json_config["config"] = config
185 ConfigSchema.ENVIRONMENT.validate(json_config)
187 @staticmethod
188 def _expand_groups(
189 groups: Iterable[str],
190 groups_exp: Dict[str, Union[str, Sequence[str]]],
191 ) -> List[str]:
192 """
193 Expand `$tunable_group` into actual names of the tunable groups.
195 Parameters
196 ----------
197 groups : List[str]
198 Names of the groups of tunables, maybe with `$` prefix (subject to expansion).
199 groups_exp : dict
200 A dictionary that maps dollar variables for tunable groups to the lists
201 of actual tunable groups IDs.
203 Returns
204 -------
205 groups : List[str]
206 A flat list of tunable groups IDs for the environment.
207 """
208 res: List[str] = []
209 for grp in groups:
210 if grp[:1] == "$":
211 tunable_group_name = grp[1:]
212 if tunable_group_name not in groups_exp:
213 raise KeyError(
214 (
215 f"Expected tunable group name ${tunable_group_name} "
216 "undefined in {groups_exp}"
217 )
218 )
219 add_groups = groups_exp[tunable_group_name]
220 res += [add_groups] if isinstance(add_groups, str) else add_groups
221 else:
222 res.append(grp)
223 return res
225 @staticmethod
226 def _expand_vars(
227 params: Dict[str, TunableValue],
228 global_config: Dict[str, TunableValue],
229 ) -> dict:
230 """Expand `$var` into actual values of the variables."""
231 return DictTemplater(params).expand_vars(extra_source_dict=global_config)
233 @property
234 def _config_loader_service(self) -> "SupportsConfigLoading":
235 assert self._service is not None
236 return self._service.config_loader_service
238 def __enter__(self) -> "Environment":
239 """Enter the environment's benchmarking context."""
240 _LOG.debug("Environment START :: %s", self)
241 assert not self._in_context
242 if self._service:
243 self._service_context = self._service.__enter__()
244 self._in_context = True
245 return self
247 def __exit__(
248 self,
249 ex_type: Optional[Type[BaseException]],
250 ex_val: Optional[BaseException],
251 ex_tb: Optional[TracebackType],
252 ) -> Literal[False]:
253 """Exit the context of the benchmarking environment."""
254 ex_throw = None
255 if ex_val is None:
256 _LOG.debug("Environment END :: %s", self)
257 else:
258 assert ex_type and ex_val
259 _LOG.warning("Environment END :: %s", self, exc_info=(ex_type, ex_val, ex_tb))
260 assert self._in_context
261 if self._service_context:
262 try:
263 self._service_context.__exit__(ex_type, ex_val, ex_tb)
264 # pylint: disable=broad-exception-caught
265 except Exception as ex:
266 _LOG.error("Exception while exiting Service context '%s': %s", self._service, ex)
267 ex_throw = ex
268 finally:
269 self._service_context = None
270 self._in_context = False
271 if ex_throw:
272 raise ex_throw
273 return False # Do not suppress exceptions
275 def __str__(self) -> str:
276 return self.name
278 def __repr__(self) -> str:
279 return f"{self.__class__.__name__} :: '{self.name}'"
281 def pprint(self, indent: int = 4, level: int = 0) -> str:
282 """
283 Pretty-print the environment configuration. For composite environments, print
284 all children environments as well.
286 Parameters
287 ----------
288 indent : int
289 Number of spaces to indent the output. Default is 4.
290 level : int
291 Current level of indentation. Default is 0.
293 Returns
294 -------
295 pretty : str
296 Pretty-printed environment configuration.
297 Default output is the same as `__repr__`.
298 """
299 return f'{" " * indent * level}{repr(self)}'
301 def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]:
302 """
303 Plug tunable values into the base config. If the tunable group is unknown,
304 ignore it (it might belong to another environment). This method should never
305 mutate the original config or the tunables.
307 Parameters
308 ----------
309 tunables : TunableGroups
310 A collection of groups of tunable parameters
311 along with the parameters' values.
313 Returns
314 -------
315 params : Dict[str, Union[int, float, str]]
316 Free-format dictionary that contains the new environment configuration.
317 """
318 return tunables.get_param_values(
319 group_names=list(self._tunable_params.get_covariant_group_names()),
320 into_params=self._const_args.copy(),
321 )
323 @property
324 def tunable_params(self) -> TunableGroups:
325 """
326 Get the configuration space of the given environment.
328 Returns
329 -------
330 tunables : TunableGroups
331 A collection of covariant groups of tunable parameters.
332 """
333 return self._tunable_params
335 @property
336 def parameters(self) -> Dict[str, TunableValue]:
337 """
338 Key/value pairs of all environment parameters (i.e., `const_args` and
339 `tunable_params`). Note that before `.setup()` is called, all tunables will be
340 set to None.
342 Returns
343 -------
344 parameters : Dict[str, TunableValue]
345 Key/value pairs of all environment parameters
346 (i.e., `const_args` and `tunable_params`).
347 """
348 return self._params
350 def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
351 """
352 Set up a new benchmark environment, if necessary. This method must be
353 idempotent, i.e., calling it several times in a row should be equivalent to a
354 single call.
356 Parameters
357 ----------
358 tunables : TunableGroups
359 A collection of tunable parameters along with their values.
360 global_config : dict
361 Free-format dictionary of global parameters of the environment
362 that are not used in the optimization process.
364 Returns
365 -------
366 is_success : bool
367 True if operation is successful, false otherwise.
368 """
369 _LOG.info("Setup %s :: %s", self, tunables)
370 assert isinstance(tunables, TunableGroups)
372 # Make sure we create a context before invoking setup/run/status/teardown
373 assert self._in_context
375 # Assign new values to the environment's tunable parameters:
376 groups = list(self._tunable_params.get_covariant_group_names())
377 self._tunable_params.assign(tunables.get_param_values(groups))
379 # Write to the log whether the environment needs to be reset.
380 # (Derived classes still have to check `self._tunable_params.is_updated()`).
381 is_updated = self._tunable_params.is_updated()
382 if _LOG.isEnabledFor(logging.DEBUG):
383 _LOG.debug(
384 "Env '%s': Tunable groups reset = %s :: %s",
385 self,
386 is_updated,
387 {
388 name: self._tunable_params.is_updated([name])
389 for name in self._tunable_params.get_covariant_group_names()
390 },
391 )
392 else:
393 _LOG.info("Env '%s': Tunable groups reset = %s", self, is_updated)
395 # Combine tunables, const_args, and global config into `self._params`:
396 self._params = self._combine_tunables(tunables)
397 merge_parameters(dest=self._params, source=global_config)
399 if _LOG.isEnabledFor(logging.DEBUG):
400 _LOG.debug("Combined parameters:\n%s", json.dumps(self._params, indent=2))
402 return True
404 def teardown(self) -> None:
405 """
406 Tear down the benchmark environment.
408 This method must be idempotent, i.e., calling it several times in a row should
409 be equivalent to a single call.
410 """
411 _LOG.info("Teardown %s", self)
412 # Make sure we create a context before invoking setup/run/status/teardown
413 assert self._in_context
414 self._is_ready = False
416 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]:
417 """
418 Execute the run script for this environment.
420 For instance, this may start a new experiment, download results, reconfigure
421 the environment, etc. Details are configurable via the environment config.
423 Returns
424 -------
425 (status, timestamp, output) : (Status, datetime.datetime, dict)
426 3-tuple of (Status, timestamp, output) values, where `output` is a dict
427 with the results or None if the status is not COMPLETED.
428 If run script is a benchmark, then the score is usually expected to
429 be in the `score` field.
430 """
431 # Make sure we create a context before invoking setup/run/status/teardown
432 assert self._in_context
433 (status, timestamp, _) = self.status()
434 return (status, timestamp, None)
436 def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]:
437 """
438 Check the status of the benchmark environment.
440 Returns
441 -------
442 (benchmark_status, timestamp, telemetry) : (Status, datetime.datetime, list)
443 3-tuple of (benchmark status, timestamp, telemetry) values.
444 `timestamp` is UTC time stamp of the status; it's current time by default.
445 `telemetry` is a list (maybe empty) of (timestamp, metric, value) triplets.
446 """
447 # Make sure we create a context before invoking setup/run/status/teardown
448 assert self._in_context
449 timestamp = datetime.now(UTC)
450 if self._is_ready:
451 return (Status.READY, timestamp, [])
452 _LOG.warning("Environment not ready: %s", self)
453 return (Status.PENDING, timestamp, [])