Coverage for mlos_bench/mlos_bench/storage/sql/schema.py: 95%

38 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6DB schema definition. 

7""" 

8 

9import logging 

10from typing import List, Any 

11 

12from sqlalchemy import ( 

13 Engine, MetaData, Dialect, create_mock_engine, 

14 Table, Column, Sequence, Integer, Float, String, DateTime, 

15 PrimaryKeyConstraint, ForeignKeyConstraint, UniqueConstraint, 

16) 

17 

18_LOG = logging.getLogger(__name__) 

19 

20 

21class _DDL: 

22 """ 

23 A helper class to capture the DDL statements from SQLAlchemy. 

24 

25 It is used in `DbSchema.__str__()` method below. 

26 """ 

27 

28 def __init__(self, dialect: Dialect): 

29 self._dialect = dialect 

30 self.statements: List[str] = [] 

31 

32 def __call__(self, sql: Any, *_args: Any, **_kwargs: Any) -> None: 

33 self.statements.append(str(sql.compile(dialect=self._dialect))) 

34 

35 def __repr__(self) -> str: 

36 res = ";\n".join(self.statements) 

37 return res + ";" if res else "" 

38 

39 

40class DbSchema: 

41 """ 

42 A class to define and create the DB schema. 

43 """ 

44 

45 # This class is internal to SqlStorage and is mostly a struct 

46 # for all DB tables, so it's ok to disable the warnings. 

47 # pylint: disable=too-many-instance-attributes 

48 

49 # Common string column sizes. 

50 _ID_LEN = 512 

51 _PARAM_VALUE_LEN = 1024 

52 _METRIC_VALUE_LEN = 255 

53 _STATUS_LEN = 16 

54 

55 def __init__(self, engine: Engine): 

56 """ 

57 Declare the SQLAlchemy schema for the database. 

58 """ 

59 _LOG.info("Create the DB schema for: %s", engine) 

60 self._engine = engine 

61 # TODO: bind for automatic schema updates? (#649) 

62 self._meta = MetaData() 

63 

64 self.experiment = Table( 

65 "experiment", 

66 self._meta, 

67 Column("exp_id", String(self._ID_LEN), nullable=False), 

68 Column("description", String(1024)), 

69 Column("root_env_config", String(1024), nullable=False), 

70 Column("git_repo", String(1024), nullable=False), 

71 Column("git_commit", String(40), nullable=False), 

72 

73 PrimaryKeyConstraint("exp_id"), 

74 ) 

75 

76 self.objectives = Table( 

77 "objectives", 

78 self._meta, 

79 Column("exp_id"), 

80 Column("optimization_target", String(self._ID_LEN), nullable=False), 

81 Column("optimization_direction", String(4), nullable=False), 

82 # TODO: Note: weight is not fully supported yet as currently 

83 # multi-objective is expected to explore each objective equally. 

84 # Will need to adjust the insert and return values to support this 

85 # eventually. 

86 Column("weight", Float, nullable=True), 

87 

88 PrimaryKeyConstraint("exp_id", "optimization_target"), 

89 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), 

90 ) 

91 

92 # A workaround for SQLAlchemy issue with autoincrement in DuckDB: 

93 if engine.dialect.name == "duckdb": 

94 seq_config_id = Sequence('seq_config_id') 

95 col_config_id = Column("config_id", Integer, seq_config_id, 

96 server_default=seq_config_id.next_value(), 

97 nullable=False, primary_key=True) 

98 else: 

99 col_config_id = Column("config_id", Integer, nullable=False, 

100 primary_key=True, autoincrement=True) 

101 

102 self.config = Table( 

103 "config", 

104 self._meta, 

105 col_config_id, 

106 Column("config_hash", String(64), nullable=False, unique=True), 

107 ) 

108 

109 self.trial = Table( 

110 "trial", 

111 self._meta, 

112 Column("exp_id", String(self._ID_LEN), nullable=False), 

113 Column("trial_id", Integer, nullable=False), 

114 Column("config_id", Integer, nullable=False), 

115 Column("ts_start", DateTime, nullable=False), 

116 Column("ts_end", DateTime), 

117 # Should match the text IDs of `mlos_bench.environments.Status` enum: 

118 Column("status", String(self._STATUS_LEN), nullable=False), 

119 

120 PrimaryKeyConstraint("exp_id", "trial_id"), 

121 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), 

122 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), 

123 ) 

124 

125 # Values of the tunable parameters of the experiment, 

126 # fixed for a particular trial config. 

127 self.config_param = Table( 

128 "config_param", 

129 self._meta, 

130 Column("config_id", Integer, nullable=False), 

131 Column("param_id", String(self._ID_LEN), nullable=False), 

132 Column("param_value", String(self._PARAM_VALUE_LEN)), 

133 

134 PrimaryKeyConstraint("config_id", "param_id"), 

135 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), 

136 ) 

137 

138 # Values of additional non-tunable parameters of the trial, 

139 # e.g., scheduled execution time, VM name / location, number of repeats, etc. 

140 self.trial_param = Table( 

141 "trial_param", 

142 self._meta, 

143 Column("exp_id", String(self._ID_LEN), nullable=False), 

144 Column("trial_id", Integer, nullable=False), 

145 Column("param_id", String(self._ID_LEN), nullable=False), 

146 Column("param_value", String(self._PARAM_VALUE_LEN)), 

147 

148 PrimaryKeyConstraint("exp_id", "trial_id", "param_id"), 

149 ForeignKeyConstraint(["exp_id", "trial_id"], 

150 [self.trial.c.exp_id, self.trial.c.trial_id]), 

151 ) 

152 

153 self.trial_status = Table( 

154 "trial_status", 

155 self._meta, 

156 Column("exp_id", String(self._ID_LEN), nullable=False), 

157 Column("trial_id", Integer, nullable=False), 

158 Column("ts", DateTime(timezone=True), nullable=False, default="now"), 

159 Column("status", String(self._STATUS_LEN), nullable=False), 

160 

161 UniqueConstraint("exp_id", "trial_id", "ts"), 

162 ForeignKeyConstraint(["exp_id", "trial_id"], 

163 [self.trial.c.exp_id, self.trial.c.trial_id]), 

164 ) 

165 

166 self.trial_result = Table( 

167 "trial_result", 

168 self._meta, 

169 Column("exp_id", String(self._ID_LEN), nullable=False), 

170 Column("trial_id", Integer, nullable=False), 

171 Column("metric_id", String(self._ID_LEN), nullable=False), 

172 Column("metric_value", String(self._METRIC_VALUE_LEN)), 

173 

174 PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"), 

175 ForeignKeyConstraint(["exp_id", "trial_id"], 

176 [self.trial.c.exp_id, self.trial.c.trial_id]), 

177 ) 

178 

179 self.trial_telemetry = Table( 

180 "trial_telemetry", 

181 self._meta, 

182 Column("exp_id", String(self._ID_LEN), nullable=False), 

183 Column("trial_id", Integer, nullable=False), 

184 Column("ts", DateTime(timezone=True), nullable=False, default="now"), 

185 Column("metric_id", String(self._ID_LEN), nullable=False), 

186 Column("metric_value", String(self._METRIC_VALUE_LEN)), 

187 

188 UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"), 

189 ForeignKeyConstraint(["exp_id", "trial_id"], 

190 [self.trial.c.exp_id, self.trial.c.trial_id]), 

191 ) 

192 

193 _LOG.debug("Schema: %s", self._meta) 

194 

195 def create(self) -> 'DbSchema': 

196 """ 

197 Create the DB schema. 

198 """ 

199 _LOG.info("Create the DB schema") 

200 self._meta.create_all(self._engine) 

201 return self 

202 

203 def __repr__(self) -> str: 

204 """ 

205 Produce a string with all SQL statements required to create the schema 

206 from scratch in current SQL dialect. 

207 

208 That is, return a collection of CREATE TABLE statements and such. 

209 NOTE: this method is quite heavy! We use it only once at startup 

210 to log the schema, and if the logging level is set to DEBUG. 

211 

212 Returns 

213 ------- 

214 sql : str 

215 A multi-line string with SQL statements to create the DB schema from scratch. 

216 """ 

217 ddl = _DDL(self._engine.dialect) 

218 mock_engine = create_mock_engine(self._engine.url, executor=ddl) 

219 self._meta.create_all(mock_engine, checkfirst=False) 

220 return str(ddl)