Coverage for mlos_bench/mlos_bench/storage/sql/experiment.py: 93%
115 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5""":py:class:`.Storage.Experiment` interface implementation for saving and restoring
6the benchmark experiment data using `SQLAlchemy <https://sqlalchemy.org>`_ backend.
7"""
9import hashlib
10import logging
11from collections.abc import Iterator
12from datetime import datetime
13from typing import Any, Literal
15from pytz import UTC
16from sqlalchemy import Connection, CursorResult, Table, column, func, select
17from sqlalchemy.engine import Engine
19from mlos_bench.environments.status import Status
20from mlos_bench.storage.base_storage import Storage
21from mlos_bench.storage.sql.common import save_params
22from mlos_bench.storage.sql.schema import DbSchema
23from mlos_bench.storage.sql.trial import Trial
24from mlos_bench.tunables.tunable_groups import TunableGroups
25from mlos_bench.util import utcify_timestamp
27_LOG = logging.getLogger(__name__)
30class Experiment(Storage.Experiment):
31 """Logic for retrieving and storing the results of a single experiment."""
33 def __init__( # pylint: disable=too-many-arguments
34 self,
35 *,
36 engine: Engine,
37 schema: DbSchema,
38 tunables: TunableGroups,
39 experiment_id: str,
40 trial_id: int,
41 root_env_config: str,
42 description: str,
43 opt_targets: dict[str, Literal["min", "max"]],
44 ):
45 super().__init__(
46 tunables=tunables,
47 experiment_id=experiment_id,
48 trial_id=trial_id,
49 root_env_config=root_env_config,
50 description=description,
51 opt_targets=opt_targets,
52 )
53 self._engine = engine
54 self._schema = schema
56 def _setup(self) -> None:
57 super()._setup()
58 with self._engine.begin() as conn:
59 # Get git info and the last trial ID for the experiment.
60 # pylint: disable=not-callable
61 exp_info = conn.execute(
62 self._schema.experiment.select()
63 .with_only_columns(
64 self._schema.experiment.c.git_repo,
65 self._schema.experiment.c.git_commit,
66 self._schema.experiment.c.root_env_config,
67 func.max(self._schema.trial.c.trial_id).label("trial_id"),
68 )
69 .join(
70 self._schema.trial,
71 self._schema.trial.c.exp_id == self._schema.experiment.c.exp_id,
72 isouter=True,
73 )
74 .where(
75 self._schema.experiment.c.exp_id == self._experiment_id,
76 )
77 .group_by(
78 self._schema.experiment.c.git_repo,
79 self._schema.experiment.c.git_commit,
80 self._schema.experiment.c.root_env_config,
81 )
82 ).fetchone()
83 if exp_info is None:
84 _LOG.info("Start new experiment: %s", self._experiment_id)
85 # It's a new experiment: create a record for it in the database.
86 conn.execute(
87 self._schema.experiment.insert().values(
88 exp_id=self._experiment_id,
89 description=self._description,
90 git_repo=self._git_repo,
91 git_commit=self._git_commit,
92 root_env_config=self._root_env_config,
93 )
94 )
95 conn.execute(
96 self._schema.objectives.insert().values(
97 [
98 {
99 "exp_id": self._experiment_id,
100 "optimization_target": opt_target,
101 "optimization_direction": opt_dir,
102 }
103 for (opt_target, opt_dir) in self.opt_targets.items()
104 ]
105 )
106 )
107 else:
108 if exp_info.trial_id is not None:
109 self._trial_id = exp_info.trial_id + 1
110 _LOG.info(
111 "Continue experiment: %s last trial: %s resume from: %d",
112 self._experiment_id,
113 exp_info.trial_id,
114 self._trial_id,
115 )
116 # TODO: Sanity check that certain critical configs (e.g.,
117 # objectives) haven't changed to be incompatible such that a new
118 # experiment should be started (possibly by prewarming with the
119 # previous one).
120 if exp_info.git_commit != self._git_commit:
121 _LOG.warning(
122 "Experiment %s git expected: %s %s",
123 self,
124 exp_info.git_repo,
125 exp_info.git_commit,
126 )
128 def merge(self, experiment_ids: list[str]) -> None:
129 _LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids)
130 raise NotImplementedError("TODO: Merging experiments not implemented yet.")
132 def load_tunable_config(self, config_id: int) -> dict[str, Any]:
133 with self._engine.connect() as conn:
134 return self._get_key_val(conn, self._schema.config_param, "param", config_id=config_id)
136 def load_telemetry(self, trial_id: int) -> list[tuple[datetime, str, Any]]:
137 with self._engine.connect() as conn:
138 cur_telemetry = conn.execute(
139 self._schema.trial_telemetry.select()
140 .where(
141 self._schema.trial_telemetry.c.exp_id == self._experiment_id,
142 self._schema.trial_telemetry.c.trial_id == trial_id,
143 )
144 .order_by(
145 self._schema.trial_telemetry.c.ts,
146 self._schema.trial_telemetry.c.metric_id,
147 )
148 )
149 # Not all storage backends store the original zone info.
150 # We try to ensure data is entered in UTC and augment it on return again here.
151 return [
152 (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value)
153 for row in cur_telemetry.fetchall()
154 ]
156 def load(
157 self,
158 last_trial_id: int = -1,
159 ) -> tuple[list[int], list[dict], list[dict[str, Any] | None], list[Status]]:
161 with self._engine.connect() as conn:
162 cur_trials = conn.execute(
163 self._schema.trial.select()
164 .with_only_columns(
165 self._schema.trial.c.trial_id,
166 self._schema.trial.c.config_id,
167 self._schema.trial.c.status,
168 )
169 .where(
170 self._schema.trial.c.exp_id == self._experiment_id,
171 self._schema.trial.c.trial_id > last_trial_id,
172 self._schema.trial.c.status.in_(
173 [
174 Status.SUCCEEDED.name,
175 Status.FAILED.name,
176 Status.TIMED_OUT.name,
177 ]
178 ),
179 )
180 .order_by(
181 self._schema.trial.c.trial_id.asc(),
182 )
183 )
185 trial_ids: list[int] = []
186 configs: list[dict[str, Any]] = []
187 scores: list[dict[str, Any] | None] = []
188 status: list[Status] = []
190 for trial in cur_trials.fetchall():
191 stat = Status.parse(trial.status)
192 status.append(stat)
193 trial_ids.append(trial.trial_id)
194 configs.append(
195 self._get_key_val(
196 conn,
197 self._schema.config_param,
198 "param",
199 config_id=trial.config_id,
200 )
201 )
202 if stat.is_succeeded():
203 scores.append(
204 self._get_key_val(
205 conn,
206 self._schema.trial_result,
207 "metric",
208 exp_id=self._experiment_id,
209 trial_id=trial.trial_id,
210 )
211 )
212 else:
213 scores.append(None)
215 return (trial_ids, configs, scores, status)
217 @staticmethod
218 def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> dict[str, Any]:
219 """
220 Helper method to retrieve key-value pairs from the database.
222 (E.g., configurations, results, and telemetry).
223 """
224 cur_result: CursorResult[tuple[str, Any]] = conn.execute(
225 select(
226 column(f"{field}_id"),
227 column(f"{field}_value"),
228 )
229 .select_from(table)
230 .where(*[column(key) == val for (key, val) in kwargs.items()])
231 )
232 # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to
233 # avoid naming conflicts.
234 return dict(
235 row._tuple() for row in cur_result.fetchall() # pylint: disable=protected-access
236 )
238 def get_trial_by_id(
239 self,
240 trial_id: int,
241 ) -> Storage.Trial | None:
242 with self._engine.connect() as conn:
243 cur_trial = conn.execute(
244 self._schema.trial.select().where(
245 self._schema.trial.c.exp_id == self._experiment_id,
246 self._schema.trial.c.trial_id == trial_id,
247 )
248 )
249 trial = cur_trial.fetchone()
250 if trial is None:
251 return None
252 tunables = self._get_key_val(
253 conn,
254 self._schema.config_param,
255 "param",
256 config_id=trial.config_id,
257 )
258 config = self._get_key_val(
259 conn,
260 self._schema.trial_param,
261 "param",
262 exp_id=self._experiment_id,
263 trial_id=trial.trial_id,
264 )
265 return Trial(
266 engine=self._engine,
267 schema=self._schema,
268 # Reset .is_updated flag after the assignment:
269 tunables=self._tunables.copy().assign(tunables).reset(),
270 experiment_id=self._experiment_id,
271 trial_id=trial.trial_id,
272 config_id=trial.config_id,
273 trial_runner_id=trial.trial_runner_id,
274 opt_targets=self._opt_targets,
275 status=Status.parse(trial.status),
276 restoring=True,
277 config=config,
278 )
280 def pending_trials(
281 self,
282 timestamp: datetime,
283 *,
284 running: bool = False,
285 trial_runner_assigned: bool | None = None,
286 ) -> Iterator[Storage.Trial]:
287 timestamp = utcify_timestamp(timestamp, origin="local")
288 _LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp)
289 if running:
290 statuses = [Status.PENDING, Status.READY, Status.RUNNING]
291 else:
292 statuses = [Status.PENDING]
293 with self._engine.connect() as conn:
294 stmt = self._schema.trial.select().where(
295 self._schema.trial.c.exp_id == self._experiment_id,
296 (
297 self._schema.trial.c.ts_start.is_(None)
298 | (self._schema.trial.c.ts_start <= timestamp)
299 ),
300 self._schema.trial.c.ts_end.is_(None),
301 self._schema.trial.c.status.in_([s.name for s in statuses]),
302 )
303 if trial_runner_assigned:
304 stmt = stmt.where(self._schema.trial.c.trial_runner_id.isnot(None))
305 elif trial_runner_assigned is False:
306 stmt = stmt.where(self._schema.trial.c.trial_runner_id.is_(None))
307 # else: # No filtering by trial_runner_id
308 cur_trials = conn.execute(stmt)
309 for trial in cur_trials.fetchall():
310 tunables = self._get_key_val(
311 conn,
312 self._schema.config_param,
313 "param",
314 config_id=trial.config_id,
315 )
316 config = self._get_key_val(
317 conn,
318 self._schema.trial_param,
319 "param",
320 exp_id=self._experiment_id,
321 trial_id=trial.trial_id,
322 )
323 yield Trial(
324 engine=self._engine,
325 schema=self._schema,
326 # Reset .is_updated flag after the assignment:
327 tunables=self._tunables.copy().assign(tunables).reset(),
328 experiment_id=self._experiment_id,
329 trial_id=trial.trial_id,
330 config_id=trial.config_id,
331 trial_runner_id=trial.trial_runner_id,
332 opt_targets=self._opt_targets,
333 status=Status.parse(trial.status),
334 restoring=True,
335 config=config,
336 )
338 def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int:
339 """
340 Get the config ID for the given tunables.
342 If the config does not exist, create a new record for it.
343 """
344 config_hash = hashlib.sha256(str(tunables).encode("utf-8")).hexdigest()
345 cur_config = conn.execute(
346 self._schema.config.select().where(self._schema.config.c.config_hash == config_hash)
347 ).fetchone()
348 if cur_config is not None:
349 return int(cur_config.config_id) # mypy doesn't know it's always int
350 # Config not found, create a new one:
351 new_config_result = conn.execute(
352 self._schema.config.insert().values(config_hash=config_hash)
353 ).inserted_primary_key
354 assert new_config_result
355 config_id: int = new_config_result[0]
356 save_params(
357 conn,
358 self._schema.config_param,
359 {tunable.name: tunable.value for (tunable, _group) in tunables},
360 config_id=config_id,
361 )
362 return config_id
364 def _new_trial(
365 self,
366 tunables: TunableGroups,
367 ts_start: datetime | None = None,
368 config: dict[str, Any] | None = None,
369 ) -> Storage.Trial:
370 ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local")
371 _LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start)
372 with self._engine.begin() as conn:
373 try:
374 new_trial_status = Status.PENDING
375 config_id = self._get_config_id(conn, tunables)
376 conn.execute(
377 self._schema.trial.insert().values(
378 exp_id=self._experiment_id,
379 trial_id=self._trial_id,
380 config_id=config_id,
381 ts_start=ts_start,
382 status=new_trial_status.name,
383 )
384 )
386 # Note: config here is the framework config, not the target
387 # environment config (i.e., tunables).
388 if config is not None:
389 save_params(
390 conn,
391 self._schema.trial_param,
392 config,
393 exp_id=self._experiment_id,
394 trial_id=self._trial_id,
395 )
397 trial = Trial(
398 engine=self._engine,
399 schema=self._schema,
400 tunables=tunables,
401 experiment_id=self._experiment_id,
402 trial_id=self._trial_id,
403 config_id=config_id,
404 trial_runner_id=None, # initially, Trials are not assigned to a TrialRunner
405 opt_targets=self._opt_targets,
406 status=new_trial_status,
407 restoring=False,
408 config=config,
409 )
410 self._trial_id += 1
411 return trial
412 except Exception:
413 conn.rollback()
414 raise