Coverage for mlos_bench/mlos_bench/storage/sql/experiment.py: 92%
101 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5""":py:class:`.Storage.Experiment` interface implementation for saving and restoring
6the benchmark experiment data using `SQLAlchemy <https://sqlalchemy.org>`_ backend.
7"""
9import hashlib
10import logging
11from datetime import datetime
12from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple
14from pytz import UTC
15from sqlalchemy import Connection, CursorResult, Table, column, func, select
16from sqlalchemy.engine import Engine
18from mlos_bench.environments.status import Status
19from mlos_bench.storage.base_storage import Storage
20from mlos_bench.storage.sql.schema import DbSchema
21from mlos_bench.storage.sql.trial import Trial
22from mlos_bench.tunables.tunable_groups import TunableGroups
23from mlos_bench.util import nullable, utcify_timestamp
25_LOG = logging.getLogger(__name__)
28class Experiment(Storage.Experiment):
29 """Logic for retrieving and storing the results of a single experiment."""
31 def __init__( # pylint: disable=too-many-arguments
32 self,
33 *,
34 engine: Engine,
35 schema: DbSchema,
36 tunables: TunableGroups,
37 experiment_id: str,
38 trial_id: int,
39 root_env_config: str,
40 description: str,
41 opt_targets: Dict[str, Literal["min", "max"]],
42 ):
43 super().__init__(
44 tunables=tunables,
45 experiment_id=experiment_id,
46 trial_id=trial_id,
47 root_env_config=root_env_config,
48 description=description,
49 opt_targets=opt_targets,
50 )
51 self._engine = engine
52 self._schema = schema
54 def _setup(self) -> None:
55 super()._setup()
56 with self._engine.begin() as conn:
57 # Get git info and the last trial ID for the experiment.
58 # pylint: disable=not-callable
59 exp_info = conn.execute(
60 self._schema.experiment.select()
61 .with_only_columns(
62 self._schema.experiment.c.git_repo,
63 self._schema.experiment.c.git_commit,
64 self._schema.experiment.c.root_env_config,
65 func.max(self._schema.trial.c.trial_id).label("trial_id"),
66 )
67 .join(
68 self._schema.trial,
69 self._schema.trial.c.exp_id == self._schema.experiment.c.exp_id,
70 isouter=True,
71 )
72 .where(
73 self._schema.experiment.c.exp_id == self._experiment_id,
74 )
75 .group_by(
76 self._schema.experiment.c.git_repo,
77 self._schema.experiment.c.git_commit,
78 self._schema.experiment.c.root_env_config,
79 )
80 ).fetchone()
81 if exp_info is None:
82 _LOG.info("Start new experiment: %s", self._experiment_id)
83 # It's a new experiment: create a record for it in the database.
84 conn.execute(
85 self._schema.experiment.insert().values(
86 exp_id=self._experiment_id,
87 description=self._description,
88 git_repo=self._git_repo,
89 git_commit=self._git_commit,
90 root_env_config=self._root_env_config,
91 )
92 )
93 conn.execute(
94 self._schema.objectives.insert().values(
95 [
96 {
97 "exp_id": self._experiment_id,
98 "optimization_target": opt_target,
99 "optimization_direction": opt_dir,
100 }
101 for (opt_target, opt_dir) in self.opt_targets.items()
102 ]
103 )
104 )
105 else:
106 if exp_info.trial_id is not None:
107 self._trial_id = exp_info.trial_id + 1
108 _LOG.info(
109 "Continue experiment: %s last trial: %s resume from: %d",
110 self._experiment_id,
111 exp_info.trial_id,
112 self._trial_id,
113 )
114 # TODO: Sanity check that certain critical configs (e.g.,
115 # objectives) haven't changed to be incompatible such that a new
116 # experiment should be started (possibly by prewarming with the
117 # previous one).
118 if exp_info.git_commit != self._git_commit:
119 _LOG.warning(
120 "Experiment %s git expected: %s %s",
121 self,
122 exp_info.git_repo,
123 exp_info.git_commit,
124 )
126 def merge(self, experiment_ids: List[str]) -> None:
127 _LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids)
128 raise NotImplementedError("TODO")
130 def load_tunable_config(self, config_id: int) -> Dict[str, Any]:
131 with self._engine.connect() as conn:
132 return self._get_key_val(conn, self._schema.config_param, "param", config_id=config_id)
134 def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]:
135 with self._engine.connect() as conn:
136 cur_telemetry = conn.execute(
137 self._schema.trial_telemetry.select()
138 .where(
139 self._schema.trial_telemetry.c.exp_id == self._experiment_id,
140 self._schema.trial_telemetry.c.trial_id == trial_id,
141 )
142 .order_by(
143 self._schema.trial_telemetry.c.ts,
144 self._schema.trial_telemetry.c.metric_id,
145 )
146 )
147 # Not all storage backends store the original zone info.
148 # We try to ensure data is entered in UTC and augment it on return again here.
149 return [
150 (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value)
151 for row in cur_telemetry.fetchall()
152 ]
154 def load(
155 self,
156 last_trial_id: int = -1,
157 ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]:
159 with self._engine.connect() as conn:
160 cur_trials = conn.execute(
161 self._schema.trial.select()
162 .with_only_columns(
163 self._schema.trial.c.trial_id,
164 self._schema.trial.c.config_id,
165 self._schema.trial.c.status,
166 )
167 .where(
168 self._schema.trial.c.exp_id == self._experiment_id,
169 self._schema.trial.c.trial_id > last_trial_id,
170 self._schema.trial.c.status.in_(["SUCCEEDED", "FAILED", "TIMED_OUT"]),
171 )
172 .order_by(
173 self._schema.trial.c.trial_id.asc(),
174 )
175 )
177 trial_ids: List[int] = []
178 configs: List[Dict[str, Any]] = []
179 scores: List[Optional[Dict[str, Any]]] = []
180 status: List[Status] = []
182 for trial in cur_trials.fetchall():
183 stat = Status[trial.status]
184 status.append(stat)
185 trial_ids.append(trial.trial_id)
186 configs.append(
187 self._get_key_val(
188 conn,
189 self._schema.config_param,
190 "param",
191 config_id=trial.config_id,
192 )
193 )
194 if stat.is_succeeded():
195 scores.append(
196 self._get_key_val(
197 conn,
198 self._schema.trial_result,
199 "metric",
200 exp_id=self._experiment_id,
201 trial_id=trial.trial_id,
202 )
203 )
204 else:
205 scores.append(None)
207 return (trial_ids, configs, scores, status)
209 @staticmethod
210 def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> Dict[str, Any]:
211 """
212 Helper method to retrieve key-value pairs from the database.
214 (E.g., configurations, results, and telemetry).
215 """
216 cur_result: CursorResult[Tuple[str, Any]] = conn.execute(
217 select(
218 column(f"{field}_id"),
219 column(f"{field}_value"),
220 )
221 .select_from(table)
222 .where(*[column(key) == val for (key, val) in kwargs.items()])
223 )
224 # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to
225 # avoid naming conflicts.
226 return dict(
227 row._tuple() for row in cur_result.fetchall() # pylint: disable=protected-access
228 )
230 @staticmethod
231 def _save_params(
232 conn: Connection,
233 table: Table,
234 params: Dict[str, Any],
235 **kwargs: Any,
236 ) -> None:
237 if not params:
238 return
239 conn.execute(
240 table.insert(),
241 [
242 {**kwargs, "param_id": key, "param_value": nullable(str, val)}
243 for (key, val) in params.items()
244 ],
245 )
247 def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]:
248 timestamp = utcify_timestamp(timestamp, origin="local")
249 _LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp)
250 if running:
251 pending_status = ["PENDING", "READY", "RUNNING"]
252 else:
253 pending_status = ["PENDING"]
254 with self._engine.connect() as conn:
255 cur_trials = conn.execute(
256 self._schema.trial.select().where(
257 self._schema.trial.c.exp_id == self._experiment_id,
258 (
259 self._schema.trial.c.ts_start.is_(None)
260 | (self._schema.trial.c.ts_start <= timestamp)
261 ),
262 self._schema.trial.c.ts_end.is_(None),
263 self._schema.trial.c.status.in_(pending_status),
264 )
265 )
266 for trial in cur_trials.fetchall():
267 tunables = self._get_key_val(
268 conn,
269 self._schema.config_param,
270 "param",
271 config_id=trial.config_id,
272 )
273 config = self._get_key_val(
274 conn,
275 self._schema.trial_param,
276 "param",
277 exp_id=self._experiment_id,
278 trial_id=trial.trial_id,
279 )
280 yield Trial(
281 engine=self._engine,
282 schema=self._schema,
283 # Reset .is_updated flag after the assignment:
284 tunables=self._tunables.copy().assign(tunables).reset(),
285 experiment_id=self._experiment_id,
286 trial_id=trial.trial_id,
287 config_id=trial.config_id,
288 opt_targets=self._opt_targets,
289 config=config,
290 )
292 def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int:
293 """
294 Get the config ID for the given tunables.
296 If the config does not exist, create a new record for it.
297 """
298 config_hash = hashlib.sha256(str(tunables).encode("utf-8")).hexdigest()
299 cur_config = conn.execute(
300 self._schema.config.select().where(self._schema.config.c.config_hash == config_hash)
301 ).fetchone()
302 if cur_config is not None:
303 return int(cur_config.config_id) # mypy doesn't know it's always int
304 # Config not found, create a new one:
305 config_id: int = conn.execute(
306 self._schema.config.insert().values(config_hash=config_hash)
307 ).inserted_primary_key[0]
308 self._save_params(
309 conn,
310 self._schema.config_param,
311 {tunable.name: tunable.value for (tunable, _group) in tunables},
312 config_id=config_id,
313 )
314 return config_id
316 def _new_trial(
317 self,
318 tunables: TunableGroups,
319 ts_start: Optional[datetime] = None,
320 config: Optional[Dict[str, Any]] = None,
321 ) -> Storage.Trial:
322 # MySQL can round microseconds into the future causing scheduler to skip trials.
323 # Truncate microseconds to avoid this issue.
324 ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local").replace(
325 microsecond=0
326 )
327 _LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start)
328 with self._engine.begin() as conn:
329 try:
330 config_id = self._get_config_id(conn, tunables)
331 conn.execute(
332 self._schema.trial.insert().values(
333 exp_id=self._experiment_id,
334 trial_id=self._trial_id,
335 config_id=config_id,
336 ts_start=ts_start,
337 status="PENDING",
338 )
339 )
341 # Note: config here is the framework config, not the target
342 # environment config (i.e., tunables).
343 if config is not None:
344 self._save_params(
345 conn,
346 self._schema.trial_param,
347 config,
348 exp_id=self._experiment_id,
349 trial_id=self._trial_id,
350 )
352 trial = Trial(
353 engine=self._engine,
354 schema=self._schema,
355 tunables=tunables,
356 experiment_id=self._experiment_id,
357 trial_id=self._trial_id,
358 config_id=config_id,
359 opt_targets=self._opt_targets,
360 config=config,
361 )
362 self._trial_id += 1
363 return trial
364 except Exception:
365 conn.rollback()
366 raise