Coverage for mlos_bench/mlos_bench/tunables/covariant_group.py: 98%
61 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 00:52 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 00:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6CovariantTunableGroup class definition.
8A collection of :py:class:`.Tunable` parameters that are updated together (e.g.,
9with the same cost).
11See Also
12--------
13mlos_bench.tunables.tunable_groups : TunableGroups class definition.
14"""
16import copy
17from collections.abc import Iterable
19from mlos_bench.tunables.tunable import Tunable
20from mlos_bench.tunables.tunable_types import TunableValue
23class CovariantTunableGroup:
24 """
25 A collection of :py:class:`.Tunable` parameters.
27 Changing any of the parameters in the group incurs the same cost of the experiment.
29 See Also
30 --------
31 mlos_bench.tunables.tunable_groups : TunableGroups class definition.
32 """
34 def __init__(self, name: str, config: dict):
35 """
36 Create a new group of tunable parameters.
38 Parameters
39 ----------
40 name : str
41 Human-readable identifier of the tunable parameters group.
42 config : dict
43 Python dict that represents a CovariantTunableGroup
44 (e.g., deserialized from JSON).
45 """
46 self._is_updated = True
47 self._name = name
48 self._cost = int(config.get("cost", 0))
49 self._tunables: dict[str, Tunable] = {
50 name: Tunable(name, tunable_config)
51 for (name, tunable_config) in config.get("params", {}).items()
52 }
54 @property
55 def name(self) -> str:
56 """
57 Get the name of the covariant group.
59 Returns
60 -------
61 name : str
62 Name (i.e., a string id) of the covariant group.
63 """
64 return self._name
66 @property
67 def cost(self) -> int:
68 """
69 Get the cost of changing the values in the covariant group. This value is a
70 constant. Use `get_current_cost()` to get the cost given the group update
71 status.
73 Returns
74 -------
75 cost : int
76 Cost of changing the values in the covariant group.
77 """
78 return self._cost
80 def copy(self) -> "CovariantTunableGroup":
81 """
82 Deep copy of the CovariantTunableGroup object.
84 Returns
85 -------
86 group : CovariantTunableGroup
87 A new instance of the CovariantTunableGroup object
88 that is a deep copy of the original one.
89 """
90 return copy.deepcopy(self)
92 def __eq__(self, other: object) -> bool:
93 """
94 Check if two CovariantTunableGroup objects are equal.
96 Parameters
97 ----------
98 other : CovariantTunableGroup
99 A covariant tunable group object to compare to.
101 Returns
102 -------
103 is_equal : bool
104 True if two CovariantTunableGroup objects are equal.
105 """
106 if not isinstance(other, CovariantTunableGroup):
107 return False
108 # TODO: May need to provide logic to relax the equality check on the
109 # tunables (e.g. "compatible" vs. "equal").
110 return (
111 self._name == other._name
112 and self._cost == other._cost
113 and self._is_updated == other._is_updated
114 and self._tunables == other._tunables
115 )
117 def equals_defaults(self, other: "CovariantTunableGroup") -> bool:
118 """
119 Checks to see if the other CovariantTunableGroup is the same, ignoring the
120 current values of the two groups' Tunables.
122 Parameters
123 ----------
124 other : CovariantTunableGroup
125 A covariant tunable group object to compare to.
127 Returns
128 -------
129 are_equal : bool
130 True if the two CovariantTunableGroup objects' *metadata* are the same,
131 False otherwise.
132 """
133 # NOTE: May be worth considering to implement this check without copies.
134 cpy = self.copy()
135 cpy.restore_defaults()
136 cpy.reset_is_updated()
138 other = other.copy()
139 other.restore_defaults()
140 other.reset_is_updated()
141 return cpy == other
143 def is_defaults(self) -> bool:
144 """
145 Checks whether the currently assigned values of all tunables are at their
146 defaults.
148 Returns
149 -------
150 bool
151 """
152 return all(tunable.is_default() for tunable in self._tunables.values())
154 def restore_defaults(self) -> None:
155 """Restore all tunable parameters to their default values."""
156 for tunable in self._tunables.values():
157 if tunable.value != tunable.default:
158 self._is_updated = True
159 tunable.value = tunable.default
161 def reset_is_updated(self) -> None:
162 """
163 Clear the update flag.
165 That is, state that running an experiment with the current values of the
166 tunables in this group has no extra cost.
167 """
168 self._is_updated = False
170 def is_updated(self) -> bool:
171 """
172 Check if any of the tunable values in the group has been updated.
174 Returns
175 -------
176 is_updated : bool
177 True if any of the tunable values in the group has been updated, False otherwise.
178 """
179 return self._is_updated
181 def get_current_cost(self) -> int:
182 """
183 Get the cost of the experiment given current tunable values.
185 Returns
186 -------
187 cost : int
188 Cost of the experiment or 0 if parameters have not been updated.
189 """
190 return self._cost if self._is_updated else 0
192 def get_names(self) -> Iterable[str]:
193 """Get the names of all tunables in the group."""
194 return self._tunables.keys()
196 def get_tunable_values_dict(self) -> dict[str, TunableValue]:
197 """
198 Get current values of all tunables in the group as a dict.
200 Returns
201 -------
202 tunables : dict[str, TunableValue]
203 """
204 return {name: tunable.value for (name, tunable) in self._tunables.items()}
206 def __repr__(self) -> str:
207 """
208 Produce a human-readable version of the CovariantTunableGroup (mostly for
209 logging).
211 Returns
212 -------
213 string : str
214 A human-readable version of the CovariantTunableGroup.
215 """
216 return f"{self._name}: {self._tunables}"
218 def get_tunable(self, tunable: str | Tunable) -> Tunable:
219 """
220 Access the entire Tunable in a group (not just its value). Throw KeyError if the
221 tunable is not in the group.
223 Parameters
224 ----------
225 tunable : str
226 Name of the tunable parameter.
228 Returns
229 -------
230 Tunable
231 An instance of the Tunable parameter.
232 """
233 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
234 return self._tunables[name]
236 def get_tunables(self) -> Iterable[Tunable]:
237 """
238 Gets the set of tunables for this CovariantTunableGroup.
240 Returns
241 -------
242 Iterable[Tunable]
243 """
244 return self._tunables.values()
246 def __contains__(self, tunable: str | Tunable) -> bool:
247 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
248 return name in self._tunables
250 def __getitem__(self, tunable: str | Tunable) -> TunableValue:
251 return self.get_tunable(tunable).value
253 def __setitem__(
254 self,
255 tunable: str | Tunable,
256 tunable_value: TunableValue | Tunable,
257 ) -> TunableValue:
258 value: TunableValue = (
259 tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value
260 )
261 self._is_updated |= self.get_tunable(tunable).update(value)
262 return value