Coverage for mlos_bench/mlos_bench/schedulers/base_scheduler.py: 90%

205 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-14 00:55 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Base class for the optimization loop scheduling policies.""" 

6 

7import json 

8import logging 

9from abc import ABCMeta, abstractmethod 

10from collections.abc import Iterable 

11from contextlib import AbstractContextManager as ContextManager 

12from datetime import datetime 

13from types import TracebackType 

14from typing import Any, Literal 

15 

16from pytz import UTC 

17 

18from mlos_bench.config.schemas import ConfigSchema 

19from mlos_bench.environments.base_environment import Environment 

20from mlos_bench.optimizers.base_optimizer import Optimizer 

21from mlos_bench.schedulers.trial_runner import TrialRunner 

22from mlos_bench.storage.base_storage import Storage 

23from mlos_bench.tunables.tunable_groups import TunableGroups 

24from mlos_bench.util import merge_parameters 

25 

26_LOG = logging.getLogger(__name__) 

27 

28 

29class Scheduler(ContextManager, metaclass=ABCMeta): 

30 # pylint: disable=too-many-instance-attributes,too-many-public-methods 

31 """Base class for the optimization loop scheduling policies.""" 

32 

33 def __init__( # pylint: disable=too-many-arguments 

34 self, 

35 *, 

36 config: dict[str, Any], 

37 global_config: dict[str, Any], 

38 trial_runners: Iterable[TrialRunner], 

39 optimizer: Optimizer, 

40 storage: Storage, 

41 root_env_config: str, 

42 ): 

43 """ 

44 Create a new instance of the scheduler. The constructor of this and the derived 

45 classes is called by the persistence service after reading the class JSON 

46 configuration. Other objects like the TrialRunner(s) and their Environment(s) 

47 and Optimizer are provided by the Launcher. 

48 

49 Parameters 

50 ---------- 

51 config : dict 

52 The configuration for the Scheduler. 

53 global_config : dict 

54 The global configuration for the Experiment. 

55 trial_runner : Iterable[TrialRunner] 

56 The set of TrialRunner(s) (and associated Environment(s)) to benchmark/optimize. 

57 optimizer : Optimizer 

58 The Optimizer to use. 

59 storage : Storage 

60 The Storage to use. 

61 root_env_config : str 

62 Path to the root Environment configuration. 

63 """ 

64 self.global_config = global_config 

65 config = merge_parameters( 

66 dest=config.copy(), 

67 source=global_config, 

68 required_keys=["experiment_id", "trial_id"], 

69 ) 

70 self._validate_json_config(config) 

71 

72 self._in_context = False 

73 self._experiment_id = config["experiment_id"].strip() 

74 self._trial_id = int(config["trial_id"]) 

75 self._config_id = int(config.get("config_id", -1)) 

76 self._max_trials = int(config.get("max_trials", -1)) 

77 self._trial_count = 0 

78 

79 self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1)) 

80 if self._trial_config_repeat_count <= 0: 

81 raise ValueError( 

82 f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}" 

83 ) 

84 

85 self._do_teardown = bool(config.get("teardown", True)) 

86 

87 self._experiment: Storage.Experiment | None = None 

88 

89 assert trial_runners, "At least one TrialRunner is required" 

90 trial_runners = list(trial_runners) 

91 self._trial_runners = { 

92 trial_runner.trial_runner_id: trial_runner for trial_runner in trial_runners 

93 } 

94 self._current_trial_runner_idx = 0 

95 self._trial_runner_ids = list(self._trial_runners.keys()) 

96 assert len(self._trial_runner_ids) == len( 

97 trial_runners 

98 ), f"Duplicate TrialRunner ids detected: {trial_runners}" 

99 

100 self._optimizer = optimizer 

101 self._storage = storage 

102 self._root_env_config = root_env_config 

103 self._last_trial_id = -1 

104 self._ran_trials: list[Storage.Trial] = [] 

105 

106 _LOG.debug("Scheduler instantiated: %s :: %s", self, config) 

107 

108 def _validate_json_config(self, config: dict) -> None: 

109 """Reconstructs a basic json config that this class might have been instantiated 

110 from in order to validate configs provided outside the file loading 

111 mechanism. 

112 """ 

113 json_config: dict = { 

114 "class": self.__class__.__module__ + "." + self.__class__.__name__, 

115 } 

116 if config: 

117 json_config["config"] = config.copy() 

118 # The json schema does not allow for -1 as a valid value for config_id. 

119 # As it is just a default placeholder value, and not required, we can 

120 # remove it from the config copy prior to validation safely. 

121 config_id = json_config["config"].get("config_id") 

122 if config_id is not None and isinstance(config_id, int) and config_id < 0: 

123 json_config["config"].pop("config_id") 

124 ConfigSchema.SCHEDULER.validate(json_config) 

125 

126 @property 

127 def trial_config_repeat_count(self) -> int: 

128 """Gets the number of trials to run for a given config.""" 

129 return self._trial_config_repeat_count 

130 

131 @property 

132 def trial_count(self) -> int: 

133 """Gets the current number of trials run for the experiment.""" 

134 return self._trial_count 

135 

136 @property 

137 def max_trials(self) -> int: 

138 """Gets the maximum number of trials to run for a given experiment, or -1 for no 

139 limit. 

140 """ 

141 return self._max_trials 

142 

143 @property 

144 def experiment(self) -> Storage.Experiment | None: 

145 """Gets the Experiment Storage.""" 

146 return self._experiment 

147 

148 @property 

149 def _root_trial_runner_id(self) -> int: 

150 # Use the first TrialRunner as the root. 

151 return self._trial_runner_ids[0] 

152 

153 @property 

154 def root_environment(self) -> Environment: 

155 """ 

156 Gets the root (prototypical) Environment from the first TrialRunner. 

157 

158 Notes 

159 ----- 

160 All TrialRunners have the same Environment config and are made 

161 unique by their use of the unique trial_runner_id assigned to each 

162 TrialRunner's Environment's global_config. 

163 """ 

164 # Use the first TrialRunner's Environment as the root Environment. 

165 return self._trial_runners[self._root_trial_runner_id].environment 

166 

167 @property 

168 def trial_runners(self) -> dict[int, TrialRunner]: 

169 """Gets the set of Trial Runners.""" 

170 return self._trial_runners 

171 

172 @property 

173 def environments(self) -> Iterable[Environment]: 

174 """Gets the Environment from the TrialRunners.""" 

175 return (trial_runner.environment for trial_runner in self._trial_runners.values()) 

176 

177 @property 

178 def optimizer(self) -> Optimizer: 

179 """Gets the Optimizer.""" 

180 return self._optimizer 

181 

182 @property 

183 def storage(self) -> Storage: 

184 """Gets the Storage.""" 

185 return self._storage 

186 

187 def __repr__(self) -> str: 

188 """ 

189 Produce a human-readable version of the Scheduler (mostly for logging). 

190 

191 Returns 

192 ------- 

193 string : str 

194 A human-readable version of the Scheduler. 

195 """ 

196 return self.__class__.__name__ 

197 

198 def __enter__(self) -> "Scheduler": 

199 """Enter the scheduler's context.""" 

200 _LOG.debug("Scheduler START :: %s", self) 

201 assert self.experiment is None 

202 assert not self._in_context 

203 # NOTE: We delay entering the context of trial_runners until it's time 

204 # to run the trial in order to avoid incompatibilities with 

205 # multiprocessing.Pool. 

206 self._optimizer.__enter__() 

207 # Start new or resume the existing experiment. Verify that the 

208 # experiment configuration is compatible with the previous runs. 

209 # If the `merge` config parameter is present, merge in the data 

210 # from other experiments and check for compatibility. 

211 self._experiment = self.storage.experiment( 

212 experiment_id=self._experiment_id, 

213 trial_id=self._trial_id, 

214 root_env_config=self._root_env_config, 

215 description=self.root_environment.name, 

216 tunables=self.root_environment.tunable_params, 

217 opt_targets=self.optimizer.targets, 

218 ).__enter__() 

219 self._in_context = True 

220 return self 

221 

222 def __exit__( 

223 self, 

224 ex_type: type[BaseException] | None, 

225 ex_val: BaseException | None, 

226 ex_tb: TracebackType | None, 

227 ) -> Literal[False]: 

228 """Exit the context of the scheduler.""" 

229 if ex_val is None: 

230 _LOG.debug("Scheduler END :: %s", self) 

231 else: 

232 assert ex_type and ex_val 

233 _LOG.warning("Scheduler END :: %s", self, exc_info=(ex_type, ex_val, ex_tb)) 

234 assert self._in_context 

235 assert self._experiment is not None 

236 self._experiment.__exit__(ex_type, ex_val, ex_tb) 

237 self._optimizer.__exit__(ex_type, ex_val, ex_tb) 

238 for trial_runner in self._trial_runners.values(): 

239 # TrialRunners should have already exited their context after running the Trial. 

240 assert not trial_runner._in_context # pylint: disable=protected-access 

241 self._experiment = None 

242 self._in_context = False 

243 return False # Do not suppress exceptions 

244 

245 def start(self) -> None: 

246 """Start the scheduling loop.""" 

247 assert self.experiment is not None 

248 _LOG.info( 

249 "START: Experiment: %s Env: %s Optimizer: %s", 

250 self._experiment, 

251 self.root_environment, 

252 self.optimizer, 

253 ) 

254 if _LOG.isEnabledFor(logging.INFO): 

255 _LOG.info("Root Environment:\n%s", self.root_environment.pprint()) 

256 

257 if self._config_id > 0: 

258 tunables = self.load_tunable_config(self._config_id) 

259 # If a config_id is provided, assume it is expected to be run immediately. 

260 self.add_trial_to_queue(tunables, ts_start=datetime.now(UTC)) 

261 

262 is_warm_up: bool = self.optimizer.supports_preload 

263 if not is_warm_up: 

264 _LOG.warning("Skip pending trials and warm-up: %s", self.optimizer) 

265 

266 not_done: bool = True 

267 while not_done: 

268 _LOG.info("Optimization loop: Last trial ID: %d", self._last_trial_id) 

269 self.run_schedule(is_warm_up) 

270 not_done = self.add_new_optimizer_suggestions() 

271 self.assign_trial_runners( 

272 self.experiment.pending_trials( 

273 datetime.now(UTC), 

274 running=False, 

275 trial_runner_assigned=False, 

276 ) 

277 ) 

278 is_warm_up = False 

279 

280 def teardown(self) -> None: 

281 """ 

282 Tear down the TrialRunners/Environment(s). 

283 

284 Call it after the completion of the :py:meth:`Scheduler.start` in the 

285 Scheduler context. 

286 """ 

287 assert self.experiment is not None 

288 if self._do_teardown: 

289 for trial_runner in self._trial_runners.values(): 

290 assert not trial_runner.is_running 

291 with trial_runner: 

292 trial_runner.teardown() 

293 

294 def get_best_observation(self) -> tuple[dict[str, float] | None, TunableGroups | None]: 

295 """Get the best observation from the optimizer.""" 

296 (best_score, best_config) = self.optimizer.get_best_observation() 

297 _LOG.info("Env: %s best score: %s", self.root_environment, best_score) 

298 return (best_score, best_config) 

299 

300 def load_tunable_config(self, config_id: int) -> TunableGroups: 

301 """Load the existing tunable configuration from the storage.""" 

302 assert self.experiment is not None 

303 tunable_values = self.experiment.load_tunable_config(config_id) 

304 tunables = TunableGroups() 

305 for environment in self.environments: 

306 tunables = environment.tunable_params.assign(tunable_values) 

307 _LOG.info("Load config from storage: %d", config_id) 

308 if _LOG.isEnabledFor(logging.DEBUG): 

309 _LOG.debug("Config %d ::\n%s", config_id, json.dumps(tunable_values, indent=2)) 

310 return tunables.copy() 

311 

312 def add_new_optimizer_suggestions(self) -> bool: 

313 """ 

314 Optimizer part of the loop. 

315 

316 Load the results of the executed trials into the 

317 :py:class:`~.Optimizer`, suggest new configurations, and add them to the 

318 queue. 

319 

320 Returns 

321 ------- 

322 bool 

323 The return value indicates whether the optimization process should 

324 continue to get suggestions from the Optimizer or not. 

325 See Also: :py:meth:`~.Scheduler.not_done`. 

326 """ 

327 assert self.experiment is not None 

328 (trial_ids, configs, scores, status) = self.experiment.load(self._last_trial_id) 

329 _LOG.info("QUEUE: Update the optimizer with trial results: %s", trial_ids) 

330 self.optimizer.bulk_register(configs, scores, status) 

331 self._last_trial_id = max(trial_ids, default=self._last_trial_id) 

332 

333 # Check if the optimizer has converged or not. 

334 not_done = self.not_done() 

335 if not_done: 

336 tunables = self.optimizer.suggest() 

337 self.add_trial_to_queue(tunables) 

338 return not_done 

339 

340 def add_trial_to_queue( 

341 self, 

342 tunables: TunableGroups, 

343 ts_start: datetime | None = None, 

344 ) -> None: 

345 """ 

346 Add a configuration to the queue of trials 1 or more times. 

347 

348 (e.g., according to the :py:attr:`~.Scheduler.trial_config_repeat_count`) 

349 

350 Parameters 

351 ---------- 

352 tunables : TunableGroups 

353 The tunable configuration to add to the queue. 

354 

355 ts_start : datetime.datetime | None 

356 Optional timestamp to use to start the trial. 

357 

358 Notes 

359 ----- 

360 Alternative scheduling policies may prefer to expand repeats over 

361 time as well as space, or adjust the number of repeats (budget) of a given 

362 trial based on whether initial results are promising. 

363 """ 

364 for repeat_i in range(1, self._trial_config_repeat_count + 1): 

365 self._add_trial_to_queue( 

366 tunables, 

367 ts_start=ts_start, 

368 config=self._augment_trial_config_metadata(tunables, repeat_i), 

369 ) 

370 

371 def _augment_trial_config_metadata( 

372 self, 

373 tunables: TunableGroups, 

374 repeat_i: int, 

375 ) -> dict[str, Any]: 

376 return { 

377 # Add some additional metadata to track for the trial such as the 

378 # optimizer config used. 

379 # Note: these values are unfortunately mutable at the moment. 

380 # Consider them as hints of what the config was the trial *started*. 

381 # It is possible that the experiment configs were changed 

382 # between resuming the experiment (since that is not currently 

383 # prevented). 

384 "optimizer": self.optimizer.name, 

385 "repeat_i": repeat_i, 

386 "is_defaults": tunables.is_defaults(), 

387 **{ 

388 f"opt_{key}_{i}": val 

389 for (i, opt_target) in enumerate(self.optimizer.targets.items()) 

390 for (key, val) in zip(["target", "direction"], opt_target) 

391 }, 

392 } 

393 

394 def _add_trial_to_queue( 

395 self, 

396 tunables: TunableGroups, 

397 ts_start: datetime | None = None, 

398 config: dict[str, Any] | None = None, 

399 ) -> None: 

400 """ 

401 Add a configuration to the queue of trials in the Storage backend. 

402 

403 A wrapper for the `Experiment.new_trial` method. 

404 """ 

405 assert self.experiment is not None 

406 trial = self.experiment.new_trial(tunables, ts_start, config) 

407 _LOG.info("QUEUE: Added new trial: %s", trial) 

408 

409 def assign_trial_runners(self, trials: Iterable[Storage.Trial]) -> None: 

410 """ 

411 Assigns a :py:class:`~.TrialRunner` to each :py:class:`~.Storage.Trial` in the 

412 batch. 

413 

414 The base class implements a simple round-robin scheduling algorithm for 

415 each Trial in sequence. 

416 

417 Subclasses can override this method to implement a more sophisticated policy. 

418 For instance:: 

419 

420 def assign_trial_runners( 

421 self, 

422 trials: Iterable[Storage.Trial], 

423 ) -> TrialRunner: 

424 trial_runners_map = {} 

425 # Implement a more sophisticated policy here. 

426 # For example, to assign the Trial to the TrialRunner with the least 

427 # number of running Trials. 

428 # Or assign the Trial to the TrialRunner that hasn't executed this 

429 # TunableValues Config yet. 

430 for (trial, trial_runner) in trial_runners_map: 

431 # Call the base class method to assign the TrialRunner in the Trial's metadata. 

432 trial.set_trial_runner(trial_runner) 

433 ... 

434 

435 Notes 

436 ----- 

437 Subclasses are *not* required to assign a TrialRunner to the Trial 

438 (e.g., if the Trial should be deferred to a later time). 

439 

440 Parameters 

441 ---------- 

442 trials : Iterable[Storage.Trial] 

443 The trial to assign a TrialRunner to. 

444 """ 

445 for trial in trials: 

446 if trial.trial_runner_id is not None: 

447 _LOG.info( 

448 "Trial %s already has a TrialRunner assigned: %s", 

449 trial, 

450 trial.trial_runner_id, 

451 ) 

452 continue 

453 

454 # Basic round-robin trial runner assignment policy: 

455 # fetch and increment the current TrialRunner index. 

456 # Override in the subclass for a more sophisticated policy. 

457 trial_runner_idx = self._current_trial_runner_idx 

458 self._current_trial_runner_idx += 1 

459 self._current_trial_runner_idx %= len(self._trial_runner_ids) 

460 trial_runner = self._trial_runners[self._trial_runner_ids[trial_runner_idx]] 

461 assert trial_runner 

462 _LOG.info( 

463 "Assigning TrialRunner %s to Trial %s via basic round-robin policy.", 

464 trial_runner, 

465 trial, 

466 ) 

467 assigned_trial_runner_id = trial.set_trial_runner(trial_runner.trial_runner_id) 

468 if assigned_trial_runner_id != trial_runner.trial_runner_id: 

469 raise ValueError( 

470 f"Failed to assign TrialRunner {trial_runner} to Trial {trial}: " 

471 f"{assigned_trial_runner_id}" 

472 ) 

473 

474 def get_trial_runner(self, trial: Storage.Trial) -> TrialRunner: 

475 """ 

476 Gets the :py:class:`~.TrialRunner` associated with the given 

477 :py:class:`~.Storage.Trial`. 

478 

479 Parameters 

480 ---------- 

481 trial : Storage.Trial 

482 The trial to get the associated TrialRunner for. 

483 

484 Returns 

485 ------- 

486 TrialRunner 

487 """ 

488 if trial.trial_runner_id is None: 

489 self.assign_trial_runners([trial]) 

490 assert trial.trial_runner_id is not None 

491 trial_runner = self._trial_runners.get(trial.trial_runner_id) 

492 if trial_runner is None: 

493 raise ValueError( 

494 f"TrialRunner {trial.trial_runner_id} for Trial {trial} " 

495 f"not found: {self._trial_runners}" 

496 ) 

497 assert trial_runner.trial_runner_id == trial.trial_runner_id 

498 return trial_runner 

499 

500 def run_schedule(self, running: bool = False) -> None: 

501 """ 

502 Runs the current schedule of trials. 

503 

504 Check for :py:class:`~.Storage.Trial` instances with `:py:attr:`.Status.PENDING` 

505 and an assigned :py:attr:`~.Storage.Trial.trial_runner_id` in the queue and run 

506 them with :py:meth:`~.Scheduler.run_trial`. 

507 

508 Subclasses can override this method to implement a more sophisticated 

509 scheduling policy. 

510 

511 Parameters 

512 ---------- 

513 running : bool 

514 If True, run the trials that are already in a "running" state (e.g., to resume them). 

515 If False (default), run the trials that are pending. 

516 """ 

517 assert self.experiment is not None 

518 pending_trials = list( 

519 self.experiment.pending_trials( 

520 datetime.now(UTC), 

521 running=running, 

522 trial_runner_assigned=True, 

523 ) 

524 ) 

525 for trial in pending_trials: 

526 assert ( 

527 trial.trial_runner_id is not None 

528 ), f"Trial {trial} has no TrialRunner assigned yet." 

529 self.run_trial(trial) 

530 

531 def not_done(self) -> bool: 

532 """ 

533 Check the stopping conditions. 

534 

535 By default, stop when the :py:class:`.Optimizer` converges or the limit 

536 of :py:attr:`~.Scheduler.max_trials` is reached. 

537 """ 

538 # TODO: Add more stopping conditions: https://github.com/microsoft/MLOS/issues/427 

539 return self.optimizer.not_converged() and ( 

540 self._trial_count < self._max_trials or self._max_trials <= 0 

541 ) 

542 

543 @abstractmethod 

544 def run_trial(self, trial: Storage.Trial) -> None: 

545 """ 

546 Set up and run a single trial. 

547 

548 Save the results in the storage. 

549 """ 

550 assert self._in_context 

551 assert self.experiment is not None 

552 self._trial_count += 1 

553 self._ran_trials.append(trial) 

554 _LOG.info("QUEUE: Execute trial # %d/%d :: %s", self._trial_count, self._max_trials, trial) 

555 

556 @property 

557 def ran_trials(self) -> list[Storage.Trial]: 

558 """Get the list of trials that were run.""" 

559 return self._ran_trials