Coverage for mlos_bench/mlos_bench/environments/composite_env.py: 89%
97 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-05 00:36 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-05 00:36 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Composite benchmark environment.
7"""
9import logging
10from datetime import datetime
12from types import TracebackType
13from typing import Any, Dict, List, Optional, Tuple, Type
14from typing_extensions import Literal
16from mlos_bench.services.base_service import Service
17from mlos_bench.environments.status import Status
18from mlos_bench.environments.base_environment import Environment
19from mlos_bench.tunables.tunable import TunableValue
20from mlos_bench.tunables.tunable_groups import TunableGroups
22_LOG = logging.getLogger(__name__)
25class CompositeEnv(Environment):
26 """
27 Composite benchmark environment.
28 """
30 def __init__(self,
31 *,
32 name: str,
33 config: dict,
34 global_config: Optional[dict] = None,
35 tunables: Optional[TunableGroups] = None,
36 service: Optional[Service] = None):
37 """
38 Create a new environment with a given config.
40 Parameters
41 ----------
42 name: str
43 Human-readable name of the environment.
44 config : dict
45 Free-format dictionary that contains the environment
46 configuration. Must have a "children" section.
47 global_config : dict
48 Free-format dictionary of global parameters (e.g., security credentials)
49 to be mixed in into the "const_args" section of the local config.
50 tunables : TunableGroups
51 A collection of groups of tunable parameters for *all* environments.
52 service: Service
53 An optional service object (e.g., providing methods to
54 deploy or reboot a VM, etc.).
55 """
56 super().__init__(name=name, config=config, global_config=global_config,
57 tunables=tunables, service=service)
59 # By default, the Environment includes only the tunables explicitly specified
60 # in the "tunable_params" section of the config. `CompositeEnv`, however, must
61 # retain all tunables from its children environments plus the ones that come
62 # from the "include_tunables".
63 tunables = tunables.copy() if tunables else TunableGroups()
65 _LOG.debug("Build composite environment '%s' START: %s", self, tunables)
66 self._children: List[Environment] = []
67 self._child_contexts: List[Environment] = []
69 # To support trees of composite environments (e.g. for multiple VM experiments),
70 # each CompositeEnv gets a copy of the original global config and adjusts it with
71 # the `const_args` specific to it.
72 global_config = (global_config or {}).copy()
73 for (key, val) in self._const_args.items():
74 global_config.setdefault(key, val)
76 for child_config_file in config.get("include_children", []):
77 for env in self._config_loader_service.load_environment_list(
78 child_config_file, tunables, global_config, self._const_args, self._service):
79 self._add_child(env, tunables)
81 for child_config in config.get("children", []):
82 env = self._config_loader_service.build_environment(
83 child_config, tunables, global_config, self._const_args, self._service)
84 self._add_child(env, tunables)
86 _LOG.debug("Build composite environment '%s' END: %s", self, self._tunable_params)
88 if not self._children:
89 raise ValueError("At least one child environment must be present")
91 def __enter__(self) -> Environment:
92 self._child_contexts = [env.__enter__() for env in self._children]
93 return super().__enter__()
95 def __exit__(self, ex_type: Optional[Type[BaseException]],
96 ex_val: Optional[BaseException],
97 ex_tb: Optional[TracebackType]) -> Literal[False]:
98 ex_throw = None
99 for env in reversed(self._children):
100 try:
101 env.__exit__(ex_type, ex_val, ex_tb)
102 # pylint: disable=broad-exception-caught
103 except Exception as ex:
104 _LOG.error("Exception while exiting child environment '%s': %s", env, ex)
105 ex_throw = ex
106 self._child_contexts = []
107 super().__exit__(ex_type, ex_val, ex_tb)
108 if ex_throw:
109 raise ex_throw
110 return False
112 @property
113 def children(self) -> List[Environment]:
114 """
115 Return the list of child environments.
116 """
117 return self._children
119 def pprint(self, indent: int = 4, level: int = 0) -> str:
120 """
121 Pretty-print the environment and its children.
123 Parameters
124 ----------
125 indent : int
126 Number of spaces to indent the output at each level. Default is 4.
127 level : int
128 Current level of indentation. Default is 0.
130 Returns
131 -------
132 pretty : str
133 Pretty-printed environment configuration.
134 """
135 return super().pprint(indent, level) + '\n' + '\n'.join(
136 child.pprint(indent, level + 1) for child in self._children)
138 def _add_child(self, env: Environment, tunables: TunableGroups) -> None:
139 """
140 Add a new child environment to the composite environment.
141 This method is called from the constructor only.
142 """
143 _LOG.debug("Merge tunables: '%s' <- '%s' :: %s", self, env, env.tunable_params)
144 self._children.append(env)
145 self._tunable_params.merge(env.tunable_params)
146 tunables.merge(env.tunable_params)
148 def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
149 """
150 Set up the children environments.
152 Parameters
153 ----------
154 tunables : TunableGroups
155 A collection of tunable parameters along with their values.
156 global_config : dict
157 Free-format dictionary of global parameters of the environment
158 that are not used in the optimization process.
160 Returns
161 -------
162 is_success : bool
163 True if all children setup() operations are successful,
164 false otherwise.
165 """
166 assert self._in_context
167 self._is_ready = super().setup(tunables, global_config) and all(
168 env_context.setup(tunables, global_config) for env_context in self._child_contexts)
169 return self._is_ready
171 def teardown(self) -> None:
172 """
173 Tear down the children environments. This method is idempotent,
174 i.e., calling it several times is equivalent to a single call.
175 The environments are being torn down in the reverse order.
176 """
177 assert self._in_context
178 for env_context in reversed(self._child_contexts):
179 env_context.teardown()
180 super().teardown()
182 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]:
183 """
184 Submit a new experiment to the environment.
185 Return the result of the *last* child environment if successful,
186 or the status of the last failed environment otherwise.
188 Returns
189 -------
190 (status, timestamp, output) : (Status, datetime, dict)
191 3-tuple of (Status, timestamp, output) values, where `output` is a dict
192 with the results or None if the status is not COMPLETED.
193 If run script is a benchmark, then the score is usually expected to
194 be in the `score` field.
195 """
196 _LOG.info("Run: %s", self._children)
197 (status, timestamp, metrics) = super().run()
198 if not status.is_ready():
199 return (status, timestamp, metrics)
201 joint_metrics = {}
202 for env_context in self._child_contexts:
203 _LOG.debug("Child env. run: %s", env_context)
204 (status, timestamp, metrics) = env_context.run()
205 _LOG.debug("Child env. run results: %s :: %s %s", env_context, status, metrics)
206 if not status.is_good():
207 _LOG.info("Run failed: %s :: %s", self, status)
208 return (status, timestamp, None)
209 joint_metrics.update(metrics or {})
211 _LOG.info("Run completed: %s :: %s %s", self, status, joint_metrics)
212 # Return the status and the timestamp of the last child environment.
213 return (status, timestamp, joint_metrics)
215 def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]:
216 """
217 Check the status of the benchmark environment.
219 Returns
220 -------
221 (benchmark_status, timestamp, telemetry) : (Status, datetime, list)
222 3-tuple of (benchmark status, timestamp, telemetry) values.
223 `timestamp` is UTC time stamp of the status; it's current time by default.
224 `telemetry` is a list (maybe empty) of (timestamp, metric, value) triplets.
225 """
226 (status, timestamp, telemetry) = super().status()
227 if not status.is_ready():
228 return (status, timestamp, telemetry)
230 joint_telemetry = []
231 final_status = None
232 for env_context in self._child_contexts:
233 (status, timestamp, telemetry) = env_context.status()
234 _LOG.debug("Child env. status: %s :: %s", env_context, status)
235 joint_telemetry.extend(telemetry)
236 if not status.is_good() and final_status is None:
237 final_status = status
239 final_status = final_status or status
240 _LOG.info("Final status: %s :: %s", self, final_status)
241 # Return the status and the timestamp of the last child environment or the first failed child environment.
242 return (final_status, timestamp, joint_telemetry)