Coverage for mlos_bench/mlos_bench/schedulers/base_scheduler.py: 91%
129 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Base class for the optimization loop scheduling policies."""
7import json
8import logging
9from abc import ABCMeta, abstractmethod
10from datetime import datetime
11from types import TracebackType
12from typing import Any, Dict, List, Literal, Optional, Tuple, Type
14from pytz import UTC
16from mlos_bench.config.schemas import ConfigSchema
17from mlos_bench.environments.base_environment import Environment
18from mlos_bench.optimizers.base_optimizer import Optimizer
19from mlos_bench.storage.base_storage import Storage
20from mlos_bench.tunables.tunable_groups import TunableGroups
21from mlos_bench.util import merge_parameters
23_LOG = logging.getLogger(__name__)
26class Scheduler(metaclass=ABCMeta):
27 # pylint: disable=too-many-instance-attributes
28 """Base class for the optimization loop scheduling policies."""
30 def __init__( # pylint: disable=too-many-arguments
31 self,
32 *,
33 config: Dict[str, Any],
34 global_config: Dict[str, Any],
35 environment: Environment,
36 optimizer: Optimizer,
37 storage: Storage,
38 root_env_config: str,
39 ):
40 """
41 Create a new instance of the scheduler. The constructor of this and the derived
42 classes is called by the persistence service after reading the class JSON
43 configuration. Other objects like the Environment and Optimizer are provided by
44 the Launcher.
46 Parameters
47 ----------
48 config : dict
49 The configuration for the scheduler.
50 global_config : dict
51 he global configuration for the experiment.
52 environment : Environment
53 The environment to benchmark/optimize.
54 optimizer : Optimizer
55 The optimizer to use.
56 storage : Storage
57 The storage to use.
58 root_env_config : str
59 Path to the root environment configuration.
60 """
61 self.global_config = global_config
62 config = merge_parameters(
63 dest=config.copy(),
64 source=global_config,
65 required_keys=["experiment_id", "trial_id"],
66 )
67 self._validate_json_config(config)
69 self._experiment_id = config["experiment_id"].strip()
70 self._trial_id = int(config["trial_id"])
71 self._config_id = int(config.get("config_id", -1))
72 self._max_trials = int(config.get("max_trials", -1))
73 self._trial_count = 0
75 self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1))
76 if self._trial_config_repeat_count <= 0:
77 raise ValueError(
78 f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}"
79 )
81 self._do_teardown = bool(config.get("teardown", True))
83 self.experiment: Optional[Storage.Experiment] = None
84 self.environment = environment
85 self.optimizer = optimizer
86 self.storage = storage
87 self._root_env_config = root_env_config
88 self._last_trial_id = -1
89 self._ran_trials: List[Storage.Trial] = []
91 _LOG.debug("Scheduler instantiated: %s :: %s", self, config)
93 def _validate_json_config(self, config: dict) -> None:
94 """Reconstructs a basic json config that this class might have been instantiated
95 from in order to validate configs provided outside the file loading
96 mechanism.
97 """
98 json_config: dict = {
99 "class": self.__class__.__module__ + "." + self.__class__.__name__,
100 }
101 if config:
102 json_config["config"] = config.copy()
103 # The json schema does not allow for -1 as a valid value for config_id.
104 # As it is just a default placeholder value, and not required, we can
105 # remove it from the config copy prior to validation safely.
106 config_id = json_config["config"].get("config_id")
107 if config_id is not None and isinstance(config_id, int) and config_id < 0:
108 json_config["config"].pop("config_id")
109 ConfigSchema.SCHEDULER.validate(json_config)
111 @property
112 def trial_config_repeat_count(self) -> int:
113 """Gets the number of trials to run for a given config."""
114 return self._trial_config_repeat_count
116 @property
117 def trial_count(self) -> int:
118 """Gets the current number of trials run for the experiment."""
119 return self._trial_count
121 @property
122 def max_trials(self) -> int:
123 """Gets the maximum number of trials to run for a given experiment, or -1 for no
124 limit.
125 """
126 return self._max_trials
128 def __repr__(self) -> str:
129 """
130 Produce a human-readable version of the Scheduler (mostly for logging).
132 Returns
133 -------
134 string : str
135 A human-readable version of the Scheduler.
136 """
137 return self.__class__.__name__
139 def __enter__(self) -> "Scheduler":
140 """Enter the scheduler's context."""
141 _LOG.debug("Scheduler START :: %s", self)
142 assert self.experiment is None
143 self.environment.__enter__()
144 self.optimizer.__enter__()
145 # Start new or resume the existing experiment. Verify that the
146 # experiment configuration is compatible with the previous runs.
147 # If the `merge` config parameter is present, merge in the data
148 # from other experiments and check for compatibility.
149 self.experiment = self.storage.experiment(
150 experiment_id=self._experiment_id,
151 trial_id=self._trial_id,
152 root_env_config=self._root_env_config,
153 description=self.environment.name,
154 tunables=self.environment.tunable_params,
155 opt_targets=self.optimizer.targets,
156 ).__enter__()
157 return self
159 def __exit__(
160 self,
161 ex_type: Optional[Type[BaseException]],
162 ex_val: Optional[BaseException],
163 ex_tb: Optional[TracebackType],
164 ) -> Literal[False]:
165 """Exit the context of the scheduler."""
166 if ex_val is None:
167 _LOG.debug("Scheduler END :: %s", self)
168 else:
169 assert ex_type and ex_val
170 _LOG.warning("Scheduler END :: %s", self, exc_info=(ex_type, ex_val, ex_tb))
171 assert self.experiment is not None
172 self.experiment.__exit__(ex_type, ex_val, ex_tb)
173 self.optimizer.__exit__(ex_type, ex_val, ex_tb)
174 self.environment.__exit__(ex_type, ex_val, ex_tb)
175 self.experiment = None
176 return False # Do not suppress exceptions
178 @abstractmethod
179 def start(self) -> None:
180 """Start the optimization loop."""
181 assert self.experiment is not None
182 _LOG.info(
183 "START: Experiment: %s Env: %s Optimizer: %s",
184 self.experiment,
185 self.environment,
186 self.optimizer,
187 )
188 if _LOG.isEnabledFor(logging.INFO):
189 _LOG.info("Root Environment:\n%s", self.environment.pprint())
191 if self._config_id > 0:
192 tunables = self.load_config(self._config_id)
193 self.schedule_trial(tunables)
195 def teardown(self) -> None:
196 """
197 Tear down the environment.
199 Call it after the completion of the `.start()` in the scheduler context.
200 """
201 assert self.experiment is not None
202 if self._do_teardown:
203 self.environment.teardown()
205 def get_best_observation(self) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]:
206 """Get the best observation from the optimizer."""
207 (best_score, best_config) = self.optimizer.get_best_observation()
208 _LOG.info("Env: %s best score: %s", self.environment, best_score)
209 return (best_score, best_config)
211 def load_config(self, config_id: int) -> TunableGroups:
212 """Load the existing tunable configuration from the storage."""
213 assert self.experiment is not None
214 tunable_values = self.experiment.load_tunable_config(config_id)
215 tunables = self.environment.tunable_params.assign(tunable_values)
216 _LOG.info("Load config from storage: %d", config_id)
217 if _LOG.isEnabledFor(logging.DEBUG):
218 _LOG.debug("Config %d ::\n%s", config_id, json.dumps(tunable_values, indent=2))
219 return tunables
221 def _schedule_new_optimizer_suggestions(self) -> bool:
222 """
223 Optimizer part of the loop.
225 Load the results of the executed trials into the optimizer, suggest new
226 configurations, and add them to the queue. Return True if optimization is not
227 over, False otherwise.
228 """
229 assert self.experiment is not None
230 (trial_ids, configs, scores, status) = self.experiment.load(self._last_trial_id)
231 _LOG.info("QUEUE: Update the optimizer with trial results: %s", trial_ids)
232 self.optimizer.bulk_register(configs, scores, status)
233 self._last_trial_id = max(trial_ids, default=self._last_trial_id)
235 not_done = self.not_done()
236 if not_done:
237 tunables = self.optimizer.suggest()
238 self.schedule_trial(tunables)
240 return not_done
242 def schedule_trial(self, tunables: TunableGroups) -> None:
243 """Add a configuration to the queue of trials."""
244 for repeat_i in range(1, self._trial_config_repeat_count + 1):
245 self._add_trial_to_queue(
246 tunables,
247 config={
248 # Add some additional metadata to track for the trial such as the
249 # optimizer config used.
250 # Note: these values are unfortunately mutable at the moment.
251 # Consider them as hints of what the config was the trial *started*.
252 # It is possible that the experiment configs were changed
253 # between resuming the experiment (since that is not currently
254 # prevented).
255 "optimizer": self.optimizer.name,
256 "repeat_i": repeat_i,
257 "is_defaults": tunables.is_defaults(),
258 **{
259 f"opt_{key}_{i}": val
260 for (i, opt_target) in enumerate(self.optimizer.targets.items())
261 for (key, val) in zip(["target", "direction"], opt_target)
262 },
263 },
264 )
266 def _add_trial_to_queue(
267 self,
268 tunables: TunableGroups,
269 ts_start: Optional[datetime] = None,
270 config: Optional[Dict[str, Any]] = None,
271 ) -> None:
272 """
273 Add a configuration to the queue of trials.
275 A wrapper for the `Experiment.new_trial` method.
276 """
277 assert self.experiment is not None
278 trial = self.experiment.new_trial(tunables, ts_start, config)
279 _LOG.info("QUEUE: Add new trial: %s", trial)
281 def _run_schedule(self, running: bool = False) -> None:
282 """
283 Scheduler part of the loop.
285 Check for pending trials in the queue and run them.
286 """
287 assert self.experiment is not None
288 for trial in self.experiment.pending_trials(datetime.now(UTC), running=running):
289 self.run_trial(trial)
291 def not_done(self) -> bool:
292 """
293 Check the stopping conditions.
295 By default, stop when the optimizer converges or max limit of trials reached.
296 """
297 return self.optimizer.not_converged() and (
298 self._trial_count < self._max_trials or self._max_trials <= 0
299 )
301 @abstractmethod
302 def run_trial(self, trial: Storage.Trial) -> None:
303 """
304 Set up and run a single trial.
306 Save the results in the storage.
307 """
308 assert self.experiment is not None
309 self._trial_count += 1
310 self._ran_trials.append(trial)
311 _LOG.info("QUEUE: Execute trial # %d/%d :: %s", self._trial_count, self._max_trials, trial)
313 @property
314 def ran_trials(self) -> List[Storage.Trial]:
315 """Get the list of trials that were run."""
316 return self._ran_trials