Coverage for mlos_bench/mlos_bench/tunables/covariant_group.py: 98%
60 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Tunable parameter definition."""
6import copy
7from typing import Dict, Iterable, Union
9from mlos_bench.tunables.tunable import Tunable, TunableValue
12class CovariantTunableGroup:
13 """
14 A collection of tunable parameters.
16 Changing any of the parameters in the group incurs the same cost of the experiment.
17 """
19 def __init__(self, name: str, config: dict):
20 """
21 Create a new group of tunable parameters.
23 Parameters
24 ----------
25 name : str
26 Human-readable identifier of the tunable parameters group.
27 config : dict
28 Python dict that represents a CovariantTunableGroup
29 (e.g., deserialized from JSON).
30 """
31 self._is_updated = True
32 self._name = name
33 self._cost = int(config.get("cost", 0))
34 self._tunables: Dict[str, Tunable] = {
35 name: Tunable(name, tunable_config)
36 for (name, tunable_config) in config.get("params", {}).items()
37 }
39 @property
40 def name(self) -> str:
41 """
42 Get the name of the covariant group.
44 Returns
45 -------
46 name : str
47 Name (i.e., a string id) of the covariant group.
48 """
49 return self._name
51 @property
52 def cost(self) -> int:
53 """
54 Get the cost of changing the values in the covariant group. This value is a
55 constant. Use `get_current_cost()` to get the cost given the group update
56 status.
58 Returns
59 -------
60 cost : int
61 Cost of changing the values in the covariant group.
62 """
63 return self._cost
65 def copy(self) -> "CovariantTunableGroup":
66 """
67 Deep copy of the CovariantTunableGroup object.
69 Returns
70 -------
71 group : CovariantTunableGroup
72 A new instance of the CovariantTunableGroup object
73 that is a deep copy of the original one.
74 """
75 return copy.deepcopy(self)
77 def __eq__(self, other: object) -> bool:
78 """
79 Check if two CovariantTunableGroup objects are equal.
81 Parameters
82 ----------
83 other : CovariantTunableGroup
84 A covariant tunable group object to compare to.
86 Returns
87 -------
88 is_equal : bool
89 True if two CovariantTunableGroup objects are equal.
90 """
91 if not isinstance(other, CovariantTunableGroup):
92 return False
93 # TODO: May need to provide logic to relax the equality check on the
94 # tunables (e.g. "compatible" vs. "equal").
95 return (
96 self._name == other._name
97 and self._cost == other._cost
98 and self._is_updated == other._is_updated
99 and self._tunables == other._tunables
100 )
102 def equals_defaults(self, other: "CovariantTunableGroup") -> bool:
103 """
104 Checks to see if the other CovariantTunableGroup is the same, ignoring the
105 current values of the two groups' Tunables.
107 Parameters
108 ----------
109 other : CovariantTunableGroup
110 A covariant tunable group object to compare to.
112 Returns
113 -------
114 are_equal : bool
115 True if the two CovariantTunableGroup objects' *metadata* are the same,
116 False otherwise.
117 """
118 # NOTE: May be worth considering to implement this check without copies.
119 cpy = self.copy()
120 cpy.restore_defaults()
121 cpy.reset_is_updated()
123 other = other.copy()
124 other.restore_defaults()
125 other.reset_is_updated()
126 return cpy == other
128 def is_defaults(self) -> bool:
129 """
130 Checks whether the currently assigned values of all tunables are at their
131 defaults.
133 Returns
134 -------
135 bool
136 """
137 return all(tunable.is_default() for tunable in self._tunables.values())
139 def restore_defaults(self) -> None:
140 """Restore all tunable parameters to their default values."""
141 for tunable in self._tunables.values():
142 if tunable.value != tunable.default:
143 self._is_updated = True
144 tunable.value = tunable.default
146 def reset_is_updated(self) -> None:
147 """
148 Clear the update flag.
150 That is, state that running an experiment with the current values of the
151 tunables in this group has no extra cost.
152 """
153 self._is_updated = False
155 def is_updated(self) -> bool:
156 """
157 Check if any of the tunable values in the group has been updated.
159 Returns
160 -------
161 is_updated : bool
162 True if any of the tunable values in the group has been updated, False otherwise.
163 """
164 return self._is_updated
166 def get_current_cost(self) -> int:
167 """
168 Get the cost of the experiment given current tunable values.
170 Returns
171 -------
172 cost : int
173 Cost of the experiment or 0 if parameters have not been updated.
174 """
175 return self._cost if self._is_updated else 0
177 def get_names(self) -> Iterable[str]:
178 """Get the names of all tunables in the group."""
179 return self._tunables.keys()
181 def get_tunable_values_dict(self) -> Dict[str, TunableValue]:
182 """
183 Get current values of all tunables in the group as a dict.
185 Returns
186 -------
187 tunables : Dict[str, TunableValue]
188 """
189 return {name: tunable.value for (name, tunable) in self._tunables.items()}
191 def __repr__(self) -> str:
192 """
193 Produce a human-readable version of the CovariantTunableGroup (mostly for
194 logging).
196 Returns
197 -------
198 string : str
199 A human-readable version of the CovariantTunableGroup.
200 """
201 return f"{self._name}: {self._tunables}"
203 def get_tunable(self, tunable: Union[str, Tunable]) -> Tunable:
204 """
205 Access the entire Tunable in a group (not just its value). Throw KeyError if the
206 tunable is not in the group.
208 Parameters
209 ----------
210 tunable : str
211 Name of the tunable parameter.
213 Returns
214 -------
215 Tunable
216 An instance of the Tunable parameter.
217 """
218 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
219 return self._tunables[name]
221 def get_tunables(self) -> Iterable[Tunable]:
222 """
223 Gets the set of tunables for this CovariantTunableGroup.
225 Returns
226 -------
227 Iterable[Tunable]
228 """
229 return self._tunables.values()
231 def __contains__(self, tunable: Union[str, Tunable]) -> bool:
232 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
233 return name in self._tunables
235 def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue:
236 return self.get_tunable(tunable).value
238 def __setitem__(
239 self,
240 tunable: Union[str, Tunable],
241 tunable_value: Union[TunableValue, Tunable],
242 ) -> TunableValue:
243 value: TunableValue = (
244 tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value
245 )
246 self._is_updated |= self.get_tunable(tunable).update(value)
247 return value