Coverage for mlos_bench/mlos_bench/storage/sql/schema.py: 95%
38 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6DB schema definition.
7"""
9import logging
10from typing import List, Any
12from sqlalchemy import (
13 Engine, MetaData, Dialect, create_mock_engine,
14 Table, Column, Sequence, Integer, Float, String, DateTime,
15 PrimaryKeyConstraint, ForeignKeyConstraint, UniqueConstraint,
16)
18_LOG = logging.getLogger(__name__)
21class _DDL:
22 """
23 A helper class to capture the DDL statements from SQLAlchemy.
25 It is used in `DbSchema.__str__()` method below.
26 """
28 def __init__(self, dialect: Dialect):
29 self._dialect = dialect
30 self.statements: List[str] = []
32 def __call__(self, sql: Any, *_args: Any, **_kwargs: Any) -> None:
33 self.statements.append(str(sql.compile(dialect=self._dialect)))
35 def __repr__(self) -> str:
36 res = ";\n".join(self.statements)
37 return res + ";" if res else ""
40class DbSchema:
41 """
42 A class to define and create the DB schema.
43 """
45 # This class is internal to SqlStorage and is mostly a struct
46 # for all DB tables, so it's ok to disable the warnings.
47 # pylint: disable=too-many-instance-attributes
49 # Common string column sizes.
50 _ID_LEN = 512
51 _PARAM_VALUE_LEN = 1024
52 _METRIC_VALUE_LEN = 255
53 _STATUS_LEN = 16
55 def __init__(self, engine: Engine):
56 """
57 Declare the SQLAlchemy schema for the database.
58 """
59 _LOG.info("Create the DB schema for: %s", engine)
60 self._engine = engine
61 # TODO: bind for automatic schema updates? (#649)
62 self._meta = MetaData()
64 self.experiment = Table(
65 "experiment",
66 self._meta,
67 Column("exp_id", String(self._ID_LEN), nullable=False),
68 Column("description", String(1024)),
69 Column("root_env_config", String(1024), nullable=False),
70 Column("git_repo", String(1024), nullable=False),
71 Column("git_commit", String(40), nullable=False),
73 PrimaryKeyConstraint("exp_id"),
74 )
76 self.objectives = Table(
77 "objectives",
78 self._meta,
79 Column("exp_id"),
80 Column("optimization_target", String(self._ID_LEN), nullable=False),
81 Column("optimization_direction", String(4), nullable=False),
82 # TODO: Note: weight is not fully supported yet as currently
83 # multi-objective is expected to explore each objective equally.
84 # Will need to adjust the insert and return values to support this
85 # eventually.
86 Column("weight", Float, nullable=True),
88 PrimaryKeyConstraint("exp_id", "optimization_target"),
89 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]),
90 )
92 # A workaround for SQLAlchemy issue with autoincrement in DuckDB:
93 if engine.dialect.name == "duckdb":
94 seq_config_id = Sequence('seq_config_id')
95 col_config_id = Column("config_id", Integer, seq_config_id,
96 server_default=seq_config_id.next_value(),
97 nullable=False, primary_key=True)
98 else:
99 col_config_id = Column("config_id", Integer, nullable=False,
100 primary_key=True, autoincrement=True)
102 self.config = Table(
103 "config",
104 self._meta,
105 col_config_id,
106 Column("config_hash", String(64), nullable=False, unique=True),
107 )
109 self.trial = Table(
110 "trial",
111 self._meta,
112 Column("exp_id", String(self._ID_LEN), nullable=False),
113 Column("trial_id", Integer, nullable=False),
114 Column("config_id", Integer, nullable=False),
115 Column("ts_start", DateTime, nullable=False),
116 Column("ts_end", DateTime),
117 # Should match the text IDs of `mlos_bench.environments.Status` enum:
118 Column("status", String(self._STATUS_LEN), nullable=False),
120 PrimaryKeyConstraint("exp_id", "trial_id"),
121 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]),
122 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]),
123 )
125 # Values of the tunable parameters of the experiment,
126 # fixed for a particular trial config.
127 self.config_param = Table(
128 "config_param",
129 self._meta,
130 Column("config_id", Integer, nullable=False),
131 Column("param_id", String(self._ID_LEN), nullable=False),
132 Column("param_value", String(self._PARAM_VALUE_LEN)),
134 PrimaryKeyConstraint("config_id", "param_id"),
135 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]),
136 )
138 # Values of additional non-tunable parameters of the trial,
139 # e.g., scheduled execution time, VM name / location, number of repeats, etc.
140 self.trial_param = Table(
141 "trial_param",
142 self._meta,
143 Column("exp_id", String(self._ID_LEN), nullable=False),
144 Column("trial_id", Integer, nullable=False),
145 Column("param_id", String(self._ID_LEN), nullable=False),
146 Column("param_value", String(self._PARAM_VALUE_LEN)),
148 PrimaryKeyConstraint("exp_id", "trial_id", "param_id"),
149 ForeignKeyConstraint(["exp_id", "trial_id"],
150 [self.trial.c.exp_id, self.trial.c.trial_id]),
151 )
153 self.trial_status = Table(
154 "trial_status",
155 self._meta,
156 Column("exp_id", String(self._ID_LEN), nullable=False),
157 Column("trial_id", Integer, nullable=False),
158 Column("ts", DateTime(timezone=True), nullable=False, default="now"),
159 Column("status", String(self._STATUS_LEN), nullable=False),
161 UniqueConstraint("exp_id", "trial_id", "ts"),
162 ForeignKeyConstraint(["exp_id", "trial_id"],
163 [self.trial.c.exp_id, self.trial.c.trial_id]),
164 )
166 self.trial_result = Table(
167 "trial_result",
168 self._meta,
169 Column("exp_id", String(self._ID_LEN), nullable=False),
170 Column("trial_id", Integer, nullable=False),
171 Column("metric_id", String(self._ID_LEN), nullable=False),
172 Column("metric_value", String(self._METRIC_VALUE_LEN)),
174 PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"),
175 ForeignKeyConstraint(["exp_id", "trial_id"],
176 [self.trial.c.exp_id, self.trial.c.trial_id]),
177 )
179 self.trial_telemetry = Table(
180 "trial_telemetry",
181 self._meta,
182 Column("exp_id", String(self._ID_LEN), nullable=False),
183 Column("trial_id", Integer, nullable=False),
184 Column("ts", DateTime(timezone=True), nullable=False, default="now"),
185 Column("metric_id", String(self._ID_LEN), nullable=False),
186 Column("metric_value", String(self._METRIC_VALUE_LEN)),
188 UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"),
189 ForeignKeyConstraint(["exp_id", "trial_id"],
190 [self.trial.c.exp_id, self.trial.c.trial_id]),
191 )
193 _LOG.debug("Schema: %s", self._meta)
195 def create(self) -> 'DbSchema':
196 """
197 Create the DB schema.
198 """
199 _LOG.info("Create the DB schema")
200 self._meta.create_all(self._engine)
201 return self
203 def __repr__(self) -> str:
204 """
205 Produce a string with all SQL statements required to create the schema
206 from scratch in current SQL dialect.
208 That is, return a collection of CREATE TABLE statements and such.
209 NOTE: this method is quite heavy! We use it only once at startup
210 to log the schema, and if the logging level is set to DEBUG.
212 Returns
213 -------
214 sql : str
215 A multi-line string with SQL statements to create the DB schema from scratch.
216 """
217 ddl = _DDL(self._engine.dialect)
218 mock_engine = create_mock_engine(self._engine.url, executor=ddl)
219 self._meta.create_all(mock_engine, checkfirst=False)
220 return str(ddl)