Coverage for mlos_bench/mlos_bench/tunables/tunable_groups.py: 99%
89 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6TunableGroups definition.
7"""
8import copy
10from typing import Dict, Generator, Iterable, Mapping, Optional, Tuple, Union
12from mlos_bench.config.schemas import ConfigSchema
13from mlos_bench.tunables.tunable import Tunable, TunableValue
14from mlos_bench.tunables.covariant_group import CovariantTunableGroup
17class TunableGroups:
18 """
19 A collection of covariant groups of tunable parameters.
20 """
22 def __init__(self, config: Optional[dict] = None):
23 """
24 Create a new group of tunable parameters.
26 Parameters
27 ----------
28 config : dict
29 Python dict of serialized representation of the covariant tunable groups.
30 """
31 if config is None:
32 config = {}
33 ConfigSchema.TUNABLE_PARAMS.validate(config)
34 self._index: Dict[str, CovariantTunableGroup] = {} # Index (Tunable id -> CovariantTunableGroup)
35 self._tunable_groups: Dict[str, CovariantTunableGroup] = {}
36 for (name, group_config) in config.items():
37 self._add_group(CovariantTunableGroup(name, group_config))
39 def __bool__(self) -> bool:
40 return bool(self._index)
42 def __len__(self) -> int:
43 return len(self._index)
45 def __eq__(self, other: object) -> bool:
46 """
47 Check if two TunableGroups are equal.
49 Parameters
50 ----------
51 other : TunableGroups
52 A tunable groups object to compare to.
54 Returns
55 -------
56 is_equal : bool
57 True if two TunableGroups are equal.
58 """
59 if not isinstance(other, TunableGroups):
60 return False
61 return bool(self._tunable_groups == other._tunable_groups)
63 def copy(self) -> "TunableGroups":
64 """
65 Deep copy of the TunableGroups object.
67 Returns
68 -------
69 tunables : TunableGroups
70 A new instance of the TunableGroups object
71 that is a deep copy of the original one.
72 """
73 return copy.deepcopy(self)
75 def _add_group(self, group: CovariantTunableGroup) -> None:
76 """
77 Add a CovariantTunableGroup to the current collection.
79 Note: non-overlapping groups are expected to be added to the collection.
81 Parameters
82 ----------
83 group : CovariantTunableGroup
84 """
85 assert group.name not in self._tunable_groups, f"Duplicate covariant tunable group name {group.name} in {self}"
86 self._tunable_groups[group.name] = group
87 for tunable in group.get_tunables():
88 if tunable.name in self._index:
89 raise ValueError(f"Duplicate Tunable {tunable.name} from group {group.name} in {self}")
90 self._index[tunable.name] = group
92 def merge(self, tunables: "TunableGroups") -> "TunableGroups":
93 """
94 Merge the two collections of covariant tunable groups.
96 Unlike the dict `update` method, this method does not modify the
97 original when overlapping keys are found.
98 It is expected be used to merge the tunable groups referenced by a
99 standalone Environment config into a parent CompositeEnvironment,
100 for instance.
101 This allows self contained, potentially overlapping, but also
102 overridable configs to be composed together.
104 Parameters
105 ----------
106 tunables : TunableGroups
107 A collection of covariant tunable groups.
109 Returns
110 -------
111 self : TunableGroups
112 Self-reference for chaining.
113 """
114 # pylint: disable=protected-access
115 # Check that covariant groups are unique, else throw an error.
116 for group in tunables._tunable_groups.values():
117 if group.name not in self._tunable_groups:
118 self._add_group(group)
119 else:
120 # Check that there's no overlap in the tunables.
121 # But allow for differing current values.
122 if not self._tunable_groups[group.name].equals_defaults(group):
123 raise ValueError(f"Overlapping covariant tunable group name {group.name} " +
124 "in {self._tunable_groups[group.name]} and {tunables}")
125 return self
127 def __repr__(self) -> str:
128 """
129 Produce a human-readable version of the TunableGroups (mostly for logging).
131 Returns
132 -------
133 string : str
134 A human-readable version of the TunableGroups.
135 """
136 return "{ " + ", ".join(
137 f"{group.name}::{tunable}"
138 for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name))
139 for tunable in sorted(group._tunables.values())) + " }"
141 def __contains__(self, tunable: Union[str, Tunable]) -> bool:
142 """
143 Checks if the given name/tunable is in this tunable group.
144 """
145 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
146 return name in self._index
148 def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue:
149 """
150 Get the current value of a single tunable parameter.
151 """
152 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
153 return self._index[name][name]
155 def __setitem__(self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable]) -> TunableValue:
156 """
157 Update the current value of a single tunable parameter.
158 """
159 # Use double index to make sure we set the is_updated flag of the group
160 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
161 value: TunableValue = tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value
162 self._index[name][name] = value
163 return self._index[name][name]
165 def __iter__(self) -> Generator[Tuple[Tunable, CovariantTunableGroup], None, None]:
166 """
167 An iterator over all tunables in the group.
169 Returns
170 -------
171 [(tunable, group), ...] : iter(Tunable, CovariantTunableGroup)
172 An iterator over all tunables in all groups. Each element is a 2-tuple
173 of an instance of the Tunable parameter and covariant group it belongs to.
174 """
175 return ((group.get_tunable(name), group) for (name, group) in self._index.items())
177 def get_tunable(self, tunable: Union[str, Tunable]) -> Tuple[Tunable, CovariantTunableGroup]:
178 """
179 Access the entire Tunable (not just its value) and its covariant group.
180 Throw KeyError if the tunable is not found.
182 Parameters
183 ----------
184 tunable : Union[str, Tunable]
185 Name of the tunable parameter.
187 Returns
188 -------
189 (tunable, group) : (Tunable, CovariantTunableGroup)
190 A 2-tuple of an instance of the Tunable parameter and covariant group it belongs to.
191 """
192 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
193 group = self._index[name]
194 return (group.get_tunable(name), group)
196 def get_covariant_group_names(self) -> Iterable[str]:
197 """
198 Get the names of all covariance groups in the collection.
200 Returns
201 -------
202 group_names : [str]
203 IDs of the covariant tunable groups.
204 """
205 return self._tunable_groups.keys()
207 def subgroup(self, group_names: Iterable[str]) -> "TunableGroups":
208 """
209 Select the covariance groups from the current set and create a new
210 TunableGroups object that consists of those covariance groups.
212 Note: The new TunableGroup will include *references* (not copies) to
213 original ones, so each will get updated together.
214 This is often desirable to support the use case of multiple related
215 Environments (e.g. Local vs Remote) using the same set of tunables
216 within a CompositeEnvironment.
218 Parameters
219 ----------
220 group_names : list of str
221 IDs of the covariant tunable groups.
223 Returns
224 -------
225 tunables : TunableGroups
226 A collection of covariant tunable groups.
227 """
228 # pylint: disable=protected-access
229 tunables = TunableGroups()
230 for name in group_names:
231 if name not in self._tunable_groups:
232 raise KeyError(f"Unknown covariant group name '{name}' in tunable group {self}")
233 tunables._add_group(self._tunable_groups[name])
234 return tunables
236 def get_param_values(self, group_names: Optional[Iterable[str]] = None,
237 into_params: Optional[Dict[str, TunableValue]] = None) -> Dict[str, TunableValue]:
238 """
239 Get the current values of the tunables that belong to the specified covariance groups.
241 Parameters
242 ----------
243 group_names : list of str or None
244 IDs of the covariant tunable groups.
245 Select parameters from all groups if omitted.
246 into_params : dict
247 An optional dict to copy the parameters and their values into.
249 Returns
250 -------
251 into_params : dict
252 Flat dict of all parameters and their values from given covariance groups.
253 """
254 if group_names is None:
255 group_names = self.get_covariant_group_names()
256 if into_params is None:
257 into_params = {}
258 for name in group_names:
259 into_params.update(self._tunable_groups[name].get_tunable_values_dict())
260 return into_params
262 def is_updated(self, group_names: Optional[Iterable[str]] = None) -> bool:
263 """
264 Check if any of the given covariant tunable groups has been updated.
266 Parameters
267 ----------
268 group_names : list of str or None
269 IDs of the (covariant) tunable groups. Check all groups if omitted.
271 Returns
272 -------
273 is_updated : bool
274 True if any of the specified tunable groups has been updated, False otherwise.
275 """
276 return any(self._tunable_groups[name].is_updated()
277 for name in (group_names or self.get_covariant_group_names()))
279 def is_defaults(self) -> bool:
280 """
281 Checks whether the currently assigned values of all tunables are at their defaults.
283 Returns
284 -------
285 bool
286 """
287 return all(group.is_defaults() for group in self._tunable_groups.values())
289 def restore_defaults(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups":
290 """
291 Restore all tunable parameters to their default values.
293 Parameters
294 ----------
295 group_names : list of str or None
296 IDs of the (covariant) tunable groups. Restore all groups if omitted.
298 Returns
299 -------
300 self : TunableGroups
301 Self-reference for chaining.
302 """
303 for name in (group_names or self.get_covariant_group_names()):
304 self._tunable_groups[name].restore_defaults()
305 return self
307 def reset(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups":
308 """
309 Clear the update flag of given covariant groups.
311 Parameters
312 ----------
313 group_names : list of str or None
314 IDs of the (covariant) tunable groups. Reset all groups if omitted.
316 Returns
317 -------
318 self : TunableGroups
319 Self-reference for chaining.
320 """
321 for name in (group_names or self.get_covariant_group_names()):
322 self._tunable_groups[name].reset_is_updated()
323 return self
325 def assign(self, param_values: Mapping[str, TunableValue]) -> "TunableGroups":
326 """
327 In-place update the values of the tunables from the dictionary
328 of (key, value) pairs.
330 Parameters
331 ----------
332 param_values : Mapping[str, TunableValue]
333 Dictionary mapping Tunable parameter names to new values.
335 Returns
336 -------
337 self : TunableGroups
338 Self-reference for chaining.
339 """
340 for key, value in param_values.items():
341 self[key] = value
342 return self