Coverage for mlos_bench/mlos_bench/storage/sql/storage.py: 95%
81 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Saving and restoring the benchmark data in SQL database."""
7import logging
8from types import TracebackType
9from typing import Literal
11from sqlalchemy import URL, Engine, create_engine
13from mlos_bench.services.base_service import Service
14from mlos_bench.storage.base_experiment_data import ExperimentData
15from mlos_bench.storage.base_storage import Storage
16from mlos_bench.storage.sql.experiment import Experiment
17from mlos_bench.storage.sql.experiment_data import ExperimentSqlData
18from mlos_bench.storage.sql.schema import DbSchema
19from mlos_bench.tunables.tunable_groups import TunableGroups
21_LOG = logging.getLogger(__name__)
24class SqlStorage(Storage):
25 """An implementation of the :py:class:`~.Storage` interface using SQLAlchemy
26 backend.
27 """
29 # pylint: disable=too-many-instance-attributes
31 def __init__(
32 self,
33 config: dict,
34 global_config: dict | None = None,
35 service: Service | None = None,
36 ):
37 super().__init__(config, global_config, service)
38 self._lazy_schema_create = self._config.pop("lazy_schema_create", False)
39 self._log_sql = self._config.pop("log_sql", False)
40 self._url = URL.create(**self._config)
41 self._repr = f"{self._url.get_backend_name()}:{self._url.database}"
42 self._engine: Engine
43 self._db_schema: DbSchema
44 self._schema_created = False
45 self._schema_updated = False
46 self._init_engine()
48 def _init_engine(self) -> None:
49 """Initialize the SQLAlchemy engine."""
50 # This is a no-op, as the engine is created in __init__.
51 _LOG.info("Connect to the database: %s", self)
52 self._engine = create_engine(self._url, echo=self._log_sql)
53 self._db_schema = DbSchema(self._engine)
54 if not self._lazy_schema_create:
55 assert self._schema
56 self.update_schema()
57 else:
58 _LOG.info("Using lazy schema create for database: %s", self)
60 # Make the object picklable.
62 def __getstate__(self) -> dict:
63 """Return the state of the object for pickling."""
64 state = self.__dict__.copy()
65 # Don't pickle the engine, as it cannot be pickled.
66 state.pop("_engine", None)
67 state.pop("_db_schema", None)
68 return state
70 def __setstate__(self, state: dict) -> None:
71 """Restore the state of the object from pickling."""
72 self.__dict__.update(state)
73 # Recreate the engine and schema.
74 self._init_engine()
76 def dispose(self) -> None:
77 """Closes the database connection pool."""
78 if self._engine:
79 self._engine.dispose()
80 _LOG.info("Closed the database connection: %s", self)
82 def __exit__(
83 self,
84 exc_type: type[BaseException] | None, # pylint: disable=unused-argument
85 exc_val: BaseException | None, # pylint: disable=unused-argument
86 exc_tb: TracebackType | None, # pylint: disable=unused-argument
87 ) -> Literal[False]:
88 """Close the engine connection when exiting the context."""
89 self.dispose()
90 return False
92 @property
93 def _schema(self) -> DbSchema:
94 """Lazily create schema upon first access."""
95 if not self._schema_created:
96 self._db_schema.create()
97 self._schema_created = True
98 if _LOG.isEnabledFor(logging.DEBUG):
99 _LOG.debug("DDL statements:\n%s", self._db_schema)
100 return self._db_schema
102 def _reset_schema(self, *, force: bool = False) -> None:
103 """
104 Helper method used in testing to reset the DB schema.
106 Notes
107 -----
108 This method is not intended for production use, as it will drop all tables
109 in the database. Use with caution.
111 Parameters
112 ----------
113 force : bool
114 If True, drop all tables in the target database.
115 If False, this method will not drop any tables and will log a warning.
116 """
117 assert self._engine
118 if force:
119 self._schema.drop_all_tables(force=force)
120 self._db_schema = DbSchema(self._engine)
121 self._schema_created = False
122 self._schema_updated = False
123 else:
124 _LOG.warning(
125 "Resetting the schema without force is not implemented. "
126 "Use force=True to drop all tables."
127 )
129 def update_schema(self) -> None:
130 """Update the database schema."""
131 if not self._schema_updated:
132 self._schema.update()
133 self._schema_updated = True
135 def __repr__(self) -> str:
136 return self._repr
138 def get_experiment_by_id(
139 self,
140 experiment_id: str,
141 tunables: TunableGroups,
142 opt_targets: dict[str, Literal["min", "max"]],
143 ) -> Storage.Experiment | None:
144 with self._engine.connect() as conn:
145 cur_exp = conn.execute(
146 self._schema.experiment.select().where(
147 self._schema.experiment.c.exp_id == experiment_id,
148 )
149 )
150 exp = cur_exp.fetchone()
151 if exp is None:
152 return None
153 return Experiment(
154 engine=self._engine,
155 schema=self._schema,
156 experiment_id=exp.exp_id,
157 trial_id=-1, # will be loaded upon __enter__ which calls _setup()
158 description=exp.description,
159 root_env_config=exp.root_env_config,
160 tunables=tunables,
161 opt_targets=opt_targets,
162 )
164 def experiment( # pylint: disable=too-many-arguments
165 self,
166 *,
167 experiment_id: str,
168 trial_id: int,
169 root_env_config: str,
170 description: str,
171 tunables: TunableGroups,
172 opt_targets: dict[str, Literal["min", "max"]],
173 ) -> Storage.Experiment:
174 return Experiment(
175 engine=self._engine,
176 schema=self._schema,
177 tunables=tunables,
178 experiment_id=experiment_id,
179 trial_id=trial_id,
180 root_env_config=root_env_config,
181 description=description,
182 opt_targets=opt_targets,
183 )
185 @property
186 def experiments(self) -> dict[str, ExperimentData]:
187 # FIXME: this is somewhat expensive if only fetching a single Experiment.
188 # May need to expand the API or data structures to lazily fetch data and/or cache it.
189 with self._engine.connect() as conn:
190 cur_exp = conn.execute(
191 self._schema.experiment.select().order_by(
192 self._schema.experiment.c.exp_id.asc(),
193 )
194 )
195 return {
196 exp.exp_id: ExperimentSqlData(
197 engine=self._engine,
198 schema=self._schema,
199 experiment_id=exp.exp_id,
200 description=exp.description,
201 root_env_config=exp.root_env_config,
202 git_repo=exp.git_repo,
203 git_commit=exp.git_commit,
204 )
205 for exp in cur_exp.fetchall()
206 }