Coverage for mlos_bench/mlos_bench/schedulers/base_scheduler.py: 90%
205 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Base class for the optimization loop scheduling policies."""
7import json
8import logging
9from abc import ABCMeta, abstractmethod
10from collections.abc import Iterable
11from contextlib import AbstractContextManager as ContextManager
12from datetime import datetime
13from types import TracebackType
14from typing import Any, Literal
16from pytz import UTC
18from mlos_bench.config.schemas import ConfigSchema
19from mlos_bench.environments.base_environment import Environment
20from mlos_bench.optimizers.base_optimizer import Optimizer
21from mlos_bench.schedulers.trial_runner import TrialRunner
22from mlos_bench.storage.base_storage import Storage
23from mlos_bench.tunables.tunable_groups import TunableGroups
24from mlos_bench.util import merge_parameters
26_LOG = logging.getLogger(__name__)
29class Scheduler(ContextManager, metaclass=ABCMeta):
30 # pylint: disable=too-many-instance-attributes,too-many-public-methods
31 """Base class for the optimization loop scheduling policies."""
33 def __init__( # pylint: disable=too-many-arguments
34 self,
35 *,
36 config: dict[str, Any],
37 global_config: dict[str, Any],
38 trial_runners: Iterable[TrialRunner],
39 optimizer: Optimizer,
40 storage: Storage,
41 root_env_config: str,
42 ):
43 """
44 Create a new instance of the scheduler. The constructor of this and the derived
45 classes is called by the persistence service after reading the class JSON
46 configuration. Other objects like the TrialRunner(s) and their Environment(s)
47 and Optimizer are provided by the Launcher.
49 Parameters
50 ----------
51 config : dict
52 The configuration for the Scheduler.
53 global_config : dict
54 The global configuration for the Experiment.
55 trial_runner : Iterable[TrialRunner]
56 The set of TrialRunner(s) (and associated Environment(s)) to benchmark/optimize.
57 optimizer : Optimizer
58 The Optimizer to use.
59 storage : Storage
60 The Storage to use.
61 root_env_config : str
62 Path to the root Environment configuration.
63 """
64 self.global_config = global_config
65 config = merge_parameters(
66 dest=config.copy(),
67 source=global_config,
68 required_keys=["experiment_id", "trial_id"],
69 )
70 self._validate_json_config(config)
72 self._in_context = False
73 self._experiment_id = config["experiment_id"].strip()
74 self._trial_id = int(config["trial_id"])
75 self._config_id = int(config.get("config_id", -1))
76 self._max_trials = int(config.get("max_trials", -1))
77 self._trial_count = 0
79 self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1))
80 if self._trial_config_repeat_count <= 0:
81 raise ValueError(
82 f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}"
83 )
85 self._do_teardown = bool(config.get("teardown", True))
87 self._experiment: Storage.Experiment | None = None
89 assert trial_runners, "At least one TrialRunner is required"
90 trial_runners = list(trial_runners)
91 self._trial_runners = {
92 trial_runner.trial_runner_id: trial_runner for trial_runner in trial_runners
93 }
94 self._current_trial_runner_idx = 0
95 self._trial_runner_ids = list(self._trial_runners.keys())
96 assert len(self._trial_runner_ids) == len(
97 trial_runners
98 ), f"Duplicate TrialRunner ids detected: {trial_runners}"
100 self._optimizer = optimizer
101 self._storage = storage
102 self._root_env_config = root_env_config
103 self._last_trial_id = -1
104 self._ran_trials: list[Storage.Trial] = []
106 _LOG.debug("Scheduler instantiated: %s :: %s", self, config)
108 def _validate_json_config(self, config: dict) -> None:
109 """Reconstructs a basic json config that this class might have been instantiated
110 from in order to validate configs provided outside the file loading
111 mechanism.
112 """
113 json_config: dict = {
114 "class": self.__class__.__module__ + "." + self.__class__.__name__,
115 }
116 if config:
117 json_config["config"] = config.copy()
118 # The json schema does not allow for -1 as a valid value for config_id.
119 # As it is just a default placeholder value, and not required, we can
120 # remove it from the config copy prior to validation safely.
121 config_id = json_config["config"].get("config_id")
122 if config_id is not None and isinstance(config_id, int) and config_id < 0:
123 json_config["config"].pop("config_id")
124 ConfigSchema.SCHEDULER.validate(json_config)
126 @property
127 def trial_config_repeat_count(self) -> int:
128 """Gets the number of trials to run for a given config."""
129 return self._trial_config_repeat_count
131 @property
132 def trial_count(self) -> int:
133 """Gets the current number of trials run for the experiment."""
134 return self._trial_count
136 @property
137 def max_trials(self) -> int:
138 """Gets the maximum number of trials to run for a given experiment, or -1 for no
139 limit.
140 """
141 return self._max_trials
143 @property
144 def experiment(self) -> Storage.Experiment | None:
145 """Gets the Experiment Storage."""
146 return self._experiment
148 @property
149 def _root_trial_runner_id(self) -> int:
150 # Use the first TrialRunner as the root.
151 return self._trial_runner_ids[0]
153 @property
154 def root_environment(self) -> Environment:
155 """
156 Gets the root (prototypical) Environment from the first TrialRunner.
158 Notes
159 -----
160 All TrialRunners have the same Environment config and are made
161 unique by their use of the unique trial_runner_id assigned to each
162 TrialRunner's Environment's global_config.
163 """
164 # Use the first TrialRunner's Environment as the root Environment.
165 return self._trial_runners[self._root_trial_runner_id].environment
167 @property
168 def trial_runners(self) -> dict[int, TrialRunner]:
169 """Gets the set of Trial Runners."""
170 return self._trial_runners
172 @property
173 def environments(self) -> Iterable[Environment]:
174 """Gets the Environment from the TrialRunners."""
175 return (trial_runner.environment for trial_runner in self._trial_runners.values())
177 @property
178 def optimizer(self) -> Optimizer:
179 """Gets the Optimizer."""
180 return self._optimizer
182 @property
183 def storage(self) -> Storage:
184 """Gets the Storage."""
185 return self._storage
187 def __repr__(self) -> str:
188 """
189 Produce a human-readable version of the Scheduler (mostly for logging).
191 Returns
192 -------
193 string : str
194 A human-readable version of the Scheduler.
195 """
196 return self.__class__.__name__
198 def __enter__(self) -> "Scheduler":
199 """Enter the scheduler's context."""
200 _LOG.debug("Scheduler START :: %s", self)
201 assert self.experiment is None
202 assert not self._in_context
203 # NOTE: We delay entering the context of trial_runners until it's time
204 # to run the trial in order to avoid incompatibilities with
205 # multiprocessing.Pool.
206 self._optimizer.__enter__()
207 # Start new or resume the existing experiment. Verify that the
208 # experiment configuration is compatible with the previous runs.
209 # If the `merge` config parameter is present, merge in the data
210 # from other experiments and check for compatibility.
211 self._experiment = self.storage.experiment(
212 experiment_id=self._experiment_id,
213 trial_id=self._trial_id,
214 root_env_config=self._root_env_config,
215 description=self.root_environment.name,
216 tunables=self.root_environment.tunable_params,
217 opt_targets=self.optimizer.targets,
218 ).__enter__()
219 self._in_context = True
220 return self
222 def __exit__(
223 self,
224 ex_type: type[BaseException] | None,
225 ex_val: BaseException | None,
226 ex_tb: TracebackType | None,
227 ) -> Literal[False]:
228 """Exit the context of the scheduler."""
229 if ex_val is None:
230 _LOG.debug("Scheduler END :: %s", self)
231 else:
232 assert ex_type and ex_val
233 _LOG.warning("Scheduler END :: %s", self, exc_info=(ex_type, ex_val, ex_tb))
234 assert self._in_context
235 assert self._experiment is not None
236 self._experiment.__exit__(ex_type, ex_val, ex_tb)
237 self._optimizer.__exit__(ex_type, ex_val, ex_tb)
238 for trial_runner in self._trial_runners.values():
239 # TrialRunners should have already exited their context after running the Trial.
240 assert not trial_runner._in_context # pylint: disable=protected-access
241 self._experiment = None
242 self._in_context = False
243 return False # Do not suppress exceptions
245 def start(self) -> None:
246 """Start the scheduling loop."""
247 assert self.experiment is not None
248 _LOG.info(
249 "START: Experiment: %s Env: %s Optimizer: %s",
250 self._experiment,
251 self.root_environment,
252 self.optimizer,
253 )
254 if _LOG.isEnabledFor(logging.INFO):
255 _LOG.info("Root Environment:\n%s", self.root_environment.pprint())
257 if self._config_id > 0:
258 tunables = self.load_tunable_config(self._config_id)
259 # If a config_id is provided, assume it is expected to be run immediately.
260 self.add_trial_to_queue(tunables, ts_start=datetime.now(UTC))
262 is_warm_up: bool = self.optimizer.supports_preload
263 if not is_warm_up:
264 _LOG.warning("Skip pending trials and warm-up: %s", self.optimizer)
266 not_done: bool = True
267 while not_done:
268 _LOG.info("Optimization loop: Last trial ID: %d", self._last_trial_id)
269 self.run_schedule(is_warm_up)
270 not_done = self.add_new_optimizer_suggestions()
271 self.assign_trial_runners(
272 self.experiment.pending_trials(
273 datetime.now(UTC),
274 running=False,
275 trial_runner_assigned=False,
276 )
277 )
278 is_warm_up = False
280 def teardown(self) -> None:
281 """
282 Tear down the TrialRunners/Environment(s).
284 Call it after the completion of the :py:meth:`Scheduler.start` in the
285 Scheduler context.
286 """
287 assert self.experiment is not None
288 if self._do_teardown:
289 for trial_runner in self._trial_runners.values():
290 assert not trial_runner.is_running
291 with trial_runner:
292 trial_runner.teardown()
294 def get_best_observation(self) -> tuple[dict[str, float] | None, TunableGroups | None]:
295 """Get the best observation from the optimizer."""
296 (best_score, best_config) = self.optimizer.get_best_observation()
297 _LOG.info("Env: %s best score: %s", self.root_environment, best_score)
298 return (best_score, best_config)
300 def load_tunable_config(self, config_id: int) -> TunableGroups:
301 """Load the existing tunable configuration from the storage."""
302 assert self.experiment is not None
303 tunable_values = self.experiment.load_tunable_config(config_id)
304 tunables = TunableGroups()
305 for environment in self.environments:
306 tunables = environment.tunable_params.assign(tunable_values)
307 _LOG.info("Load config from storage: %d", config_id)
308 if _LOG.isEnabledFor(logging.DEBUG):
309 _LOG.debug("Config %d ::\n%s", config_id, json.dumps(tunable_values, indent=2))
310 return tunables.copy()
312 def add_new_optimizer_suggestions(self) -> bool:
313 """
314 Optimizer part of the loop.
316 Load the results of the executed trials into the
317 :py:class:`~.Optimizer`, suggest new configurations, and add them to the
318 queue.
320 Returns
321 -------
322 bool
323 The return value indicates whether the optimization process should
324 continue to get suggestions from the Optimizer or not.
325 See Also: :py:meth:`~.Scheduler.not_done`.
326 """
327 assert self.experiment is not None
328 (trial_ids, configs, scores, status) = self.experiment.load(self._last_trial_id)
329 _LOG.info("QUEUE: Update the optimizer with trial results: %s", trial_ids)
330 self.optimizer.bulk_register(configs, scores, status)
331 self._last_trial_id = max(trial_ids, default=self._last_trial_id)
333 # Check if the optimizer has converged or not.
334 not_done = self.not_done()
335 if not_done:
336 tunables = self.optimizer.suggest()
337 self.add_trial_to_queue(tunables)
338 return not_done
340 def add_trial_to_queue(
341 self,
342 tunables: TunableGroups,
343 ts_start: datetime | None = None,
344 ) -> None:
345 """
346 Add a configuration to the queue of trials 1 or more times.
348 (e.g., according to the :py:attr:`~.Scheduler.trial_config_repeat_count`)
350 Parameters
351 ----------
352 tunables : TunableGroups
353 The tunable configuration to add to the queue.
355 ts_start : datetime.datetime | None
356 Optional timestamp to use to start the trial.
358 Notes
359 -----
360 Alternative scheduling policies may prefer to expand repeats over
361 time as well as space, or adjust the number of repeats (budget) of a given
362 trial based on whether initial results are promising.
363 """
364 for repeat_i in range(1, self._trial_config_repeat_count + 1):
365 self._add_trial_to_queue(
366 tunables,
367 ts_start=ts_start,
368 config=self._augment_trial_config_metadata(tunables, repeat_i),
369 )
371 def _augment_trial_config_metadata(
372 self,
373 tunables: TunableGroups,
374 repeat_i: int,
375 ) -> dict[str, Any]:
376 return {
377 # Add some additional metadata to track for the trial such as the
378 # optimizer config used.
379 # Note: these values are unfortunately mutable at the moment.
380 # Consider them as hints of what the config was the trial *started*.
381 # It is possible that the experiment configs were changed
382 # between resuming the experiment (since that is not currently
383 # prevented).
384 "optimizer": self.optimizer.name,
385 "repeat_i": repeat_i,
386 "is_defaults": tunables.is_defaults(),
387 **{
388 f"opt_{key}_{i}": val
389 for (i, opt_target) in enumerate(self.optimizer.targets.items())
390 for (key, val) in zip(["target", "direction"], opt_target)
391 },
392 }
394 def _add_trial_to_queue(
395 self,
396 tunables: TunableGroups,
397 ts_start: datetime | None = None,
398 config: dict[str, Any] | None = None,
399 ) -> None:
400 """
401 Add a configuration to the queue of trials in the Storage backend.
403 A wrapper for the `Experiment.new_trial` method.
404 """
405 assert self.experiment is not None
406 trial = self.experiment.new_trial(tunables, ts_start, config)
407 _LOG.info("QUEUE: Added new trial: %s", trial)
409 def assign_trial_runners(self, trials: Iterable[Storage.Trial]) -> None:
410 """
411 Assigns a :py:class:`~.TrialRunner` to each :py:class:`~.Storage.Trial` in the
412 batch.
414 The base class implements a simple round-robin scheduling algorithm for
415 each Trial in sequence.
417 Subclasses can override this method to implement a more sophisticated policy.
418 For instance::
420 def assign_trial_runners(
421 self,
422 trials: Iterable[Storage.Trial],
423 ) -> TrialRunner:
424 trial_runners_map = {}
425 # Implement a more sophisticated policy here.
426 # For example, to assign the Trial to the TrialRunner with the least
427 # number of running Trials.
428 # Or assign the Trial to the TrialRunner that hasn't executed this
429 # TunableValues Config yet.
430 for (trial, trial_runner) in trial_runners_map:
431 # Call the base class method to assign the TrialRunner in the Trial's metadata.
432 trial.set_trial_runner(trial_runner)
433 ...
435 Notes
436 -----
437 Subclasses are *not* required to assign a TrialRunner to the Trial
438 (e.g., if the Trial should be deferred to a later time).
440 Parameters
441 ----------
442 trials : Iterable[Storage.Trial]
443 The trial to assign a TrialRunner to.
444 """
445 for trial in trials:
446 if trial.trial_runner_id is not None:
447 _LOG.info(
448 "Trial %s already has a TrialRunner assigned: %s",
449 trial,
450 trial.trial_runner_id,
451 )
452 continue
454 # Basic round-robin trial runner assignment policy:
455 # fetch and increment the current TrialRunner index.
456 # Override in the subclass for a more sophisticated policy.
457 trial_runner_idx = self._current_trial_runner_idx
458 self._current_trial_runner_idx += 1
459 self._current_trial_runner_idx %= len(self._trial_runner_ids)
460 trial_runner = self._trial_runners[self._trial_runner_ids[trial_runner_idx]]
461 assert trial_runner
462 _LOG.info(
463 "Assigning TrialRunner %s to Trial %s via basic round-robin policy.",
464 trial_runner,
465 trial,
466 )
467 assigned_trial_runner_id = trial.set_trial_runner(trial_runner.trial_runner_id)
468 if assigned_trial_runner_id != trial_runner.trial_runner_id:
469 raise ValueError(
470 f"Failed to assign TrialRunner {trial_runner} to Trial {trial}: "
471 f"{assigned_trial_runner_id}"
472 )
474 def get_trial_runner(self, trial: Storage.Trial) -> TrialRunner:
475 """
476 Gets the :py:class:`~.TrialRunner` associated with the given
477 :py:class:`~.Storage.Trial`.
479 Parameters
480 ----------
481 trial : Storage.Trial
482 The trial to get the associated TrialRunner for.
484 Returns
485 -------
486 TrialRunner
487 """
488 if trial.trial_runner_id is None:
489 self.assign_trial_runners([trial])
490 assert trial.trial_runner_id is not None
491 trial_runner = self._trial_runners.get(trial.trial_runner_id)
492 if trial_runner is None:
493 raise ValueError(
494 f"TrialRunner {trial.trial_runner_id} for Trial {trial} "
495 f"not found: {self._trial_runners}"
496 )
497 assert trial_runner.trial_runner_id == trial.trial_runner_id
498 return trial_runner
500 def run_schedule(self, running: bool = False) -> None:
501 """
502 Runs the current schedule of trials.
504 Check for :py:class:`~.Storage.Trial` instances with `:py:attr:`.Status.PENDING`
505 and an assigned :py:attr:`~.Storage.Trial.trial_runner_id` in the queue and run
506 them with :py:meth:`~.Scheduler.run_trial`.
508 Subclasses can override this method to implement a more sophisticated
509 scheduling policy.
511 Parameters
512 ----------
513 running : bool
514 If True, run the trials that are already in a "running" state (e.g., to resume them).
515 If False (default), run the trials that are pending.
516 """
517 assert self.experiment is not None
518 pending_trials = list(
519 self.experiment.pending_trials(
520 datetime.now(UTC),
521 running=running,
522 trial_runner_assigned=True,
523 )
524 )
525 for trial in pending_trials:
526 assert (
527 trial.trial_runner_id is not None
528 ), f"Trial {trial} has no TrialRunner assigned yet."
529 self.run_trial(trial)
531 def not_done(self) -> bool:
532 """
533 Check the stopping conditions.
535 By default, stop when the :py:class:`.Optimizer` converges or the limit
536 of :py:attr:`~.Scheduler.max_trials` is reached.
537 """
538 # TODO: Add more stopping conditions: https://github.com/microsoft/MLOS/issues/427
539 return self.optimizer.not_converged() and (
540 self._trial_count < self._max_trials or self._max_trials <= 0
541 )
543 @abstractmethod
544 def run_trial(self, trial: Storage.Trial) -> None:
545 """
546 Set up and run a single trial.
548 Save the results in the storage.
549 """
550 assert self._in_context
551 assert self.experiment is not None
552 self._trial_count += 1
553 self._ran_trials.append(trial)
554 _LOG.info("QUEUE: Execute trial # %d/%d :: %s", self._trial_count, self._max_trials, trial)
556 @property
557 def ran_trials(self) -> list[Storage.Trial]:
558 """Get the list of trials that were run."""
559 return self._ran_trials