Coverage for mlos_bench/mlos_bench/schedulers/base_scheduler.py: 89%
107 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Base class for the optimization loop scheduling policies.
7"""
9import json
10import logging
11from datetime import datetime
13from abc import ABCMeta, abstractmethod
14from types import TracebackType
15from typing import Any, Dict, Optional, Tuple, Type
16from typing_extensions import Literal
18from pytz import UTC
20from mlos_bench.environments.base_environment import Environment
21from mlos_bench.optimizers.base_optimizer import Optimizer
22from mlos_bench.storage.base_storage import Storage
23from mlos_bench.tunables.tunable_groups import TunableGroups
24from mlos_bench.util import merge_parameters
26_LOG = logging.getLogger(__name__)
29class Scheduler(metaclass=ABCMeta):
30 # pylint: disable=too-many-instance-attributes
31 """
32 Base class for the optimization loop scheduling policies.
33 """
35 def __init__(self, *,
36 config: Dict[str, Any],
37 global_config: Dict[str, Any],
38 environment: Environment,
39 optimizer: Optimizer,
40 storage: Storage,
41 root_env_config: str):
42 """
43 Create a new instance of the scheduler. The constructor of this
44 and the derived classes is called by the persistence service
45 after reading the class JSON configuration. Other objects like
46 the Environment and Optimizer are provided by the Launcher.
48 Parameters
49 ----------
50 config : dict
51 The configuration for the scheduler.
52 global_config : dict
53 he global configuration for the experiment.
54 environment : Environment
55 The environment to benchmark/optimize.
56 optimizer : Optimizer
57 The optimizer to use.
58 storage : Storage
59 The storage to use.
60 root_env_config : str
61 Path to the root environment configuration.
62 """
63 self.global_config = global_config
64 config = merge_parameters(dest=config.copy(), source=global_config,
65 required_keys=["experiment_id", "trial_id"])
67 self._experiment_id = config["experiment_id"].strip()
68 self._trial_id = int(config["trial_id"])
69 self._config_id = int(config.get("config_id", -1))
70 self._max_trials = int(config.get("max_trials", -1))
71 self._trial_count = 0
73 self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1))
74 if self._trial_config_repeat_count <= 0:
75 raise ValueError(f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}")
77 self._do_teardown = bool(config.get("teardown", True))
79 self.experiment: Optional[Storage.Experiment] = None
80 self.environment = environment
81 self.optimizer = optimizer
82 self.storage = storage
83 self._root_env_config = root_env_config
84 self._last_trial_id = -1
86 _LOG.debug("Scheduler instantiated: %s :: %s", self, config)
88 def __repr__(self) -> str:
89 """
90 Produce a human-readable version of the Scheduler (mostly for logging).
92 Returns
93 -------
94 string : str
95 A human-readable version of the Scheduler.
96 """
97 return self.__class__.__name__
99 def __enter__(self) -> 'Scheduler':
100 """
101 Enter the scheduler's context.
102 """
103 _LOG.debug("Scheduler START :: %s", self)
104 assert self.experiment is None
105 self.environment.__enter__()
106 self.optimizer.__enter__()
107 # Start new or resume the existing experiment. Verify that the
108 # experiment configuration is compatible with the previous runs.
109 # If the `merge` config parameter is present, merge in the data
110 # from other experiments and check for compatibility.
111 self.experiment = self.storage.experiment(
112 experiment_id=self._experiment_id,
113 trial_id=self._trial_id,
114 root_env_config=self._root_env_config,
115 description=self.environment.name,
116 tunables=self.environment.tunable_params,
117 opt_target=self.optimizer.target,
118 opt_direction=self.optimizer.direction,
119 ).__enter__()
120 return self
122 def __exit__(self,
123 ex_type: Optional[Type[BaseException]],
124 ex_val: Optional[BaseException],
125 ex_tb: Optional[TracebackType]) -> Literal[False]:
126 """
127 Exit the context of the scheduler.
128 """
129 if ex_val is None:
130 _LOG.debug("Scheduler END :: %s", self)
131 else:
132 assert ex_type and ex_val
133 _LOG.warning("Scheduler END :: %s", self, exc_info=(ex_type, ex_val, ex_tb))
134 assert self.experiment is not None
135 self.experiment.__exit__(ex_type, ex_val, ex_tb)
136 self.optimizer.__exit__(ex_type, ex_val, ex_tb)
137 self.environment.__exit__(ex_type, ex_val, ex_tb)
138 self.experiment = None
139 return False # Do not suppress exceptions
141 @abstractmethod
142 def start(self) -> None:
143 """
144 Start the optimization loop.
145 """
146 assert self.experiment is not None
147 _LOG.info("START: Experiment: %s Env: %s Optimizer: %s",
148 self.experiment, self.environment, self.optimizer)
149 if _LOG.isEnabledFor(logging.INFO):
150 _LOG.info("Root Environment:\n%s", self.environment.pprint())
152 if self._config_id > 0:
153 tunables = self.load_config(self._config_id)
154 self.schedule_trial(tunables)
156 def teardown(self) -> None:
157 """
158 Tear down the environment.
159 Call it after the completion of the `.start()` in the scheduler context.
160 """
161 assert self.experiment is not None
162 if self._do_teardown:
163 self.environment.teardown()
165 def get_best_observation(self) -> Tuple[Optional[float], Optional[TunableGroups]]:
166 """
167 Get the best observation from the optimizer.
168 """
169 (best_score, best_config) = self.optimizer.get_best_observation()
170 _LOG.info("Env: %s best score: %s", self.environment, best_score)
171 return (best_score, best_config)
173 def load_config(self, config_id: int) -> TunableGroups:
174 """
175 Load the existing tunable configuration from the storage.
176 """
177 assert self.experiment is not None
178 tunable_values = self.experiment.load_tunable_config(config_id)
179 tunables = self.environment.tunable_params.assign(tunable_values)
180 _LOG.info("Load config from storage: %d", config_id)
181 if _LOG.isEnabledFor(logging.DEBUG):
182 _LOG.debug("Config %d ::\n%s", config_id, json.dumps(tunable_values, indent=2))
183 return tunables
185 def _schedule_new_optimizer_suggestions(self) -> bool:
186 """
187 Optimizer part of the loop. Load the results of the executed trials
188 into the optimizer, suggest new configurations, and add them to the queue.
189 Return True if optimization is not over, False otherwise.
190 """
191 assert self.experiment is not None
192 (trial_ids, configs, scores, status) = self.experiment.load(self._last_trial_id)
193 _LOG.info("QUEUE: Update the optimizer with trial results: %s", trial_ids)
194 self.optimizer.bulk_register(configs, scores, status)
195 self._last_trial_id = max(trial_ids, default=self._last_trial_id)
197 not_done = self.not_done()
198 if not_done:
199 tunables = self.optimizer.suggest()
200 self.schedule_trial(tunables)
202 return not_done
204 def schedule_trial(self, tunables: TunableGroups) -> None:
205 """
206 Add a configuration to the queue of trials.
207 """
208 for repeat_i in range(1, self._trial_config_repeat_count + 1):
209 self._add_trial_to_queue(tunables, config={
210 # Add some additional metadata to track for the trial such as the
211 # optimizer config used.
212 # Note: these values are unfortunately mutable at the moment.
213 # Consider them as hints of what the config was the trial *started*.
214 # It is possible that the experiment configs were changed
215 # between resuming the experiment (since that is not currently
216 # prevented).
217 # TODO: Improve for supporting multi-objective
218 # (e.g., opt_target_1, opt_target_2, ... and opt_direction_1, opt_direction_2, ...)
219 "optimizer": self.optimizer.name,
220 "opt_target": self.optimizer.target,
221 "opt_direction": self.optimizer.direction,
222 "repeat_i": repeat_i,
223 "is_defaults": tunables.is_defaults,
224 })
226 def _add_trial_to_queue(self, tunables: TunableGroups,
227 ts_start: Optional[datetime] = None,
228 config: Optional[Dict[str, Any]] = None) -> None:
229 """
230 Add a configuration to the queue of trials.
231 A wrapper for the `Experiment.new_trial` method.
232 """
233 assert self.experiment is not None
234 trial = self.experiment.new_trial(tunables, ts_start, config)
235 _LOG.info("QUEUE: Add new trial: %s", trial)
237 def _run_schedule(self, running: bool = False) -> None:
238 """
239 Scheduler part of the loop. Check for pending trials in the queue and run them.
240 """
241 assert self.experiment is not None
242 for trial in self.experiment.pending_trials(datetime.now(UTC), running=running):
243 self.run_trial(trial)
245 def not_done(self) -> bool:
246 """
247 Check the stopping conditions.
248 By default, stop when the optimizer converges or max limit of trials reached.
249 """
250 return self.optimizer.not_converged() and (
251 self._trial_count < self._max_trials or self._max_trials <= 0
252 )
254 @abstractmethod
255 def run_trial(self, trial: Storage.Trial) -> None:
256 """
257 Set up and run a single trial. Save the results in the storage.
258 """
259 assert self.experiment is not None
260 self._trial_count += 1
261 _LOG.info("QUEUE: Execute trial # %d/%d :: %s", self._trial_count, self._max_trials, trial)