Coverage for mlos_bench/mlos_bench/storage/sql/schema.py: 95%

38 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6DB schema definition for the :py:class:`~mlos_bench.storage.sql.storage.SqlStorage` 

7backend. 

8 

9Notes 

10----- 

11The SQL statements are generated by SQLAlchemy, but can be obtained using 

12``repr`` or ``str`` (e.g., via ``print()``) on this object. 

13The ``mlos_bench`` CLI will do this automatically if the logging level is set to 

14``DEBUG``. 

15""" 

16 

17import logging 

18from typing import Any, List 

19 

20from sqlalchemy import ( 

21 Column, 

22 DateTime, 

23 Dialect, 

24 Float, 

25 ForeignKeyConstraint, 

26 Integer, 

27 MetaData, 

28 PrimaryKeyConstraint, 

29 Sequence, 

30 String, 

31 Table, 

32 UniqueConstraint, 

33 create_mock_engine, 

34) 

35from sqlalchemy.engine import Engine 

36 

37_LOG = logging.getLogger(__name__) 

38 

39 

40class _DDL: 

41 """ 

42 A helper class to capture the DDL statements from SQLAlchemy. 

43 

44 It is used in `DbSchema.__str__()` method below. 

45 """ 

46 

47 def __init__(self, dialect: Dialect): 

48 self._dialect = dialect 

49 self.statements: List[str] = [] 

50 

51 def __call__(self, sql: Any, *_args: Any, **_kwargs: Any) -> None: 

52 self.statements.append(str(sql.compile(dialect=self._dialect))) 

53 

54 def __repr__(self) -> str: 

55 res = ";\n".join(self.statements) 

56 return res + ";" if res else "" 

57 

58 

59class DbSchema: 

60 """A class to define and create the DB schema.""" 

61 

62 # This class is internal to SqlStorage and is mostly a struct 

63 # for all DB tables, so it's ok to disable the warnings. 

64 # pylint: disable=too-many-instance-attributes 

65 

66 # Common string column sizes. 

67 _ID_LEN = 512 

68 _PARAM_VALUE_LEN = 1024 

69 _METRIC_VALUE_LEN = 255 

70 _STATUS_LEN = 16 

71 

72 def __init__(self, engine: Engine): 

73 """Declare the SQLAlchemy schema for the database.""" 

74 _LOG.info("Create the DB schema for: %s", engine) 

75 self._engine = engine 

76 # TODO: bind for automatic schema updates? (#649) 

77 self._meta = MetaData() 

78 

79 self.experiment = Table( 

80 "experiment", 

81 self._meta, 

82 Column("exp_id", String(self._ID_LEN), nullable=False), 

83 Column("description", String(1024)), 

84 Column("root_env_config", String(1024), nullable=False), 

85 Column("git_repo", String(1024), nullable=False), 

86 Column("git_commit", String(40), nullable=False), 

87 PrimaryKeyConstraint("exp_id"), 

88 ) 

89 

90 self.objectives = Table( 

91 "objectives", 

92 self._meta, 

93 Column("exp_id"), 

94 Column("optimization_target", String(self._ID_LEN), nullable=False), 

95 Column("optimization_direction", String(4), nullable=False), 

96 # TODO: Note: weight is not fully supported yet as currently 

97 # multi-objective is expected to explore each objective equally. 

98 # Will need to adjust the insert and return values to support this 

99 # eventually. 

100 Column("weight", Float, nullable=True), 

101 PrimaryKeyConstraint("exp_id", "optimization_target"), 

102 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), 

103 ) 

104 

105 # A workaround for SQLAlchemy issue with autoincrement in DuckDB: 

106 if engine.dialect.name == "duckdb": 

107 seq_config_id = Sequence("seq_config_id") 

108 col_config_id = Column( 

109 "config_id", 

110 Integer, 

111 seq_config_id, 

112 server_default=seq_config_id.next_value(), 

113 nullable=False, 

114 primary_key=True, 

115 ) 

116 else: 

117 col_config_id = Column( 

118 "config_id", 

119 Integer, 

120 nullable=False, 

121 primary_key=True, 

122 autoincrement=True, 

123 ) 

124 

125 self.config = Table( 

126 "config", 

127 self._meta, 

128 col_config_id, 

129 Column("config_hash", String(64), nullable=False, unique=True), 

130 ) 

131 

132 self.trial = Table( 

133 "trial", 

134 self._meta, 

135 Column("exp_id", String(self._ID_LEN), nullable=False), 

136 Column("trial_id", Integer, nullable=False), 

137 Column("config_id", Integer, nullable=False), 

138 Column("ts_start", DateTime, nullable=False), 

139 Column("ts_end", DateTime), 

140 # Should match the text IDs of `mlos_bench.environments.Status` enum: 

141 Column("status", String(self._STATUS_LEN), nullable=False), 

142 PrimaryKeyConstraint("exp_id", "trial_id"), 

143 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), 

144 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), 

145 ) 

146 

147 # Values of the tunable parameters of the experiment, 

148 # fixed for a particular trial config. 

149 self.config_param = Table( 

150 "config_param", 

151 self._meta, 

152 Column("config_id", Integer, nullable=False), 

153 Column("param_id", String(self._ID_LEN), nullable=False), 

154 Column("param_value", String(self._PARAM_VALUE_LEN)), 

155 PrimaryKeyConstraint("config_id", "param_id"), 

156 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), 

157 ) 

158 

159 # Values of additional non-tunable parameters of the trial, 

160 # e.g., scheduled execution time, VM name / location, number of repeats, etc. 

161 self.trial_param = Table( 

162 "trial_param", 

163 self._meta, 

164 Column("exp_id", String(self._ID_LEN), nullable=False), 

165 Column("trial_id", Integer, nullable=False), 

166 Column("param_id", String(self._ID_LEN), nullable=False), 

167 Column("param_value", String(self._PARAM_VALUE_LEN)), 

168 PrimaryKeyConstraint("exp_id", "trial_id", "param_id"), 

169 ForeignKeyConstraint( 

170 ["exp_id", "trial_id"], 

171 [self.trial.c.exp_id, self.trial.c.trial_id], 

172 ), 

173 ) 

174 

175 self.trial_status = Table( 

176 "trial_status", 

177 self._meta, 

178 Column("exp_id", String(self._ID_LEN), nullable=False), 

179 Column("trial_id", Integer, nullable=False), 

180 Column("ts", DateTime(timezone=True), nullable=False, default="now"), 

181 Column("status", String(self._STATUS_LEN), nullable=False), 

182 UniqueConstraint("exp_id", "trial_id", "ts"), 

183 ForeignKeyConstraint( 

184 ["exp_id", "trial_id"], 

185 [self.trial.c.exp_id, self.trial.c.trial_id], 

186 ), 

187 ) 

188 

189 self.trial_result = Table( 

190 "trial_result", 

191 self._meta, 

192 Column("exp_id", String(self._ID_LEN), nullable=False), 

193 Column("trial_id", Integer, nullable=False), 

194 Column("metric_id", String(self._ID_LEN), nullable=False), 

195 Column("metric_value", String(self._METRIC_VALUE_LEN)), 

196 PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"), 

197 ForeignKeyConstraint( 

198 ["exp_id", "trial_id"], 

199 [self.trial.c.exp_id, self.trial.c.trial_id], 

200 ), 

201 ) 

202 

203 self.trial_telemetry = Table( 

204 "trial_telemetry", 

205 self._meta, 

206 Column("exp_id", String(self._ID_LEN), nullable=False), 

207 Column("trial_id", Integer, nullable=False), 

208 Column("ts", DateTime(timezone=True), nullable=False, default="now"), 

209 Column("metric_id", String(self._ID_LEN), nullable=False), 

210 Column("metric_value", String(self._METRIC_VALUE_LEN)), 

211 UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"), 

212 ForeignKeyConstraint( 

213 ["exp_id", "trial_id"], 

214 [self.trial.c.exp_id, self.trial.c.trial_id], 

215 ), 

216 ) 

217 

218 _LOG.debug("Schema: %s", self._meta) 

219 

220 def create(self) -> "DbSchema": 

221 """Create the DB schema.""" 

222 _LOG.info("Create the DB schema") 

223 self._meta.create_all(self._engine) 

224 return self 

225 

226 def __repr__(self) -> str: 

227 """ 

228 Produce a string with all SQL statements required to create the schema from 

229 scratch in current SQL dialect. 

230 

231 That is, return a collection of CREATE TABLE statements and such. 

232 NOTE: this method is quite heavy! We use it only once at startup 

233 to log the schema, and if the logging level is set to DEBUG. 

234 

235 Returns 

236 ------- 

237 sql : str 

238 A multi-line string with SQL statements to create the DB schema from scratch. 

239 """ 

240 ddl = _DDL(self._engine.dialect) 

241 mock_engine = create_mock_engine(self._engine.url, executor=ddl) 

242 self._meta.create_all(mock_engine, checkfirst=False) 

243 return str(ddl)