Coverage for mlos_bench/mlos_bench/tunables/covariant_group.py: 98%
61 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Tunable parameter definition.
7"""
8import copy
10from typing import Dict, Iterable, Union
12from mlos_bench.tunables.tunable import Tunable, TunableValue
15class CovariantTunableGroup:
16 """
17 A collection of tunable parameters.
18 Changing any of the parameters in the group incurs the same cost of the experiment.
19 """
21 def __init__(self, name: str, config: dict):
22 """
23 Create a new group of tunable parameters.
25 Parameters
26 ----------
27 name : str
28 Human-readable identifier of the tunable parameters group.
29 config : dict
30 Python dict that represents a CovariantTunableGroup
31 (e.g., deserialized from JSON).
32 """
33 self._is_updated = True
34 self._name = name
35 self._cost = int(config.get("cost", 0))
36 self._tunables: Dict[str, Tunable] = {
37 name: Tunable(name, tunable_config)
38 for (name, tunable_config) in config.get("params", {}).items()
39 }
41 @property
42 def name(self) -> str:
43 """
44 Get the name of the covariant group.
46 Returns
47 -------
48 name : str
49 Name (i.e., a string id) of the covariant group.
50 """
51 return self._name
53 @property
54 def cost(self) -> int:
55 """
56 Get the cost of changing the values in the covariant group.
57 This value is a constant. Use `get_current_cost()` to get
58 the cost given the group update status.
60 Returns
61 -------
62 cost : int
63 Cost of changing the values in the covariant group.
64 """
65 return self._cost
67 def copy(self) -> "CovariantTunableGroup":
68 """
69 Deep copy of the CovariantTunableGroup object.
71 Returns
72 -------
73 group : CovariantTunableGroup
74 A new instance of the CovariantTunableGroup object
75 that is a deep copy of the original one.
76 """
77 return copy.deepcopy(self)
79 def __eq__(self, other: object) -> bool:
80 """
81 Check if two CovariantTunableGroup objects are equal.
83 Parameters
84 ----------
85 other : CovariantTunableGroup
86 A covariant tunable group object to compare to.
88 Returns
89 -------
90 is_equal : bool
91 True if two CovariantTunableGroup objects are equal.
92 """
93 if not isinstance(other, CovariantTunableGroup):
94 return False
95 # TODO: May need to provide logic to relax the equality check on the
96 # tunables (e.g. "compatible" vs. "equal").
97 return (self._name == other._name and
98 self._cost == other._cost and
99 self._is_updated == other._is_updated and
100 self._tunables == other._tunables)
102 def equals_defaults(self, other: "CovariantTunableGroup") -> bool:
103 """
104 Checks to see if the other CovariantTunableGroup is the same, ignoring
105 the current values of the two groups' Tunables.
107 Parameters
108 ----------
109 other : CovariantTunableGroup
110 A covariant tunable group object to compare to.
112 Returns
113 -------
114 are_equal : bool
115 True if the two CovariantTunableGroup objects' *metadata* are the same,
116 False otherwise.
117 """
118 # NOTE: May be worth considering to implement this check without copies.
119 cpy = self.copy()
120 cpy.restore_defaults()
121 cpy.reset_is_updated()
123 other = other.copy()
124 other.restore_defaults()
125 other.reset_is_updated()
126 return cpy == other
128 def is_defaults(self) -> bool:
129 """
130 Checks whether the currently assigned values of all tunables are at their defaults.
132 Returns
133 -------
134 bool
135 """
136 return all(tunable.is_default() for tunable in self._tunables.values())
138 def restore_defaults(self) -> None:
139 """
140 Restore all tunable parameters to their default values.
141 """
142 for tunable in self._tunables.values():
143 if tunable.value != tunable.default:
144 self._is_updated = True
145 tunable.value = tunable.default
147 def reset_is_updated(self) -> None:
148 """
149 Clear the update flag. That is, state that running an experiment with the
150 current values of the tunables in this group has no extra cost.
151 """
152 self._is_updated = False
154 def is_updated(self) -> bool:
155 """
156 Check if any of the tunable values in the group has been updated.
158 Returns
159 -------
160 is_updated : bool
161 True if any of the tunable values in the group has been updated, False otherwise.
162 """
163 return self._is_updated
165 def get_current_cost(self) -> int:
166 """
167 Get the cost of the experiment given current tunable values.
169 Returns
170 -------
171 cost : int
172 Cost of the experiment or 0 if parameters have not been updated.
173 """
174 return self._cost if self._is_updated else 0
176 def get_names(self) -> Iterable[str]:
177 """
178 Get the names of all tunables in the group.
179 """
180 return self._tunables.keys()
182 def get_tunable_values_dict(self) -> Dict[str, TunableValue]:
183 """
184 Get current values of all tunables in the group as a dict.
186 Returns
187 -------
188 tunables : Dict[str, TunableValue]
189 """
190 return {name: tunable.value for (name, tunable) in self._tunables.items()}
192 def __repr__(self) -> str:
193 """
194 Produce a human-readable version of the CovariantTunableGroup
195 (mostly for logging).
197 Returns
198 -------
199 string : str
200 A human-readable version of the CovariantTunableGroup.
201 """
202 return f"{self._name}: {self._tunables}"
204 def get_tunable(self, tunable: Union[str, Tunable]) -> Tunable:
205 """
206 Access the entire Tunable in a group (not just its value).
207 Throw KeyError if the tunable is not in the group.
209 Parameters
210 ----------
211 tunable : str
212 Name of the tunable parameter.
214 Returns
215 -------
216 Tunable
217 An instance of the Tunable parameter.
218 """
219 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
220 return self._tunables[name]
222 def get_tunables(self) -> Iterable[Tunable]:
223 """Gets the set of tunables for this CovariantTunableGroup.
225 Returns
226 -------
227 Iterable[Tunable]
228 """
229 return self._tunables.values()
231 def __contains__(self, tunable: Union[str, Tunable]) -> bool:
232 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
233 return name in self._tunables
235 def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue:
236 return self.get_tunable(tunable).value
238 def __setitem__(self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable]) -> TunableValue:
239 value: TunableValue = tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value
240 self._is_updated |= self.get_tunable(tunable).update(value)
241 return value