Coverage for mlos_bench/mlos_bench/tunables/covariant_group.py: 98%

61 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Tunable parameter definition. 

7""" 

8import copy 

9 

10from typing import Dict, Iterable, Union 

11 

12from mlos_bench.tunables.tunable import Tunable, TunableValue 

13 

14 

15class CovariantTunableGroup: 

16 """ 

17 A collection of tunable parameters. 

18 Changing any of the parameters in the group incurs the same cost of the experiment. 

19 """ 

20 

21 def __init__(self, name: str, config: dict): 

22 """ 

23 Create a new group of tunable parameters. 

24 

25 Parameters 

26 ---------- 

27 name : str 

28 Human-readable identifier of the tunable parameters group. 

29 config : dict 

30 Python dict that represents a CovariantTunableGroup 

31 (e.g., deserialized from JSON). 

32 """ 

33 self._is_updated = True 

34 self._name = name 

35 self._cost = int(config.get("cost", 0)) 

36 self._tunables: Dict[str, Tunable] = { 

37 name: Tunable(name, tunable_config) 

38 for (name, tunable_config) in config.get("params", {}).items() 

39 } 

40 

41 @property 

42 def name(self) -> str: 

43 """ 

44 Get the name of the covariant group. 

45 

46 Returns 

47 ------- 

48 name : str 

49 Name (i.e., a string id) of the covariant group. 

50 """ 

51 return self._name 

52 

53 @property 

54 def cost(self) -> int: 

55 """ 

56 Get the cost of changing the values in the covariant group. 

57 This value is a constant. Use `get_current_cost()` to get 

58 the cost given the group update status. 

59 

60 Returns 

61 ------- 

62 cost : int 

63 Cost of changing the values in the covariant group. 

64 """ 

65 return self._cost 

66 

67 def copy(self) -> "CovariantTunableGroup": 

68 """ 

69 Deep copy of the CovariantTunableGroup object. 

70 

71 Returns 

72 ------- 

73 group : CovariantTunableGroup 

74 A new instance of the CovariantTunableGroup object 

75 that is a deep copy of the original one. 

76 """ 

77 return copy.deepcopy(self) 

78 

79 def __eq__(self, other: object) -> bool: 

80 """ 

81 Check if two CovariantTunableGroup objects are equal. 

82 

83 Parameters 

84 ---------- 

85 other : CovariantTunableGroup 

86 A covariant tunable group object to compare to. 

87 

88 Returns 

89 ------- 

90 is_equal : bool 

91 True if two CovariantTunableGroup objects are equal. 

92 """ 

93 if not isinstance(other, CovariantTunableGroup): 

94 return False 

95 # TODO: May need to provide logic to relax the equality check on the 

96 # tunables (e.g. "compatible" vs. "equal"). 

97 return (self._name == other._name and 

98 self._cost == other._cost and 

99 self._is_updated == other._is_updated and 

100 self._tunables == other._tunables) 

101 

102 def equals_defaults(self, other: "CovariantTunableGroup") -> bool: 

103 """ 

104 Checks to see if the other CovariantTunableGroup is the same, ignoring 

105 the current values of the two groups' Tunables. 

106 

107 Parameters 

108 ---------- 

109 other : CovariantTunableGroup 

110 A covariant tunable group object to compare to. 

111 

112 Returns 

113 ------- 

114 are_equal : bool 

115 True if the two CovariantTunableGroup objects' *metadata* are the same, 

116 False otherwise. 

117 """ 

118 # NOTE: May be worth considering to implement this check without copies. 

119 cpy = self.copy() 

120 cpy.restore_defaults() 

121 cpy.reset_is_updated() 

122 

123 other = other.copy() 

124 other.restore_defaults() 

125 other.reset_is_updated() 

126 return cpy == other 

127 

128 def is_defaults(self) -> bool: 

129 """ 

130 Checks whether the currently assigned values of all tunables are at their defaults. 

131 

132 Returns 

133 ------- 

134 bool 

135 """ 

136 return all(tunable.is_default() for tunable in self._tunables.values()) 

137 

138 def restore_defaults(self) -> None: 

139 """ 

140 Restore all tunable parameters to their default values. 

141 """ 

142 for tunable in self._tunables.values(): 

143 if tunable.value != tunable.default: 

144 self._is_updated = True 

145 tunable.value = tunable.default 

146 

147 def reset_is_updated(self) -> None: 

148 """ 

149 Clear the update flag. That is, state that running an experiment with the 

150 current values of the tunables in this group has no extra cost. 

151 """ 

152 self._is_updated = False 

153 

154 def is_updated(self) -> bool: 

155 """ 

156 Check if any of the tunable values in the group has been updated. 

157 

158 Returns 

159 ------- 

160 is_updated : bool 

161 True if any of the tunable values in the group has been updated, False otherwise. 

162 """ 

163 return self._is_updated 

164 

165 def get_current_cost(self) -> int: 

166 """ 

167 Get the cost of the experiment given current tunable values. 

168 

169 Returns 

170 ------- 

171 cost : int 

172 Cost of the experiment or 0 if parameters have not been updated. 

173 """ 

174 return self._cost if self._is_updated else 0 

175 

176 def get_names(self) -> Iterable[str]: 

177 """ 

178 Get the names of all tunables in the group. 

179 """ 

180 return self._tunables.keys() 

181 

182 def get_tunable_values_dict(self) -> Dict[str, TunableValue]: 

183 """ 

184 Get current values of all tunables in the group as a dict. 

185 

186 Returns 

187 ------- 

188 tunables : Dict[str, TunableValue] 

189 """ 

190 return {name: tunable.value for (name, tunable) in self._tunables.items()} 

191 

192 def __repr__(self) -> str: 

193 """ 

194 Produce a human-readable version of the CovariantTunableGroup 

195 (mostly for logging). 

196 

197 Returns 

198 ------- 

199 string : str 

200 A human-readable version of the CovariantTunableGroup. 

201 """ 

202 return f"{self._name}: {self._tunables}" 

203 

204 def get_tunable(self, tunable: Union[str, Tunable]) -> Tunable: 

205 """ 

206 Access the entire Tunable in a group (not just its value). 

207 Throw KeyError if the tunable is not in the group. 

208 

209 Parameters 

210 ---------- 

211 tunable : str 

212 Name of the tunable parameter. 

213 

214 Returns 

215 ------- 

216 Tunable 

217 An instance of the Tunable parameter. 

218 """ 

219 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

220 return self._tunables[name] 

221 

222 def get_tunables(self) -> Iterable[Tunable]: 

223 """Gets the set of tunables for this CovariantTunableGroup. 

224 

225 Returns 

226 ------- 

227 Iterable[Tunable] 

228 """ 

229 return self._tunables.values() 

230 

231 def __contains__(self, tunable: Union[str, Tunable]) -> bool: 

232 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

233 return name in self._tunables 

234 

235 def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: 

236 return self.get_tunable(tunable).value 

237 

238 def __setitem__(self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable]) -> TunableValue: 

239 value: TunableValue = tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value 

240 self._is_updated |= self.get_tunable(tunable).update(value) 

241 return value