Coverage for mlos_bench/mlos_bench/storage/sql/schema.py: 99%

86 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-14 00:55 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6DB schema definition for the :py:class:`~mlos_bench.storage.sql.storage.SqlStorage` 

7backend. 

8 

9Notes 

10----- 

11The SQL statements are generated by SQLAlchemy, but can be obtained using 

12``repr`` or ``str`` (e.g., via ``print()``) on this object. 

13The ``mlos_bench`` CLI will do this automatically if the logging level is set to 

14``DEBUG``. 

15 

16Also see the `mlos_bench CLI usage <../../../../../mlos_bench.run.usage.html>`__ for 

17details on how to invoke only the schema creation/update routines. 

18""" 

19 

20import logging 

21from importlib.resources import files 

22from typing import Any 

23 

24from alembic import command, config 

25from sqlalchemy import ( 

26 Column, 

27 Connection, 

28 DateTime, 

29 Dialect, 

30 Float, 

31 ForeignKeyConstraint, 

32 Integer, 

33 MetaData, 

34 PrimaryKeyConstraint, 

35 Sequence, 

36 String, 

37 Table, 

38 UniqueConstraint, 

39 create_mock_engine, 

40 inspect, 

41) 

42from sqlalchemy.dialects import mysql 

43from sqlalchemy.engine import Engine 

44 

45from mlos_bench.util import path_join 

46 

47_LOG = logging.getLogger(__name__) 

48 

49 

50def _mysql_datetime_with_fsp() -> mysql.DATETIME: 

51 """ 

52 Return a MySQL DATETIME type with fractional seconds precision (fsp=6). 

53 

54 Notes 

55 ----- 

56 Split out to allow single mypy ignore. 

57 See <https://github.com/sqlalchemy/sqlalchemy/pull/12164> for details. 

58 """ 

59 return mysql.DATETIME(fsp=6) # type: ignore[no-untyped-call] 

60 

61 

62class _DDL: 

63 """ 

64 A helper class to capture the DDL statements from SQLAlchemy. 

65 

66 It is used in `DbSchema.__str__()` method below. 

67 """ 

68 

69 def __init__(self, dialect: Dialect): 

70 self._dialect = dialect 

71 self.statements: list[str] = [] 

72 

73 def __call__(self, sql: Any, *_args: Any, **_kwargs: Any) -> None: 

74 self.statements.append(str(sql.compile(dialect=self._dialect))) 

75 

76 def __repr__(self) -> str: 

77 res = ";\n".join(self.statements) 

78 return res + ";" if res else "" 

79 

80 

81class DbSchema: 

82 """A class to define and create the DB schema.""" 

83 

84 # This class is internal to SqlStorage and is mostly a struct 

85 # for all DB tables, so it's ok to disable the warnings. 

86 # pylint: disable=too-many-instance-attributes 

87 

88 def __init__(self, engine: Engine): 

89 """ 

90 Declare the SQLAlchemy schema for the database. 

91 

92 Parameters 

93 ---------- 

94 engine : sqlalchemy.engine.Engine 

95 """ 

96 assert engine, "Error: can't create schema without engine." 

97 _LOG.info("Create the DB schema for: %s", engine) 

98 self._engine = engine 

99 self._meta = MetaData() 

100 

101 # Common string column sizes. 

102 self._exp_id_len = 512 

103 self._param_id_len = 512 

104 self._param_value_len = 1024 

105 self._metric_id_len = 512 

106 self._metric_value_len = 255 

107 self._status_len = 16 

108 

109 # Some overrides for certain DB engines: 

110 if engine and engine.dialect.name in {"mysql", "mariadb"}: 

111 self._exp_id_len = 255 

112 self._param_id_len = 255 

113 self._metric_id_len = 255 

114 

115 self.experiment = Table( 

116 "experiment", 

117 self._meta, 

118 Column("exp_id", String(self._exp_id_len), nullable=False), 

119 Column("description", String(1024)), 

120 Column("root_env_config", String(1024), nullable=False), 

121 Column("git_repo", String(1024), nullable=False), 

122 Column("git_commit", String(40), nullable=False), 

123 # For backwards compatibility, we allow NULL for ts_start. 

124 Column( 

125 "ts_start", 

126 DateTime(timezone=True).with_variant( 

127 _mysql_datetime_with_fsp(), 

128 "mysql", 

129 ), 

130 ), 

131 Column( 

132 "ts_end", 

133 DateTime(timezone=True).with_variant( 

134 _mysql_datetime_with_fsp(), 

135 "mysql", 

136 ), 

137 ), 

138 # Should match the text IDs of `mlos_bench.environments.Status` enum: 

139 # For backwards compatibility, we allow NULL for status. 

140 Column("status", String(self._status_len)), 

141 # There may be more than one mlos_benchd_service running on different hosts. 

142 # This column stores the host/container name of the driver that 

143 # picked up the experiment. 

144 # They should use a transaction to update it to their own hostname when 

145 # they start if and only if its NULL. 

146 Column("driver_name", String(40), comment="Driver Host/Container Name"), 

147 Column("driver_pid", Integer, comment="Driver Process ID"), 

148 PrimaryKeyConstraint("exp_id"), 

149 ) 

150 """The Table storing 

151 :py:class:`~mlos_bench.storage.base_experiment_data.ExperimentData` info. 

152 """ 

153 

154 self.objectives = Table( 

155 "objectives", 

156 self._meta, 

157 Column("exp_id"), 

158 Column("optimization_target", String(self._metric_id_len), nullable=False), 

159 Column("optimization_direction", String(4), nullable=False), 

160 # TODO: Note: weight is not fully supported yet as currently 

161 # multi-objective is expected to explore each objective equally. 

162 # Will need to adjust the insert and return values to support this 

163 # eventually. 

164 Column("weight", Float, nullable=True), 

165 PrimaryKeyConstraint("exp_id", "optimization_target"), 

166 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), 

167 ) 

168 """The Table storing 

169 :py:class:`~mlos_bench.storage.base_storage.Storage.Experiment` optimization 

170 objectives info. 

171 """ 

172 

173 # A workaround for SQLAlchemy issue with autoincrement in DuckDB: 

174 if engine and engine.dialect.name == "duckdb": 

175 seq_config_id = Sequence("seq_config_id") 

176 col_config_id = Column( 

177 "config_id", 

178 Integer, 

179 seq_config_id, 

180 server_default=seq_config_id.next_value(), 

181 nullable=False, 

182 primary_key=True, 

183 ) 

184 else: 

185 col_config_id = Column( 

186 "config_id", 

187 Integer, 

188 nullable=False, 

189 primary_key=True, 

190 autoincrement=True, 

191 ) 

192 

193 self.config = Table( 

194 "config", 

195 self._meta, 

196 col_config_id, 

197 Column("config_hash", String(64), nullable=False, unique=True), 

198 ) 

199 """The Table storing 

200 :py:class:`~mlos_bench.storage.base_tunable_config_data.TunableConfigData` 

201 info. 

202 """ 

203 

204 self.trial = Table( 

205 "trial", 

206 self._meta, 

207 Column("exp_id", String(self._exp_id_len), nullable=False), 

208 Column("trial_id", Integer, nullable=False), 

209 Column("config_id", Integer, nullable=False), 

210 Column("trial_runner_id", Integer, nullable=True, default=None), 

211 Column( 

212 "ts_start", 

213 DateTime(timezone=True).with_variant( 

214 _mysql_datetime_with_fsp(), 

215 "mysql", 

216 ), 

217 nullable=False, 

218 ), 

219 Column( 

220 "ts_end", 

221 DateTime(timezone=True).with_variant( 

222 _mysql_datetime_with_fsp(), 

223 "mysql", 

224 ), 

225 nullable=True, 

226 ), 

227 # Should match the text IDs of `mlos_bench.environments.Status` enum: 

228 Column("status", String(self._status_len), nullable=False), 

229 PrimaryKeyConstraint("exp_id", "trial_id"), 

230 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), 

231 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), 

232 ) 

233 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData` 

234 info. 

235 """ 

236 

237 # Values of the tunable parameters of the experiment, 

238 # fixed for a particular trial config. 

239 self.config_param = Table( 

240 "config_param", 

241 self._meta, 

242 Column("config_id", Integer, nullable=False), 

243 Column("param_id", String(self._param_id_len), nullable=False), 

244 Column("param_value", String(self._param_value_len)), 

245 PrimaryKeyConstraint("config_id", "param_id"), 

246 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), 

247 ) 

248 """The Table storing 

249 :py:class:`~mlos_bench.storage.base_tunable_config_data.TunableConfigData` 

250 info. 

251 """ 

252 

253 # Values of additional non-tunable parameters of the trial, 

254 # e.g., scheduled execution time, VM name / location, number of repeats, etc. 

255 self.trial_param = Table( 

256 "trial_param", 

257 self._meta, 

258 Column("exp_id", String(self._exp_id_len), nullable=False), 

259 Column("trial_id", Integer, nullable=False), 

260 Column("param_id", String(self._param_id_len), nullable=False), 

261 Column("param_value", String(self._param_value_len)), 

262 PrimaryKeyConstraint("exp_id", "trial_id", "param_id"), 

263 ForeignKeyConstraint( 

264 ["exp_id", "trial_id"], 

265 [self.trial.c.exp_id, self.trial.c.trial_id], 

266 ), 

267 ) 

268 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData` 

269 :py:attr:`metadata <mlos_bench.storage.base_trial_data.TrialData.metadata_dict>` 

270 info. 

271 """ 

272 

273 self.trial_status = Table( 

274 "trial_status", 

275 self._meta, 

276 Column("exp_id", String(self._exp_id_len), nullable=False), 

277 Column("trial_id", Integer, nullable=False), 

278 Column( 

279 "ts", 

280 DateTime(timezone=True).with_variant( 

281 _mysql_datetime_with_fsp(), 

282 "mysql", 

283 ), 

284 nullable=False, 

285 default="now", 

286 ), 

287 Column("status", String(self._status_len), nullable=False), 

288 UniqueConstraint("exp_id", "trial_id", "ts"), 

289 ForeignKeyConstraint( 

290 ["exp_id", "trial_id"], 

291 [self.trial.c.exp_id, self.trial.c.trial_id], 

292 ), 

293 ) 

294 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData` 

295 :py:class:`~mlos_bench.environments.status.Status` info. 

296 """ 

297 

298 self.trial_result = Table( 

299 "trial_result", 

300 self._meta, 

301 Column("exp_id", String(self._exp_id_len), nullable=False), 

302 Column("trial_id", Integer, nullable=False), 

303 Column("metric_id", String(self._metric_id_len), nullable=False), 

304 Column("metric_value", String(self._metric_value_len)), 

305 PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"), 

306 ForeignKeyConstraint( 

307 ["exp_id", "trial_id"], 

308 [self.trial.c.exp_id, self.trial.c.trial_id], 

309 ), 

310 ) 

311 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData` 

312 :py:attr:`results <mlos_bench.storage.base_trial_data.TrialData.results_dict>` 

313 info. 

314 """ 

315 

316 self.trial_telemetry = Table( 

317 "trial_telemetry", 

318 self._meta, 

319 Column("exp_id", String(self._exp_id_len), nullable=False), 

320 Column("trial_id", Integer, nullable=False), 

321 Column( 

322 "ts", 

323 DateTime(timezone=True).with_variant( 

324 _mysql_datetime_with_fsp(), 

325 "mysql", 

326 ), 

327 nullable=False, 

328 default="now", 

329 ), 

330 Column("metric_id", String(self._metric_id_len), nullable=False), 

331 Column("metric_value", String(self._metric_value_len)), 

332 UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"), 

333 ForeignKeyConstraint( 

334 ["exp_id", "trial_id"], 

335 [self.trial.c.exp_id, self.trial.c.trial_id], 

336 ), 

337 ) 

338 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData` 

339 :py:attr:`telemetry <mlos_bench.storage.base_trial_data.TrialData.telemetry_df>` 

340 info. 

341 """ 

342 

343 _LOG.debug("Schema: %s", self._meta) 

344 

345 @property 

346 def meta(self) -> MetaData: 

347 """Return the SQLAlchemy MetaData object.""" 

348 return self._meta 

349 

350 def _get_alembic_cfg(self, conn: Connection) -> config.Config: 

351 alembic_cfg = config.Config( 

352 path_join(str(files("mlos_bench.storage.sql")), "alembic.ini", abs_path=True) 

353 ) 

354 assert self._engine is not None 

355 alembic_cfg.set_main_option( 

356 "sqlalchemy.url", 

357 self._engine.url.render_as_string( 

358 hide_password=False, 

359 ), 

360 ) 

361 alembic_cfg.attributes["connection"] = conn 

362 return alembic_cfg 

363 

364 def drop_all_tables(self, *, force: bool = False) -> None: 

365 """ 

366 Helper method used in testing to reset the DB schema. 

367 

368 Notes 

369 ----- 

370 This method is not intended for production use, as it will drop all tables 

371 in the database. Use with caution. 

372 

373 Parameters 

374 ---------- 

375 force : bool 

376 If True, drop all tables in the target database. 

377 If False, this method will not drop any tables and will log a warning. 

378 """ 

379 assert self._engine 

380 self.meta.reflect(bind=self._engine) 

381 if force: 

382 self.meta.drop_all(bind=self._engine) 

383 else: 

384 _LOG.warning( 

385 "Resetting the schema without force is not implemented. " 

386 "Use force=True to drop all tables." 

387 ) 

388 

389 def create(self) -> "DbSchema": 

390 """Create the DB schema.""" 

391 _LOG.info("Create the DB schema") 

392 assert self._engine 

393 self._meta.create_all(self._engine) 

394 with self._engine.begin() as conn: 

395 # If the trial table has the trial_runner_id column but no 

396 # "alembic_version" table, then the schema is up to date as of initial 

397 # create and we should mark it as such to avoid trying to run the 

398 # (non-idempotent) upgrade scripts. 

399 # Otherwise, either we already have an alembic_version table and can 

400 # safely run the necessary upgrades or we are missing the 

401 # trial_runner_id column (the first to introduce schema updates) and 

402 # should run the upgrades. 

403 if any( 

404 column["name"] == "trial_runner_id" 

405 for column in inspect(conn).get_columns(self.trial.name) 

406 ) and not inspect(conn).has_table("alembic_version"): 

407 # Mark the schema as up to date. 

408 alembic_cfg = self._get_alembic_cfg(conn) 

409 command.stamp(alembic_cfg, "heads") 

410 # command.current(alembic_cfg) 

411 return self 

412 

413 def update(self) -> "DbSchema": 

414 """ 

415 Updates the DB schema to the latest version. 

416 

417 Notes 

418 ----- 

419 Also see the `mlos_bench CLI usage <../../../../../mlos_bench.run.usage.html>`__ 

420 for details on how to invoke only the schema creation/update routines. 

421 """ 

422 assert self._engine 

423 with self._engine.connect() as conn: 

424 alembic_cfg = self._get_alembic_cfg(conn) 

425 command.upgrade(alembic_cfg, "head") 

426 return self 

427 

428 def __repr__(self) -> str: 

429 """ 

430 Produce a string with all SQL statements required to create the schema from 

431 scratch in current SQL dialect. 

432 

433 That is, return a collection of CREATE TABLE statements and such. 

434 NOTE: this method is quite heavy! We use it only once at startup 

435 to log the schema, and if the logging level is set to DEBUG. 

436 

437 Returns 

438 ------- 

439 sql : str 

440 A multi-line string with SQL statements to create the DB schema from scratch. 

441 """ 

442 assert self._engine 

443 ddl = _DDL(self._engine.dialect) 

444 mock_engine = create_mock_engine(self._engine.url, executor=ddl) 

445 self._meta.create_all(mock_engine, checkfirst=False) 

446 return str(ddl)