Coverage for mlos_bench/mlos_bench/storage/sql/storage.py: 100%
38 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Saving and restoring the benchmark data in SQL database."""
7import logging
8from typing import Dict, Literal, Optional
10from sqlalchemy import URL, create_engine
12from mlos_bench.services.base_service import Service
13from mlos_bench.storage.base_experiment_data import ExperimentData
14from mlos_bench.storage.base_storage import Storage
15from mlos_bench.storage.sql.experiment import Experiment
16from mlos_bench.storage.sql.experiment_data import ExperimentSqlData
17from mlos_bench.storage.sql.schema import DbSchema
18from mlos_bench.tunables.tunable_groups import TunableGroups
20_LOG = logging.getLogger(__name__)
23class SqlStorage(Storage):
24 """An implementation of the :py:class:`~.Storage` interface using SQLAlchemy
25 backend.
26 """
28 def __init__(
29 self,
30 config: dict,
31 global_config: Optional[dict] = None,
32 service: Optional[Service] = None,
33 ):
34 super().__init__(config, global_config, service)
35 lazy_schema_create = self._config.pop("lazy_schema_create", False)
36 self._log_sql = self._config.pop("log_sql", False)
37 self._url = URL.create(**self._config)
38 self._repr = f"{self._url.get_backend_name()}:{self._url.database}"
39 _LOG.info("Connect to the database: %s", self)
40 self._engine = create_engine(self._url, echo=self._log_sql)
41 self._db_schema: DbSchema
42 if not lazy_schema_create:
43 assert self._schema
44 else:
45 _LOG.info("Using lazy schema create for database: %s", self)
47 @property
48 def _schema(self) -> DbSchema:
49 """Lazily create schema upon first access."""
50 if not hasattr(self, "_db_schema"):
51 self._db_schema = DbSchema(self._engine).create()
52 if _LOG.isEnabledFor(logging.DEBUG):
53 _LOG.debug("DDL statements:\n%s", self._schema)
54 return self._db_schema
56 def __repr__(self) -> str:
57 return self._repr
59 def experiment( # pylint: disable=too-many-arguments
60 self,
61 *,
62 experiment_id: str,
63 trial_id: int,
64 root_env_config: str,
65 description: str,
66 tunables: TunableGroups,
67 opt_targets: Dict[str, Literal["min", "max"]],
68 ) -> Storage.Experiment:
69 return Experiment(
70 engine=self._engine,
71 schema=self._schema,
72 tunables=tunables,
73 experiment_id=experiment_id,
74 trial_id=trial_id,
75 root_env_config=root_env_config,
76 description=description,
77 opt_targets=opt_targets,
78 )
80 @property
81 def experiments(self) -> Dict[str, ExperimentData]:
82 # FIXME: this is somewhat expensive if only fetching a single Experiment.
83 # May need to expand the API or data structures to lazily fetch data and/or cache it.
84 with self._engine.connect() as conn:
85 cur_exp = conn.execute(
86 self._schema.experiment.select().order_by(
87 self._schema.experiment.c.exp_id.asc(),
88 )
89 )
90 return {
91 exp.exp_id: ExperimentSqlData(
92 engine=self._engine,
93 schema=self._schema,
94 experiment_id=exp.exp_id,
95 description=exp.description,
96 root_env_config=exp.root_env_config,
97 git_repo=exp.git_repo,
98 git_commit=exp.git_commit,
99 )
100 for exp in cur_exp.fetchall()
101 }