Coverage for mlos_bench/mlos_bench/storage/sql/trial.py: 99%

73 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-30 00:51 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""":py:class:`.Storage.Trial` interface implementation for saving and restoring 

6the benchmark trial data using `SQLAlchemy <https://sqlalchemy.org>`_ backend. 

7""" 

8 

9 

10import logging 

11from collections.abc import Mapping 

12from datetime import datetime 

13from typing import Any, Literal 

14 

15from sqlalchemy import or_ 

16from sqlalchemy.engine import Connection, Engine 

17from sqlalchemy.exc import IntegrityError 

18 

19from mlos_bench.environments.status import Status 

20from mlos_bench.storage.base_storage import Storage 

21from mlos_bench.storage.sql.common import save_params 

22from mlos_bench.storage.sql.schema import DbSchema 

23from mlos_bench.tunables.tunable_groups import TunableGroups 

24from mlos_bench.util import nullable, utcify_timestamp 

25 

26_LOG = logging.getLogger(__name__) 

27 

28 

29class Trial(Storage.Trial): 

30 """Store the results of a single run of the experiment in SQL database.""" 

31 

32 def __init__( # pylint: disable=too-many-arguments 

33 self, 

34 *, 

35 engine: Engine, 

36 schema: DbSchema, 

37 tunables: TunableGroups, 

38 experiment_id: str, 

39 trial_id: int, 

40 config_id: int, 

41 trial_runner_id: int | None, 

42 opt_targets: dict[str, Literal["min", "max"]], 

43 status: Status, 

44 restoring: bool, 

45 config: dict[str, Any] | None = None, 

46 ): 

47 super().__init__( 

48 tunables=tunables, 

49 experiment_id=experiment_id, 

50 trial_id=trial_id, 

51 tunable_config_id=config_id, 

52 trial_runner_id=trial_runner_id, 

53 opt_targets=opt_targets, 

54 status=status, 

55 restoring=restoring, 

56 config=config, 

57 ) 

58 self._engine = engine 

59 self._schema = schema 

60 

61 def set_trial_runner(self, trial_runner_id: int) -> int: 

62 trial_runner_id = super().set_trial_runner(trial_runner_id) 

63 with self._engine.begin() as conn: 

64 conn.execute( 

65 self._schema.trial.update() 

66 .where( 

67 self._schema.trial.c.exp_id == self._experiment_id, 

68 self._schema.trial.c.trial_id == self._trial_id, 

69 ( 

70 or_( 

71 self._schema.trial.c.trial_runner_id.is_(None), 

72 self._schema.trial.c.status == Status.PENDING.name, 

73 ) 

74 ), 

75 ) 

76 .values( 

77 trial_runner_id=trial_runner_id, 

78 ) 

79 ) 

80 # Guard against concurrent updates. 

81 with self._engine.begin() as conn: 

82 trial_runner_rs = conn.execute( 

83 self._schema.trial.select() 

84 .with_only_columns( 

85 self._schema.trial.c.trial_runner_id, 

86 ) 

87 .where( 

88 self._schema.trial.c.exp_id == self._experiment_id, 

89 self._schema.trial.c.trial_id == self._trial_id, 

90 ) 

91 ) 

92 trial_runner_row = trial_runner_rs.fetchone() 

93 assert trial_runner_row 

94 self._trial_runner_id = trial_runner_row.trial_runner_id 

95 assert isinstance(self._trial_runner_id, int) 

96 return self._trial_runner_id 

97 

98 def _save_new_config_data(self, new_config_data: Mapping[str, int | float | str]) -> None: 

99 with self._engine.begin() as conn: 

100 save_params( 

101 conn, 

102 self._schema.trial_param, 

103 new_config_data, 

104 exp_id=self._experiment_id, 

105 trial_id=self._trial_id, 

106 ) 

107 

108 def update( 

109 self, 

110 status: Status, 

111 timestamp: datetime, 

112 metrics: dict[str, Any] | None = None, 

113 ) -> dict[str, Any] | None: 

114 # Make sure to convert the timestamp to UTC before storing it in the database. 

115 timestamp = utcify_timestamp(timestamp, origin="local") 

116 metrics = super().update(status, timestamp, metrics) 

117 with self._engine.begin() as conn: 

118 self._update_status(conn, status, timestamp) 

119 # Use a separate transaction to avoid issues with PostgreSQL's duplicate key 

120 # constraint handling. (See Issue #999). 

121 with self._engine.begin() as conn: 

122 try: 

123 if status.is_completed(): 

124 # Final update of the status and ts_end: 

125 cur_status = conn.execute( 

126 self._schema.trial.update() 

127 .where( 

128 self._schema.trial.c.exp_id == self._experiment_id, 

129 self._schema.trial.c.trial_id == self._trial_id, 

130 self._schema.trial.c.ts_end.is_(None), 

131 self._schema.trial.c.status.notin_( 

132 [ 

133 Status.SUCCEEDED.name, 

134 Status.CANCELED.name, 

135 Status.FAILED.name, 

136 Status.TIMED_OUT.name, 

137 ] 

138 ), 

139 ) 

140 .values( 

141 status=status.name, 

142 ts_end=timestamp, 

143 ) 

144 ) 

145 if cur_status.rowcount not in {1, -1}: 

146 _LOG.warning("Trial %s :: update failed: %s", self, status) 

147 raise RuntimeError( 

148 f"Failed to update the status of the trial {self} to {status}. " 

149 f"({cur_status.rowcount} rows)" 

150 ) 

151 if metrics: 

152 conn.execute( 

153 self._schema.trial_result.insert().values( 

154 [ 

155 { 

156 "exp_id": self._experiment_id, 

157 "trial_id": self._trial_id, 

158 "metric_id": key, 

159 "metric_value": nullable(str, val), 

160 } 

161 for (key, val) in metrics.items() 

162 ] 

163 ) 

164 ) 

165 else: 

166 # Update of the status and ts_start when starting the trial: 

167 assert metrics is None, f"Unexpected metrics for status: {status}" 

168 cur_status = conn.execute( 

169 self._schema.trial.update() 

170 .where( 

171 self._schema.trial.c.exp_id == self._experiment_id, 

172 self._schema.trial.c.trial_id == self._trial_id, 

173 self._schema.trial.c.ts_end.is_(None), 

174 self._schema.trial.c.status.notin_( 

175 [ 

176 Status.RUNNING.name, 

177 Status.SUCCEEDED.name, 

178 Status.CANCELED.name, 

179 Status.FAILED.name, 

180 Status.TIMED_OUT.name, 

181 ] 

182 ), 

183 ) 

184 .values( 

185 status=status.name, 

186 ts_start=timestamp, 

187 ) 

188 ) 

189 if cur_status.rowcount not in {1, -1}: 

190 # Keep the old status and timestamp if already running, but log it. 

191 _LOG.warning("Trial %s :: cannot be updated to: %s", self, status) 

192 except Exception: 

193 conn.rollback() 

194 raise 

195 return metrics 

196 

197 def update_telemetry( 

198 self, 

199 status: Status, 

200 timestamp: datetime, 

201 metrics: list[tuple[datetime, str, Any]], 

202 ) -> None: 

203 super().update_telemetry(status, timestamp, metrics) 

204 # Make sure to convert the timestamp to UTC before storing it in the database. 

205 timestamp = utcify_timestamp(timestamp, origin="local") 

206 metrics = [(utcify_timestamp(ts, origin="local"), key, val) for (ts, key, val) in metrics] 

207 # NOTE: Not every SQLAlchemy dialect supports `Insert.on_conflict_do_nothing()` 

208 # and we need to keep `.update_telemetry()` idempotent; hence a loop instead of 

209 # a bulk upsert. 

210 # See Also: comments in <https://github.com/microsoft/MLOS/pull/466> 

211 with self._engine.begin() as conn: 

212 self._update_status(conn, status, timestamp) 

213 for metric_ts, key, val in metrics: 

214 with self._engine.begin() as conn: 

215 try: 

216 conn.execute( 

217 self._schema.trial_telemetry.insert().values( 

218 exp_id=self._experiment_id, 

219 trial_id=self._trial_id, 

220 ts=metric_ts, 

221 metric_id=key, 

222 metric_value=nullable(str, val), 

223 ) 

224 ) 

225 except IntegrityError as ex: 

226 _LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex) 

227 

228 def _update_status(self, conn: Connection, status: Status, timestamp: datetime) -> None: 

229 """ 

230 Insert a new status record into the database. 

231 

232 This call is idempotent. 

233 """ 

234 # Make sure to convert the timestamp to UTC before storing it in the database. 

235 timestamp = utcify_timestamp(timestamp, origin="local") 

236 try: 

237 conn.execute( 

238 self._schema.trial_status.insert().values( 

239 exp_id=self._experiment_id, 

240 trial_id=self._trial_id, 

241 ts=timestamp, 

242 status=status.name, 

243 ) 

244 ) 

245 except IntegrityError as ex: 

246 _LOG.warning( 

247 "Status with that timestamp already exists: %s %s :: %s", 

248 self, 

249 timestamp, 

250 ex, 

251 )