Coverage for mlos_bench/mlos_bench/storage/sql/trial.py: 99%
72 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5""":py:class:`.Storage.Trial` interface implementation for saving and restoring
6the benchmark trial data using `SQLAlchemy <https://sqlalchemy.org>`_ backend.
7"""
10import logging
11from collections.abc import Mapping
12from datetime import datetime
13from typing import Any, Literal
15from sqlalchemy import or_
16from sqlalchemy.engine import Connection, Engine
17from sqlalchemy.exc import IntegrityError
19from mlos_bench.environments.status import Status
20from mlos_bench.storage.base_storage import Storage
21from mlos_bench.storage.sql.common import save_params
22from mlos_bench.storage.sql.schema import DbSchema
23from mlos_bench.tunables.tunable_groups import TunableGroups
24from mlos_bench.util import nullable, utcify_timestamp
26_LOG = logging.getLogger(__name__)
29class Trial(Storage.Trial):
30 """Store the results of a single run of the experiment in SQL database."""
32 def __init__( # pylint: disable=too-many-arguments
33 self,
34 *,
35 engine: Engine,
36 schema: DbSchema,
37 tunables: TunableGroups,
38 experiment_id: str,
39 trial_id: int,
40 config_id: int,
41 trial_runner_id: int | None,
42 opt_targets: dict[str, Literal["min", "max"]],
43 status: Status,
44 restoring: bool,
45 config: dict[str, Any] | None = None,
46 ):
47 super().__init__(
48 tunables=tunables,
49 experiment_id=experiment_id,
50 trial_id=trial_id,
51 tunable_config_id=config_id,
52 trial_runner_id=trial_runner_id,
53 opt_targets=opt_targets,
54 status=status,
55 restoring=restoring,
56 config=config,
57 )
58 self._engine = engine
59 self._schema = schema
61 def set_trial_runner(self, trial_runner_id: int) -> int:
62 trial_runner_id = super().set_trial_runner(trial_runner_id)
63 with self._engine.begin() as conn:
64 conn.execute(
65 self._schema.trial.update()
66 .where(
67 self._schema.trial.c.exp_id == self._experiment_id,
68 self._schema.trial.c.trial_id == self._trial_id,
69 (
70 or_(
71 self._schema.trial.c.trial_runner_id.is_(None),
72 self._schema.trial.c.status == Status.PENDING.name,
73 )
74 ),
75 )
76 .values(
77 trial_runner_id=trial_runner_id,
78 )
79 )
80 # Guard against concurrent updates.
81 with self._engine.begin() as conn:
82 trial_runner_rs = conn.execute(
83 self._schema.trial.select()
84 .with_only_columns(
85 self._schema.trial.c.trial_runner_id,
86 )
87 .where(
88 self._schema.trial.c.exp_id == self._experiment_id,
89 self._schema.trial.c.trial_id == self._trial_id,
90 )
91 )
92 trial_runner_row = trial_runner_rs.fetchone()
93 assert trial_runner_row
94 self._trial_runner_id = trial_runner_row.trial_runner_id
95 assert isinstance(self._trial_runner_id, int)
96 return self._trial_runner_id
98 def _save_new_config_data(self, new_config_data: Mapping[str, int | float | str]) -> None:
99 with self._engine.begin() as conn:
100 save_params(
101 conn,
102 self._schema.trial_param,
103 new_config_data,
104 exp_id=self._experiment_id,
105 trial_id=self._trial_id,
106 )
108 def update(
109 self,
110 status: Status,
111 timestamp: datetime,
112 metrics: dict[str, Any] | None = None,
113 ) -> dict[str, Any] | None:
114 # Make sure to convert the timestamp to UTC before storing it in the database.
115 timestamp = utcify_timestamp(timestamp, origin="local")
116 metrics = super().update(status, timestamp, metrics)
117 with self._engine.begin() as conn:
118 self._update_status(conn, status, timestamp)
119 try:
120 if status.is_completed():
121 # Final update of the status and ts_end:
122 cur_status = conn.execute(
123 self._schema.trial.update()
124 .where(
125 self._schema.trial.c.exp_id == self._experiment_id,
126 self._schema.trial.c.trial_id == self._trial_id,
127 self._schema.trial.c.ts_end.is_(None),
128 self._schema.trial.c.status.notin_(
129 [
130 Status.SUCCEEDED.name,
131 Status.CANCELED.name,
132 Status.FAILED.name,
133 Status.TIMED_OUT.name,
134 ]
135 ),
136 )
137 .values(
138 status=status.name,
139 ts_end=timestamp,
140 )
141 )
142 if cur_status.rowcount not in {1, -1}:
143 _LOG.warning("Trial %s :: update failed: %s", self, status)
144 raise RuntimeError(
145 f"Failed to update the status of the trial {self} to {status}. "
146 f"({cur_status.rowcount} rows)"
147 )
148 if metrics:
149 conn.execute(
150 self._schema.trial_result.insert().values(
151 [
152 {
153 "exp_id": self._experiment_id,
154 "trial_id": self._trial_id,
155 "metric_id": key,
156 "metric_value": nullable(str, val),
157 }
158 for (key, val) in metrics.items()
159 ]
160 )
161 )
162 else:
163 # Update of the status and ts_start when starting the trial:
164 assert metrics is None, f"Unexpected metrics for status: {status}"
165 cur_status = conn.execute(
166 self._schema.trial.update()
167 .where(
168 self._schema.trial.c.exp_id == self._experiment_id,
169 self._schema.trial.c.trial_id == self._trial_id,
170 self._schema.trial.c.ts_end.is_(None),
171 self._schema.trial.c.status.notin_(
172 [
173 Status.RUNNING.name,
174 Status.SUCCEEDED.name,
175 Status.CANCELED.name,
176 Status.FAILED.name,
177 Status.TIMED_OUT.name,
178 ]
179 ),
180 )
181 .values(
182 status=status.name,
183 ts_start=timestamp,
184 )
185 )
186 if cur_status.rowcount not in {1, -1}:
187 # Keep the old status and timestamp if already running, but log it.
188 _LOG.warning("Trial %s :: cannot be updated to: %s", self, status)
189 except Exception:
190 conn.rollback()
191 raise
192 return metrics
194 def update_telemetry(
195 self,
196 status: Status,
197 timestamp: datetime,
198 metrics: list[tuple[datetime, str, Any]],
199 ) -> None:
200 super().update_telemetry(status, timestamp, metrics)
201 # Make sure to convert the timestamp to UTC before storing it in the database.
202 timestamp = utcify_timestamp(timestamp, origin="local")
203 metrics = [(utcify_timestamp(ts, origin="local"), key, val) for (ts, key, val) in metrics]
204 # NOTE: Not every SQLAlchemy dialect supports `Insert.on_conflict_do_nothing()`
205 # and we need to keep `.update_telemetry()` idempotent; hence a loop instead of
206 # a bulk upsert.
207 # See Also: comments in <https://github.com/microsoft/MLOS/pull/466>
208 with self._engine.begin() as conn:
209 self._update_status(conn, status, timestamp)
210 for metric_ts, key, val in metrics:
211 with self._engine.begin() as conn:
212 try:
213 conn.execute(
214 self._schema.trial_telemetry.insert().values(
215 exp_id=self._experiment_id,
216 trial_id=self._trial_id,
217 ts=metric_ts,
218 metric_id=key,
219 metric_value=nullable(str, val),
220 )
221 )
222 except IntegrityError as ex:
223 _LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex)
225 def _update_status(self, conn: Connection, status: Status, timestamp: datetime) -> None:
226 """
227 Insert a new status record into the database.
229 This call is idempotent.
230 """
231 # Make sure to convert the timestamp to UTC before storing it in the database.
232 timestamp = utcify_timestamp(timestamp, origin="local")
233 try:
234 conn.execute(
235 self._schema.trial_status.insert().values(
236 exp_id=self._experiment_id,
237 trial_id=self._trial_id,
238 ts=timestamp,
239 status=status.name,
240 )
241 )
242 except IntegrityError as ex:
243 _LOG.warning(
244 "Status with that timestamp already exists: %s %s :: %s",
245 self,
246 timestamp,
247 ex,
248 )