Coverage for mlos_bench/mlos_bench/storage/base_storage.py: 97%

174 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-14 00:55 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Base interface for saving and restoring the benchmark data. 

7 

8See Also 

9-------- 

10mlos_bench.storage.base_storage.Storage.experiments : 

11 Retrieves a dictionary of the Experiments' data. 

12mlos_bench.storage.base_experiment_data.ExperimentData.results_df : 

13 Retrieves a pandas DataFrame of the Experiment's trials' results data. 

14mlos_bench.storage.base_experiment_data.ExperimentData.trials : 

15 Retrieves a dictionary of the Experiment's trials' data. 

16mlos_bench.storage.base_experiment_data.ExperimentData.tunable_configs : 

17 Retrieves a dictionary of the Experiment's sampled configs data. 

18mlos_bench.storage.base_experiment_data.ExperimentData.tunable_config_trial_groups : 

19 Retrieves a dictionary of the Experiment's trials' data, grouped by shared 

20 tunable config. 

21mlos_bench.storage.base_trial_data.TrialData : 

22 Base interface for accessing the stored benchmark trial data. 

23""" 

24 

25from __future__ import annotations 

26 

27import logging 

28from abc import ABCMeta, abstractmethod 

29from collections.abc import Iterator, Mapping 

30from contextlib import AbstractContextManager as ContextManager 

31from datetime import datetime 

32from types import TracebackType 

33from typing import Any, Literal 

34 

35from mlos_bench.config.schemas import ConfigSchema 

36from mlos_bench.dict_templater import DictTemplater 

37from mlos_bench.environments.status import Status 

38from mlos_bench.services.base_service import Service 

39from mlos_bench.storage.base_experiment_data import ExperimentData 

40from mlos_bench.tunables.tunable_groups import TunableGroups 

41from mlos_bench.util import get_git_info 

42 

43_LOG = logging.getLogger(__name__) 

44 

45 

46class Storage(metaclass=ABCMeta): 

47 """An abstract interface between the benchmarking framework and storage systems 

48 (e.g., SQLite or MLFLow). 

49 """ 

50 

51 def __init__( 

52 self, 

53 config: dict[str, Any], 

54 global_config: dict | None = None, 

55 service: Service | None = None, 

56 ): 

57 """ 

58 Create a new storage object. 

59 

60 Parameters 

61 ---------- 

62 config : dict 

63 Free-format key/value pairs of configuration parameters. 

64 """ 

65 _LOG.debug("Storage config: %s", config) 

66 self._validate_json_config(config) 

67 self._service = service 

68 self._config = config.copy() 

69 self._global_config = global_config or {} 

70 

71 @abstractmethod 

72 def update_schema(self) -> None: 

73 """Update the schema of the storage backend if needed.""" 

74 

75 def _validate_json_config(self, config: dict) -> None: 

76 """Reconstructs a basic json config that this class might have been instantiated 

77 from in order to validate configs provided outside the file loading 

78 mechanism. 

79 """ 

80 json_config: dict = { 

81 "class": self.__class__.__module__ + "." + self.__class__.__name__, 

82 } 

83 if config: 

84 json_config["config"] = config 

85 ConfigSchema.STORAGE.validate(json_config) 

86 

87 @property 

88 @abstractmethod 

89 def experiments(self) -> dict[str, ExperimentData]: 

90 """ 

91 Retrieve the experiments' data from the storage. 

92 

93 Returns 

94 ------- 

95 experiments : dict[str, ExperimentData] 

96 A dictionary of the experiments' data, keyed by experiment id. 

97 """ 

98 

99 @abstractmethod 

100 def get_experiment_by_id( 

101 self, 

102 experiment_id: str, 

103 tunables: TunableGroups, 

104 opt_targets: dict[str, Literal["min", "max"]], 

105 ) -> Storage.Experiment | None: 

106 """ 

107 Gets an Experiment by its ID. 

108 

109 Parameters 

110 ---------- 

111 experiment_id : str 

112 ID of the Experiment to retrieve. 

113 tunables : TunableGroups 

114 The tunables for the Experiment. 

115 opt_targets : dict[str, Literal["min", "max"]] 

116 The optimization targets for the Experiment's 

117 :py:class:`~mlos_bench.optimizers.base_optimizer.Optimizer`. 

118 

119 Returns 

120 ------- 

121 experiment : Storage.Experiment | None 

122 The Experiment object, or None if it doesn't exist. 

123 

124 Notes 

125 ----- 

126 Tunables are not stored in the database for the Experiment, only for the 

127 Trials, so currently they can change if the user (incorrectly) adjusts 

128 the configs on disk between resume runs. 

129 Since this method is generally meant to load th Experiment from the 

130 database for a child process to execute a Trial in the background we are 

131 generally safe to simply pass these values from the parent process 

132 rather than look them up in the database. 

133 """ 

134 

135 @abstractmethod 

136 def experiment( # pylint: disable=too-many-arguments 

137 self, 

138 *, 

139 experiment_id: str, 

140 trial_id: int, 

141 root_env_config: str, 

142 description: str, 

143 tunables: TunableGroups, 

144 opt_targets: dict[str, Literal["min", "max"]], 

145 ) -> Storage.Experiment: 

146 """ 

147 Create or reload an experiment in the Storage. 

148 

149 Notes 

150 ----- 

151 We need the `opt_target` parameter here to know what metric to retrieve 

152 when we load the data from previous trials. Later we will replace it with 

153 full metadata about the optimization direction, multiple objectives, etc. 

154 

155 Parameters 

156 ---------- 

157 experiment_id : str 

158 Unique identifier of the experiment. 

159 trial_id : int 

160 Starting number of the trial. 

161 root_env_config : str 

162 A path to the root JSON configuration file of the benchmarking environment. 

163 description : str 

164 Human-readable description of the experiment. 

165 tunables : TunableGroups 

166 opt_targets : dict[str, Literal["min", "max"]] 

167 Names of metrics we're optimizing for and the optimization direction {min, max}. 

168 

169 Returns 

170 ------- 

171 experiment : Storage.Experiment 

172 An object that allows to update the storage with 

173 the results of the experiment and related data. 

174 """ 

175 

176 class Experiment(ContextManager, metaclass=ABCMeta): 

177 # pylint: disable=too-many-instance-attributes 

178 """ 

179 Base interface for storing the results of the experiment. 

180 

181 This class is instantiated in the `Storage.experiment()` method. 

182 """ 

183 

184 def __init__( # pylint: disable=too-many-arguments 

185 self, 

186 *, 

187 tunables: TunableGroups, 

188 experiment_id: str, 

189 trial_id: int, 

190 root_env_config: str, 

191 description: str, 

192 opt_targets: dict[str, Literal["min", "max"]], 

193 ): 

194 self._tunables = tunables.copy() 

195 self._trial_id = trial_id 

196 self._experiment_id = experiment_id 

197 (self._git_repo, self._git_commit, self._root_env_config) = get_git_info( 

198 root_env_config 

199 ) 

200 self._description = description 

201 self._opt_targets = opt_targets 

202 self._in_context = False 

203 

204 def __enter__(self) -> Storage.Experiment: 

205 """ 

206 Enter the context of the experiment. 

207 

208 Override the `_setup` method to add custom context initialization. 

209 """ 

210 _LOG.debug("Starting experiment: %s", self) 

211 assert not self._in_context 

212 self._setup() 

213 self._in_context = True 

214 return self 

215 

216 def __exit__( 

217 self, 

218 exc_type: type[BaseException] | None, 

219 exc_val: BaseException | None, 

220 exc_tb: TracebackType | None, 

221 ) -> Literal[False]: 

222 """ 

223 End the context of the experiment. 

224 

225 Override the `_teardown` method to add custom context teardown logic. 

226 """ 

227 is_ok = exc_val is None 

228 if is_ok: 

229 _LOG.debug("Finishing experiment: %s", self) 

230 else: 

231 assert exc_type and exc_val 

232 _LOG.warning( 

233 "Finishing experiment: %s", 

234 self, 

235 exc_info=(exc_type, exc_val, exc_tb), 

236 ) 

237 assert self._in_context 

238 self._teardown(is_ok) 

239 self._in_context = False 

240 return False # Do not suppress exceptions 

241 

242 def __repr__(self) -> str: 

243 return self._experiment_id 

244 

245 def _setup(self) -> None: 

246 """ 

247 Create a record of the new experiment or find an existing one in the 

248 storage. 

249 

250 This method is called by `Storage.Experiment.__enter__()`. 

251 """ 

252 

253 def _teardown(self, is_ok: bool) -> None: 

254 """ 

255 Finalize the experiment in the storage. 

256 

257 This method is called by `Storage.Experiment.__exit__()`. 

258 

259 Parameters 

260 ---------- 

261 is_ok : bool 

262 True if there were no exceptions during the experiment, False otherwise. 

263 """ 

264 

265 @property 

266 def experiment_id(self) -> str: 

267 """Get the Experiment's ID.""" 

268 return self._experiment_id 

269 

270 @property 

271 def trial_id(self) -> int: 

272 """Get the current Trial ID.""" 

273 return self._trial_id 

274 

275 @property 

276 def description(self) -> str: 

277 """Get the Experiment's description.""" 

278 return self._description 

279 

280 @property 

281 def root_env_config(self) -> str: 

282 """Get the Experiment's root Environment config file path.""" 

283 return self._root_env_config 

284 

285 @property 

286 def tunables(self) -> TunableGroups: 

287 """Get the Experiment's tunables.""" 

288 return self._tunables 

289 

290 @property 

291 def opt_targets(self) -> dict[str, Literal["min", "max"]]: 

292 """Get the Experiment's optimization targets and directions.""" 

293 return self._opt_targets 

294 

295 @abstractmethod 

296 def merge(self, experiment_ids: list[str]) -> None: 

297 """ 

298 Merge in the results of other (compatible) experiments trials. Used to help 

299 warm up the optimizer for this experiment. 

300 

301 Parameters 

302 ---------- 

303 experiment_ids : list[str] 

304 List of IDs of the experiments to merge in. 

305 """ 

306 

307 @abstractmethod 

308 def load_tunable_config(self, config_id: int) -> dict[str, Any]: 

309 """Load tunable values for a given config ID.""" 

310 

311 @abstractmethod 

312 def load_telemetry(self, trial_id: int) -> list[tuple[datetime, str, Any]]: 

313 """ 

314 Retrieve the telemetry data for a given trial. 

315 

316 Parameters 

317 ---------- 

318 trial_id : int 

319 Trial ID. 

320 

321 Returns 

322 ------- 

323 metrics : list[tuple[datetime.datetime, str, Any]] 

324 Telemetry data. 

325 """ 

326 

327 @abstractmethod 

328 def load( 

329 self, 

330 last_trial_id: int = -1, 

331 ) -> tuple[list[int], list[dict], list[dict[str, Any] | None], list[Status]]: 

332 """ 

333 Load (tunable values, benchmark scores, status) to warm-up the optimizer. 

334 

335 If `last_trial_id` is present, load only the data from the (completed) trials 

336 that were scheduled *after* the given trial ID. Otherwise, return data from ALL 

337 merged-in experiments and attempt to impute the missing tunable values. 

338 

339 Parameters 

340 ---------- 

341 last_trial_id : int 

342 (Optional) Trial ID to start from. 

343 

344 Returns 

345 ------- 

346 (trial_ids, configs, scores, status) : ([int], [dict], [dict] | None, [Status]) 

347 Trial ids, Tunable values, benchmark scores, and status of the trials. 

348 """ 

349 

350 @abstractmethod 

351 def get_trial_by_id( 

352 self, 

353 trial_id: int, 

354 ) -> Storage.Trial | None: 

355 """ 

356 Gets a Trial by its ID. 

357 

358 Parameters 

359 ---------- 

360 trial_id : int 

361 ID of the Trial to retrieve for this Experiment. 

362 

363 Returns 

364 ------- 

365 trial : Storage.Trial | None 

366 The Trial object, or None if it doesn't exist. 

367 """ 

368 

369 @abstractmethod 

370 def pending_trials( 

371 self, 

372 timestamp: datetime, 

373 *, 

374 running: bool, 

375 trial_runner_assigned: bool | None = None, 

376 ) -> Iterator[Storage.Trial]: 

377 """ 

378 Return an iterator over :py:attr:`~.Status.PENDING` 

379 :py:class:`~.Storage.Trial` instances that have a scheduled start time to 

380 run on or before the specified timestamp. 

381 

382 Parameters 

383 ---------- 

384 timestamp : datetime.datetime 

385 The time in UTC to check for scheduled Trials. 

386 running : bool 

387 If True, include the Trials that are also 

388 :py:attr:`~.Status.RUNNING` or :py:attr:`~.Status.READY`. 

389 Otherwise, return only the scheduled trials. 

390 trial_runner_assigned : bool | None 

391 If True, include the Trials that are assigned to a 

392 :py:class:`~.TrialRunner`. If False, return only the trials 

393 that are not assigned to any :py:class:`~.TrialRunner`. 

394 If None, return all trials regardless of their assignment. 

395 

396 Returns 

397 ------- 

398 trials : Iterator[Storage.Trial] 

399 An iterator over the scheduled (and maybe running) trials. 

400 """ 

401 

402 def new_trial( 

403 self, 

404 tunables: TunableGroups, 

405 ts_start: datetime | None = None, 

406 config: dict[str, Any] | None = None, 

407 ) -> Storage.Trial: 

408 """ 

409 Create a new experiment run in the storage. 

410 

411 Parameters 

412 ---------- 

413 tunables : TunableGroups 

414 Tunable parameters to use for the trial. 

415 ts_start : datetime.datetime | None 

416 Timestamp of the trial start (can be in the future). 

417 config : dict 

418 Key/value pairs of additional non-tunable parameters of the trial. 

419 

420 Returns 

421 ------- 

422 trial : Storage.Trial 

423 An object that allows to update the storage with 

424 the results of the experiment trial run. 

425 """ 

426 # Check that `config` is json serializable (e.g., no callables) 

427 if config: 

428 try: 

429 # Relies on the fact that DictTemplater only accepts primitive 

430 # types in it's nested dict structure walk. 

431 _config = DictTemplater(config).expand_vars() 

432 assert isinstance(_config, dict) 

433 except ValueError as e: 

434 _LOG.error("Non-serializable config: %s", config, exc_info=e) 

435 raise e 

436 return self._new_trial(tunables, ts_start, config) 

437 

438 @abstractmethod 

439 def _new_trial( 

440 self, 

441 tunables: TunableGroups, 

442 ts_start: datetime | None = None, 

443 config: dict[str, Any] | None = None, 

444 ) -> Storage.Trial: 

445 """ 

446 Create a new experiment run in the storage. 

447 

448 Parameters 

449 ---------- 

450 tunables : TunableGroups 

451 Tunable parameters to use for the trial. 

452 ts_start : datetime.datetime | None 

453 Timestamp of the trial start (can be in the future). 

454 config : dict 

455 Key/value pairs of additional non-tunable parameters of the trial. 

456 

457 Returns 

458 ------- 

459 trial : Storage.Trial 

460 An object that allows to update the storage with 

461 the results of the experiment trial run. 

462 """ 

463 

464 class Trial(metaclass=ABCMeta): 

465 # pylint: disable=too-many-instance-attributes 

466 """ 

467 Base interface for storing the results of a single run of the experiment. 

468 

469 This class is instantiated in the `Storage.Experiment.trial()` method. 

470 """ 

471 

472 def __init__( # pylint: disable=too-many-arguments 

473 self, 

474 *, 

475 tunables: TunableGroups, 

476 experiment_id: str, 

477 trial_id: int, 

478 tunable_config_id: int, 

479 trial_runner_id: int | None, 

480 opt_targets: dict[str, Literal["min", "max"]], 

481 status: Status, 

482 restoring: bool, 

483 config: dict[str, Any] | None = None, 

484 ): 

485 if not restoring and status not in (Status.UNKNOWN, Status.PENDING): 

486 raise ValueError(f"Invalid status for a new trial: {status}") 

487 self._tunables = tunables 

488 self._experiment_id = experiment_id 

489 self._trial_id = trial_id 

490 self._tunable_config_id = tunable_config_id 

491 self._trial_runner_id = trial_runner_id 

492 self._opt_targets = opt_targets 

493 self._config = config or {} 

494 self._status = status 

495 

496 def __repr__(self) -> str: 

497 return ( 

498 f"{self._experiment_id}:{self._trial_id}:" 

499 f"{self._tunable_config_id}:{self.trial_runner_id}" 

500 ) 

501 

502 @property 

503 def experiment_id(self) -> str: 

504 """Experiment ID of the Trial.""" 

505 return self._experiment_id 

506 

507 @property 

508 def trial_id(self) -> int: 

509 """ID of the current trial.""" 

510 return self._trial_id 

511 

512 @property 

513 def tunable_config_id(self) -> int: 

514 """ID of the current trial (tunable) configuration.""" 

515 return self._tunable_config_id 

516 

517 @property 

518 def trial_runner_id(self) -> int | None: 

519 """ID of the TrialRunner this trial is assigned to.""" 

520 return self._trial_runner_id 

521 

522 def opt_targets(self) -> dict[str, Literal["min", "max"]]: 

523 """Get the Trial's optimization targets and directions.""" 

524 return self._opt_targets 

525 

526 @property 

527 def tunables(self) -> TunableGroups: 

528 """ 

529 Tunable parameters of the current trial. 

530 

531 (e.g., application Environment's "config") 

532 """ 

533 return self._tunables 

534 

535 @abstractmethod 

536 def set_trial_runner(self, trial_runner_id: int) -> int: 

537 """Assign the trial to a specific TrialRunner.""" 

538 if self._trial_runner_id is None or self._status.is_pending(): 

539 _LOG.debug( 

540 "%sSetting Trial %s to TrialRunner %d", 

541 "Re-" if self._trial_runner_id else "", 

542 self, 

543 trial_runner_id, 

544 ) 

545 self._trial_runner_id = trial_runner_id 

546 else: 

547 _LOG.warning( 

548 "Trial %s already assigned to a TrialRunner, cannot switch to %d", 

549 self, 

550 self._trial_runner_id, 

551 ) 

552 return self._trial_runner_id 

553 

554 def config(self, global_config: dict[str, Any] | None = None) -> dict[str, Any]: 

555 """ 

556 Produce a copy of the global configuration updated with the parameters of 

557 the current trial. 

558 

559 Note: this is not the target Environment's "config" (i.e., tunable 

560 params), but rather the internal "config" which consists of a 

561 combination of somewhat more static variables defined in the json config 

562 files. 

563 """ 

564 config = self._config.copy() 

565 config.update(global_config or {}) 

566 # Here we add some built-in variables for the trial to use while it's running. 

567 config["experiment_id"] = self._experiment_id 

568 config["trial_id"] = self._trial_id 

569 trial_runner_id = self.trial_runner_id 

570 if trial_runner_id is not None: 

571 config["trial_runner_id"] = trial_runner_id 

572 return config 

573 

574 def add_new_config_data( 

575 self, 

576 new_config_data: Mapping[str, int | float | str], 

577 ) -> None: 

578 """ 

579 Add new config data to the trial. 

580 

581 Parameters 

582 ---------- 

583 new_config_data : dict[str, int | float | str] 

584 New data to add (must not already exist for the trial). 

585 

586 Raises 

587 ------ 

588 ValueError 

589 If any of the data already exists. 

590 """ 

591 for key, value in new_config_data.items(): 

592 if key in self._config: 

593 raise ValueError( 

594 f"New config data {key}={value} already exists for trial {self}: " 

595 f"{self._config[key]}" 

596 ) 

597 self._config[key] = value 

598 self._save_new_config_data(new_config_data) 

599 

600 @abstractmethod 

601 def _save_new_config_data( 

602 self, 

603 new_config_data: Mapping[str, int | float | str], 

604 ) -> None: 

605 """ 

606 Save the new config data to the storage. 

607 

608 Parameters 

609 ---------- 

610 new_config_data : dict[str, int | float | str]] 

611 New data to add. 

612 """ 

613 

614 @property 

615 def status(self) -> Status: 

616 """Get the status of the current trial.""" 

617 return self._status 

618 

619 @abstractmethod 

620 def update( 

621 self, 

622 status: Status, 

623 timestamp: datetime, 

624 metrics: dict[str, Any] | None = None, 

625 ) -> dict[str, Any] | None: 

626 """ 

627 Update the storage with the results of the experiment. 

628 

629 Parameters 

630 ---------- 

631 status : Status 

632 Status of the experiment run. 

633 timestamp: datetime.datetime 

634 Timestamp of the status and metrics. 

635 metrics : Optional[dict[str, Any]] 

636 One or several metrics of the experiment run. 

637 Must contain the (float) optimization target if the status is SUCCEEDED. 

638 

639 Returns 

640 ------- 

641 metrics : Optional[dict[str, Any]] 

642 Same as `metrics`, but always in the dict format. 

643 """ 

644 _LOG.info("Store trial: %s :: %s %s", self, status, metrics) 

645 if status.is_succeeded(): 

646 assert metrics is not None 

647 opt_targets = set(self._opt_targets.keys()) 

648 if not opt_targets.issubset(metrics.keys()): 

649 _LOG.warning( 

650 "Trial %s :: opt.targets missing: %s", 

651 self, 

652 opt_targets.difference(metrics.keys()), 

653 ) 

654 # raise ValueError() 

655 self._status = status 

656 return metrics 

657 

658 @abstractmethod 

659 def update_telemetry( 

660 self, 

661 status: Status, 

662 timestamp: datetime, 

663 metrics: list[tuple[datetime, str, Any]], 

664 ) -> None: 

665 """ 

666 Save the experiment's telemetry data and intermediate status. 

667 

668 Parameters 

669 ---------- 

670 status : Status 

671 Current status of the trial. 

672 timestamp: datetime.datetime 

673 Timestamp of the status (but not the metrics). 

674 metrics : list[tuple[datetime.datetime, str, Any]] 

675 Telemetry data. 

676 """ 

677 _LOG.info("Store telemetry: %s :: %s %d records", self, status, len(metrics))