Coverage for mlos_bench/mlos_bench/storage/sql/trial.py: 99%

72 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-14 00:55 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""":py:class:`.Storage.Trial` interface implementation for saving and restoring 

6the benchmark trial data using `SQLAlchemy <https://sqlalchemy.org>`_ backend. 

7""" 

8 

9 

10import logging 

11from collections.abc import Mapping 

12from datetime import datetime 

13from typing import Any, Literal 

14 

15from sqlalchemy import or_ 

16from sqlalchemy.engine import Connection, Engine 

17from sqlalchemy.exc import IntegrityError 

18 

19from mlos_bench.environments.status import Status 

20from mlos_bench.storage.base_storage import Storage 

21from mlos_bench.storage.sql.common import save_params 

22from mlos_bench.storage.sql.schema import DbSchema 

23from mlos_bench.tunables.tunable_groups import TunableGroups 

24from mlos_bench.util import nullable, utcify_timestamp 

25 

26_LOG = logging.getLogger(__name__) 

27 

28 

29class Trial(Storage.Trial): 

30 """Store the results of a single run of the experiment in SQL database.""" 

31 

32 def __init__( # pylint: disable=too-many-arguments 

33 self, 

34 *, 

35 engine: Engine, 

36 schema: DbSchema, 

37 tunables: TunableGroups, 

38 experiment_id: str, 

39 trial_id: int, 

40 config_id: int, 

41 trial_runner_id: int | None, 

42 opt_targets: dict[str, Literal["min", "max"]], 

43 status: Status, 

44 restoring: bool, 

45 config: dict[str, Any] | None = None, 

46 ): 

47 super().__init__( 

48 tunables=tunables, 

49 experiment_id=experiment_id, 

50 trial_id=trial_id, 

51 tunable_config_id=config_id, 

52 trial_runner_id=trial_runner_id, 

53 opt_targets=opt_targets, 

54 status=status, 

55 restoring=restoring, 

56 config=config, 

57 ) 

58 self._engine = engine 

59 self._schema = schema 

60 

61 def set_trial_runner(self, trial_runner_id: int) -> int: 

62 trial_runner_id = super().set_trial_runner(trial_runner_id) 

63 with self._engine.begin() as conn: 

64 conn.execute( 

65 self._schema.trial.update() 

66 .where( 

67 self._schema.trial.c.exp_id == self._experiment_id, 

68 self._schema.trial.c.trial_id == self._trial_id, 

69 ( 

70 or_( 

71 self._schema.trial.c.trial_runner_id.is_(None), 

72 self._schema.trial.c.status == Status.PENDING.name, 

73 ) 

74 ), 

75 ) 

76 .values( 

77 trial_runner_id=trial_runner_id, 

78 ) 

79 ) 

80 # Guard against concurrent updates. 

81 with self._engine.begin() as conn: 

82 trial_runner_rs = conn.execute( 

83 self._schema.trial.select() 

84 .with_only_columns( 

85 self._schema.trial.c.trial_runner_id, 

86 ) 

87 .where( 

88 self._schema.trial.c.exp_id == self._experiment_id, 

89 self._schema.trial.c.trial_id == self._trial_id, 

90 ) 

91 ) 

92 trial_runner_row = trial_runner_rs.fetchone() 

93 assert trial_runner_row 

94 self._trial_runner_id = trial_runner_row.trial_runner_id 

95 assert isinstance(self._trial_runner_id, int) 

96 return self._trial_runner_id 

97 

98 def _save_new_config_data(self, new_config_data: Mapping[str, int | float | str]) -> None: 

99 with self._engine.begin() as conn: 

100 save_params( 

101 conn, 

102 self._schema.trial_param, 

103 new_config_data, 

104 exp_id=self._experiment_id, 

105 trial_id=self._trial_id, 

106 ) 

107 

108 def update( 

109 self, 

110 status: Status, 

111 timestamp: datetime, 

112 metrics: dict[str, Any] | None = None, 

113 ) -> dict[str, Any] | None: 

114 # Make sure to convert the timestamp to UTC before storing it in the database. 

115 timestamp = utcify_timestamp(timestamp, origin="local") 

116 metrics = super().update(status, timestamp, metrics) 

117 with self._engine.begin() as conn: 

118 self._update_status(conn, status, timestamp) 

119 try: 

120 if status.is_completed(): 

121 # Final update of the status and ts_end: 

122 cur_status = conn.execute( 

123 self._schema.trial.update() 

124 .where( 

125 self._schema.trial.c.exp_id == self._experiment_id, 

126 self._schema.trial.c.trial_id == self._trial_id, 

127 self._schema.trial.c.ts_end.is_(None), 

128 self._schema.trial.c.status.notin_( 

129 [ 

130 Status.SUCCEEDED.name, 

131 Status.CANCELED.name, 

132 Status.FAILED.name, 

133 Status.TIMED_OUT.name, 

134 ] 

135 ), 

136 ) 

137 .values( 

138 status=status.name, 

139 ts_end=timestamp, 

140 ) 

141 ) 

142 if cur_status.rowcount not in {1, -1}: 

143 _LOG.warning("Trial %s :: update failed: %s", self, status) 

144 raise RuntimeError( 

145 f"Failed to update the status of the trial {self} to {status}. " 

146 f"({cur_status.rowcount} rows)" 

147 ) 

148 if metrics: 

149 conn.execute( 

150 self._schema.trial_result.insert().values( 

151 [ 

152 { 

153 "exp_id": self._experiment_id, 

154 "trial_id": self._trial_id, 

155 "metric_id": key, 

156 "metric_value": nullable(str, val), 

157 } 

158 for (key, val) in metrics.items() 

159 ] 

160 ) 

161 ) 

162 else: 

163 # Update of the status and ts_start when starting the trial: 

164 assert metrics is None, f"Unexpected metrics for status: {status}" 

165 cur_status = conn.execute( 

166 self._schema.trial.update() 

167 .where( 

168 self._schema.trial.c.exp_id == self._experiment_id, 

169 self._schema.trial.c.trial_id == self._trial_id, 

170 self._schema.trial.c.ts_end.is_(None), 

171 self._schema.trial.c.status.notin_( 

172 [ 

173 Status.RUNNING.name, 

174 Status.SUCCEEDED.name, 

175 Status.CANCELED.name, 

176 Status.FAILED.name, 

177 Status.TIMED_OUT.name, 

178 ] 

179 ), 

180 ) 

181 .values( 

182 status=status.name, 

183 ts_start=timestamp, 

184 ) 

185 ) 

186 if cur_status.rowcount not in {1, -1}: 

187 # Keep the old status and timestamp if already running, but log it. 

188 _LOG.warning("Trial %s :: cannot be updated to: %s", self, status) 

189 except Exception: 

190 conn.rollback() 

191 raise 

192 return metrics 

193 

194 def update_telemetry( 

195 self, 

196 status: Status, 

197 timestamp: datetime, 

198 metrics: list[tuple[datetime, str, Any]], 

199 ) -> None: 

200 super().update_telemetry(status, timestamp, metrics) 

201 # Make sure to convert the timestamp to UTC before storing it in the database. 

202 timestamp = utcify_timestamp(timestamp, origin="local") 

203 metrics = [(utcify_timestamp(ts, origin="local"), key, val) for (ts, key, val) in metrics] 

204 # NOTE: Not every SQLAlchemy dialect supports `Insert.on_conflict_do_nothing()` 

205 # and we need to keep `.update_telemetry()` idempotent; hence a loop instead of 

206 # a bulk upsert. 

207 # See Also: comments in <https://github.com/microsoft/MLOS/pull/466> 

208 with self._engine.begin() as conn: 

209 self._update_status(conn, status, timestamp) 

210 for metric_ts, key, val in metrics: 

211 with self._engine.begin() as conn: 

212 try: 

213 conn.execute( 

214 self._schema.trial_telemetry.insert().values( 

215 exp_id=self._experiment_id, 

216 trial_id=self._trial_id, 

217 ts=metric_ts, 

218 metric_id=key, 

219 metric_value=nullable(str, val), 

220 ) 

221 ) 

222 except IntegrityError as ex: 

223 _LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex) 

224 

225 def _update_status(self, conn: Connection, status: Status, timestamp: datetime) -> None: 

226 """ 

227 Insert a new status record into the database. 

228 

229 This call is idempotent. 

230 """ 

231 # Make sure to convert the timestamp to UTC before storing it in the database. 

232 timestamp = utcify_timestamp(timestamp, origin="local") 

233 try: 

234 conn.execute( 

235 self._schema.trial_status.insert().values( 

236 exp_id=self._experiment_id, 

237 trial_id=self._trial_id, 

238 ts=timestamp, 

239 status=status.name, 

240 ) 

241 ) 

242 except IntegrityError as ex: 

243 _LOG.warning( 

244 "Status with that timestamp already exists: %s %s :: %s", 

245 self, 

246 timestamp, 

247 ex, 

248 )