Coverage for mlos_bench/mlos_bench/storage/sql/trial.py: 98%
56 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Saving and updating benchmark data using SQLAlchemy backend.
7"""
9import logging
10from datetime import datetime
11from typing import List, Optional, Tuple, Union, Dict, Any
13from sqlalchemy import Engine, Connection
14from sqlalchemy.exc import IntegrityError
16from mlos_bench.environments.status import Status
17from mlos_bench.tunables.tunable_groups import TunableGroups
18from mlos_bench.storage.base_storage import Storage
19from mlos_bench.storage.sql.schema import DbSchema
20from mlos_bench.util import nullable, utcify_timestamp
22_LOG = logging.getLogger(__name__)
25class Trial(Storage.Trial):
26 """
27 Store the results of a single run of the experiment in SQL database.
28 """
30 def __init__(self, *,
31 engine: Engine, schema: DbSchema, tunables: TunableGroups,
32 experiment_id: str, trial_id: int, config_id: int,
33 opt_target: str, opt_direction: Optional[str],
34 config: Optional[Dict[str, Any]] = None):
35 super().__init__(
36 tunables=tunables,
37 experiment_id=experiment_id,
38 trial_id=trial_id,
39 tunable_config_id=config_id,
40 opt_target=opt_target,
41 opt_direction=opt_direction,
42 config=config,
43 )
44 self._engine = engine
45 self._schema = schema
47 def update(self, status: Status, timestamp: datetime,
48 metrics: Optional[Union[Dict[str, Any], float]] = None
49 ) -> Optional[Dict[str, Any]]:
50 # Make sure to convert the timestamp to UTC before storing it in the database.
51 timestamp = utcify_timestamp(timestamp, origin="local")
52 metrics = super().update(status, timestamp, metrics)
53 with self._engine.begin() as conn:
54 self._update_status(conn, status, timestamp)
55 try:
56 if status.is_completed():
57 # Final update of the status and ts_end:
58 cur_status = conn.execute(
59 self._schema.trial.update().where(
60 self._schema.trial.c.exp_id == self._experiment_id,
61 self._schema.trial.c.trial_id == self._trial_id,
62 self._schema.trial.c.ts_end.is_(None),
63 self._schema.trial.c.status.notin_(
64 ['SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']),
65 ).values(
66 status=status.name,
67 ts_end=timestamp,
68 )
69 )
70 if cur_status.rowcount not in {1, -1}:
71 _LOG.warning("Trial %s :: update failed: %s", self, status)
72 raise RuntimeError(
73 f"Failed to update the status of the trial {self} to {status}." +
74 f" ({cur_status.rowcount} rows)")
75 if metrics:
76 conn.execute(self._schema.trial_result.insert().values([
77 {
78 "exp_id": self._experiment_id,
79 "trial_id": self._trial_id,
80 "metric_id": key,
81 "metric_value": nullable(str, val),
82 }
83 for (key, val) in metrics.items()
84 ]))
85 else:
86 # Update of the status and ts_start when starting the trial:
87 assert metrics is None, f"Unexpected metrics for status: {status}"
88 cur_status = conn.execute(
89 self._schema.trial.update().where(
90 self._schema.trial.c.exp_id == self._experiment_id,
91 self._schema.trial.c.trial_id == self._trial_id,
92 self._schema.trial.c.ts_end.is_(None),
93 self._schema.trial.c.status.notin_(
94 ['RUNNING', 'SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']),
95 ).values(
96 status=status.name,
97 ts_start=timestamp,
98 )
99 )
100 if cur_status.rowcount not in {1, -1}:
101 # Keep the old status and timestamp if already running, but log it.
102 _LOG.warning("Trial %s :: cannot be updated to: %s", self, status)
103 except Exception:
104 conn.rollback()
105 raise
106 return metrics
108 def update_telemetry(self, status: Status, timestamp: datetime,
109 metrics: List[Tuple[datetime, str, Any]]) -> None:
110 super().update_telemetry(status, timestamp, metrics)
111 # Make sure to convert the timestamp to UTC before storing it in the database.
112 timestamp = utcify_timestamp(timestamp, origin="local")
113 metrics = [(utcify_timestamp(ts, origin="local"), key, val) for (ts, key, val) in metrics]
114 # NOTE: Not every SQLAlchemy dialect supports `Insert.on_conflict_do_nothing()`
115 # and we need to keep `.update_telemetry()` idempotent; hence a loop instead of
116 # a bulk upsert.
117 # See Also: comments in <https://github.com/microsoft/MLOS/pull/466>
118 with self._engine.begin() as conn:
119 self._update_status(conn, status, timestamp)
120 for (metric_ts, key, val) in metrics:
121 with self._engine.begin() as conn:
122 try:
123 conn.execute(self._schema.trial_telemetry.insert().values(
124 exp_id=self._experiment_id,
125 trial_id=self._trial_id,
126 ts=metric_ts,
127 metric_id=key,
128 metric_value=nullable(str, val),
129 ))
130 except IntegrityError as ex:
131 _LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex)
133 def _update_status(self, conn: Connection, status: Status, timestamp: datetime) -> None:
134 """
135 Insert a new status record into the database.
136 This call is idempotent.
137 """
138 # Make sure to convert the timestamp to UTC before storing it in the database.
139 timestamp = utcify_timestamp(timestamp, origin="local")
140 try:
141 conn.execute(self._schema.trial_status.insert().values(
142 exp_id=self._experiment_id,
143 trial_id=self._trial_id,
144 ts=timestamp,
145 status=status.name,
146 ))
147 except IntegrityError as ex:
148 _LOG.warning("Status with that timestamp already exists: %s %s :: %s",
149 self, timestamp, ex)