Coverage for mlos_bench/mlos_bench/tunables/covariant_group.py: 98%

61 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 00:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6CovariantTunableGroup class definition. 

7 

8A collection of :py:class:`.Tunable` parameters that are updated together (e.g., 

9with the same cost). 

10 

11See Also 

12-------- 

13mlos_bench.tunables.tunable_groups : TunableGroups class definition. 

14""" 

15 

16import copy 

17from collections.abc import Iterable 

18 

19from mlos_bench.tunables.tunable import Tunable 

20from mlos_bench.tunables.tunable_types import TunableValue 

21 

22 

23class CovariantTunableGroup: 

24 """ 

25 A collection of :py:class:`.Tunable` parameters. 

26 

27 Changing any of the parameters in the group incurs the same cost of the experiment. 

28 

29 See Also 

30 -------- 

31 mlos_bench.tunables.tunable_groups : TunableGroups class definition. 

32 """ 

33 

34 def __init__(self, name: str, config: dict): 

35 """ 

36 Create a new group of tunable parameters. 

37 

38 Parameters 

39 ---------- 

40 name : str 

41 Human-readable identifier of the tunable parameters group. 

42 config : dict 

43 Python dict that represents a CovariantTunableGroup 

44 (e.g., deserialized from JSON). 

45 """ 

46 self._is_updated = True 

47 self._name = name 

48 self._cost = int(config.get("cost", 0)) 

49 self._tunables: dict[str, Tunable] = { 

50 name: Tunable(name, tunable_config) 

51 for (name, tunable_config) in config.get("params", {}).items() 

52 } 

53 

54 @property 

55 def name(self) -> str: 

56 """ 

57 Get the name of the covariant group. 

58 

59 Returns 

60 ------- 

61 name : str 

62 Name (i.e., a string id) of the covariant group. 

63 """ 

64 return self._name 

65 

66 @property 

67 def cost(self) -> int: 

68 """ 

69 Get the cost of changing the values in the covariant group. This value is a 

70 constant. Use `get_current_cost()` to get the cost given the group update 

71 status. 

72 

73 Returns 

74 ------- 

75 cost : int 

76 Cost of changing the values in the covariant group. 

77 """ 

78 return self._cost 

79 

80 def copy(self) -> "CovariantTunableGroup": 

81 """ 

82 Deep copy of the CovariantTunableGroup object. 

83 

84 Returns 

85 ------- 

86 group : CovariantTunableGroup 

87 A new instance of the CovariantTunableGroup object 

88 that is a deep copy of the original one. 

89 """ 

90 return copy.deepcopy(self) 

91 

92 def __eq__(self, other: object) -> bool: 

93 """ 

94 Check if two CovariantTunableGroup objects are equal. 

95 

96 Parameters 

97 ---------- 

98 other : CovariantTunableGroup 

99 A covariant tunable group object to compare to. 

100 

101 Returns 

102 ------- 

103 is_equal : bool 

104 True if two CovariantTunableGroup objects are equal. 

105 """ 

106 if not isinstance(other, CovariantTunableGroup): 

107 return False 

108 # TODO: May need to provide logic to relax the equality check on the 

109 # tunables (e.g. "compatible" vs. "equal"). 

110 return ( 

111 self._name == other._name 

112 and self._cost == other._cost 

113 and self._is_updated == other._is_updated 

114 and self._tunables == other._tunables 

115 ) 

116 

117 def equals_defaults(self, other: "CovariantTunableGroup") -> bool: 

118 """ 

119 Checks to see if the other CovariantTunableGroup is the same, ignoring the 

120 current values of the two groups' Tunables. 

121 

122 Parameters 

123 ---------- 

124 other : CovariantTunableGroup 

125 A covariant tunable group object to compare to. 

126 

127 Returns 

128 ------- 

129 are_equal : bool 

130 True if the two CovariantTunableGroup objects' *metadata* are the same, 

131 False otherwise. 

132 """ 

133 # NOTE: May be worth considering to implement this check without copies. 

134 cpy = self.copy() 

135 cpy.restore_defaults() 

136 cpy.reset_is_updated() 

137 

138 other = other.copy() 

139 other.restore_defaults() 

140 other.reset_is_updated() 

141 return cpy == other 

142 

143 def is_defaults(self) -> bool: 

144 """ 

145 Checks whether the currently assigned values of all tunables are at their 

146 defaults. 

147 

148 Returns 

149 ------- 

150 bool 

151 """ 

152 return all(tunable.is_default() for tunable in self._tunables.values()) 

153 

154 def restore_defaults(self) -> None: 

155 """Restore all tunable parameters to their default values.""" 

156 for tunable in self._tunables.values(): 

157 if tunable.value != tunable.default: 

158 self._is_updated = True 

159 tunable.value = tunable.default 

160 

161 def reset_is_updated(self) -> None: 

162 """ 

163 Clear the update flag. 

164 

165 That is, state that running an experiment with the current values of the 

166 tunables in this group has no extra cost. 

167 """ 

168 self._is_updated = False 

169 

170 def is_updated(self) -> bool: 

171 """ 

172 Check if any of the tunable values in the group has been updated. 

173 

174 Returns 

175 ------- 

176 is_updated : bool 

177 True if any of the tunable values in the group has been updated, False otherwise. 

178 """ 

179 return self._is_updated 

180 

181 def get_current_cost(self) -> int: 

182 """ 

183 Get the cost of the experiment given current tunable values. 

184 

185 Returns 

186 ------- 

187 cost : int 

188 Cost of the experiment or 0 if parameters have not been updated. 

189 """ 

190 return self._cost if self._is_updated else 0 

191 

192 def get_names(self) -> Iterable[str]: 

193 """Get the names of all tunables in the group.""" 

194 return self._tunables.keys() 

195 

196 def get_tunable_values_dict(self) -> dict[str, TunableValue]: 

197 """ 

198 Get current values of all tunables in the group as a dict. 

199 

200 Returns 

201 ------- 

202 tunables : dict[str, TunableValue] 

203 """ 

204 return {name: tunable.value for (name, tunable) in self._tunables.items()} 

205 

206 def __repr__(self) -> str: 

207 """ 

208 Produce a human-readable version of the CovariantTunableGroup (mostly for 

209 logging). 

210 

211 Returns 

212 ------- 

213 string : str 

214 A human-readable version of the CovariantTunableGroup. 

215 """ 

216 return f"{self._name}: {self._tunables}" 

217 

218 def get_tunable(self, tunable: str | Tunable) -> Tunable: 

219 """ 

220 Access the entire Tunable in a group (not just its value). Throw KeyError if the 

221 tunable is not in the group. 

222 

223 Parameters 

224 ---------- 

225 tunable : str 

226 Name of the tunable parameter. 

227 

228 Returns 

229 ------- 

230 Tunable 

231 An instance of the Tunable parameter. 

232 """ 

233 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

234 return self._tunables[name] 

235 

236 def get_tunables(self) -> Iterable[Tunable]: 

237 """ 

238 Gets the set of tunables for this CovariantTunableGroup. 

239 

240 Returns 

241 ------- 

242 Iterable[Tunable] 

243 """ 

244 return self._tunables.values() 

245 

246 def __contains__(self, tunable: str | Tunable) -> bool: 

247 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

248 return name in self._tunables 

249 

250 def __getitem__(self, tunable: str | Tunable) -> TunableValue: 

251 return self.get_tunable(tunable).value 

252 

253 def __setitem__( 

254 self, 

255 tunable: str | Tunable, 

256 tunable_value: TunableValue | Tunable, 

257 ) -> TunableValue: 

258 value: TunableValue = ( 

259 tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value 

260 ) 

261 self._is_updated |= self.get_tunable(tunable).update(value) 

262 return value