Coverage for mlos_bench/mlos_bench/tunables/tunable_groups.py: 99%
94 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 00:52 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 00:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6TunableGroups definition.
8A collection of :py:class:`.CovariantTunableGroup` s of :py:class:`.Tunable`
9parameters.
11Used to define the configuration space for an
12:py:class:`~mlos_bench.environments.base_environment.Environment` for an
13:py:class:`~mlos_bench.optimizers.base_optimizer.Optimizer` to explore.
15Config
16++++++
18The configuration of the tunable parameters is generally given via a JSON config file.
19The syntax looks something like this:
21.. code-block:: json
23 { // starts a TunableGroups config (e.g., for one Environment)
24 "group1": { // starts a CovariantTunableGroup config
25 "cost": 7,
26 "params": {
27 "param1": { // starts a Tunable config, named "param1"
28 "type": "int",
29 "range": [0, 100],
30 "default": 50
31 },
32 "param2": { // starts a new Tunable config, named "param2", within that same group
33 "type": "float",
34 "range": [0.0, 100.0],
35 "default": 50.0
36 },
37 "param3": {
38 "type": "categorical",
39 "values": ["on", "off", "auto"],
40 "default": "auto"
41 }
42 }
43 },
44 "group2": { // starts a new CovariantTunableGroup config
45 "cost": 7,
46 "params": {
47 "some_param1": {
48 "type": "int",
49 "range": [0, 10],
50 "default": 5
51 },
52 "some_param2": {
53 "type": "float",
54 "range": [0.0, 100.0],
55 "default": 50.0
56 },
57 "some_param3": {
58 "type": "categorical",
59 "values": ["red", "green", "blue"],
60 "default": "green"
61 }
62 }
63 }
64 }
66The JSON config is expected to be a dictionary of covariant tunable groups.
68Each covariant group has a name and a cost associated with changing any/all of the
69parameters in that covariant group.
71Each group has a dictionary of :py:class:`.Tunable` parameters, where the key is
72the name of the parameter and the value is a dictionary of the parameter's
73configuration (see the :py:class:`.Tunable` class for more information on the
74different ways they can be configured).
76Generally tunables are associated with an
77:py:class:`~mlos_bench.environments.base_environment.Environment` and included along
78with the Environment's config directory (e.g., ``env-name-tunables.mlos.jsonc``) and
79referenced in the Environment config using the ``include_tunables`` property.
81See Also
82--------
83:py:mod:`mlos_bench.tunables` :
84 For more information on tunable parameters and their configuration.
85:py:mod:`mlos_bench.tunables.tunable` :
86 Tunable parameter definition.
87:py:mod:`mlos_bench.config` :
88 Configuration system for mlos_bench.
89:py:mod:`mlos_bench.environments` :
90 Environment configuration and setup.
91"""
93import copy
94import logging
95from collections.abc import Generator, Iterable, Mapping
97from mlos_bench.config.schemas import ConfigSchema
98from mlos_bench.tunables.covariant_group import CovariantTunableGroup
99from mlos_bench.tunables.tunable import Tunable
100from mlos_bench.tunables.tunable_types import TunableValue
102_LOG = logging.getLogger(__name__)
105class TunableGroups:
106 """A collection of :py:class:`.CovariantTunableGroup` s of :py:class:`.Tunable`
107 parameters.
108 """
110 def __init__(self, config: dict | None = None):
111 """
112 Create a new group of tunable parameters.
114 Parameters
115 ----------
116 config : dict
117 Python dict of serialized representation of the covariant tunable groups.
119 See Also
120 --------
121 :py:mod:`mlos_bench.tunables` :
122 For more information on tunable parameters and their configuration.
123 """
124 if config is None:
125 config = {}
126 ConfigSchema.TUNABLE_PARAMS.validate(config)
127 # Index (Tunable id -> CovariantTunableGroup)
128 self._index: dict[str, CovariantTunableGroup] = {}
129 self._tunable_groups: dict[str, CovariantTunableGroup] = {}
130 for name, group_config in config.items():
131 self._add_group(CovariantTunableGroup(name, group_config))
133 def __bool__(self) -> bool:
134 return bool(self._index)
136 def __len__(self) -> int:
137 return len(self._index)
139 def __eq__(self, other: object) -> bool:
140 """
141 Check if two TunableGroups are equal.
143 Parameters
144 ----------
145 other : TunableGroups
146 A tunable groups object to compare to.
148 Returns
149 -------
150 is_equal : bool
151 True if two TunableGroups are equal.
152 """
153 if not isinstance(other, TunableGroups):
154 return False
155 return bool(self._tunable_groups == other._tunable_groups)
157 def copy(self) -> "TunableGroups":
158 """
159 Deep copy of the TunableGroups object.
161 Returns
162 -------
163 tunables : TunableGroups
164 A new instance of the TunableGroups object
165 that is a deep copy of the original one.
166 """
167 return copy.deepcopy(self)
169 def _add_group(self, group: CovariantTunableGroup) -> None:
170 """
171 Add a CovariantTunableGroup to the current collection.
173 Note: non-overlapping groups are expected to be added to the collection.
175 Parameters
176 ----------
177 group : CovariantTunableGroup
178 """
179 assert (
180 group.name not in self._tunable_groups
181 ), f"Duplicate covariant tunable group name {group.name} in {self}"
182 self._tunable_groups[group.name] = group
183 for tunable in group.get_tunables():
184 if tunable.name in self._index:
185 raise ValueError(
186 f"Duplicate Tunable {tunable.name} from group {group.name} in {self}"
187 )
188 self._index[tunable.name] = group
190 def merge(self, tunables: "TunableGroups") -> "TunableGroups":
191 """
192 Merge the two collections of covariant tunable groups.
194 Unlike the dict `update` method, this method does not modify the
195 original when overlapping keys are found.
196 It is expected be used to merge the tunable groups referenced by a
197 standalone Environment config into a parent CompositeEnvironment,
198 for instance.
199 This allows self contained, potentially overlapping, but also
200 overridable configs to be composed together.
202 Parameters
203 ----------
204 tunables : TunableGroups
205 A collection of covariant tunable groups.
207 Returns
208 -------
209 self : TunableGroups
210 Self-reference for chaining.
211 """
212 # pylint: disable=protected-access
213 # Check that covariant groups are unique, else throw an error.
214 for group in tunables._tunable_groups.values():
215 if group.name not in self._tunable_groups:
216 self._add_group(group)
217 else:
218 # Check that there's no overlap in the tunables.
219 # But allow for differing current values.
220 if not self._tunable_groups[group.name].equals_defaults(group):
221 raise ValueError(
222 f"Overlapping covariant tunable group name {group.name} "
223 "in {self._tunable_groups[group.name]} and {tunables}"
224 )
225 return self
227 def __repr__(self) -> str:
228 """
229 Produce a human-readable version of the TunableGroups (mostly for logging).
231 Returns
232 -------
233 string : str
234 A human-readable version of the TunableGroups.
235 """
236 return (
237 "{ "
238 + ", ".join(
239 f"{group.name}::{tunable}"
240 for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name))
241 for tunable in sorted(group._tunables.values())
242 )
243 + " }"
244 )
246 def __contains__(self, tunable: str | Tunable) -> bool:
247 """Checks if the given name/tunable is in this tunable group."""
248 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
249 return name in self._index
251 def __getitem__(self, tunable: str | Tunable) -> TunableValue:
252 """Get the current value of a single tunable parameter."""
253 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
254 return self._index[name][name]
256 def __setitem__(
257 self,
258 tunable: str | Tunable,
259 tunable_value: TunableValue | Tunable,
260 ) -> TunableValue:
261 """Update the current value of a single tunable parameter."""
262 # Use double index to make sure we set the is_updated flag of the group
263 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
264 value: TunableValue = (
265 tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value
266 )
267 self._index[name][name] = value
268 return self._index[name][name]
270 def __iter__(self) -> Generator[tuple[Tunable, CovariantTunableGroup]]:
271 """
272 An iterator over all tunables in the group.
274 Returns
275 -------
276 [(tunable, group), ...] : Generator[tuple[Tunable, CovariantTunableGroup]]
277 An iterator over all tunables in all groups. Each element is a 2-tuple
278 of an instance of the Tunable parameter and covariant group it belongs to.
279 """
280 return ((group.get_tunable(name), group) for (name, group) in self._index.items())
282 def get_tunable(self, tunable: str | Tunable) -> tuple[Tunable, CovariantTunableGroup]:
283 """
284 Access the entire Tunable (not just its value) and its covariant group. Throw
285 KeyError if the tunable is not found.
287 Parameters
288 ----------
289 tunable : Union[str, Tunable]
290 Name of the tunable parameter.
292 Returns
293 -------
294 (tunable, group) : (Tunable, CovariantTunableGroup)
295 A 2-tuple of an instance of the Tunable parameter and covariant group it belongs to.
296 """
297 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
298 group = self._index[name]
299 return (group.get_tunable(name), group)
301 def get_covariant_group_names(self) -> Iterable[str]:
302 """
303 Get the names of all covariance groups in the collection.
305 Returns
306 -------
307 group_names : [str]
308 IDs of the covariant tunable groups.
309 """
310 return self._tunable_groups.keys()
312 def subgroup(self, group_names: Iterable[str]) -> "TunableGroups":
313 """
314 Select the covariance groups from the current set and create a new TunableGroups
315 object that consists of those covariance groups.
317 Note: The new TunableGroup will include *references* (not copies) to
318 original ones, so each will get updated together.
319 This is often desirable to support the use case of multiple related
320 Environments (e.g. Local vs Remote) using the same set of tunables
321 within a CompositeEnvironment.
323 Parameters
324 ----------
325 group_names : list of str
326 IDs of the covariant tunable groups.
328 Returns
329 -------
330 tunables : TunableGroups
331 A collection of covariant tunable groups.
332 """
333 # pylint: disable=protected-access
334 tunables = TunableGroups()
335 for name in group_names:
336 if name not in self._tunable_groups:
337 raise KeyError(f"Unknown covariant group name '{name}' in tunable group {self}")
338 tunables._add_group(self._tunable_groups[name])
339 return tunables
341 def get_param_values(
342 self,
343 group_names: Iterable[str] | None = None,
344 into_params: dict[str, TunableValue] | None = None,
345 ) -> dict[str, TunableValue]:
346 """
347 Get the current values of the tunables that belong to the specified covariance
348 groups.
350 Parameters
351 ----------
352 group_names : list of str or None
353 IDs of the covariant tunable groups.
354 Select parameters from all groups if omitted.
355 into_params : dict
356 An optional dict to copy the parameters and their values into.
358 Returns
359 -------
360 into_params : dict
361 Flat dict of all parameters and their values from given covariance groups.
362 """
363 if group_names is None:
364 group_names = self.get_covariant_group_names()
365 if into_params is None:
366 into_params = {}
367 for name in group_names:
368 into_params.update(self._tunable_groups[name].get_tunable_values_dict())
369 return into_params
371 def is_updated(self, group_names: Iterable[str] | None = None) -> bool:
372 """
373 Check if any of the given covariant tunable groups has been updated.
375 Parameters
376 ----------
377 group_names : list of str or None
378 IDs of the (covariant) tunable groups. Check all groups if omitted.
380 Returns
381 -------
382 is_updated : bool
383 True if any of the specified tunable groups has been updated, False otherwise.
384 """
385 return any(
386 self._tunable_groups[name].is_updated()
387 for name in (group_names or self.get_covariant_group_names())
388 )
390 def is_defaults(self) -> bool:
391 """
392 Checks whether the currently assigned values of all tunables are at their
393 defaults.
395 Returns
396 -------
397 bool
398 """
399 return all(group.is_defaults() for group in self._tunable_groups.values())
401 def restore_defaults(self, group_names: Iterable[str] | None = None) -> "TunableGroups":
402 """
403 Restore all tunable parameters to their default values.
405 Parameters
406 ----------
407 group_names : list of str or None
408 IDs of the (covariant) tunable groups. Restore all groups if omitted.
410 Returns
411 -------
412 self : TunableGroups
413 Self-reference for chaining.
414 """
415 for name in group_names or self.get_covariant_group_names():
416 self._tunable_groups[name].restore_defaults()
417 return self
419 def reset(self, group_names: Iterable[str] | None = None) -> "TunableGroups":
420 """
421 Clear the update flag of given covariant groups.
423 Parameters
424 ----------
425 group_names : list of str or None
426 IDs of the (covariant) tunable groups. Reset all groups if omitted.
428 Returns
429 -------
430 self : TunableGroups
431 Self-reference for chaining.
432 """
433 for name in group_names or self.get_covariant_group_names():
434 self._tunable_groups[name].reset_is_updated()
435 return self
437 def assign(self, param_values: Mapping[str, TunableValue]) -> "TunableGroups":
438 """
439 In-place update the values of the tunables from the dictionary of (key, value)
440 pairs.
442 Parameters
443 ----------
444 param_values : Mapping[str, TunableValue]
445 Dictionary mapping Tunable parameter names to new values.
447 As a special behavior when the mapping is empty (``{}``) the method will
448 restore the default values rather than no-op.
449 This allows an empty dictionary in json configs to be used to reset the
450 tunables to defaults without having to copy the original values from the
451 tunable_params definition.
453 Returns
454 -------
455 self : TunableGroups
456 Self-reference for chaining.
457 """
458 if not param_values:
459 _LOG.info("Empty tunable values set provided. Resetting all tunables to defaults.")
460 return self.restore_defaults()
461 for key, value in param_values.items():
462 self[key] = value
463 return self