Coverage for mlos_bench/mlos_bench/tunables/covariant_group.py: 98%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tunable parameter definition.""" 

6import copy 

7from typing import Dict, Iterable, Union 

8 

9from mlos_bench.tunables.tunable import Tunable, TunableValue 

10 

11 

12class CovariantTunableGroup: 

13 """ 

14 A collection of tunable parameters. 

15 

16 Changing any of the parameters in the group incurs the same cost of the experiment. 

17 """ 

18 

19 def __init__(self, name: str, config: dict): 

20 """ 

21 Create a new group of tunable parameters. 

22 

23 Parameters 

24 ---------- 

25 name : str 

26 Human-readable identifier of the tunable parameters group. 

27 config : dict 

28 Python dict that represents a CovariantTunableGroup 

29 (e.g., deserialized from JSON). 

30 """ 

31 self._is_updated = True 

32 self._name = name 

33 self._cost = int(config.get("cost", 0)) 

34 self._tunables: Dict[str, Tunable] = { 

35 name: Tunable(name, tunable_config) 

36 for (name, tunable_config) in config.get("params", {}).items() 

37 } 

38 

39 @property 

40 def name(self) -> str: 

41 """ 

42 Get the name of the covariant group. 

43 

44 Returns 

45 ------- 

46 name : str 

47 Name (i.e., a string id) of the covariant group. 

48 """ 

49 return self._name 

50 

51 @property 

52 def cost(self) -> int: 

53 """ 

54 Get the cost of changing the values in the covariant group. This value is a 

55 constant. Use `get_current_cost()` to get the cost given the group update 

56 status. 

57 

58 Returns 

59 ------- 

60 cost : int 

61 Cost of changing the values in the covariant group. 

62 """ 

63 return self._cost 

64 

65 def copy(self) -> "CovariantTunableGroup": 

66 """ 

67 Deep copy of the CovariantTunableGroup object. 

68 

69 Returns 

70 ------- 

71 group : CovariantTunableGroup 

72 A new instance of the CovariantTunableGroup object 

73 that is a deep copy of the original one. 

74 """ 

75 return copy.deepcopy(self) 

76 

77 def __eq__(self, other: object) -> bool: 

78 """ 

79 Check if two CovariantTunableGroup objects are equal. 

80 

81 Parameters 

82 ---------- 

83 other : CovariantTunableGroup 

84 A covariant tunable group object to compare to. 

85 

86 Returns 

87 ------- 

88 is_equal : bool 

89 True if two CovariantTunableGroup objects are equal. 

90 """ 

91 if not isinstance(other, CovariantTunableGroup): 

92 return False 

93 # TODO: May need to provide logic to relax the equality check on the 

94 # tunables (e.g. "compatible" vs. "equal"). 

95 return ( 

96 self._name == other._name 

97 and self._cost == other._cost 

98 and self._is_updated == other._is_updated 

99 and self._tunables == other._tunables 

100 ) 

101 

102 def equals_defaults(self, other: "CovariantTunableGroup") -> bool: 

103 """ 

104 Checks to see if the other CovariantTunableGroup is the same, ignoring the 

105 current values of the two groups' Tunables. 

106 

107 Parameters 

108 ---------- 

109 other : CovariantTunableGroup 

110 A covariant tunable group object to compare to. 

111 

112 Returns 

113 ------- 

114 are_equal : bool 

115 True if the two CovariantTunableGroup objects' *metadata* are the same, 

116 False otherwise. 

117 """ 

118 # NOTE: May be worth considering to implement this check without copies. 

119 cpy = self.copy() 

120 cpy.restore_defaults() 

121 cpy.reset_is_updated() 

122 

123 other = other.copy() 

124 other.restore_defaults() 

125 other.reset_is_updated() 

126 return cpy == other 

127 

128 def is_defaults(self) -> bool: 

129 """ 

130 Checks whether the currently assigned values of all tunables are at their 

131 defaults. 

132 

133 Returns 

134 ------- 

135 bool 

136 """ 

137 return all(tunable.is_default() for tunable in self._tunables.values()) 

138 

139 def restore_defaults(self) -> None: 

140 """Restore all tunable parameters to their default values.""" 

141 for tunable in self._tunables.values(): 

142 if tunable.value != tunable.default: 

143 self._is_updated = True 

144 tunable.value = tunable.default 

145 

146 def reset_is_updated(self) -> None: 

147 """ 

148 Clear the update flag. 

149 

150 That is, state that running an experiment with the current values of the 

151 tunables in this group has no extra cost. 

152 """ 

153 self._is_updated = False 

154 

155 def is_updated(self) -> bool: 

156 """ 

157 Check if any of the tunable values in the group has been updated. 

158 

159 Returns 

160 ------- 

161 is_updated : bool 

162 True if any of the tunable values in the group has been updated, False otherwise. 

163 """ 

164 return self._is_updated 

165 

166 def get_current_cost(self) -> int: 

167 """ 

168 Get the cost of the experiment given current tunable values. 

169 

170 Returns 

171 ------- 

172 cost : int 

173 Cost of the experiment or 0 if parameters have not been updated. 

174 """ 

175 return self._cost if self._is_updated else 0 

176 

177 def get_names(self) -> Iterable[str]: 

178 """Get the names of all tunables in the group.""" 

179 return self._tunables.keys() 

180 

181 def get_tunable_values_dict(self) -> Dict[str, TunableValue]: 

182 """ 

183 Get current values of all tunables in the group as a dict. 

184 

185 Returns 

186 ------- 

187 tunables : Dict[str, TunableValue] 

188 """ 

189 return {name: tunable.value for (name, tunable) in self._tunables.items()} 

190 

191 def __repr__(self) -> str: 

192 """ 

193 Produce a human-readable version of the CovariantTunableGroup (mostly for 

194 logging). 

195 

196 Returns 

197 ------- 

198 string : str 

199 A human-readable version of the CovariantTunableGroup. 

200 """ 

201 return f"{self._name}: {self._tunables}" 

202 

203 def get_tunable(self, tunable: Union[str, Tunable]) -> Tunable: 

204 """ 

205 Access the entire Tunable in a group (not just its value). Throw KeyError if the 

206 tunable is not in the group. 

207 

208 Parameters 

209 ---------- 

210 tunable : str 

211 Name of the tunable parameter. 

212 

213 Returns 

214 ------- 

215 Tunable 

216 An instance of the Tunable parameter. 

217 """ 

218 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

219 return self._tunables[name] 

220 

221 def get_tunables(self) -> Iterable[Tunable]: 

222 """ 

223 Gets the set of tunables for this CovariantTunableGroup. 

224 

225 Returns 

226 ------- 

227 Iterable[Tunable] 

228 """ 

229 return self._tunables.values() 

230 

231 def __contains__(self, tunable: Union[str, Tunable]) -> bool: 

232 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

233 return name in self._tunables 

234 

235 def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: 

236 return self.get_tunable(tunable).value 

237 

238 def __setitem__( 

239 self, 

240 tunable: Union[str, Tunable], 

241 tunable_value: Union[TunableValue, Tunable], 

242 ) -> TunableValue: 

243 value: TunableValue = ( 

244 tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value 

245 ) 

246 self._is_updated |= self.get_tunable(tunable).update(value) 

247 return value