Coverage for mlos_bench/mlos_bench/tunables/tunable_groups.py: 99%

93 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""TunableGroups definition.""" 

6import copy 

7import logging 

8from typing import Dict, Generator, Iterable, Mapping, Optional, Tuple, Union 

9 

10from mlos_bench.config.schemas import ConfigSchema 

11from mlos_bench.tunables.covariant_group import CovariantTunableGroup 

12from mlos_bench.tunables.tunable import Tunable, TunableValue 

13 

14_LOG = logging.getLogger(__name__) 

15 

16 

17class TunableGroups: 

18 """A collection of :py:class:`.CovariantTunableGroup` s of :py:class:`.Tunable` 

19 parameters. 

20 """ 

21 

22 def __init__(self, config: Optional[dict] = None): 

23 """ 

24 Create a new group of tunable parameters. 

25 

26 Parameters 

27 ---------- 

28 config : dict 

29 Python dict of serialized representation of the covariant tunable groups. 

30 

31 See Also 

32 -------- 

33 :py:mod:`mlos_bench.tunables` : for more information on tunable parameters and 

34 their configuration. 

35 """ 

36 if config is None: 

37 config = {} 

38 ConfigSchema.TUNABLE_PARAMS.validate(config) 

39 # Index (Tunable id -> CovariantTunableGroup) 

40 self._index: Dict[str, CovariantTunableGroup] = {} 

41 self._tunable_groups: Dict[str, CovariantTunableGroup] = {} 

42 for name, group_config in config.items(): 

43 self._add_group(CovariantTunableGroup(name, group_config)) 

44 

45 def __bool__(self) -> bool: 

46 return bool(self._index) 

47 

48 def __len__(self) -> int: 

49 return len(self._index) 

50 

51 def __eq__(self, other: object) -> bool: 

52 """ 

53 Check if two TunableGroups are equal. 

54 

55 Parameters 

56 ---------- 

57 other : TunableGroups 

58 A tunable groups object to compare to. 

59 

60 Returns 

61 ------- 

62 is_equal : bool 

63 True if two TunableGroups are equal. 

64 """ 

65 if not isinstance(other, TunableGroups): 

66 return False 

67 return bool(self._tunable_groups == other._tunable_groups) 

68 

69 def copy(self) -> "TunableGroups": 

70 """ 

71 Deep copy of the TunableGroups object. 

72 

73 Returns 

74 ------- 

75 tunables : TunableGroups 

76 A new instance of the TunableGroups object 

77 that is a deep copy of the original one. 

78 """ 

79 return copy.deepcopy(self) 

80 

81 def _add_group(self, group: CovariantTunableGroup) -> None: 

82 """ 

83 Add a CovariantTunableGroup to the current collection. 

84 

85 Note: non-overlapping groups are expected to be added to the collection. 

86 

87 Parameters 

88 ---------- 

89 group : CovariantTunableGroup 

90 """ 

91 assert ( 

92 group.name not in self._tunable_groups 

93 ), f"Duplicate covariant tunable group name {group.name} in {self}" 

94 self._tunable_groups[group.name] = group 

95 for tunable in group.get_tunables(): 

96 if tunable.name in self._index: 

97 raise ValueError( 

98 f"Duplicate Tunable {tunable.name} from group {group.name} in {self}" 

99 ) 

100 self._index[tunable.name] = group 

101 

102 def merge(self, tunables: "TunableGroups") -> "TunableGroups": 

103 """ 

104 Merge the two collections of covariant tunable groups. 

105 

106 Unlike the dict `update` method, this method does not modify the 

107 original when overlapping keys are found. 

108 It is expected be used to merge the tunable groups referenced by a 

109 standalone Environment config into a parent CompositeEnvironment, 

110 for instance. 

111 This allows self contained, potentially overlapping, but also 

112 overridable configs to be composed together. 

113 

114 Parameters 

115 ---------- 

116 tunables : TunableGroups 

117 A collection of covariant tunable groups. 

118 

119 Returns 

120 ------- 

121 self : TunableGroups 

122 Self-reference for chaining. 

123 """ 

124 # pylint: disable=protected-access 

125 # Check that covariant groups are unique, else throw an error. 

126 for group in tunables._tunable_groups.values(): 

127 if group.name not in self._tunable_groups: 

128 self._add_group(group) 

129 else: 

130 # Check that there's no overlap in the tunables. 

131 # But allow for differing current values. 

132 if not self._tunable_groups[group.name].equals_defaults(group): 

133 raise ValueError( 

134 f"Overlapping covariant tunable group name {group.name} " 

135 "in {self._tunable_groups[group.name]} and {tunables}" 

136 ) 

137 return self 

138 

139 def __repr__(self) -> str: 

140 """ 

141 Produce a human-readable version of the TunableGroups (mostly for logging). 

142 

143 Returns 

144 ------- 

145 string : str 

146 A human-readable version of the TunableGroups. 

147 """ 

148 return ( 

149 "{ " 

150 + ", ".join( 

151 f"{group.name}::{tunable}" 

152 for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name)) 

153 for tunable in sorted(group._tunables.values()) 

154 ) 

155 + " }" 

156 ) 

157 

158 def __contains__(self, tunable: Union[str, Tunable]) -> bool: 

159 """Checks if the given name/tunable is in this tunable group.""" 

160 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

161 return name in self._index 

162 

163 def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: 

164 """Get the current value of a single tunable parameter.""" 

165 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

166 return self._index[name][name] 

167 

168 def __setitem__( 

169 self, 

170 tunable: Union[str, Tunable], 

171 tunable_value: Union[TunableValue, Tunable], 

172 ) -> TunableValue: 

173 """Update the current value of a single tunable parameter.""" 

174 # Use double index to make sure we set the is_updated flag of the group 

175 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

176 value: TunableValue = ( 

177 tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value 

178 ) 

179 self._index[name][name] = value 

180 return self._index[name][name] 

181 

182 def __iter__(self) -> Generator[Tuple[Tunable, CovariantTunableGroup], None, None]: 

183 """ 

184 An iterator over all tunables in the group. 

185 

186 Returns 

187 ------- 

188 [(tunable, group), ...] : Generator[Tuple[Tunable, CovariantTunableGroup], None, None] 

189 An iterator over all tunables in all groups. Each element is a 2-tuple 

190 of an instance of the Tunable parameter and covariant group it belongs to. 

191 """ 

192 return ((group.get_tunable(name), group) for (name, group) in self._index.items()) 

193 

194 def get_tunable(self, tunable: Union[str, Tunable]) -> Tuple[Tunable, CovariantTunableGroup]: 

195 """ 

196 Access the entire Tunable (not just its value) and its covariant group. Throw 

197 KeyError if the tunable is not found. 

198 

199 Parameters 

200 ---------- 

201 tunable : Union[str, Tunable] 

202 Name of the tunable parameter. 

203 

204 Returns 

205 ------- 

206 (tunable, group) : (Tunable, CovariantTunableGroup) 

207 A 2-tuple of an instance of the Tunable parameter and covariant group it belongs to. 

208 """ 

209 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

210 group = self._index[name] 

211 return (group.get_tunable(name), group) 

212 

213 def get_covariant_group_names(self) -> Iterable[str]: 

214 """ 

215 Get the names of all covariance groups in the collection. 

216 

217 Returns 

218 ------- 

219 group_names : [str] 

220 IDs of the covariant tunable groups. 

221 """ 

222 return self._tunable_groups.keys() 

223 

224 def subgroup(self, group_names: Iterable[str]) -> "TunableGroups": 

225 """ 

226 Select the covariance groups from the current set and create a new TunableGroups 

227 object that consists of those covariance groups. 

228 

229 Note: The new TunableGroup will include *references* (not copies) to 

230 original ones, so each will get updated together. 

231 This is often desirable to support the use case of multiple related 

232 Environments (e.g. Local vs Remote) using the same set of tunables 

233 within a CompositeEnvironment. 

234 

235 Parameters 

236 ---------- 

237 group_names : list of str 

238 IDs of the covariant tunable groups. 

239 

240 Returns 

241 ------- 

242 tunables : TunableGroups 

243 A collection of covariant tunable groups. 

244 """ 

245 # pylint: disable=protected-access 

246 tunables = TunableGroups() 

247 for name in group_names: 

248 if name not in self._tunable_groups: 

249 raise KeyError(f"Unknown covariant group name '{name}' in tunable group {self}") 

250 tunables._add_group(self._tunable_groups[name]) 

251 return tunables 

252 

253 def get_param_values( 

254 self, 

255 group_names: Optional[Iterable[str]] = None, 

256 into_params: Optional[Dict[str, TunableValue]] = None, 

257 ) -> Dict[str, TunableValue]: 

258 """ 

259 Get the current values of the tunables that belong to the specified covariance 

260 groups. 

261 

262 Parameters 

263 ---------- 

264 group_names : list of str or None 

265 IDs of the covariant tunable groups. 

266 Select parameters from all groups if omitted. 

267 into_params : dict 

268 An optional dict to copy the parameters and their values into. 

269 

270 Returns 

271 ------- 

272 into_params : dict 

273 Flat dict of all parameters and their values from given covariance groups. 

274 """ 

275 if group_names is None: 

276 group_names = self.get_covariant_group_names() 

277 if into_params is None: 

278 into_params = {} 

279 for name in group_names: 

280 into_params.update(self._tunable_groups[name].get_tunable_values_dict()) 

281 return into_params 

282 

283 def is_updated(self, group_names: Optional[Iterable[str]] = None) -> bool: 

284 """ 

285 Check if any of the given covariant tunable groups has been updated. 

286 

287 Parameters 

288 ---------- 

289 group_names : list of str or None 

290 IDs of the (covariant) tunable groups. Check all groups if omitted. 

291 

292 Returns 

293 ------- 

294 is_updated : bool 

295 True if any of the specified tunable groups has been updated, False otherwise. 

296 """ 

297 return any( 

298 self._tunable_groups[name].is_updated() 

299 for name in (group_names or self.get_covariant_group_names()) 

300 ) 

301 

302 def is_defaults(self) -> bool: 

303 """ 

304 Checks whether the currently assigned values of all tunables are at their 

305 defaults. 

306 

307 Returns 

308 ------- 

309 bool 

310 """ 

311 return all(group.is_defaults() for group in self._tunable_groups.values()) 

312 

313 def restore_defaults(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups": 

314 """ 

315 Restore all tunable parameters to their default values. 

316 

317 Parameters 

318 ---------- 

319 group_names : list of str or None 

320 IDs of the (covariant) tunable groups. Restore all groups if omitted. 

321 

322 Returns 

323 ------- 

324 self : TunableGroups 

325 Self-reference for chaining. 

326 """ 

327 for name in group_names or self.get_covariant_group_names(): 

328 self._tunable_groups[name].restore_defaults() 

329 return self 

330 

331 def reset(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups": 

332 """ 

333 Clear the update flag of given covariant groups. 

334 

335 Parameters 

336 ---------- 

337 group_names : list of str or None 

338 IDs of the (covariant) tunable groups. Reset all groups if omitted. 

339 

340 Returns 

341 ------- 

342 self : TunableGroups 

343 Self-reference for chaining. 

344 """ 

345 for name in group_names or self.get_covariant_group_names(): 

346 self._tunable_groups[name].reset_is_updated() 

347 return self 

348 

349 def assign(self, param_values: Mapping[str, TunableValue]) -> "TunableGroups": 

350 """ 

351 In-place update the values of the tunables from the dictionary of (key, value) 

352 pairs. 

353 

354 Parameters 

355 ---------- 

356 param_values : Mapping[str, TunableValue] 

357 Dictionary mapping Tunable parameter names to new values. 

358 

359 As a special behavior when the mapping is empty the method will restore 

360 the default values rather than no-op. 

361 This allows an empty dictionary in json configs to be used to reset the 

362 tunables to defaults without having to copy the original values from the 

363 tunable_params definition. 

364 

365 Returns 

366 ------- 

367 self : TunableGroups 

368 Self-reference for chaining. 

369 """ 

370 if not param_values: 

371 _LOG.info("Empty tunable values set provided. Resetting all tunables to defaults.") 

372 return self.restore_defaults() 

373 for key, value in param_values.items(): 

374 self[key] = value 

375 return self