Coverage for mlos_bench/mlos_bench/storage/sql/schema.py: 95%
38 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6DB schema definition for the :py:class:`~mlos_bench.storage.sql.storage.SqlStorage`
7backend.
9Notes
10-----
11The SQL statements are generated by SQLAlchemy, but can be obtained using
12``repr`` or ``str`` (e.g., via ``print()``) on this object.
13The ``mlos_bench`` CLI will do this automatically if the logging level is set to
14``DEBUG``.
15"""
17import logging
18from typing import Any, List
20from sqlalchemy import (
21 Column,
22 DateTime,
23 Dialect,
24 Float,
25 ForeignKeyConstraint,
26 Integer,
27 MetaData,
28 PrimaryKeyConstraint,
29 Sequence,
30 String,
31 Table,
32 UniqueConstraint,
33 create_mock_engine,
34)
35from sqlalchemy.engine import Engine
37_LOG = logging.getLogger(__name__)
40class _DDL:
41 """
42 A helper class to capture the DDL statements from SQLAlchemy.
44 It is used in `DbSchema.__str__()` method below.
45 """
47 def __init__(self, dialect: Dialect):
48 self._dialect = dialect
49 self.statements: List[str] = []
51 def __call__(self, sql: Any, *_args: Any, **_kwargs: Any) -> None:
52 self.statements.append(str(sql.compile(dialect=self._dialect)))
54 def __repr__(self) -> str:
55 res = ";\n".join(self.statements)
56 return res + ";" if res else ""
59class DbSchema:
60 """A class to define and create the DB schema."""
62 # This class is internal to SqlStorage and is mostly a struct
63 # for all DB tables, so it's ok to disable the warnings.
64 # pylint: disable=too-many-instance-attributes
66 # Common string column sizes.
67 _ID_LEN = 512
68 _PARAM_VALUE_LEN = 1024
69 _METRIC_VALUE_LEN = 255
70 _STATUS_LEN = 16
72 def __init__(self, engine: Engine):
73 """Declare the SQLAlchemy schema for the database."""
74 _LOG.info("Create the DB schema for: %s", engine)
75 self._engine = engine
76 # TODO: bind for automatic schema updates? (#649)
77 self._meta = MetaData()
79 self.experiment = Table(
80 "experiment",
81 self._meta,
82 Column("exp_id", String(self._ID_LEN), nullable=False),
83 Column("description", String(1024)),
84 Column("root_env_config", String(1024), nullable=False),
85 Column("git_repo", String(1024), nullable=False),
86 Column("git_commit", String(40), nullable=False),
87 PrimaryKeyConstraint("exp_id"),
88 )
90 self.objectives = Table(
91 "objectives",
92 self._meta,
93 Column("exp_id"),
94 Column("optimization_target", String(self._ID_LEN), nullable=False),
95 Column("optimization_direction", String(4), nullable=False),
96 # TODO: Note: weight is not fully supported yet as currently
97 # multi-objective is expected to explore each objective equally.
98 # Will need to adjust the insert and return values to support this
99 # eventually.
100 Column("weight", Float, nullable=True),
101 PrimaryKeyConstraint("exp_id", "optimization_target"),
102 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]),
103 )
105 # A workaround for SQLAlchemy issue with autoincrement in DuckDB:
106 if engine.dialect.name == "duckdb":
107 seq_config_id = Sequence("seq_config_id")
108 col_config_id = Column(
109 "config_id",
110 Integer,
111 seq_config_id,
112 server_default=seq_config_id.next_value(),
113 nullable=False,
114 primary_key=True,
115 )
116 else:
117 col_config_id = Column(
118 "config_id",
119 Integer,
120 nullable=False,
121 primary_key=True,
122 autoincrement=True,
123 )
125 self.config = Table(
126 "config",
127 self._meta,
128 col_config_id,
129 Column("config_hash", String(64), nullable=False, unique=True),
130 )
132 self.trial = Table(
133 "trial",
134 self._meta,
135 Column("exp_id", String(self._ID_LEN), nullable=False),
136 Column("trial_id", Integer, nullable=False),
137 Column("config_id", Integer, nullable=False),
138 Column("ts_start", DateTime, nullable=False),
139 Column("ts_end", DateTime),
140 # Should match the text IDs of `mlos_bench.environments.Status` enum:
141 Column("status", String(self._STATUS_LEN), nullable=False),
142 PrimaryKeyConstraint("exp_id", "trial_id"),
143 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]),
144 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]),
145 )
147 # Values of the tunable parameters of the experiment,
148 # fixed for a particular trial config.
149 self.config_param = Table(
150 "config_param",
151 self._meta,
152 Column("config_id", Integer, nullable=False),
153 Column("param_id", String(self._ID_LEN), nullable=False),
154 Column("param_value", String(self._PARAM_VALUE_LEN)),
155 PrimaryKeyConstraint("config_id", "param_id"),
156 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]),
157 )
159 # Values of additional non-tunable parameters of the trial,
160 # e.g., scheduled execution time, VM name / location, number of repeats, etc.
161 self.trial_param = Table(
162 "trial_param",
163 self._meta,
164 Column("exp_id", String(self._ID_LEN), nullable=False),
165 Column("trial_id", Integer, nullable=False),
166 Column("param_id", String(self._ID_LEN), nullable=False),
167 Column("param_value", String(self._PARAM_VALUE_LEN)),
168 PrimaryKeyConstraint("exp_id", "trial_id", "param_id"),
169 ForeignKeyConstraint(
170 ["exp_id", "trial_id"],
171 [self.trial.c.exp_id, self.trial.c.trial_id],
172 ),
173 )
175 self.trial_status = Table(
176 "trial_status",
177 self._meta,
178 Column("exp_id", String(self._ID_LEN), nullable=False),
179 Column("trial_id", Integer, nullable=False),
180 Column("ts", DateTime(timezone=True), nullable=False, default="now"),
181 Column("status", String(self._STATUS_LEN), nullable=False),
182 UniqueConstraint("exp_id", "trial_id", "ts"),
183 ForeignKeyConstraint(
184 ["exp_id", "trial_id"],
185 [self.trial.c.exp_id, self.trial.c.trial_id],
186 ),
187 )
189 self.trial_result = Table(
190 "trial_result",
191 self._meta,
192 Column("exp_id", String(self._ID_LEN), nullable=False),
193 Column("trial_id", Integer, nullable=False),
194 Column("metric_id", String(self._ID_LEN), nullable=False),
195 Column("metric_value", String(self._METRIC_VALUE_LEN)),
196 PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"),
197 ForeignKeyConstraint(
198 ["exp_id", "trial_id"],
199 [self.trial.c.exp_id, self.trial.c.trial_id],
200 ),
201 )
203 self.trial_telemetry = Table(
204 "trial_telemetry",
205 self._meta,
206 Column("exp_id", String(self._ID_LEN), nullable=False),
207 Column("trial_id", Integer, nullable=False),
208 Column("ts", DateTime(timezone=True), nullable=False, default="now"),
209 Column("metric_id", String(self._ID_LEN), nullable=False),
210 Column("metric_value", String(self._METRIC_VALUE_LEN)),
211 UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"),
212 ForeignKeyConstraint(
213 ["exp_id", "trial_id"],
214 [self.trial.c.exp_id, self.trial.c.trial_id],
215 ),
216 )
218 _LOG.debug("Schema: %s", self._meta)
220 def create(self) -> "DbSchema":
221 """Create the DB schema."""
222 _LOG.info("Create the DB schema")
223 self._meta.create_all(self._engine)
224 return self
226 def __repr__(self) -> str:
227 """
228 Produce a string with all SQL statements required to create the schema from
229 scratch in current SQL dialect.
231 That is, return a collection of CREATE TABLE statements and such.
232 NOTE: this method is quite heavy! We use it only once at startup
233 to log the schema, and if the logging level is set to DEBUG.
235 Returns
236 -------
237 sql : str
238 A multi-line string with SQL statements to create the DB schema from scratch.
239 """
240 ddl = _DDL(self._engine.dialect)
241 mock_engine = create_mock_engine(self._engine.url, executor=ddl)
242 self._meta.create_all(mock_engine, checkfirst=False)
243 return str(ddl)