Coverage for mlos_bench/mlos_bench/storage/sql/experiment.py: 93%

115 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-14 00:55 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""":py:class:`.Storage.Experiment` interface implementation for saving and restoring 

6the benchmark experiment data using `SQLAlchemy <https://sqlalchemy.org>`_ backend. 

7""" 

8 

9import hashlib 

10import logging 

11from collections.abc import Iterator 

12from datetime import datetime 

13from typing import Any, Literal 

14 

15from pytz import UTC 

16from sqlalchemy import Connection, CursorResult, Table, column, func, select 

17from sqlalchemy.engine import Engine 

18 

19from mlos_bench.environments.status import Status 

20from mlos_bench.storage.base_storage import Storage 

21from mlos_bench.storage.sql.common import save_params 

22from mlos_bench.storage.sql.schema import DbSchema 

23from mlos_bench.storage.sql.trial import Trial 

24from mlos_bench.tunables.tunable_groups import TunableGroups 

25from mlos_bench.util import utcify_timestamp 

26 

27_LOG = logging.getLogger(__name__) 

28 

29 

30class Experiment(Storage.Experiment): 

31 """Logic for retrieving and storing the results of a single experiment.""" 

32 

33 def __init__( # pylint: disable=too-many-arguments 

34 self, 

35 *, 

36 engine: Engine, 

37 schema: DbSchema, 

38 tunables: TunableGroups, 

39 experiment_id: str, 

40 trial_id: int, 

41 root_env_config: str, 

42 description: str, 

43 opt_targets: dict[str, Literal["min", "max"]], 

44 ): 

45 super().__init__( 

46 tunables=tunables, 

47 experiment_id=experiment_id, 

48 trial_id=trial_id, 

49 root_env_config=root_env_config, 

50 description=description, 

51 opt_targets=opt_targets, 

52 ) 

53 self._engine = engine 

54 self._schema = schema 

55 

56 def _setup(self) -> None: 

57 super()._setup() 

58 with self._engine.begin() as conn: 

59 # Get git info and the last trial ID for the experiment. 

60 # pylint: disable=not-callable 

61 exp_info = conn.execute( 

62 self._schema.experiment.select() 

63 .with_only_columns( 

64 self._schema.experiment.c.git_repo, 

65 self._schema.experiment.c.git_commit, 

66 self._schema.experiment.c.root_env_config, 

67 func.max(self._schema.trial.c.trial_id).label("trial_id"), 

68 ) 

69 .join( 

70 self._schema.trial, 

71 self._schema.trial.c.exp_id == self._schema.experiment.c.exp_id, 

72 isouter=True, 

73 ) 

74 .where( 

75 self._schema.experiment.c.exp_id == self._experiment_id, 

76 ) 

77 .group_by( 

78 self._schema.experiment.c.git_repo, 

79 self._schema.experiment.c.git_commit, 

80 self._schema.experiment.c.root_env_config, 

81 ) 

82 ).fetchone() 

83 if exp_info is None: 

84 _LOG.info("Start new experiment: %s", self._experiment_id) 

85 # It's a new experiment: create a record for it in the database. 

86 conn.execute( 

87 self._schema.experiment.insert().values( 

88 exp_id=self._experiment_id, 

89 description=self._description, 

90 git_repo=self._git_repo, 

91 git_commit=self._git_commit, 

92 root_env_config=self._root_env_config, 

93 ) 

94 ) 

95 conn.execute( 

96 self._schema.objectives.insert().values( 

97 [ 

98 { 

99 "exp_id": self._experiment_id, 

100 "optimization_target": opt_target, 

101 "optimization_direction": opt_dir, 

102 } 

103 for (opt_target, opt_dir) in self.opt_targets.items() 

104 ] 

105 ) 

106 ) 

107 else: 

108 if exp_info.trial_id is not None: 

109 self._trial_id = exp_info.trial_id + 1 

110 _LOG.info( 

111 "Continue experiment: %s last trial: %s resume from: %d", 

112 self._experiment_id, 

113 exp_info.trial_id, 

114 self._trial_id, 

115 ) 

116 # TODO: Sanity check that certain critical configs (e.g., 

117 # objectives) haven't changed to be incompatible such that a new 

118 # experiment should be started (possibly by prewarming with the 

119 # previous one). 

120 if exp_info.git_commit != self._git_commit: 

121 _LOG.warning( 

122 "Experiment %s git expected: %s %s", 

123 self, 

124 exp_info.git_repo, 

125 exp_info.git_commit, 

126 ) 

127 

128 def merge(self, experiment_ids: list[str]) -> None: 

129 _LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids) 

130 raise NotImplementedError("TODO: Merging experiments not implemented yet.") 

131 

132 def load_tunable_config(self, config_id: int) -> dict[str, Any]: 

133 with self._engine.connect() as conn: 

134 return self._get_key_val(conn, self._schema.config_param, "param", config_id=config_id) 

135 

136 def load_telemetry(self, trial_id: int) -> list[tuple[datetime, str, Any]]: 

137 with self._engine.connect() as conn: 

138 cur_telemetry = conn.execute( 

139 self._schema.trial_telemetry.select() 

140 .where( 

141 self._schema.trial_telemetry.c.exp_id == self._experiment_id, 

142 self._schema.trial_telemetry.c.trial_id == trial_id, 

143 ) 

144 .order_by( 

145 self._schema.trial_telemetry.c.ts, 

146 self._schema.trial_telemetry.c.metric_id, 

147 ) 

148 ) 

149 # Not all storage backends store the original zone info. 

150 # We try to ensure data is entered in UTC and augment it on return again here. 

151 return [ 

152 (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) 

153 for row in cur_telemetry.fetchall() 

154 ] 

155 

156 def load( 

157 self, 

158 last_trial_id: int = -1, 

159 ) -> tuple[list[int], list[dict], list[dict[str, Any] | None], list[Status]]: 

160 

161 with self._engine.connect() as conn: 

162 cur_trials = conn.execute( 

163 self._schema.trial.select() 

164 .with_only_columns( 

165 self._schema.trial.c.trial_id, 

166 self._schema.trial.c.config_id, 

167 self._schema.trial.c.status, 

168 ) 

169 .where( 

170 self._schema.trial.c.exp_id == self._experiment_id, 

171 self._schema.trial.c.trial_id > last_trial_id, 

172 self._schema.trial.c.status.in_( 

173 [ 

174 Status.SUCCEEDED.name, 

175 Status.FAILED.name, 

176 Status.TIMED_OUT.name, 

177 ] 

178 ), 

179 ) 

180 .order_by( 

181 self._schema.trial.c.trial_id.asc(), 

182 ) 

183 ) 

184 

185 trial_ids: list[int] = [] 

186 configs: list[dict[str, Any]] = [] 

187 scores: list[dict[str, Any] | None] = [] 

188 status: list[Status] = [] 

189 

190 for trial in cur_trials.fetchall(): 

191 stat = Status.parse(trial.status) 

192 status.append(stat) 

193 trial_ids.append(trial.trial_id) 

194 configs.append( 

195 self._get_key_val( 

196 conn, 

197 self._schema.config_param, 

198 "param", 

199 config_id=trial.config_id, 

200 ) 

201 ) 

202 if stat.is_succeeded(): 

203 scores.append( 

204 self._get_key_val( 

205 conn, 

206 self._schema.trial_result, 

207 "metric", 

208 exp_id=self._experiment_id, 

209 trial_id=trial.trial_id, 

210 ) 

211 ) 

212 else: 

213 scores.append(None) 

214 

215 return (trial_ids, configs, scores, status) 

216 

217 @staticmethod 

218 def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> dict[str, Any]: 

219 """ 

220 Helper method to retrieve key-value pairs from the database. 

221 

222 (E.g., configurations, results, and telemetry). 

223 """ 

224 cur_result: CursorResult[tuple[str, Any]] = conn.execute( 

225 select( 

226 column(f"{field}_id"), 

227 column(f"{field}_value"), 

228 ) 

229 .select_from(table) 

230 .where(*[column(key) == val for (key, val) in kwargs.items()]) 

231 ) 

232 # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to 

233 # avoid naming conflicts. 

234 return dict( 

235 row._tuple() for row in cur_result.fetchall() # pylint: disable=protected-access 

236 ) 

237 

238 def get_trial_by_id( 

239 self, 

240 trial_id: int, 

241 ) -> Storage.Trial | None: 

242 with self._engine.connect() as conn: 

243 cur_trial = conn.execute( 

244 self._schema.trial.select().where( 

245 self._schema.trial.c.exp_id == self._experiment_id, 

246 self._schema.trial.c.trial_id == trial_id, 

247 ) 

248 ) 

249 trial = cur_trial.fetchone() 

250 if trial is None: 

251 return None 

252 tunables = self._get_key_val( 

253 conn, 

254 self._schema.config_param, 

255 "param", 

256 config_id=trial.config_id, 

257 ) 

258 config = self._get_key_val( 

259 conn, 

260 self._schema.trial_param, 

261 "param", 

262 exp_id=self._experiment_id, 

263 trial_id=trial.trial_id, 

264 ) 

265 return Trial( 

266 engine=self._engine, 

267 schema=self._schema, 

268 # Reset .is_updated flag after the assignment: 

269 tunables=self._tunables.copy().assign(tunables).reset(), 

270 experiment_id=self._experiment_id, 

271 trial_id=trial.trial_id, 

272 config_id=trial.config_id, 

273 trial_runner_id=trial.trial_runner_id, 

274 opt_targets=self._opt_targets, 

275 status=Status.parse(trial.status), 

276 restoring=True, 

277 config=config, 

278 ) 

279 

280 def pending_trials( 

281 self, 

282 timestamp: datetime, 

283 *, 

284 running: bool = False, 

285 trial_runner_assigned: bool | None = None, 

286 ) -> Iterator[Storage.Trial]: 

287 timestamp = utcify_timestamp(timestamp, origin="local") 

288 _LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp) 

289 if running: 

290 statuses = [Status.PENDING, Status.READY, Status.RUNNING] 

291 else: 

292 statuses = [Status.PENDING] 

293 with self._engine.connect() as conn: 

294 stmt = self._schema.trial.select().where( 

295 self._schema.trial.c.exp_id == self._experiment_id, 

296 ( 

297 self._schema.trial.c.ts_start.is_(None) 

298 | (self._schema.trial.c.ts_start <= timestamp) 

299 ), 

300 self._schema.trial.c.ts_end.is_(None), 

301 self._schema.trial.c.status.in_([s.name for s in statuses]), 

302 ) 

303 if trial_runner_assigned: 

304 stmt = stmt.where(self._schema.trial.c.trial_runner_id.isnot(None)) 

305 elif trial_runner_assigned is False: 

306 stmt = stmt.where(self._schema.trial.c.trial_runner_id.is_(None)) 

307 # else: # No filtering by trial_runner_id 

308 cur_trials = conn.execute(stmt) 

309 for trial in cur_trials.fetchall(): 

310 tunables = self._get_key_val( 

311 conn, 

312 self._schema.config_param, 

313 "param", 

314 config_id=trial.config_id, 

315 ) 

316 config = self._get_key_val( 

317 conn, 

318 self._schema.trial_param, 

319 "param", 

320 exp_id=self._experiment_id, 

321 trial_id=trial.trial_id, 

322 ) 

323 yield Trial( 

324 engine=self._engine, 

325 schema=self._schema, 

326 # Reset .is_updated flag after the assignment: 

327 tunables=self._tunables.copy().assign(tunables).reset(), 

328 experiment_id=self._experiment_id, 

329 trial_id=trial.trial_id, 

330 config_id=trial.config_id, 

331 trial_runner_id=trial.trial_runner_id, 

332 opt_targets=self._opt_targets, 

333 status=Status.parse(trial.status), 

334 restoring=True, 

335 config=config, 

336 ) 

337 

338 def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int: 

339 """ 

340 Get the config ID for the given tunables. 

341 

342 If the config does not exist, create a new record for it. 

343 """ 

344 config_hash = hashlib.sha256(str(tunables).encode("utf-8")).hexdigest() 

345 cur_config = conn.execute( 

346 self._schema.config.select().where(self._schema.config.c.config_hash == config_hash) 

347 ).fetchone() 

348 if cur_config is not None: 

349 return int(cur_config.config_id) # mypy doesn't know it's always int 

350 # Config not found, create a new one: 

351 new_config_result = conn.execute( 

352 self._schema.config.insert().values(config_hash=config_hash) 

353 ).inserted_primary_key 

354 assert new_config_result 

355 config_id: int = new_config_result[0] 

356 save_params( 

357 conn, 

358 self._schema.config_param, 

359 {tunable.name: tunable.value for (tunable, _group) in tunables}, 

360 config_id=config_id, 

361 ) 

362 return config_id 

363 

364 def _new_trial( 

365 self, 

366 tunables: TunableGroups, 

367 ts_start: datetime | None = None, 

368 config: dict[str, Any] | None = None, 

369 ) -> Storage.Trial: 

370 ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local") 

371 _LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start) 

372 with self._engine.begin() as conn: 

373 try: 

374 new_trial_status = Status.PENDING 

375 config_id = self._get_config_id(conn, tunables) 

376 conn.execute( 

377 self._schema.trial.insert().values( 

378 exp_id=self._experiment_id, 

379 trial_id=self._trial_id, 

380 config_id=config_id, 

381 ts_start=ts_start, 

382 status=new_trial_status.name, 

383 ) 

384 ) 

385 

386 # Note: config here is the framework config, not the target 

387 # environment config (i.e., tunables). 

388 if config is not None: 

389 save_params( 

390 conn, 

391 self._schema.trial_param, 

392 config, 

393 exp_id=self._experiment_id, 

394 trial_id=self._trial_id, 

395 ) 

396 

397 trial = Trial( 

398 engine=self._engine, 

399 schema=self._schema, 

400 tunables=tunables, 

401 experiment_id=self._experiment_id, 

402 trial_id=self._trial_id, 

403 config_id=config_id, 

404 trial_runner_id=None, # initially, Trials are not assigned to a TrialRunner 

405 opt_targets=self._opt_targets, 

406 status=new_trial_status, 

407 restoring=False, 

408 config=config, 

409 ) 

410 self._trial_id += 1 

411 return trial 

412 except Exception: 

413 conn.rollback() 

414 raise