Coverage for mlos_bench/mlos_bench/tunables/tunable_groups.py: 99%

89 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6TunableGroups definition. 

7""" 

8import copy 

9 

10from typing import Dict, Generator, Iterable, Mapping, Optional, Tuple, Union 

11 

12from mlos_bench.config.schemas import ConfigSchema 

13from mlos_bench.tunables.tunable import Tunable, TunableValue 

14from mlos_bench.tunables.covariant_group import CovariantTunableGroup 

15 

16 

17class TunableGroups: 

18 """ 

19 A collection of covariant groups of tunable parameters. 

20 """ 

21 

22 def __init__(self, config: Optional[dict] = None): 

23 """ 

24 Create a new group of tunable parameters. 

25 

26 Parameters 

27 ---------- 

28 config : dict 

29 Python dict of serialized representation of the covariant tunable groups. 

30 """ 

31 if config is None: 

32 config = {} 

33 ConfigSchema.TUNABLE_PARAMS.validate(config) 

34 self._index: Dict[str, CovariantTunableGroup] = {} # Index (Tunable id -> CovariantTunableGroup) 

35 self._tunable_groups: Dict[str, CovariantTunableGroup] = {} 

36 for (name, group_config) in config.items(): 

37 self._add_group(CovariantTunableGroup(name, group_config)) 

38 

39 def __bool__(self) -> bool: 

40 return bool(self._index) 

41 

42 def __len__(self) -> int: 

43 return len(self._index) 

44 

45 def __eq__(self, other: object) -> bool: 

46 """ 

47 Check if two TunableGroups are equal. 

48 

49 Parameters 

50 ---------- 

51 other : TunableGroups 

52 A tunable groups object to compare to. 

53 

54 Returns 

55 ------- 

56 is_equal : bool 

57 True if two TunableGroups are equal. 

58 """ 

59 if not isinstance(other, TunableGroups): 

60 return False 

61 return bool(self._tunable_groups == other._tunable_groups) 

62 

63 def copy(self) -> "TunableGroups": 

64 """ 

65 Deep copy of the TunableGroups object. 

66 

67 Returns 

68 ------- 

69 tunables : TunableGroups 

70 A new instance of the TunableGroups object 

71 that is a deep copy of the original one. 

72 """ 

73 return copy.deepcopy(self) 

74 

75 def _add_group(self, group: CovariantTunableGroup) -> None: 

76 """ 

77 Add a CovariantTunableGroup to the current collection. 

78 

79 Note: non-overlapping groups are expected to be added to the collection. 

80 

81 Parameters 

82 ---------- 

83 group : CovariantTunableGroup 

84 """ 

85 assert group.name not in self._tunable_groups, f"Duplicate covariant tunable group name {group.name} in {self}" 

86 self._tunable_groups[group.name] = group 

87 for tunable in group.get_tunables(): 

88 if tunable.name in self._index: 

89 raise ValueError(f"Duplicate Tunable {tunable.name} from group {group.name} in {self}") 

90 self._index[tunable.name] = group 

91 

92 def merge(self, tunables: "TunableGroups") -> "TunableGroups": 

93 """ 

94 Merge the two collections of covariant tunable groups. 

95 

96 Unlike the dict `update` method, this method does not modify the 

97 original when overlapping keys are found. 

98 It is expected be used to merge the tunable groups referenced by a 

99 standalone Environment config into a parent CompositeEnvironment, 

100 for instance. 

101 This allows self contained, potentially overlapping, but also 

102 overridable configs to be composed together. 

103 

104 Parameters 

105 ---------- 

106 tunables : TunableGroups 

107 A collection of covariant tunable groups. 

108 

109 Returns 

110 ------- 

111 self : TunableGroups 

112 Self-reference for chaining. 

113 """ 

114 # pylint: disable=protected-access 

115 # Check that covariant groups are unique, else throw an error. 

116 for group in tunables._tunable_groups.values(): 

117 if group.name not in self._tunable_groups: 

118 self._add_group(group) 

119 else: 

120 # Check that there's no overlap in the tunables. 

121 # But allow for differing current values. 

122 if not self._tunable_groups[group.name].equals_defaults(group): 

123 raise ValueError(f"Overlapping covariant tunable group name {group.name} " + 

124 "in {self._tunable_groups[group.name]} and {tunables}") 

125 return self 

126 

127 def __repr__(self) -> str: 

128 """ 

129 Produce a human-readable version of the TunableGroups (mostly for logging). 

130 

131 Returns 

132 ------- 

133 string : str 

134 A human-readable version of the TunableGroups. 

135 """ 

136 return "{ " + ", ".join( 

137 f"{group.name}::{tunable}" 

138 for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name)) 

139 for tunable in sorted(group._tunables.values())) + " }" 

140 

141 def __contains__(self, tunable: Union[str, Tunable]) -> bool: 

142 """ 

143 Checks if the given name/tunable is in this tunable group. 

144 """ 

145 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

146 return name in self._index 

147 

148 def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: 

149 """ 

150 Get the current value of a single tunable parameter. 

151 """ 

152 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

153 return self._index[name][name] 

154 

155 def __setitem__(self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable]) -> TunableValue: 

156 """ 

157 Update the current value of a single tunable parameter. 

158 """ 

159 # Use double index to make sure we set the is_updated flag of the group 

160 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

161 value: TunableValue = tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value 

162 self._index[name][name] = value 

163 return self._index[name][name] 

164 

165 def __iter__(self) -> Generator[Tuple[Tunable, CovariantTunableGroup], None, None]: 

166 """ 

167 An iterator over all tunables in the group. 

168 

169 Returns 

170 ------- 

171 [(tunable, group), ...] : iter(Tunable, CovariantTunableGroup) 

172 An iterator over all tunables in all groups. Each element is a 2-tuple 

173 of an instance of the Tunable parameter and covariant group it belongs to. 

174 """ 

175 return ((group.get_tunable(name), group) for (name, group) in self._index.items()) 

176 

177 def get_tunable(self, tunable: Union[str, Tunable]) -> Tuple[Tunable, CovariantTunableGroup]: 

178 """ 

179 Access the entire Tunable (not just its value) and its covariant group. 

180 Throw KeyError if the tunable is not found. 

181 

182 Parameters 

183 ---------- 

184 tunable : Union[str, Tunable] 

185 Name of the tunable parameter. 

186 

187 Returns 

188 ------- 

189 (tunable, group) : (Tunable, CovariantTunableGroup) 

190 A 2-tuple of an instance of the Tunable parameter and covariant group it belongs to. 

191 """ 

192 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

193 group = self._index[name] 

194 return (group.get_tunable(name), group) 

195 

196 def get_covariant_group_names(self) -> Iterable[str]: 

197 """ 

198 Get the names of all covariance groups in the collection. 

199 

200 Returns 

201 ------- 

202 group_names : [str] 

203 IDs of the covariant tunable groups. 

204 """ 

205 return self._tunable_groups.keys() 

206 

207 def subgroup(self, group_names: Iterable[str]) -> "TunableGroups": 

208 """ 

209 Select the covariance groups from the current set and create a new 

210 TunableGroups object that consists of those covariance groups. 

211 

212 Note: The new TunableGroup will include *references* (not copies) to 

213 original ones, so each will get updated together. 

214 This is often desirable to support the use case of multiple related 

215 Environments (e.g. Local vs Remote) using the same set of tunables 

216 within a CompositeEnvironment. 

217 

218 Parameters 

219 ---------- 

220 group_names : list of str 

221 IDs of the covariant tunable groups. 

222 

223 Returns 

224 ------- 

225 tunables : TunableGroups 

226 A collection of covariant tunable groups. 

227 """ 

228 # pylint: disable=protected-access 

229 tunables = TunableGroups() 

230 for name in group_names: 

231 if name not in self._tunable_groups: 

232 raise KeyError(f"Unknown covariant group name '{name}' in tunable group {self}") 

233 tunables._add_group(self._tunable_groups[name]) 

234 return tunables 

235 

236 def get_param_values(self, group_names: Optional[Iterable[str]] = None, 

237 into_params: Optional[Dict[str, TunableValue]] = None) -> Dict[str, TunableValue]: 

238 """ 

239 Get the current values of the tunables that belong to the specified covariance groups. 

240 

241 Parameters 

242 ---------- 

243 group_names : list of str or None 

244 IDs of the covariant tunable groups. 

245 Select parameters from all groups if omitted. 

246 into_params : dict 

247 An optional dict to copy the parameters and their values into. 

248 

249 Returns 

250 ------- 

251 into_params : dict 

252 Flat dict of all parameters and their values from given covariance groups. 

253 """ 

254 if group_names is None: 

255 group_names = self.get_covariant_group_names() 

256 if into_params is None: 

257 into_params = {} 

258 for name in group_names: 

259 into_params.update(self._tunable_groups[name].get_tunable_values_dict()) 

260 return into_params 

261 

262 def is_updated(self, group_names: Optional[Iterable[str]] = None) -> bool: 

263 """ 

264 Check if any of the given covariant tunable groups has been updated. 

265 

266 Parameters 

267 ---------- 

268 group_names : list of str or None 

269 IDs of the (covariant) tunable groups. Check all groups if omitted. 

270 

271 Returns 

272 ------- 

273 is_updated : bool 

274 True if any of the specified tunable groups has been updated, False otherwise. 

275 """ 

276 return any(self._tunable_groups[name].is_updated() 

277 for name in (group_names or self.get_covariant_group_names())) 

278 

279 def is_defaults(self) -> bool: 

280 """ 

281 Checks whether the currently assigned values of all tunables are at their defaults. 

282 

283 Returns 

284 ------- 

285 bool 

286 """ 

287 return all(group.is_defaults() for group in self._tunable_groups.values()) 

288 

289 def restore_defaults(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups": 

290 """ 

291 Restore all tunable parameters to their default values. 

292 

293 Parameters 

294 ---------- 

295 group_names : list of str or None 

296 IDs of the (covariant) tunable groups. Restore all groups if omitted. 

297 

298 Returns 

299 ------- 

300 self : TunableGroups 

301 Self-reference for chaining. 

302 """ 

303 for name in (group_names or self.get_covariant_group_names()): 

304 self._tunable_groups[name].restore_defaults() 

305 return self 

306 

307 def reset(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups": 

308 """ 

309 Clear the update flag of given covariant groups. 

310 

311 Parameters 

312 ---------- 

313 group_names : list of str or None 

314 IDs of the (covariant) tunable groups. Reset all groups if omitted. 

315 

316 Returns 

317 ------- 

318 self : TunableGroups 

319 Self-reference for chaining. 

320 """ 

321 for name in (group_names or self.get_covariant_group_names()): 

322 self._tunable_groups[name].reset_is_updated() 

323 return self 

324 

325 def assign(self, param_values: Mapping[str, TunableValue]) -> "TunableGroups": 

326 """ 

327 In-place update the values of the tunables from the dictionary 

328 of (key, value) pairs. 

329 

330 Parameters 

331 ---------- 

332 param_values : Mapping[str, TunableValue] 

333 Dictionary mapping Tunable parameter names to new values. 

334 

335 Returns 

336 ------- 

337 self : TunableGroups 

338 Self-reference for chaining. 

339 """ 

340 for key, value in param_values.items(): 

341 self[key] = value 

342 return self