Coverage for mlos_bench/mlos_bench/tunables/tunable_groups.py: 99%

94 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 00:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6TunableGroups definition. 

7 

8A collection of :py:class:`.CovariantTunableGroup` s of :py:class:`.Tunable` 

9parameters. 

10 

11Used to define the configuration space for an 

12:py:class:`~mlos_bench.environments.base_environment.Environment` for an 

13:py:class:`~mlos_bench.optimizers.base_optimizer.Optimizer` to explore. 

14 

15Config 

16++++++ 

17 

18The configuration of the tunable parameters is generally given via a JSON config file. 

19The syntax looks something like this: 

20 

21.. code-block:: json 

22 

23 { // starts a TunableGroups config (e.g., for one Environment) 

24 "group1": { // starts a CovariantTunableGroup config 

25 "cost": 7, 

26 "params": { 

27 "param1": { // starts a Tunable config, named "param1" 

28 "type": "int", 

29 "range": [0, 100], 

30 "default": 50 

31 }, 

32 "param2": { // starts a new Tunable config, named "param2", within that same group 

33 "type": "float", 

34 "range": [0.0, 100.0], 

35 "default": 50.0 

36 }, 

37 "param3": { 

38 "type": "categorical", 

39 "values": ["on", "off", "auto"], 

40 "default": "auto" 

41 } 

42 } 

43 }, 

44 "group2": { // starts a new CovariantTunableGroup config 

45 "cost": 7, 

46 "params": { 

47 "some_param1": { 

48 "type": "int", 

49 "range": [0, 10], 

50 "default": 5 

51 }, 

52 "some_param2": { 

53 "type": "float", 

54 "range": [0.0, 100.0], 

55 "default": 50.0 

56 }, 

57 "some_param3": { 

58 "type": "categorical", 

59 "values": ["red", "green", "blue"], 

60 "default": "green" 

61 } 

62 } 

63 } 

64 } 

65 

66The JSON config is expected to be a dictionary of covariant tunable groups. 

67 

68Each covariant group has a name and a cost associated with changing any/all of the 

69parameters in that covariant group. 

70 

71Each group has a dictionary of :py:class:`.Tunable` parameters, where the key is 

72the name of the parameter and the value is a dictionary of the parameter's 

73configuration (see the :py:class:`.Tunable` class for more information on the 

74different ways they can be configured). 

75 

76Generally tunables are associated with an 

77:py:class:`~mlos_bench.environments.base_environment.Environment` and included along 

78with the Environment's config directory (e.g., ``env-name-tunables.mlos.jsonc``) and 

79referenced in the Environment config using the ``include_tunables`` property. 

80 

81See Also 

82-------- 

83:py:mod:`mlos_bench.tunables` : 

84 For more information on tunable parameters and their configuration. 

85:py:mod:`mlos_bench.tunables.tunable` : 

86 Tunable parameter definition. 

87:py:mod:`mlos_bench.config` : 

88 Configuration system for mlos_bench. 

89:py:mod:`mlos_bench.environments` : 

90 Environment configuration and setup. 

91""" 

92 

93import copy 

94import logging 

95from collections.abc import Generator, Iterable, Mapping 

96 

97from mlos_bench.config.schemas import ConfigSchema 

98from mlos_bench.tunables.covariant_group import CovariantTunableGroup 

99from mlos_bench.tunables.tunable import Tunable 

100from mlos_bench.tunables.tunable_types import TunableValue 

101 

102_LOG = logging.getLogger(__name__) 

103 

104 

105class TunableGroups: 

106 """A collection of :py:class:`.CovariantTunableGroup` s of :py:class:`.Tunable` 

107 parameters. 

108 """ 

109 

110 def __init__(self, config: dict | None = None): 

111 """ 

112 Create a new group of tunable parameters. 

113 

114 Parameters 

115 ---------- 

116 config : dict 

117 Python dict of serialized representation of the covariant tunable groups. 

118 

119 See Also 

120 -------- 

121 :py:mod:`mlos_bench.tunables` : 

122 For more information on tunable parameters and their configuration. 

123 """ 

124 if config is None: 

125 config = {} 

126 ConfigSchema.TUNABLE_PARAMS.validate(config) 

127 # Index (Tunable id -> CovariantTunableGroup) 

128 self._index: dict[str, CovariantTunableGroup] = {} 

129 self._tunable_groups: dict[str, CovariantTunableGroup] = {} 

130 for name, group_config in config.items(): 

131 self._add_group(CovariantTunableGroup(name, group_config)) 

132 

133 def __bool__(self) -> bool: 

134 return bool(self._index) 

135 

136 def __len__(self) -> int: 

137 return len(self._index) 

138 

139 def __eq__(self, other: object) -> bool: 

140 """ 

141 Check if two TunableGroups are equal. 

142 

143 Parameters 

144 ---------- 

145 other : TunableGroups 

146 A tunable groups object to compare to. 

147 

148 Returns 

149 ------- 

150 is_equal : bool 

151 True if two TunableGroups are equal. 

152 """ 

153 if not isinstance(other, TunableGroups): 

154 return False 

155 return bool(self._tunable_groups == other._tunable_groups) 

156 

157 def copy(self) -> "TunableGroups": 

158 """ 

159 Deep copy of the TunableGroups object. 

160 

161 Returns 

162 ------- 

163 tunables : TunableGroups 

164 A new instance of the TunableGroups object 

165 that is a deep copy of the original one. 

166 """ 

167 return copy.deepcopy(self) 

168 

169 def _add_group(self, group: CovariantTunableGroup) -> None: 

170 """ 

171 Add a CovariantTunableGroup to the current collection. 

172 

173 Note: non-overlapping groups are expected to be added to the collection. 

174 

175 Parameters 

176 ---------- 

177 group : CovariantTunableGroup 

178 """ 

179 assert ( 

180 group.name not in self._tunable_groups 

181 ), f"Duplicate covariant tunable group name {group.name} in {self}" 

182 self._tunable_groups[group.name] = group 

183 for tunable in group.get_tunables(): 

184 if tunable.name in self._index: 

185 raise ValueError( 

186 f"Duplicate Tunable {tunable.name} from group {group.name} in {self}" 

187 ) 

188 self._index[tunable.name] = group 

189 

190 def merge(self, tunables: "TunableGroups") -> "TunableGroups": 

191 """ 

192 Merge the two collections of covariant tunable groups. 

193 

194 Unlike the dict `update` method, this method does not modify the 

195 original when overlapping keys are found. 

196 It is expected be used to merge the tunable groups referenced by a 

197 standalone Environment config into a parent CompositeEnvironment, 

198 for instance. 

199 This allows self contained, potentially overlapping, but also 

200 overridable configs to be composed together. 

201 

202 Parameters 

203 ---------- 

204 tunables : TunableGroups 

205 A collection of covariant tunable groups. 

206 

207 Returns 

208 ------- 

209 self : TunableGroups 

210 Self-reference for chaining. 

211 """ 

212 # pylint: disable=protected-access 

213 # Check that covariant groups are unique, else throw an error. 

214 for group in tunables._tunable_groups.values(): 

215 if group.name not in self._tunable_groups: 

216 self._add_group(group) 

217 else: 

218 # Check that there's no overlap in the tunables. 

219 # But allow for differing current values. 

220 if not self._tunable_groups[group.name].equals_defaults(group): 

221 raise ValueError( 

222 f"Overlapping covariant tunable group name {group.name} " 

223 "in {self._tunable_groups[group.name]} and {tunables}" 

224 ) 

225 return self 

226 

227 def __repr__(self) -> str: 

228 """ 

229 Produce a human-readable version of the TunableGroups (mostly for logging). 

230 

231 Returns 

232 ------- 

233 string : str 

234 A human-readable version of the TunableGroups. 

235 """ 

236 return ( 

237 "{ " 

238 + ", ".join( 

239 f"{group.name}::{tunable}" 

240 for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name)) 

241 for tunable in sorted(group._tunables.values()) 

242 ) 

243 + " }" 

244 ) 

245 

246 def __contains__(self, tunable: str | Tunable) -> bool: 

247 """Checks if the given name/tunable is in this tunable group.""" 

248 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

249 return name in self._index 

250 

251 def __getitem__(self, tunable: str | Tunable) -> TunableValue: 

252 """Get the current value of a single tunable parameter.""" 

253 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

254 return self._index[name][name] 

255 

256 def __setitem__( 

257 self, 

258 tunable: str | Tunable, 

259 tunable_value: TunableValue | Tunable, 

260 ) -> TunableValue: 

261 """Update the current value of a single tunable parameter.""" 

262 # Use double index to make sure we set the is_updated flag of the group 

263 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

264 value: TunableValue = ( 

265 tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value 

266 ) 

267 self._index[name][name] = value 

268 return self._index[name][name] 

269 

270 def __iter__(self) -> Generator[tuple[Tunable, CovariantTunableGroup]]: 

271 """ 

272 An iterator over all tunables in the group. 

273 

274 Returns 

275 ------- 

276 [(tunable, group), ...] : Generator[tuple[Tunable, CovariantTunableGroup]] 

277 An iterator over all tunables in all groups. Each element is a 2-tuple 

278 of an instance of the Tunable parameter and covariant group it belongs to. 

279 """ 

280 return ((group.get_tunable(name), group) for (name, group) in self._index.items()) 

281 

282 def get_tunable(self, tunable: str | Tunable) -> tuple[Tunable, CovariantTunableGroup]: 

283 """ 

284 Access the entire Tunable (not just its value) and its covariant group. Throw 

285 KeyError if the tunable is not found. 

286 

287 Parameters 

288 ---------- 

289 tunable : Union[str, Tunable] 

290 Name of the tunable parameter. 

291 

292 Returns 

293 ------- 

294 (tunable, group) : (Tunable, CovariantTunableGroup) 

295 A 2-tuple of an instance of the Tunable parameter and covariant group it belongs to. 

296 """ 

297 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

298 group = self._index[name] 

299 return (group.get_tunable(name), group) 

300 

301 def get_covariant_group_names(self) -> Iterable[str]: 

302 """ 

303 Get the names of all covariance groups in the collection. 

304 

305 Returns 

306 ------- 

307 group_names : [str] 

308 IDs of the covariant tunable groups. 

309 """ 

310 return self._tunable_groups.keys() 

311 

312 def subgroup(self, group_names: Iterable[str]) -> "TunableGroups": 

313 """ 

314 Select the covariance groups from the current set and create a new TunableGroups 

315 object that consists of those covariance groups. 

316 

317 Note: The new TunableGroup will include *references* (not copies) to 

318 original ones, so each will get updated together. 

319 This is often desirable to support the use case of multiple related 

320 Environments (e.g. Local vs Remote) using the same set of tunables 

321 within a CompositeEnvironment. 

322 

323 Parameters 

324 ---------- 

325 group_names : list of str 

326 IDs of the covariant tunable groups. 

327 

328 Returns 

329 ------- 

330 tunables : TunableGroups 

331 A collection of covariant tunable groups. 

332 """ 

333 # pylint: disable=protected-access 

334 tunables = TunableGroups() 

335 for name in group_names: 

336 if name not in self._tunable_groups: 

337 raise KeyError(f"Unknown covariant group name '{name}' in tunable group {self}") 

338 tunables._add_group(self._tunable_groups[name]) 

339 return tunables 

340 

341 def get_param_values( 

342 self, 

343 group_names: Iterable[str] | None = None, 

344 into_params: dict[str, TunableValue] | None = None, 

345 ) -> dict[str, TunableValue]: 

346 """ 

347 Get the current values of the tunables that belong to the specified covariance 

348 groups. 

349 

350 Parameters 

351 ---------- 

352 group_names : list of str or None 

353 IDs of the covariant tunable groups. 

354 Select parameters from all groups if omitted. 

355 into_params : dict 

356 An optional dict to copy the parameters and their values into. 

357 

358 Returns 

359 ------- 

360 into_params : dict 

361 Flat dict of all parameters and their values from given covariance groups. 

362 """ 

363 if group_names is None: 

364 group_names = self.get_covariant_group_names() 

365 if into_params is None: 

366 into_params = {} 

367 for name in group_names: 

368 into_params.update(self._tunable_groups[name].get_tunable_values_dict()) 

369 return into_params 

370 

371 def is_updated(self, group_names: Iterable[str] | None = None) -> bool: 

372 """ 

373 Check if any of the given covariant tunable groups has been updated. 

374 

375 Parameters 

376 ---------- 

377 group_names : list of str or None 

378 IDs of the (covariant) tunable groups. Check all groups if omitted. 

379 

380 Returns 

381 ------- 

382 is_updated : bool 

383 True if any of the specified tunable groups has been updated, False otherwise. 

384 """ 

385 return any( 

386 self._tunable_groups[name].is_updated() 

387 for name in (group_names or self.get_covariant_group_names()) 

388 ) 

389 

390 def is_defaults(self) -> bool: 

391 """ 

392 Checks whether the currently assigned values of all tunables are at their 

393 defaults. 

394 

395 Returns 

396 ------- 

397 bool 

398 """ 

399 return all(group.is_defaults() for group in self._tunable_groups.values()) 

400 

401 def restore_defaults(self, group_names: Iterable[str] | None = None) -> "TunableGroups": 

402 """ 

403 Restore all tunable parameters to their default values. 

404 

405 Parameters 

406 ---------- 

407 group_names : list of str or None 

408 IDs of the (covariant) tunable groups. Restore all groups if omitted. 

409 

410 Returns 

411 ------- 

412 self : TunableGroups 

413 Self-reference for chaining. 

414 """ 

415 for name in group_names or self.get_covariant_group_names(): 

416 self._tunable_groups[name].restore_defaults() 

417 return self 

418 

419 def reset(self, group_names: Iterable[str] | None = None) -> "TunableGroups": 

420 """ 

421 Clear the update flag of given covariant groups. 

422 

423 Parameters 

424 ---------- 

425 group_names : list of str or None 

426 IDs of the (covariant) tunable groups. Reset all groups if omitted. 

427 

428 Returns 

429 ------- 

430 self : TunableGroups 

431 Self-reference for chaining. 

432 """ 

433 for name in group_names or self.get_covariant_group_names(): 

434 self._tunable_groups[name].reset_is_updated() 

435 return self 

436 

437 def assign(self, param_values: Mapping[str, TunableValue]) -> "TunableGroups": 

438 """ 

439 In-place update the values of the tunables from the dictionary of (key, value) 

440 pairs. 

441 

442 Parameters 

443 ---------- 

444 param_values : Mapping[str, TunableValue] 

445 Dictionary mapping Tunable parameter names to new values. 

446 

447 As a special behavior when the mapping is empty (``{}``) the method will 

448 restore the default values rather than no-op. 

449 This allows an empty dictionary in json configs to be used to reset the 

450 tunables to defaults without having to copy the original values from the 

451 tunable_params definition. 

452 

453 Returns 

454 ------- 

455 self : TunableGroups 

456 Self-reference for chaining. 

457 """ 

458 if not param_values: 

459 _LOG.info("Empty tunable values set provided. Resetting all tunables to defaults.") 

460 return self.restore_defaults() 

461 for key, value in param_values.items(): 

462 self[key] = value 

463 return self