Coverage for mlos_bench/mlos_bench/storage/sql/common.py: 100%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 01:18 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Common SQL methods for accessing the stored benchmark data.""" 

6from typing import Dict, Optional 

7 

8import pandas 

9from sqlalchemy import Integer, and_, func, select 

10from sqlalchemy.engine import Engine 

11 

12from mlos_bench.environments.status import Status 

13from mlos_bench.storage.base_experiment_data import ExperimentData 

14from mlos_bench.storage.base_trial_data import TrialData 

15from mlos_bench.storage.sql.schema import DbSchema 

16from mlos_bench.util import utcify_nullable_timestamp, utcify_timestamp 

17 

18 

19def get_trials( 

20 engine: Engine, 

21 schema: DbSchema, 

22 experiment_id: str, 

23 tunable_config_id: Optional[int] = None, 

24) -> Dict[int, TrialData]: 

25 """ 

26 Gets TrialData for the given experiment_data and optionally additionally restricted 

27 by tunable_config_id. 

28 

29 Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData. 

30 """ 

31 # pylint: disable=import-outside-toplevel,cyclic-import 

32 from mlos_bench.storage.sql.trial_data import TrialSqlData 

33 

34 with engine.connect() as conn: 

35 # Build up sql a statement for fetching trials. 

36 stmt = ( 

37 schema.trial.select() 

38 .where( 

39 schema.trial.c.exp_id == experiment_id, 

40 ) 

41 .order_by( 

42 schema.trial.c.exp_id.asc(), 

43 schema.trial.c.trial_id.asc(), 

44 ) 

45 ) 

46 # Optionally restrict to those using a particular tunable config. 

47 if tunable_config_id is not None: 

48 stmt = stmt.where( 

49 schema.trial.c.config_id == tunable_config_id, 

50 ) 

51 trials = conn.execute(stmt) 

52 return { 

53 trial.trial_id: TrialSqlData( 

54 engine=engine, 

55 schema=schema, 

56 experiment_id=experiment_id, 

57 trial_id=trial.trial_id, 

58 config_id=trial.config_id, 

59 ts_start=utcify_timestamp(trial.ts_start, origin="utc"), 

60 ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"), 

61 status=Status[trial.status], 

62 ) 

63 for trial in trials.fetchall() 

64 } 

65 

66 

67def get_results_df( 

68 engine: Engine, 

69 schema: DbSchema, 

70 experiment_id: str, 

71 tunable_config_id: Optional[int] = None, 

72) -> pandas.DataFrame: 

73 """ 

74 Gets TrialData for the given experiment_data and optionally additionally restricted 

75 by tunable_config_id. 

76 

77 Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData. 

78 """ 

79 # pylint: disable=too-many-locals 

80 with engine.connect() as conn: 

81 # Compose a subquery to fetch the tunable_config_trial_group_id for each tunable config. 

82 tunable_config_group_id_stmt = ( 

83 schema.trial.select() 

84 .with_only_columns( 

85 schema.trial.c.exp_id, 

86 schema.trial.c.config_id, 

87 func.min(schema.trial.c.trial_id) 

88 .cast(Integer) 

89 .label("tunable_config_trial_group_id"), 

90 ) 

91 .where( 

92 schema.trial.c.exp_id == experiment_id, 

93 ) 

94 .group_by( 

95 schema.trial.c.exp_id, 

96 schema.trial.c.config_id, 

97 ) 

98 ) 

99 # Optionally restrict to those using a particular tunable config. 

100 if tunable_config_id is not None: 

101 tunable_config_group_id_stmt = tunable_config_group_id_stmt.where( 

102 schema.trial.c.config_id == tunable_config_id, 

103 ) 

104 tunable_config_trial_group_id_subquery = tunable_config_group_id_stmt.subquery() 

105 

106 # Get each trial's metadata. 

107 cur_trials_stmt = ( 

108 select( 

109 schema.trial, 

110 tunable_config_trial_group_id_subquery, 

111 ) 

112 .where( 

113 schema.trial.c.exp_id == experiment_id, 

114 and_( 

115 tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id, 

116 tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id, 

117 ), 

118 ) 

119 .order_by( 

120 schema.trial.c.exp_id.asc(), 

121 schema.trial.c.trial_id.asc(), 

122 ) 

123 ) 

124 # Optionally restrict to those using a particular tunable config. 

125 if tunable_config_id is not None: 

126 cur_trials_stmt = cur_trials_stmt.where( 

127 schema.trial.c.config_id == tunable_config_id, 

128 ) 

129 cur_trials = conn.execute(cur_trials_stmt) 

130 trials_df = pandas.DataFrame( 

131 [ 

132 ( 

133 row.trial_id, 

134 utcify_timestamp(row.ts_start, origin="utc"), 

135 utcify_nullable_timestamp(row.ts_end, origin="utc"), 

136 row.config_id, 

137 row.tunable_config_trial_group_id, 

138 row.status, 

139 ) 

140 for row in cur_trials.fetchall() 

141 ], 

142 columns=[ 

143 "trial_id", 

144 "ts_start", 

145 "ts_end", 

146 "tunable_config_id", 

147 "tunable_config_trial_group_id", 

148 "status", 

149 ], 

150 ) 

151 

152 # Get each trial's config in wide format. 

153 configs_stmt = ( 

154 schema.trial.select() 

155 .with_only_columns( 

156 schema.trial.c.trial_id, 

157 schema.trial.c.config_id, 

158 schema.config_param.c.param_id, 

159 schema.config_param.c.param_value, 

160 ) 

161 .where( 

162 schema.trial.c.exp_id == experiment_id, 

163 ) 

164 .join( 

165 schema.config_param, 

166 schema.config_param.c.config_id == schema.trial.c.config_id, 

167 isouter=True, 

168 ) 

169 .order_by( 

170 schema.trial.c.trial_id, 

171 schema.config_param.c.param_id, 

172 ) 

173 ) 

174 if tunable_config_id is not None: 

175 configs_stmt = configs_stmt.where( 

176 schema.trial.c.config_id == tunable_config_id, 

177 ) 

178 configs = conn.execute(configs_stmt) 

179 configs_df = pandas.DataFrame( 

180 [ 

181 ( 

182 row.trial_id, 

183 row.config_id, 

184 ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id, 

185 row.param_value, 

186 ) 

187 for row in configs.fetchall() 

188 ], 

189 columns=["trial_id", "tunable_config_id", "param", "value"], 

190 ).pivot( 

191 index=["trial_id", "tunable_config_id"], 

192 columns="param", 

193 values="value", 

194 ) 

195 configs_df = configs_df.apply( 

196 pandas.to_numeric, 

197 errors="coerce", 

198 ).fillna(configs_df) 

199 

200 # Get each trial's results in wide format. 

201 results_stmt = ( 

202 schema.trial_result.select() 

203 .with_only_columns( 

204 schema.trial_result.c.trial_id, 

205 schema.trial_result.c.metric_id, 

206 schema.trial_result.c.metric_value, 

207 ) 

208 .where( 

209 schema.trial_result.c.exp_id == experiment_id, 

210 ) 

211 .order_by( 

212 schema.trial_result.c.trial_id, 

213 schema.trial_result.c.metric_id, 

214 ) 

215 ) 

216 if tunable_config_id is not None: 

217 results_stmt = results_stmt.join( 

218 schema.trial, 

219 and_( 

220 schema.trial.c.exp_id == schema.trial_result.c.exp_id, 

221 schema.trial.c.trial_id == schema.trial_result.c.trial_id, 

222 schema.trial.c.config_id == tunable_config_id, 

223 ), 

224 ) 

225 results = conn.execute(results_stmt) 

226 results_df = pandas.DataFrame( 

227 [ 

228 ( 

229 row.trial_id, 

230 ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id, 

231 row.metric_value, 

232 ) 

233 for row in results.fetchall() 

234 ], 

235 columns=["trial_id", "metric", "value"], 

236 ).pivot( 

237 index="trial_id", 

238 columns="metric", 

239 values="value", 

240 ) 

241 results_df = results_df.apply( 

242 pandas.to_numeric, 

243 errors="coerce", 

244 ).fillna(results_df) 

245 

246 # Concat the trials, configs, and results. 

247 return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left").merge( 

248 results_df, 

249 on="trial_id", 

250 how="left", 

251 )