Coverage for mlos_bench/mlos_bench/storage/sql/schema.py: 99%
86 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6DB schema definition for the :py:class:`~mlos_bench.storage.sql.storage.SqlStorage`
7backend.
9Notes
10-----
11The SQL statements are generated by SQLAlchemy, but can be obtained using
12``repr`` or ``str`` (e.g., via ``print()``) on this object.
13The ``mlos_bench`` CLI will do this automatically if the logging level is set to
14``DEBUG``.
16Also see the `mlos_bench CLI usage <../../../../../mlos_bench.run.usage.html>`__ for
17details on how to invoke only the schema creation/update routines.
18"""
20import logging
21from importlib.resources import files
22from typing import Any
24from alembic import command, config
25from sqlalchemy import (
26 Column,
27 Connection,
28 DateTime,
29 Dialect,
30 Float,
31 ForeignKeyConstraint,
32 Integer,
33 MetaData,
34 PrimaryKeyConstraint,
35 Sequence,
36 String,
37 Table,
38 UniqueConstraint,
39 create_mock_engine,
40 inspect,
41)
42from sqlalchemy.dialects import mysql
43from sqlalchemy.engine import Engine
45from mlos_bench.util import path_join
47_LOG = logging.getLogger(__name__)
50def _mysql_datetime_with_fsp() -> mysql.DATETIME:
51 """
52 Return a MySQL DATETIME type with fractional seconds precision (fsp=6).
54 Notes
55 -----
56 Split out to allow single mypy ignore.
57 See <https://github.com/sqlalchemy/sqlalchemy/pull/12164> for details.
58 """
59 return mysql.DATETIME(fsp=6) # type: ignore[no-untyped-call]
62class _DDL:
63 """
64 A helper class to capture the DDL statements from SQLAlchemy.
66 It is used in `DbSchema.__str__()` method below.
67 """
69 def __init__(self, dialect: Dialect):
70 self._dialect = dialect
71 self.statements: list[str] = []
73 def __call__(self, sql: Any, *_args: Any, **_kwargs: Any) -> None:
74 self.statements.append(str(sql.compile(dialect=self._dialect)))
76 def __repr__(self) -> str:
77 res = ";\n".join(self.statements)
78 return res + ";" if res else ""
81class DbSchema:
82 """A class to define and create the DB schema."""
84 # This class is internal to SqlStorage and is mostly a struct
85 # for all DB tables, so it's ok to disable the warnings.
86 # pylint: disable=too-many-instance-attributes
88 def __init__(self, engine: Engine):
89 """
90 Declare the SQLAlchemy schema for the database.
92 Parameters
93 ----------
94 engine : sqlalchemy.engine.Engine
95 """
96 assert engine, "Error: can't create schema without engine."
97 _LOG.info("Create the DB schema for: %s", engine)
98 self._engine = engine
99 self._meta = MetaData()
101 # Common string column sizes.
102 self._exp_id_len = 512
103 self._param_id_len = 512
104 self._param_value_len = 1024
105 self._metric_id_len = 512
106 self._metric_value_len = 255
107 self._status_len = 16
109 # Some overrides for certain DB engines:
110 if engine and engine.dialect.name in {"mysql", "mariadb"}:
111 self._exp_id_len = 255
112 self._param_id_len = 255
113 self._metric_id_len = 255
115 self.experiment = Table(
116 "experiment",
117 self._meta,
118 Column("exp_id", String(self._exp_id_len), nullable=False),
119 Column("description", String(1024)),
120 Column("root_env_config", String(1024), nullable=False),
121 Column("git_repo", String(1024), nullable=False),
122 Column("git_commit", String(40), nullable=False),
123 # For backwards compatibility, we allow NULL for ts_start.
124 Column(
125 "ts_start",
126 DateTime(timezone=True).with_variant(
127 _mysql_datetime_with_fsp(),
128 "mysql",
129 ),
130 ),
131 Column(
132 "ts_end",
133 DateTime(timezone=True).with_variant(
134 _mysql_datetime_with_fsp(),
135 "mysql",
136 ),
137 ),
138 # Should match the text IDs of `mlos_bench.environments.Status` enum:
139 # For backwards compatibility, we allow NULL for status.
140 Column("status", String(self._status_len)),
141 # There may be more than one mlos_benchd_service running on different hosts.
142 # This column stores the host/container name of the driver that
143 # picked up the experiment.
144 # They should use a transaction to update it to their own hostname when
145 # they start if and only if its NULL.
146 Column("driver_name", String(40), comment="Driver Host/Container Name"),
147 Column("driver_pid", Integer, comment="Driver Process ID"),
148 PrimaryKeyConstraint("exp_id"),
149 )
150 """The Table storing
151 :py:class:`~mlos_bench.storage.base_experiment_data.ExperimentData` info.
152 """
154 self.objectives = Table(
155 "objectives",
156 self._meta,
157 Column("exp_id"),
158 Column("optimization_target", String(self._metric_id_len), nullable=False),
159 Column("optimization_direction", String(4), nullable=False),
160 # TODO: Note: weight is not fully supported yet as currently
161 # multi-objective is expected to explore each objective equally.
162 # Will need to adjust the insert and return values to support this
163 # eventually.
164 Column("weight", Float, nullable=True),
165 PrimaryKeyConstraint("exp_id", "optimization_target"),
166 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]),
167 )
168 """The Table storing
169 :py:class:`~mlos_bench.storage.base_storage.Storage.Experiment` optimization
170 objectives info.
171 """
173 # A workaround for SQLAlchemy issue with autoincrement in DuckDB:
174 if engine and engine.dialect.name == "duckdb":
175 seq_config_id = Sequence("seq_config_id")
176 col_config_id = Column(
177 "config_id",
178 Integer,
179 seq_config_id,
180 server_default=seq_config_id.next_value(),
181 nullable=False,
182 primary_key=True,
183 )
184 else:
185 col_config_id = Column(
186 "config_id",
187 Integer,
188 nullable=False,
189 primary_key=True,
190 autoincrement=True,
191 )
193 self.config = Table(
194 "config",
195 self._meta,
196 col_config_id,
197 Column("config_hash", String(64), nullable=False, unique=True),
198 )
199 """The Table storing
200 :py:class:`~mlos_bench.storage.base_tunable_config_data.TunableConfigData`
201 info.
202 """
204 self.trial = Table(
205 "trial",
206 self._meta,
207 Column("exp_id", String(self._exp_id_len), nullable=False),
208 Column("trial_id", Integer, nullable=False),
209 Column("config_id", Integer, nullable=False),
210 Column("trial_runner_id", Integer, nullable=True, default=None),
211 Column(
212 "ts_start",
213 DateTime(timezone=True).with_variant(
214 _mysql_datetime_with_fsp(),
215 "mysql",
216 ),
217 nullable=False,
218 ),
219 Column(
220 "ts_end",
221 DateTime(timezone=True).with_variant(
222 _mysql_datetime_with_fsp(),
223 "mysql",
224 ),
225 nullable=True,
226 ),
227 # Should match the text IDs of `mlos_bench.environments.Status` enum:
228 Column("status", String(self._status_len), nullable=False),
229 PrimaryKeyConstraint("exp_id", "trial_id"),
230 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]),
231 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]),
232 )
233 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData`
234 info.
235 """
237 # Values of the tunable parameters of the experiment,
238 # fixed for a particular trial config.
239 self.config_param = Table(
240 "config_param",
241 self._meta,
242 Column("config_id", Integer, nullable=False),
243 Column("param_id", String(self._param_id_len), nullable=False),
244 Column("param_value", String(self._param_value_len)),
245 PrimaryKeyConstraint("config_id", "param_id"),
246 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]),
247 )
248 """The Table storing
249 :py:class:`~mlos_bench.storage.base_tunable_config_data.TunableConfigData`
250 info.
251 """
253 # Values of additional non-tunable parameters of the trial,
254 # e.g., scheduled execution time, VM name / location, number of repeats, etc.
255 self.trial_param = Table(
256 "trial_param",
257 self._meta,
258 Column("exp_id", String(self._exp_id_len), nullable=False),
259 Column("trial_id", Integer, nullable=False),
260 Column("param_id", String(self._param_id_len), nullable=False),
261 Column("param_value", String(self._param_value_len)),
262 PrimaryKeyConstraint("exp_id", "trial_id", "param_id"),
263 ForeignKeyConstraint(
264 ["exp_id", "trial_id"],
265 [self.trial.c.exp_id, self.trial.c.trial_id],
266 ),
267 )
268 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData`
269 :py:attr:`metadata <mlos_bench.storage.base_trial_data.TrialData.metadata_dict>`
270 info.
271 """
273 self.trial_status = Table(
274 "trial_status",
275 self._meta,
276 Column("exp_id", String(self._exp_id_len), nullable=False),
277 Column("trial_id", Integer, nullable=False),
278 Column(
279 "ts",
280 DateTime(timezone=True).with_variant(
281 _mysql_datetime_with_fsp(),
282 "mysql",
283 ),
284 nullable=False,
285 default="now",
286 ),
287 Column("status", String(self._status_len), nullable=False),
288 UniqueConstraint("exp_id", "trial_id", "ts"),
289 ForeignKeyConstraint(
290 ["exp_id", "trial_id"],
291 [self.trial.c.exp_id, self.trial.c.trial_id],
292 ),
293 )
294 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData`
295 :py:class:`~mlos_bench.environments.status.Status` info.
296 """
298 self.trial_result = Table(
299 "trial_result",
300 self._meta,
301 Column("exp_id", String(self._exp_id_len), nullable=False),
302 Column("trial_id", Integer, nullable=False),
303 Column("metric_id", String(self._metric_id_len), nullable=False),
304 Column("metric_value", String(self._metric_value_len)),
305 PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"),
306 ForeignKeyConstraint(
307 ["exp_id", "trial_id"],
308 [self.trial.c.exp_id, self.trial.c.trial_id],
309 ),
310 )
311 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData`
312 :py:attr:`results <mlos_bench.storage.base_trial_data.TrialData.results_dict>`
313 info.
314 """
316 self.trial_telemetry = Table(
317 "trial_telemetry",
318 self._meta,
319 Column("exp_id", String(self._exp_id_len), nullable=False),
320 Column("trial_id", Integer, nullable=False),
321 Column(
322 "ts",
323 DateTime(timezone=True).with_variant(
324 _mysql_datetime_with_fsp(),
325 "mysql",
326 ),
327 nullable=False,
328 default="now",
329 ),
330 Column("metric_id", String(self._metric_id_len), nullable=False),
331 Column("metric_value", String(self._metric_value_len)),
332 UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"),
333 ForeignKeyConstraint(
334 ["exp_id", "trial_id"],
335 [self.trial.c.exp_id, self.trial.c.trial_id],
336 ),
337 )
338 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData`
339 :py:attr:`telemetry <mlos_bench.storage.base_trial_data.TrialData.telemetry_df>`
340 info.
341 """
343 _LOG.debug("Schema: %s", self._meta)
345 @property
346 def meta(self) -> MetaData:
347 """Return the SQLAlchemy MetaData object."""
348 return self._meta
350 def _get_alembic_cfg(self, conn: Connection) -> config.Config:
351 alembic_cfg = config.Config(
352 path_join(str(files("mlos_bench.storage.sql")), "alembic.ini", abs_path=True)
353 )
354 assert self._engine is not None
355 alembic_cfg.set_main_option(
356 "sqlalchemy.url",
357 self._engine.url.render_as_string(
358 hide_password=False,
359 ),
360 )
361 alembic_cfg.attributes["connection"] = conn
362 return alembic_cfg
364 def drop_all_tables(self, *, force: bool = False) -> None:
365 """
366 Helper method used in testing to reset the DB schema.
368 Notes
369 -----
370 This method is not intended for production use, as it will drop all tables
371 in the database. Use with caution.
373 Parameters
374 ----------
375 force : bool
376 If True, drop all tables in the target database.
377 If False, this method will not drop any tables and will log a warning.
378 """
379 assert self._engine
380 self.meta.reflect(bind=self._engine)
381 if force:
382 self.meta.drop_all(bind=self._engine)
383 else:
384 _LOG.warning(
385 "Resetting the schema without force is not implemented. "
386 "Use force=True to drop all tables."
387 )
389 def create(self) -> "DbSchema":
390 """Create the DB schema."""
391 _LOG.info("Create the DB schema")
392 assert self._engine
393 self._meta.create_all(self._engine)
394 with self._engine.begin() as conn:
395 # If the trial table has the trial_runner_id column but no
396 # "alembic_version" table, then the schema is up to date as of initial
397 # create and we should mark it as such to avoid trying to run the
398 # (non-idempotent) upgrade scripts.
399 # Otherwise, either we already have an alembic_version table and can
400 # safely run the necessary upgrades or we are missing the
401 # trial_runner_id column (the first to introduce schema updates) and
402 # should run the upgrades.
403 if any(
404 column["name"] == "trial_runner_id"
405 for column in inspect(conn).get_columns(self.trial.name)
406 ) and not inspect(conn).has_table("alembic_version"):
407 # Mark the schema as up to date.
408 alembic_cfg = self._get_alembic_cfg(conn)
409 command.stamp(alembic_cfg, "heads")
410 # command.current(alembic_cfg)
411 return self
413 def update(self) -> "DbSchema":
414 """
415 Updates the DB schema to the latest version.
417 Notes
418 -----
419 Also see the `mlos_bench CLI usage <../../../../../mlos_bench.run.usage.html>`__
420 for details on how to invoke only the schema creation/update routines.
421 """
422 assert self._engine
423 with self._engine.connect() as conn:
424 alembic_cfg = self._get_alembic_cfg(conn)
425 command.upgrade(alembic_cfg, "head")
426 return self
428 def __repr__(self) -> str:
429 """
430 Produce a string with all SQL statements required to create the schema from
431 scratch in current SQL dialect.
433 That is, return a collection of CREATE TABLE statements and such.
434 NOTE: this method is quite heavy! We use it only once at startup
435 to log the schema, and if the logging level is set to DEBUG.
437 Returns
438 -------
439 sql : str
440 A multi-line string with SQL statements to create the DB schema from scratch.
441 """
442 assert self._engine
443 ddl = _DDL(self._engine.dialect)
444 mock_engine = create_mock_engine(self._engine.url, executor=ddl)
445 self._meta.create_all(mock_engine, checkfirst=False)
446 return str(ddl)