Coverage for mlos_bench/mlos_bench/tunables/tunable_groups.py: 99%
93 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""TunableGroups definition."""
6import copy
7import logging
8from typing import Dict, Generator, Iterable, Mapping, Optional, Tuple, Union
10from mlos_bench.config.schemas import ConfigSchema
11from mlos_bench.tunables.covariant_group import CovariantTunableGroup
12from mlos_bench.tunables.tunable import Tunable, TunableValue
14_LOG = logging.getLogger(__name__)
17class TunableGroups:
18 """A collection of covariant groups of tunable parameters."""
20 def __init__(self, config: Optional[dict] = None):
21 """
22 Create a new group of tunable parameters.
24 Parameters
25 ----------
26 config : dict
27 Python dict of serialized representation of the covariant tunable groups.
28 """
29 if config is None:
30 config = {}
31 ConfigSchema.TUNABLE_PARAMS.validate(config)
32 # Index (Tunable id -> CovariantTunableGroup)
33 self._index: Dict[str, CovariantTunableGroup] = {}
34 self._tunable_groups: Dict[str, CovariantTunableGroup] = {}
35 for name, group_config in config.items():
36 self._add_group(CovariantTunableGroup(name, group_config))
38 def __bool__(self) -> bool:
39 return bool(self._index)
41 def __len__(self) -> int:
42 return len(self._index)
44 def __eq__(self, other: object) -> bool:
45 """
46 Check if two TunableGroups are equal.
48 Parameters
49 ----------
50 other : TunableGroups
51 A tunable groups object to compare to.
53 Returns
54 -------
55 is_equal : bool
56 True if two TunableGroups are equal.
57 """
58 if not isinstance(other, TunableGroups):
59 return False
60 return bool(self._tunable_groups == other._tunable_groups)
62 def copy(self) -> "TunableGroups":
63 """
64 Deep copy of the TunableGroups object.
66 Returns
67 -------
68 tunables : TunableGroups
69 A new instance of the TunableGroups object
70 that is a deep copy of the original one.
71 """
72 return copy.deepcopy(self)
74 def _add_group(self, group: CovariantTunableGroup) -> None:
75 """
76 Add a CovariantTunableGroup to the current collection.
78 Note: non-overlapping groups are expected to be added to the collection.
80 Parameters
81 ----------
82 group : CovariantTunableGroup
83 """
84 assert (
85 group.name not in self._tunable_groups
86 ), f"Duplicate covariant tunable group name {group.name} in {self}"
87 self._tunable_groups[group.name] = group
88 for tunable in group.get_tunables():
89 if tunable.name in self._index:
90 raise ValueError(
91 f"Duplicate Tunable {tunable.name} from group {group.name} in {self}"
92 )
93 self._index[tunable.name] = group
95 def merge(self, tunables: "TunableGroups") -> "TunableGroups":
96 """
97 Merge the two collections of covariant tunable groups.
99 Unlike the dict `update` method, this method does not modify the
100 original when overlapping keys are found.
101 It is expected be used to merge the tunable groups referenced by a
102 standalone Environment config into a parent CompositeEnvironment,
103 for instance.
104 This allows self contained, potentially overlapping, but also
105 overridable configs to be composed together.
107 Parameters
108 ----------
109 tunables : TunableGroups
110 A collection of covariant tunable groups.
112 Returns
113 -------
114 self : TunableGroups
115 Self-reference for chaining.
116 """
117 # pylint: disable=protected-access
118 # Check that covariant groups are unique, else throw an error.
119 for group in tunables._tunable_groups.values():
120 if group.name not in self._tunable_groups:
121 self._add_group(group)
122 else:
123 # Check that there's no overlap in the tunables.
124 # But allow for differing current values.
125 if not self._tunable_groups[group.name].equals_defaults(group):
126 raise ValueError(
127 f"Overlapping covariant tunable group name {group.name} "
128 "in {self._tunable_groups[group.name]} and {tunables}"
129 )
130 return self
132 def __repr__(self) -> str:
133 """
134 Produce a human-readable version of the TunableGroups (mostly for logging).
136 Returns
137 -------
138 string : str
139 A human-readable version of the TunableGroups.
140 """
141 return (
142 "{ "
143 + ", ".join(
144 f"{group.name}::{tunable}"
145 for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name))
146 for tunable in sorted(group._tunables.values())
147 )
148 + " }"
149 )
151 def __contains__(self, tunable: Union[str, Tunable]) -> bool:
152 """Checks if the given name/tunable is in this tunable group."""
153 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
154 return name in self._index
156 def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue:
157 """Get the current value of a single tunable parameter."""
158 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
159 return self._index[name][name]
161 def __setitem__(
162 self,
163 tunable: Union[str, Tunable],
164 tunable_value: Union[TunableValue, Tunable],
165 ) -> TunableValue:
166 """Update the current value of a single tunable parameter."""
167 # Use double index to make sure we set the is_updated flag of the group
168 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
169 value: TunableValue = (
170 tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value
171 )
172 self._index[name][name] = value
173 return self._index[name][name]
175 def __iter__(self) -> Generator[Tuple[Tunable, CovariantTunableGroup], None, None]:
176 """
177 An iterator over all tunables in the group.
179 Returns
180 -------
181 [(tunable, group), ...] : Generator[Tuple[Tunable, CovariantTunableGroup], None, None]
182 An iterator over all tunables in all groups. Each element is a 2-tuple
183 of an instance of the Tunable parameter and covariant group it belongs to.
184 """
185 return ((group.get_tunable(name), group) for (name, group) in self._index.items())
187 def get_tunable(self, tunable: Union[str, Tunable]) -> Tuple[Tunable, CovariantTunableGroup]:
188 """
189 Access the entire Tunable (not just its value) and its covariant group. Throw
190 KeyError if the tunable is not found.
192 Parameters
193 ----------
194 tunable : Union[str, Tunable]
195 Name of the tunable parameter.
197 Returns
198 -------
199 (tunable, group) : (Tunable, CovariantTunableGroup)
200 A 2-tuple of an instance of the Tunable parameter and covariant group it belongs to.
201 """
202 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
203 group = self._index[name]
204 return (group.get_tunable(name), group)
206 def get_covariant_group_names(self) -> Iterable[str]:
207 """
208 Get the names of all covariance groups in the collection.
210 Returns
211 -------
212 group_names : [str]
213 IDs of the covariant tunable groups.
214 """
215 return self._tunable_groups.keys()
217 def subgroup(self, group_names: Iterable[str]) -> "TunableGroups":
218 """
219 Select the covariance groups from the current set and create a new TunableGroups
220 object that consists of those covariance groups.
222 Note: The new TunableGroup will include *references* (not copies) to
223 original ones, so each will get updated together.
224 This is often desirable to support the use case of multiple related
225 Environments (e.g. Local vs Remote) using the same set of tunables
226 within a CompositeEnvironment.
228 Parameters
229 ----------
230 group_names : list of str
231 IDs of the covariant tunable groups.
233 Returns
234 -------
235 tunables : TunableGroups
236 A collection of covariant tunable groups.
237 """
238 # pylint: disable=protected-access
239 tunables = TunableGroups()
240 for name in group_names:
241 if name not in self._tunable_groups:
242 raise KeyError(f"Unknown covariant group name '{name}' in tunable group {self}")
243 tunables._add_group(self._tunable_groups[name])
244 return tunables
246 def get_param_values(
247 self,
248 group_names: Optional[Iterable[str]] = None,
249 into_params: Optional[Dict[str, TunableValue]] = None,
250 ) -> Dict[str, TunableValue]:
251 """
252 Get the current values of the tunables that belong to the specified covariance
253 groups.
255 Parameters
256 ----------
257 group_names : list of str or None
258 IDs of the covariant tunable groups.
259 Select parameters from all groups if omitted.
260 into_params : dict
261 An optional dict to copy the parameters and their values into.
263 Returns
264 -------
265 into_params : dict
266 Flat dict of all parameters and their values from given covariance groups.
267 """
268 if group_names is None:
269 group_names = self.get_covariant_group_names()
270 if into_params is None:
271 into_params = {}
272 for name in group_names:
273 into_params.update(self._tunable_groups[name].get_tunable_values_dict())
274 return into_params
276 def is_updated(self, group_names: Optional[Iterable[str]] = None) -> bool:
277 """
278 Check if any of the given covariant tunable groups has been updated.
280 Parameters
281 ----------
282 group_names : list of str or None
283 IDs of the (covariant) tunable groups. Check all groups if omitted.
285 Returns
286 -------
287 is_updated : bool
288 True if any of the specified tunable groups has been updated, False otherwise.
289 """
290 return any(
291 self._tunable_groups[name].is_updated()
292 for name in (group_names or self.get_covariant_group_names())
293 )
295 def is_defaults(self) -> bool:
296 """
297 Checks whether the currently assigned values of all tunables are at their
298 defaults.
300 Returns
301 -------
302 bool
303 """
304 return all(group.is_defaults() for group in self._tunable_groups.values())
306 def restore_defaults(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups":
307 """
308 Restore all tunable parameters to their default values.
310 Parameters
311 ----------
312 group_names : list of str or None
313 IDs of the (covariant) tunable groups. Restore all groups if omitted.
315 Returns
316 -------
317 self : TunableGroups
318 Self-reference for chaining.
319 """
320 for name in group_names or self.get_covariant_group_names():
321 self._tunable_groups[name].restore_defaults()
322 return self
324 def reset(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups":
325 """
326 Clear the update flag of given covariant groups.
328 Parameters
329 ----------
330 group_names : list of str or None
331 IDs of the (covariant) tunable groups. Reset all groups if omitted.
333 Returns
334 -------
335 self : TunableGroups
336 Self-reference for chaining.
337 """
338 for name in group_names or self.get_covariant_group_names():
339 self._tunable_groups[name].reset_is_updated()
340 return self
342 def assign(self, param_values: Mapping[str, TunableValue]) -> "TunableGroups":
343 """
344 In-place update the values of the tunables from the dictionary of (key, value)
345 pairs.
347 Parameters
348 ----------
349 param_values : Mapping[str, TunableValue]
350 Dictionary mapping Tunable parameter names to new values.
352 As a special behavior when the mapping is empty the method will restore
353 the default values rather than no-op.
354 This allows an empty dictionary in json configs to be used to reset the
355 tunables to defaults without having to copy the original values from the
356 tunable_params definition.
358 Returns
359 -------
360 self : TunableGroups
361 Self-reference for chaining.
362 """
363 if not param_values:
364 _LOG.info("Empty tunable values set provided. Resetting all tunables to defaults.")
365 return self.restore_defaults()
366 for key, value in param_values.items():
367 self[key] = value
368 return self