Coverage for mlos_bench/mlos_bench/storage/sql/trial.py: 99%
73 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-30 00:51 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-30 00:51 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5""":py:class:`.Storage.Trial` interface implementation for saving and restoring
6the benchmark trial data using `SQLAlchemy <https://sqlalchemy.org>`_ backend.
7"""
10import logging
11from collections.abc import Mapping
12from datetime import datetime
13from typing import Any, Literal
15from sqlalchemy import or_
16from sqlalchemy.engine import Connection, Engine
17from sqlalchemy.exc import IntegrityError
19from mlos_bench.environments.status import Status
20from mlos_bench.storage.base_storage import Storage
21from mlos_bench.storage.sql.common import save_params
22from mlos_bench.storage.sql.schema import DbSchema
23from mlos_bench.tunables.tunable_groups import TunableGroups
24from mlos_bench.util import nullable, utcify_timestamp
26_LOG = logging.getLogger(__name__)
29class Trial(Storage.Trial):
30 """Store the results of a single run of the experiment in SQL database."""
32 def __init__( # pylint: disable=too-many-arguments
33 self,
34 *,
35 engine: Engine,
36 schema: DbSchema,
37 tunables: TunableGroups,
38 experiment_id: str,
39 trial_id: int,
40 config_id: int,
41 trial_runner_id: int | None,
42 opt_targets: dict[str, Literal["min", "max"]],
43 status: Status,
44 restoring: bool,
45 config: dict[str, Any] | None = None,
46 ):
47 super().__init__(
48 tunables=tunables,
49 experiment_id=experiment_id,
50 trial_id=trial_id,
51 tunable_config_id=config_id,
52 trial_runner_id=trial_runner_id,
53 opt_targets=opt_targets,
54 status=status,
55 restoring=restoring,
56 config=config,
57 )
58 self._engine = engine
59 self._schema = schema
61 def set_trial_runner(self, trial_runner_id: int) -> int:
62 trial_runner_id = super().set_trial_runner(trial_runner_id)
63 with self._engine.begin() as conn:
64 conn.execute(
65 self._schema.trial.update()
66 .where(
67 self._schema.trial.c.exp_id == self._experiment_id,
68 self._schema.trial.c.trial_id == self._trial_id,
69 (
70 or_(
71 self._schema.trial.c.trial_runner_id.is_(None),
72 self._schema.trial.c.status == Status.PENDING.name,
73 )
74 ),
75 )
76 .values(
77 trial_runner_id=trial_runner_id,
78 )
79 )
80 # Guard against concurrent updates.
81 with self._engine.begin() as conn:
82 trial_runner_rs = conn.execute(
83 self._schema.trial.select()
84 .with_only_columns(
85 self._schema.trial.c.trial_runner_id,
86 )
87 .where(
88 self._schema.trial.c.exp_id == self._experiment_id,
89 self._schema.trial.c.trial_id == self._trial_id,
90 )
91 )
92 trial_runner_row = trial_runner_rs.fetchone()
93 assert trial_runner_row
94 self._trial_runner_id = trial_runner_row.trial_runner_id
95 assert isinstance(self._trial_runner_id, int)
96 return self._trial_runner_id
98 def _save_new_config_data(self, new_config_data: Mapping[str, int | float | str]) -> None:
99 with self._engine.begin() as conn:
100 save_params(
101 conn,
102 self._schema.trial_param,
103 new_config_data,
104 exp_id=self._experiment_id,
105 trial_id=self._trial_id,
106 )
108 def update(
109 self,
110 status: Status,
111 timestamp: datetime,
112 metrics: dict[str, Any] | None = None,
113 ) -> dict[str, Any] | None:
114 # Make sure to convert the timestamp to UTC before storing it in the database.
115 timestamp = utcify_timestamp(timestamp, origin="local")
116 metrics = super().update(status, timestamp, metrics)
117 with self._engine.begin() as conn:
118 self._update_status(conn, status, timestamp)
119 # Use a separate transaction to avoid issues with PostgreSQL's duplicate key
120 # constraint handling. (See Issue #999).
121 with self._engine.begin() as conn:
122 try:
123 if status.is_completed():
124 # Final update of the status and ts_end:
125 cur_status = conn.execute(
126 self._schema.trial.update()
127 .where(
128 self._schema.trial.c.exp_id == self._experiment_id,
129 self._schema.trial.c.trial_id == self._trial_id,
130 self._schema.trial.c.ts_end.is_(None),
131 self._schema.trial.c.status.notin_(
132 [
133 Status.SUCCEEDED.name,
134 Status.CANCELED.name,
135 Status.FAILED.name,
136 Status.TIMED_OUT.name,
137 ]
138 ),
139 )
140 .values(
141 status=status.name,
142 ts_end=timestamp,
143 )
144 )
145 if cur_status.rowcount not in {1, -1}:
146 _LOG.warning("Trial %s :: update failed: %s", self, status)
147 raise RuntimeError(
148 f"Failed to update the status of the trial {self} to {status}. "
149 f"({cur_status.rowcount} rows)"
150 )
151 if metrics:
152 conn.execute(
153 self._schema.trial_result.insert().values(
154 [
155 {
156 "exp_id": self._experiment_id,
157 "trial_id": self._trial_id,
158 "metric_id": key,
159 "metric_value": nullable(str, val),
160 }
161 for (key, val) in metrics.items()
162 ]
163 )
164 )
165 else:
166 # Update of the status and ts_start when starting the trial:
167 assert metrics is None, f"Unexpected metrics for status: {status}"
168 cur_status = conn.execute(
169 self._schema.trial.update()
170 .where(
171 self._schema.trial.c.exp_id == self._experiment_id,
172 self._schema.trial.c.trial_id == self._trial_id,
173 self._schema.trial.c.ts_end.is_(None),
174 self._schema.trial.c.status.notin_(
175 [
176 Status.RUNNING.name,
177 Status.SUCCEEDED.name,
178 Status.CANCELED.name,
179 Status.FAILED.name,
180 Status.TIMED_OUT.name,
181 ]
182 ),
183 )
184 .values(
185 status=status.name,
186 ts_start=timestamp,
187 )
188 )
189 if cur_status.rowcount not in {1, -1}:
190 # Keep the old status and timestamp if already running, but log it.
191 _LOG.warning("Trial %s :: cannot be updated to: %s", self, status)
192 except Exception:
193 conn.rollback()
194 raise
195 return metrics
197 def update_telemetry(
198 self,
199 status: Status,
200 timestamp: datetime,
201 metrics: list[tuple[datetime, str, Any]],
202 ) -> None:
203 super().update_telemetry(status, timestamp, metrics)
204 # Make sure to convert the timestamp to UTC before storing it in the database.
205 timestamp = utcify_timestamp(timestamp, origin="local")
206 metrics = [(utcify_timestamp(ts, origin="local"), key, val) for (ts, key, val) in metrics]
207 # NOTE: Not every SQLAlchemy dialect supports `Insert.on_conflict_do_nothing()`
208 # and we need to keep `.update_telemetry()` idempotent; hence a loop instead of
209 # a bulk upsert.
210 # See Also: comments in <https://github.com/microsoft/MLOS/pull/466>
211 with self._engine.begin() as conn:
212 self._update_status(conn, status, timestamp)
213 for metric_ts, key, val in metrics:
214 with self._engine.begin() as conn:
215 try:
216 conn.execute(
217 self._schema.trial_telemetry.insert().values(
218 exp_id=self._experiment_id,
219 trial_id=self._trial_id,
220 ts=metric_ts,
221 metric_id=key,
222 metric_value=nullable(str, val),
223 )
224 )
225 except IntegrityError as ex:
226 _LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex)
228 def _update_status(self, conn: Connection, status: Status, timestamp: datetime) -> None:
229 """
230 Insert a new status record into the database.
232 This call is idempotent.
233 """
234 # Make sure to convert the timestamp to UTC before storing it in the database.
235 timestamp = utcify_timestamp(timestamp, origin="local")
236 try:
237 conn.execute(
238 self._schema.trial_status.insert().values(
239 exp_id=self._experiment_id,
240 trial_id=self._trial_id,
241 ts=timestamp,
242 status=status.name,
243 )
244 )
245 except IntegrityError as ex:
246 _LOG.warning(
247 "Status with that timestamp already exists: %s %s :: %s",
248 self,
249 timestamp,
250 ex,
251 )