Coverage for mlos_bench/mlos_bench/environments/base_environment.py: 93%
139 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6A hierarchy of benchmark environments.
7"""
9import abc
10import json
11import logging
12from datetime import datetime
13from types import TracebackType
14from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union
15from typing_extensions import Literal
17from pytz import UTC
19from mlos_bench.config.schemas import ConfigSchema
20from mlos_bench.dict_templater import DictTemplater
21from mlos_bench.environments.status import Status
22from mlos_bench.services.base_service import Service
23from mlos_bench.tunables.tunable import TunableValue
24from mlos_bench.tunables.tunable_groups import TunableGroups
25from mlos_bench.util import instantiate_from_config, merge_parameters
27if TYPE_CHECKING:
28 from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
30_LOG = logging.getLogger(__name__)
33class Environment(metaclass=abc.ABCMeta):
34 # pylint: disable=too-many-instance-attributes
35 """
36 An abstract base of all benchmark environments.
37 """
39 @classmethod
40 def new(cls,
41 *,
42 env_name: str,
43 class_name: str,
44 config: dict,
45 global_config: Optional[dict] = None,
46 tunables: Optional[TunableGroups] = None,
47 service: Optional[Service] = None,
48 ) -> "Environment":
49 """
50 Factory method for a new environment with a given config.
52 Parameters
53 ----------
54 env_name: str
55 Human-readable name of the environment.
56 class_name: str
57 FQN of a Python class to instantiate, e.g.,
58 "mlos_bench.environments.remote.HostEnv".
59 Must be derived from the `Environment` class.
60 config : dict
61 Free-format dictionary that contains the benchmark environment
62 configuration. It will be passed as a constructor parameter of
63 the class specified by `name`.
64 global_config : dict
65 Free-format dictionary of global parameters (e.g., security credentials)
66 to be mixed in into the "const_args" section of the local config.
67 tunables : TunableGroups
68 A collection of groups of tunable parameters for all environments.
69 service: Service
70 An optional service object (e.g., providing methods to
71 deploy or reboot a VM/Host, etc.).
73 Returns
74 -------
75 env : Environment
76 An instance of the `Environment` class initialized with `config`.
77 """
78 assert issubclass(cls, Environment)
79 return instantiate_from_config(
80 cls,
81 class_name,
82 name=env_name,
83 config=config,
84 global_config=global_config,
85 tunables=tunables,
86 service=service
87 )
89 def __init__(self,
90 *,
91 name: str,
92 config: dict,
93 global_config: Optional[dict] = None,
94 tunables: Optional[TunableGroups] = None,
95 service: Optional[Service] = None):
96 """
97 Create a new environment with a given config.
99 Parameters
100 ----------
101 name: str
102 Human-readable name of the environment.
103 config : dict
104 Free-format dictionary that contains the benchmark environment
105 configuration. Each config must have at least the "tunable_params"
106 and the "const_args" sections.
107 global_config : dict
108 Free-format dictionary of global parameters (e.g., security credentials)
109 to be mixed in into the "const_args" section of the local config.
110 tunables : TunableGroups
111 A collection of groups of tunable parameters for all environments.
112 service: Service
113 An optional service object (e.g., providing methods to
114 deploy or reboot a VM/Host, etc.).
115 """
116 self._validate_json_config(config, name)
117 self.name = name
118 self.config = config
119 self._service = service
120 self._service_context: Optional[Service] = None
121 self._is_ready = False
122 self._in_context = False
123 self._const_args: Dict[str, TunableValue] = config.get("const_args", {})
125 if _LOG.isEnabledFor(logging.DEBUG):
126 _LOG.debug("Environment: '%s' Service: %s", name,
127 self._service.pprint() if self._service else None)
129 if tunables is None:
130 _LOG.warning("No tunables provided for %s. Tunable inheritance across composite environments may be broken.", name)
131 tunables = TunableGroups()
133 groups = self._expand_groups(
134 config.get("tunable_params", []),
135 (global_config or {}).get("tunable_params_map", {}))
136 _LOG.debug("Tunable groups for: '%s' :: %s", name, groups)
138 self._tunable_params = tunables.subgroup(groups)
140 # If a parameter comes from the tunables, do not require it in the const_args or globals
141 req_args = (
142 set(config.get("required_args", [])) -
143 set(self._tunable_params.get_param_values().keys())
144 )
145 merge_parameters(dest=self._const_args, source=global_config, required_keys=req_args)
146 self._const_args = self._expand_vars(self._const_args, global_config or {})
148 self._params = self._combine_tunables(self._tunable_params)
149 _LOG.debug("Parameters for '%s' :: %s", name, self._params)
151 if _LOG.isEnabledFor(logging.DEBUG):
152 _LOG.debug("Config for: '%s'\n%s",
153 name, json.dumps(self.config, indent=2))
155 def _validate_json_config(self, config: dict, name: str) -> None:
156 """
157 Reconstructs a basic json config that this class might have been
158 instantiated from in order to validate configs provided outside the
159 file loading mechanism.
160 """
161 json_config: dict = {
162 "class": self.__class__.__module__ + "." + self.__class__.__name__,
163 }
164 if name:
165 json_config["name"] = name
166 if config:
167 json_config["config"] = config
168 ConfigSchema.ENVIRONMENT.validate(json_config)
170 @staticmethod
171 def _expand_groups(groups: Iterable[str],
172 groups_exp: Dict[str, Union[str, Sequence[str]]]) -> List[str]:
173 """
174 Expand `$tunable_group` into actual names of the tunable groups.
176 Parameters
177 ----------
178 groups : List[str]
179 Names of the groups of tunables, maybe with `$` prefix (subject to expansion).
180 groups_exp : dict
181 A dictionary that maps dollar variables for tunable groups to the lists
182 of actual tunable groups IDs.
184 Returns
185 -------
186 groups : List[str]
187 A flat list of tunable groups IDs for the environment.
188 """
189 res: List[str] = []
190 for grp in groups:
191 if grp[:1] == "$":
192 tunable_group_name = grp[1:]
193 if tunable_group_name not in groups_exp:
194 raise KeyError(f"Expected tunable group name ${tunable_group_name} undefined in {groups_exp}")
195 add_groups = groups_exp[tunable_group_name]
196 res += [add_groups] if isinstance(add_groups, str) else add_groups
197 else:
198 res.append(grp)
199 return res
201 @staticmethod
202 def _expand_vars(params: Dict[str, TunableValue], global_config: Dict[str, TunableValue]) -> dict:
203 """
204 Expand `$var` into actual values of the variables.
205 """
206 return DictTemplater(params).expand_vars(extra_source_dict=global_config)
208 @property
209 def _config_loader_service(self) -> "SupportsConfigLoading":
210 assert self._service is not None
211 return self._service.config_loader_service
213 def __enter__(self) -> 'Environment':
214 """
215 Enter the environment's benchmarking context.
216 """
217 _LOG.debug("Environment START :: %s", self)
218 assert not self._in_context
219 if self._service:
220 self._service_context = self._service.__enter__()
221 self._in_context = True
222 return self
224 def __exit__(self, ex_type: Optional[Type[BaseException]],
225 ex_val: Optional[BaseException],
226 ex_tb: Optional[TracebackType]) -> Literal[False]:
227 """
228 Exit the context of the benchmarking environment.
229 """
230 ex_throw = None
231 if ex_val is None:
232 _LOG.debug("Environment END :: %s", self)
233 else:
234 assert ex_type and ex_val
235 _LOG.warning("Environment END :: %s", self, exc_info=(ex_type, ex_val, ex_tb))
236 assert self._in_context
237 if self._service_context:
238 try:
239 self._service_context.__exit__(ex_type, ex_val, ex_tb)
240 # pylint: disable=broad-exception-caught
241 except Exception as ex:
242 _LOG.error("Exception while exiting Service context '%s': %s", self._service, ex)
243 ex_throw = ex
244 finally:
245 self._service_context = None
246 self._in_context = False
247 if ex_throw:
248 raise ex_throw
249 return False # Do not suppress exceptions
251 def __str__(self) -> str:
252 return self.name
254 def __repr__(self) -> str:
255 return f"{self.__class__.__name__} :: '{self.name}'"
257 def pprint(self, indent: int = 4, level: int = 0) -> str:
258 """
259 Pretty-print the environment configuration.
260 For composite environments, print all children environments as well.
262 Parameters
263 ----------
264 indent : int
265 Number of spaces to indent the output. Default is 4.
266 level : int
267 Current level of indentation. Default is 0.
269 Returns
270 -------
271 pretty : str
272 Pretty-printed environment configuration.
273 Default output is the same as `__repr__`.
274 """
275 return f'{" " * indent * level}{repr(self)}'
277 def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]:
278 """
279 Plug tunable values into the base config. If the tunable group is unknown,
280 ignore it (it might belong to another environment). This method should
281 never mutate the original config or the tunables.
283 Parameters
284 ----------
285 tunables : TunableGroups
286 A collection of groups of tunable parameters
287 along with the parameters' values.
289 Returns
290 -------
291 params : Dict[str, Union[int, float, str]]
292 Free-format dictionary that contains the new environment configuration.
293 """
294 return tunables.get_param_values(
295 group_names=list(self._tunable_params.get_covariant_group_names()),
296 into_params=self._const_args.copy())
298 @property
299 def tunable_params(self) -> TunableGroups:
300 """
301 Get the configuration space of the given environment.
303 Returns
304 -------
305 tunables : TunableGroups
306 A collection of covariant groups of tunable parameters.
307 """
308 return self._tunable_params
310 @property
311 def parameters(self) -> Dict[str, TunableValue]:
312 """
313 Key/value pairs of all environment parameters (i.e., `const_args` and `tunable_params`).
314 Note that before `.setup()` is called, all tunables will be set to None.
316 Returns
317 -------
318 parameters : Dict[str, TunableValue]
319 Key/value pairs of all environment parameters (i.e., `const_args` and `tunable_params`).
320 """
321 return self._params
323 def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
324 """
325 Set up a new benchmark environment, if necessary. This method must be
326 idempotent, i.e., calling it several times in a row should be
327 equivalent to a single call.
329 Parameters
330 ----------
331 tunables : TunableGroups
332 A collection of tunable parameters along with their values.
333 global_config : dict
334 Free-format dictionary of global parameters of the environment
335 that are not used in the optimization process.
337 Returns
338 -------
339 is_success : bool
340 True if operation is successful, false otherwise.
341 """
342 _LOG.info("Setup %s :: %s", self, tunables)
343 assert isinstance(tunables, TunableGroups)
345 # Make sure we create a context before invoking setup/run/status/teardown
346 assert self._in_context
348 # Assign new values to the environment's tunable parameters:
349 groups = list(self._tunable_params.get_covariant_group_names())
350 self._tunable_params.assign(tunables.get_param_values(groups))
352 # Write to the log whether the environment needs to be reset.
353 # (Derived classes still have to check `self._tunable_params.is_updated()`).
354 is_updated = self._tunable_params.is_updated()
355 if _LOG.isEnabledFor(logging.DEBUG):
356 _LOG.debug("Env '%s': Tunable groups reset = %s :: %s", self, is_updated, {
357 name: self._tunable_params.is_updated([name])
358 for name in self._tunable_params.get_covariant_group_names()
359 })
360 else:
361 _LOG.info("Env '%s': Tunable groups reset = %s", self, is_updated)
363 # Combine tunables, const_args, and global config into `self._params`:
364 self._params = self._combine_tunables(tunables)
365 merge_parameters(dest=self._params, source=global_config)
367 if _LOG.isEnabledFor(logging.DEBUG):
368 _LOG.debug("Combined parameters:\n%s", json.dumps(self._params, indent=2))
370 return True
372 def teardown(self) -> None:
373 """
374 Tear down the benchmark environment. This method must be idempotent,
375 i.e., calling it several times in a row should be equivalent to a
376 single call.
377 """
378 _LOG.info("Teardown %s", self)
379 # Make sure we create a context before invoking setup/run/status/teardown
380 assert self._in_context
381 self._is_ready = False
383 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]:
384 """
385 Execute the run script for this environment.
387 For instance, this may start a new experiment, download results, reconfigure
388 the environment, etc. Details are configurable via the environment config.
390 Returns
391 -------
392 (status, timestamp, output) : (Status, datetime, dict)
393 3-tuple of (Status, timestamp, output) values, where `output` is a dict
394 with the results or None if the status is not COMPLETED.
395 If run script is a benchmark, then the score is usually expected to
396 be in the `score` field.
397 """
398 # Make sure we create a context before invoking setup/run/status/teardown
399 assert self._in_context
400 (status, timestamp, _) = self.status()
401 return (status, timestamp, None)
403 def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]:
404 """
405 Check the status of the benchmark environment.
407 Returns
408 -------
409 (benchmark_status, timestamp, telemetry) : (Status, datetime, list)
410 3-tuple of (benchmark status, timestamp, telemetry) values.
411 `timestamp` is UTC time stamp of the status; it's current time by default.
412 `telemetry` is a list (maybe empty) of (timestamp, metric, value) triplets.
413 """
414 # Make sure we create a context before invoking setup/run/status/teardown
415 assert self._in_context
416 timestamp = datetime.now(UTC)
417 if self._is_ready:
418 return (Status.READY, timestamp, [])
419 _LOG.warning("Environment not ready: %s", self)
420 return (Status.PENDING, timestamp, [])