Coverage for mlos_bench/mlos_bench/storage/sql/storage.py: 95%

81 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-14 00:55 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Saving and restoring the benchmark data in SQL database.""" 

6 

7import logging 

8from types import TracebackType 

9from typing import Literal 

10 

11from sqlalchemy import URL, Engine, create_engine 

12 

13from mlos_bench.services.base_service import Service 

14from mlos_bench.storage.base_experiment_data import ExperimentData 

15from mlos_bench.storage.base_storage import Storage 

16from mlos_bench.storage.sql.experiment import Experiment 

17from mlos_bench.storage.sql.experiment_data import ExperimentSqlData 

18from mlos_bench.storage.sql.schema import DbSchema 

19from mlos_bench.tunables.tunable_groups import TunableGroups 

20 

21_LOG = logging.getLogger(__name__) 

22 

23 

24class SqlStorage(Storage): 

25 """An implementation of the :py:class:`~.Storage` interface using SQLAlchemy 

26 backend. 

27 """ 

28 

29 # pylint: disable=too-many-instance-attributes 

30 

31 def __init__( 

32 self, 

33 config: dict, 

34 global_config: dict | None = None, 

35 service: Service | None = None, 

36 ): 

37 super().__init__(config, global_config, service) 

38 self._lazy_schema_create = self._config.pop("lazy_schema_create", False) 

39 self._log_sql = self._config.pop("log_sql", False) 

40 self._url = URL.create(**self._config) 

41 self._repr = f"{self._url.get_backend_name()}:{self._url.database}" 

42 self._engine: Engine 

43 self._db_schema: DbSchema 

44 self._schema_created = False 

45 self._schema_updated = False 

46 self._init_engine() 

47 

48 def _init_engine(self) -> None: 

49 """Initialize the SQLAlchemy engine.""" 

50 # This is a no-op, as the engine is created in __init__. 

51 _LOG.info("Connect to the database: %s", self) 

52 self._engine = create_engine(self._url, echo=self._log_sql) 

53 self._db_schema = DbSchema(self._engine) 

54 if not self._lazy_schema_create: 

55 assert self._schema 

56 self.update_schema() 

57 else: 

58 _LOG.info("Using lazy schema create for database: %s", self) 

59 

60 # Make the object picklable. 

61 

62 def __getstate__(self) -> dict: 

63 """Return the state of the object for pickling.""" 

64 state = self.__dict__.copy() 

65 # Don't pickle the engine, as it cannot be pickled. 

66 state.pop("_engine", None) 

67 state.pop("_db_schema", None) 

68 return state 

69 

70 def __setstate__(self, state: dict) -> None: 

71 """Restore the state of the object from pickling.""" 

72 self.__dict__.update(state) 

73 # Recreate the engine and schema. 

74 self._init_engine() 

75 

76 def dispose(self) -> None: 

77 """Closes the database connection pool.""" 

78 if self._engine: 

79 self._engine.dispose() 

80 _LOG.info("Closed the database connection: %s", self) 

81 

82 def __exit__( 

83 self, 

84 exc_type: type[BaseException] | None, # pylint: disable=unused-argument 

85 exc_val: BaseException | None, # pylint: disable=unused-argument 

86 exc_tb: TracebackType | None, # pylint: disable=unused-argument 

87 ) -> Literal[False]: 

88 """Close the engine connection when exiting the context.""" 

89 self.dispose() 

90 return False 

91 

92 @property 

93 def _schema(self) -> DbSchema: 

94 """Lazily create schema upon first access.""" 

95 if not self._schema_created: 

96 self._db_schema.create() 

97 self._schema_created = True 

98 if _LOG.isEnabledFor(logging.DEBUG): 

99 _LOG.debug("DDL statements:\n%s", self._db_schema) 

100 return self._db_schema 

101 

102 def _reset_schema(self, *, force: bool = False) -> None: 

103 """ 

104 Helper method used in testing to reset the DB schema. 

105 

106 Notes 

107 ----- 

108 This method is not intended for production use, as it will drop all tables 

109 in the database. Use with caution. 

110 

111 Parameters 

112 ---------- 

113 force : bool 

114 If True, drop all tables in the target database. 

115 If False, this method will not drop any tables and will log a warning. 

116 """ 

117 assert self._engine 

118 if force: 

119 self._schema.drop_all_tables(force=force) 

120 self._db_schema = DbSchema(self._engine) 

121 self._schema_created = False 

122 self._schema_updated = False 

123 else: 

124 _LOG.warning( 

125 "Resetting the schema without force is not implemented. " 

126 "Use force=True to drop all tables." 

127 ) 

128 

129 def update_schema(self) -> None: 

130 """Update the database schema.""" 

131 if not self._schema_updated: 

132 self._schema.update() 

133 self._schema_updated = True 

134 

135 def __repr__(self) -> str: 

136 return self._repr 

137 

138 def get_experiment_by_id( 

139 self, 

140 experiment_id: str, 

141 tunables: TunableGroups, 

142 opt_targets: dict[str, Literal["min", "max"]], 

143 ) -> Storage.Experiment | None: 

144 with self._engine.connect() as conn: 

145 cur_exp = conn.execute( 

146 self._schema.experiment.select().where( 

147 self._schema.experiment.c.exp_id == experiment_id, 

148 ) 

149 ) 

150 exp = cur_exp.fetchone() 

151 if exp is None: 

152 return None 

153 return Experiment( 

154 engine=self._engine, 

155 schema=self._schema, 

156 experiment_id=exp.exp_id, 

157 trial_id=-1, # will be loaded upon __enter__ which calls _setup() 

158 description=exp.description, 

159 root_env_config=exp.root_env_config, 

160 tunables=tunables, 

161 opt_targets=opt_targets, 

162 ) 

163 

164 def experiment( # pylint: disable=too-many-arguments 

165 self, 

166 *, 

167 experiment_id: str, 

168 trial_id: int, 

169 root_env_config: str, 

170 description: str, 

171 tunables: TunableGroups, 

172 opt_targets: dict[str, Literal["min", "max"]], 

173 ) -> Storage.Experiment: 

174 return Experiment( 

175 engine=self._engine, 

176 schema=self._schema, 

177 tunables=tunables, 

178 experiment_id=experiment_id, 

179 trial_id=trial_id, 

180 root_env_config=root_env_config, 

181 description=description, 

182 opt_targets=opt_targets, 

183 ) 

184 

185 @property 

186 def experiments(self) -> dict[str, ExperimentData]: 

187 # FIXME: this is somewhat expensive if only fetching a single Experiment. 

188 # May need to expand the API or data structures to lazily fetch data and/or cache it. 

189 with self._engine.connect() as conn: 

190 cur_exp = conn.execute( 

191 self._schema.experiment.select().order_by( 

192 self._schema.experiment.c.exp_id.asc(), 

193 ) 

194 ) 

195 return { 

196 exp.exp_id: ExperimentSqlData( 

197 engine=self._engine, 

198 schema=self._schema, 

199 experiment_id=exp.exp_id, 

200 description=exp.description, 

201 root_env_config=exp.root_env_config, 

202 git_repo=exp.git_repo, 

203 git_commit=exp.git_commit, 

204 ) 

205 for exp in cur_exp.fetchall() 

206 }