Coverage for mlos_bench/mlos_bench/tunables/tunable_groups.py: 99%

93 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 01:18 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""TunableGroups definition.""" 

6import copy 

7import logging 

8from typing import Dict, Generator, Iterable, Mapping, Optional, Tuple, Union 

9 

10from mlos_bench.config.schemas import ConfigSchema 

11from mlos_bench.tunables.covariant_group import CovariantTunableGroup 

12from mlos_bench.tunables.tunable import Tunable, TunableValue 

13 

14_LOG = logging.getLogger(__name__) 

15 

16 

17class TunableGroups: 

18 """A collection of covariant groups of tunable parameters.""" 

19 

20 def __init__(self, config: Optional[dict] = None): 

21 """ 

22 Create a new group of tunable parameters. 

23 

24 Parameters 

25 ---------- 

26 config : dict 

27 Python dict of serialized representation of the covariant tunable groups. 

28 """ 

29 if config is None: 

30 config = {} 

31 ConfigSchema.TUNABLE_PARAMS.validate(config) 

32 # Index (Tunable id -> CovariantTunableGroup) 

33 self._index: Dict[str, CovariantTunableGroup] = {} 

34 self._tunable_groups: Dict[str, CovariantTunableGroup] = {} 

35 for name, group_config in config.items(): 

36 self._add_group(CovariantTunableGroup(name, group_config)) 

37 

38 def __bool__(self) -> bool: 

39 return bool(self._index) 

40 

41 def __len__(self) -> int: 

42 return len(self._index) 

43 

44 def __eq__(self, other: object) -> bool: 

45 """ 

46 Check if two TunableGroups are equal. 

47 

48 Parameters 

49 ---------- 

50 other : TunableGroups 

51 A tunable groups object to compare to. 

52 

53 Returns 

54 ------- 

55 is_equal : bool 

56 True if two TunableGroups are equal. 

57 """ 

58 if not isinstance(other, TunableGroups): 

59 return False 

60 return bool(self._tunable_groups == other._tunable_groups) 

61 

62 def copy(self) -> "TunableGroups": 

63 """ 

64 Deep copy of the TunableGroups object. 

65 

66 Returns 

67 ------- 

68 tunables : TunableGroups 

69 A new instance of the TunableGroups object 

70 that is a deep copy of the original one. 

71 """ 

72 return copy.deepcopy(self) 

73 

74 def _add_group(self, group: CovariantTunableGroup) -> None: 

75 """ 

76 Add a CovariantTunableGroup to the current collection. 

77 

78 Note: non-overlapping groups are expected to be added to the collection. 

79 

80 Parameters 

81 ---------- 

82 group : CovariantTunableGroup 

83 """ 

84 assert ( 

85 group.name not in self._tunable_groups 

86 ), f"Duplicate covariant tunable group name {group.name} in {self}" 

87 self._tunable_groups[group.name] = group 

88 for tunable in group.get_tunables(): 

89 if tunable.name in self._index: 

90 raise ValueError( 

91 f"Duplicate Tunable {tunable.name} from group {group.name} in {self}" 

92 ) 

93 self._index[tunable.name] = group 

94 

95 def merge(self, tunables: "TunableGroups") -> "TunableGroups": 

96 """ 

97 Merge the two collections of covariant tunable groups. 

98 

99 Unlike the dict `update` method, this method does not modify the 

100 original when overlapping keys are found. 

101 It is expected be used to merge the tunable groups referenced by a 

102 standalone Environment config into a parent CompositeEnvironment, 

103 for instance. 

104 This allows self contained, potentially overlapping, but also 

105 overridable configs to be composed together. 

106 

107 Parameters 

108 ---------- 

109 tunables : TunableGroups 

110 A collection of covariant tunable groups. 

111 

112 Returns 

113 ------- 

114 self : TunableGroups 

115 Self-reference for chaining. 

116 """ 

117 # pylint: disable=protected-access 

118 # Check that covariant groups are unique, else throw an error. 

119 for group in tunables._tunable_groups.values(): 

120 if group.name not in self._tunable_groups: 

121 self._add_group(group) 

122 else: 

123 # Check that there's no overlap in the tunables. 

124 # But allow for differing current values. 

125 if not self._tunable_groups[group.name].equals_defaults(group): 

126 raise ValueError( 

127 f"Overlapping covariant tunable group name {group.name} " 

128 "in {self._tunable_groups[group.name]} and {tunables}" 

129 ) 

130 return self 

131 

132 def __repr__(self) -> str: 

133 """ 

134 Produce a human-readable version of the TunableGroups (mostly for logging). 

135 

136 Returns 

137 ------- 

138 string : str 

139 A human-readable version of the TunableGroups. 

140 """ 

141 return ( 

142 "{ " 

143 + ", ".join( 

144 f"{group.name}::{tunable}" 

145 for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name)) 

146 for tunable in sorted(group._tunables.values()) 

147 ) 

148 + " }" 

149 ) 

150 

151 def __contains__(self, tunable: Union[str, Tunable]) -> bool: 

152 """Checks if the given name/tunable is in this tunable group.""" 

153 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

154 return name in self._index 

155 

156 def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: 

157 """Get the current value of a single tunable parameter.""" 

158 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

159 return self._index[name][name] 

160 

161 def __setitem__( 

162 self, 

163 tunable: Union[str, Tunable], 

164 tunable_value: Union[TunableValue, Tunable], 

165 ) -> TunableValue: 

166 """Update the current value of a single tunable parameter.""" 

167 # Use double index to make sure we set the is_updated flag of the group 

168 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

169 value: TunableValue = ( 

170 tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value 

171 ) 

172 self._index[name][name] = value 

173 return self._index[name][name] 

174 

175 def __iter__(self) -> Generator[Tuple[Tunable, CovariantTunableGroup], None, None]: 

176 """ 

177 An iterator over all tunables in the group. 

178 

179 Returns 

180 ------- 

181 [(tunable, group), ...] : Generator[Tuple[Tunable, CovariantTunableGroup], None, None] 

182 An iterator over all tunables in all groups. Each element is a 2-tuple 

183 of an instance of the Tunable parameter and covariant group it belongs to. 

184 """ 

185 return ((group.get_tunable(name), group) for (name, group) in self._index.items()) 

186 

187 def get_tunable(self, tunable: Union[str, Tunable]) -> Tuple[Tunable, CovariantTunableGroup]: 

188 """ 

189 Access the entire Tunable (not just its value) and its covariant group. Throw 

190 KeyError if the tunable is not found. 

191 

192 Parameters 

193 ---------- 

194 tunable : Union[str, Tunable] 

195 Name of the tunable parameter. 

196 

197 Returns 

198 ------- 

199 (tunable, group) : (Tunable, CovariantTunableGroup) 

200 A 2-tuple of an instance of the Tunable parameter and covariant group it belongs to. 

201 """ 

202 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

203 group = self._index[name] 

204 return (group.get_tunable(name), group) 

205 

206 def get_covariant_group_names(self) -> Iterable[str]: 

207 """ 

208 Get the names of all covariance groups in the collection. 

209 

210 Returns 

211 ------- 

212 group_names : [str] 

213 IDs of the covariant tunable groups. 

214 """ 

215 return self._tunable_groups.keys() 

216 

217 def subgroup(self, group_names: Iterable[str]) -> "TunableGroups": 

218 """ 

219 Select the covariance groups from the current set and create a new TunableGroups 

220 object that consists of those covariance groups. 

221 

222 Note: The new TunableGroup will include *references* (not copies) to 

223 original ones, so each will get updated together. 

224 This is often desirable to support the use case of multiple related 

225 Environments (e.g. Local vs Remote) using the same set of tunables 

226 within a CompositeEnvironment. 

227 

228 Parameters 

229 ---------- 

230 group_names : list of str 

231 IDs of the covariant tunable groups. 

232 

233 Returns 

234 ------- 

235 tunables : TunableGroups 

236 A collection of covariant tunable groups. 

237 """ 

238 # pylint: disable=protected-access 

239 tunables = TunableGroups() 

240 for name in group_names: 

241 if name not in self._tunable_groups: 

242 raise KeyError(f"Unknown covariant group name '{name}' in tunable group {self}") 

243 tunables._add_group(self._tunable_groups[name]) 

244 return tunables 

245 

246 def get_param_values( 

247 self, 

248 group_names: Optional[Iterable[str]] = None, 

249 into_params: Optional[Dict[str, TunableValue]] = None, 

250 ) -> Dict[str, TunableValue]: 

251 """ 

252 Get the current values of the tunables that belong to the specified covariance 

253 groups. 

254 

255 Parameters 

256 ---------- 

257 group_names : list of str or None 

258 IDs of the covariant tunable groups. 

259 Select parameters from all groups if omitted. 

260 into_params : dict 

261 An optional dict to copy the parameters and their values into. 

262 

263 Returns 

264 ------- 

265 into_params : dict 

266 Flat dict of all parameters and their values from given covariance groups. 

267 """ 

268 if group_names is None: 

269 group_names = self.get_covariant_group_names() 

270 if into_params is None: 

271 into_params = {} 

272 for name in group_names: 

273 into_params.update(self._tunable_groups[name].get_tunable_values_dict()) 

274 return into_params 

275 

276 def is_updated(self, group_names: Optional[Iterable[str]] = None) -> bool: 

277 """ 

278 Check if any of the given covariant tunable groups has been updated. 

279 

280 Parameters 

281 ---------- 

282 group_names : list of str or None 

283 IDs of the (covariant) tunable groups. Check all groups if omitted. 

284 

285 Returns 

286 ------- 

287 is_updated : bool 

288 True if any of the specified tunable groups has been updated, False otherwise. 

289 """ 

290 return any( 

291 self._tunable_groups[name].is_updated() 

292 for name in (group_names or self.get_covariant_group_names()) 

293 ) 

294 

295 def is_defaults(self) -> bool: 

296 """ 

297 Checks whether the currently assigned values of all tunables are at their 

298 defaults. 

299 

300 Returns 

301 ------- 

302 bool 

303 """ 

304 return all(group.is_defaults() for group in self._tunable_groups.values()) 

305 

306 def restore_defaults(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups": 

307 """ 

308 Restore all tunable parameters to their default values. 

309 

310 Parameters 

311 ---------- 

312 group_names : list of str or None 

313 IDs of the (covariant) tunable groups. Restore all groups if omitted. 

314 

315 Returns 

316 ------- 

317 self : TunableGroups 

318 Self-reference for chaining. 

319 """ 

320 for name in group_names or self.get_covariant_group_names(): 

321 self._tunable_groups[name].restore_defaults() 

322 return self 

323 

324 def reset(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups": 

325 """ 

326 Clear the update flag of given covariant groups. 

327 

328 Parameters 

329 ---------- 

330 group_names : list of str or None 

331 IDs of the (covariant) tunable groups. Reset all groups if omitted. 

332 

333 Returns 

334 ------- 

335 self : TunableGroups 

336 Self-reference for chaining. 

337 """ 

338 for name in group_names or self.get_covariant_group_names(): 

339 self._tunable_groups[name].reset_is_updated() 

340 return self 

341 

342 def assign(self, param_values: Mapping[str, TunableValue]) -> "TunableGroups": 

343 """ 

344 In-place update the values of the tunables from the dictionary of (key, value) 

345 pairs. 

346 

347 Parameters 

348 ---------- 

349 param_values : Mapping[str, TunableValue] 

350 Dictionary mapping Tunable parameter names to new values. 

351 

352 As a special behavior when the mapping is empty the method will restore 

353 the default values rather than no-op. 

354 This allows an empty dictionary in json configs to be used to reset the 

355 tunables to defaults without having to copy the original values from the 

356 tunable_params definition. 

357 

358 Returns 

359 ------- 

360 self : TunableGroups 

361 Self-reference for chaining. 

362 """ 

363 if not param_values: 

364 _LOG.info("Empty tunable values set provided. Resetting all tunables to defaults.") 

365 return self.restore_defaults() 

366 for key, value in param_values.items(): 

367 self[key] = value 

368 return self