Coverage for mlos_bench/mlos_bench/storage/sql/common.py: 100%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Common SQL methods for accessing the stored benchmark data.""" 

6from typing import Dict, Optional 

7 

8import pandas 

9from sqlalchemy import Integer, and_, func, select 

10from sqlalchemy.engine import Engine 

11 

12from mlos_bench.environments.status import Status 

13from mlos_bench.storage.base_experiment_data import ExperimentData 

14from mlos_bench.storage.base_trial_data import TrialData 

15from mlos_bench.storage.sql.schema import DbSchema 

16from mlos_bench.util import utcify_nullable_timestamp, utcify_timestamp 

17 

18 

19def get_trials( 

20 engine: Engine, 

21 schema: DbSchema, 

22 experiment_id: str, 

23 tunable_config_id: Optional[int] = None, 

24) -> Dict[int, TrialData]: 

25 """ 

26 Gets :py:class:`~.TrialData` for the given ``experiment_id`` and optionally 

27 additionally restricted by ``tunable_config_id``. 

28 

29 See Also 

30 -------- 

31 :py:class:`~mlos_bench.storage.sql.tunable_config_trial_group_data.TunableConfigTrialGroupSqlData` 

32 :py:class:`~mlos_bench.storage.sql.experiment_data.ExperimentSqlData` 

33 """ # pylint: disable=line-too-long # noqa: E501 

34 # pylint: disable=import-outside-toplevel,cyclic-import 

35 from mlos_bench.storage.sql.trial_data import TrialSqlData 

36 

37 with engine.connect() as conn: 

38 # Build up sql a statement for fetching trials. 

39 stmt = ( 

40 schema.trial.select() 

41 .where( 

42 schema.trial.c.exp_id == experiment_id, 

43 ) 

44 .order_by( 

45 schema.trial.c.exp_id.asc(), 

46 schema.trial.c.trial_id.asc(), 

47 ) 

48 ) 

49 # Optionally restrict to those using a particular tunable config. 

50 if tunable_config_id is not None: 

51 stmt = stmt.where( 

52 schema.trial.c.config_id == tunable_config_id, 

53 ) 

54 trials = conn.execute(stmt) 

55 return { 

56 trial.trial_id: TrialSqlData( 

57 engine=engine, 

58 schema=schema, 

59 experiment_id=experiment_id, 

60 trial_id=trial.trial_id, 

61 config_id=trial.config_id, 

62 ts_start=utcify_timestamp(trial.ts_start, origin="utc"), 

63 ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"), 

64 status=Status[trial.status], 

65 ) 

66 for trial in trials.fetchall() 

67 } 

68 

69 

70def get_results_df( 

71 engine: Engine, 

72 schema: DbSchema, 

73 experiment_id: str, 

74 tunable_config_id: Optional[int] = None, 

75) -> pandas.DataFrame: 

76 """ 

77 Gets TrialData for the given experiment_id and optionally additionally restricted by 

78 tunable_config_id. 

79 

80 The returned DataFrame includes each trial's metadata, config, and results in 

81 wide format, with config parameters prefixed with 

82 :py:attr:`.ExperimentData.CONFIG_COLUMN_PREFIX` and results prefixed with 

83 :py:attr:`.ExperimentData.RESULT_COLUMN_PREFIX`. 

84 

85 See Also 

86 -------- 

87 :py:class:`~mlos_bench.storage.sql.tunable_config_trial_group_data.TunableConfigTrialGroupSqlData` 

88 :py:class:`~mlos_bench.storage.sql.experiment_data.ExperimentSqlData` 

89 """ # pylint: disable=line-too-long # noqa: E501 

90 # pylint: disable=too-many-locals 

91 with engine.connect() as conn: 

92 # Compose a subquery to fetch the tunable_config_trial_group_id for each tunable config. 

93 tunable_config_group_id_stmt = ( 

94 schema.trial.select() 

95 .with_only_columns( 

96 schema.trial.c.exp_id, 

97 schema.trial.c.config_id, 

98 func.min(schema.trial.c.trial_id) 

99 .cast(Integer) 

100 .label("tunable_config_trial_group_id"), 

101 ) 

102 .where( 

103 schema.trial.c.exp_id == experiment_id, 

104 ) 

105 .group_by( 

106 schema.trial.c.exp_id, 

107 schema.trial.c.config_id, 

108 ) 

109 ) 

110 # Optionally restrict to those using a particular tunable config. 

111 if tunable_config_id is not None: 

112 tunable_config_group_id_stmt = tunable_config_group_id_stmt.where( 

113 schema.trial.c.config_id == tunable_config_id, 

114 ) 

115 tunable_config_trial_group_id_subquery = tunable_config_group_id_stmt.subquery() 

116 

117 # Get each trial's metadata. 

118 cur_trials_stmt = ( 

119 select( 

120 schema.trial, 

121 tunable_config_trial_group_id_subquery, 

122 ) 

123 .where( 

124 schema.trial.c.exp_id == experiment_id, 

125 and_( 

126 tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id, 

127 tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id, 

128 ), 

129 ) 

130 .order_by( 

131 schema.trial.c.exp_id.asc(), 

132 schema.trial.c.trial_id.asc(), 

133 ) 

134 ) 

135 # Optionally restrict to those using a particular tunable config. 

136 if tunable_config_id is not None: 

137 cur_trials_stmt = cur_trials_stmt.where( 

138 schema.trial.c.config_id == tunable_config_id, 

139 ) 

140 cur_trials = conn.execute(cur_trials_stmt) 

141 trials_df = pandas.DataFrame( 

142 [ 

143 ( 

144 row.trial_id, 

145 utcify_timestamp(row.ts_start, origin="utc"), 

146 utcify_nullable_timestamp(row.ts_end, origin="utc"), 

147 row.config_id, 

148 row.tunable_config_trial_group_id, 

149 row.status, 

150 ) 

151 for row in cur_trials.fetchall() 

152 ], 

153 columns=[ 

154 "trial_id", 

155 "ts_start", 

156 "ts_end", 

157 "tunable_config_id", 

158 "tunable_config_trial_group_id", 

159 "status", 

160 ], 

161 ) 

162 

163 # Get each trial's config in wide format. 

164 configs_stmt = ( 

165 schema.trial.select() 

166 .with_only_columns( 

167 schema.trial.c.trial_id, 

168 schema.trial.c.config_id, 

169 schema.config_param.c.param_id, 

170 schema.config_param.c.param_value, 

171 ) 

172 .where( 

173 schema.trial.c.exp_id == experiment_id, 

174 ) 

175 .join( 

176 schema.config_param, 

177 schema.config_param.c.config_id == schema.trial.c.config_id, 

178 ) 

179 .order_by( 

180 schema.trial.c.trial_id, 

181 schema.config_param.c.param_id, 

182 ) 

183 ) 

184 if tunable_config_id is not None: 

185 configs_stmt = configs_stmt.where( 

186 schema.trial.c.config_id == tunable_config_id, 

187 ) 

188 configs = conn.execute(configs_stmt) 

189 configs_df = pandas.DataFrame( 

190 [ 

191 ( 

192 row.trial_id, 

193 row.config_id, 

194 ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id, 

195 row.param_value, 

196 ) 

197 for row in configs.fetchall() 

198 ], 

199 columns=["trial_id", "tunable_config_id", "param", "value"], 

200 ).pivot( 

201 index=["trial_id", "tunable_config_id"], 

202 columns="param", 

203 values="value", 

204 ) 

205 configs_df = configs_df.apply( 

206 pandas.to_numeric, 

207 errors="coerce", 

208 ).fillna(configs_df) 

209 

210 # Get each trial's results in wide format. 

211 results_stmt = ( 

212 schema.trial_result.select() 

213 .with_only_columns( 

214 schema.trial_result.c.trial_id, 

215 schema.trial_result.c.metric_id, 

216 schema.trial_result.c.metric_value, 

217 ) 

218 .where( 

219 schema.trial_result.c.exp_id == experiment_id, 

220 ) 

221 .order_by( 

222 schema.trial_result.c.trial_id, 

223 schema.trial_result.c.metric_id, 

224 ) 

225 ) 

226 if tunable_config_id is not None: 

227 results_stmt = results_stmt.join( 

228 schema.trial, 

229 and_( 

230 schema.trial.c.exp_id == schema.trial_result.c.exp_id, 

231 schema.trial.c.trial_id == schema.trial_result.c.trial_id, 

232 schema.trial.c.config_id == tunable_config_id, 

233 ), 

234 ) 

235 results = conn.execute(results_stmt) 

236 results_df = pandas.DataFrame( 

237 [ 

238 ( 

239 row.trial_id, 

240 ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id, 

241 row.metric_value, 

242 ) 

243 for row in results.fetchall() 

244 ], 

245 columns=["trial_id", "metric", "value"], 

246 ).pivot( 

247 index="trial_id", 

248 columns="metric", 

249 values="value", 

250 ) 

251 results_df = results_df.apply( 

252 pandas.to_numeric, 

253 errors="coerce", 

254 ).fillna(results_df) 

255 

256 # Concat the trials, configs, and results. 

257 return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left").merge( 

258 results_df, 

259 on="trial_id", 

260 how="left", 

261 )