Coverage for mlos_bench/mlos_bench/storage/sql/common.py: 100%
41 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Common SQL methods for accessing the stored benchmark data."""
6from typing import Dict, Optional
8import pandas
9from sqlalchemy import Integer, and_, func, select
10from sqlalchemy.engine import Engine
12from mlos_bench.environments.status import Status
13from mlos_bench.storage.base_experiment_data import ExperimentData
14from mlos_bench.storage.base_trial_data import TrialData
15from mlos_bench.storage.sql.schema import DbSchema
16from mlos_bench.util import utcify_nullable_timestamp, utcify_timestamp
19def get_trials(
20 engine: Engine,
21 schema: DbSchema,
22 experiment_id: str,
23 tunable_config_id: Optional[int] = None,
24) -> Dict[int, TrialData]:
25 """
26 Gets TrialData for the given experiment_data and optionally additionally restricted
27 by tunable_config_id.
29 Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData.
30 """
31 # pylint: disable=import-outside-toplevel,cyclic-import
32 from mlos_bench.storage.sql.trial_data import TrialSqlData
34 with engine.connect() as conn:
35 # Build up sql a statement for fetching trials.
36 stmt = (
37 schema.trial.select()
38 .where(
39 schema.trial.c.exp_id == experiment_id,
40 )
41 .order_by(
42 schema.trial.c.exp_id.asc(),
43 schema.trial.c.trial_id.asc(),
44 )
45 )
46 # Optionally restrict to those using a particular tunable config.
47 if tunable_config_id is not None:
48 stmt = stmt.where(
49 schema.trial.c.config_id == tunable_config_id,
50 )
51 trials = conn.execute(stmt)
52 return {
53 trial.trial_id: TrialSqlData(
54 engine=engine,
55 schema=schema,
56 experiment_id=experiment_id,
57 trial_id=trial.trial_id,
58 config_id=trial.config_id,
59 ts_start=utcify_timestamp(trial.ts_start, origin="utc"),
60 ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"),
61 status=Status[trial.status],
62 )
63 for trial in trials.fetchall()
64 }
67def get_results_df(
68 engine: Engine,
69 schema: DbSchema,
70 experiment_id: str,
71 tunable_config_id: Optional[int] = None,
72) -> pandas.DataFrame:
73 """
74 Gets TrialData for the given experiment_data and optionally additionally restricted
75 by tunable_config_id.
77 Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData.
78 """
79 # pylint: disable=too-many-locals
80 with engine.connect() as conn:
81 # Compose a subquery to fetch the tunable_config_trial_group_id for each tunable config.
82 tunable_config_group_id_stmt = (
83 schema.trial.select()
84 .with_only_columns(
85 schema.trial.c.exp_id,
86 schema.trial.c.config_id,
87 func.min(schema.trial.c.trial_id)
88 .cast(Integer)
89 .label("tunable_config_trial_group_id"),
90 )
91 .where(
92 schema.trial.c.exp_id == experiment_id,
93 )
94 .group_by(
95 schema.trial.c.exp_id,
96 schema.trial.c.config_id,
97 )
98 )
99 # Optionally restrict to those using a particular tunable config.
100 if tunable_config_id is not None:
101 tunable_config_group_id_stmt = tunable_config_group_id_stmt.where(
102 schema.trial.c.config_id == tunable_config_id,
103 )
104 tunable_config_trial_group_id_subquery = tunable_config_group_id_stmt.subquery()
106 # Get each trial's metadata.
107 cur_trials_stmt = (
108 select(
109 schema.trial,
110 tunable_config_trial_group_id_subquery,
111 )
112 .where(
113 schema.trial.c.exp_id == experiment_id,
114 and_(
115 tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id,
116 tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id,
117 ),
118 )
119 .order_by(
120 schema.trial.c.exp_id.asc(),
121 schema.trial.c.trial_id.asc(),
122 )
123 )
124 # Optionally restrict to those using a particular tunable config.
125 if tunable_config_id is not None:
126 cur_trials_stmt = cur_trials_stmt.where(
127 schema.trial.c.config_id == tunable_config_id,
128 )
129 cur_trials = conn.execute(cur_trials_stmt)
130 trials_df = pandas.DataFrame(
131 [
132 (
133 row.trial_id,
134 utcify_timestamp(row.ts_start, origin="utc"),
135 utcify_nullable_timestamp(row.ts_end, origin="utc"),
136 row.config_id,
137 row.tunable_config_trial_group_id,
138 row.status,
139 )
140 for row in cur_trials.fetchall()
141 ],
142 columns=[
143 "trial_id",
144 "ts_start",
145 "ts_end",
146 "tunable_config_id",
147 "tunable_config_trial_group_id",
148 "status",
149 ],
150 )
152 # Get each trial's config in wide format.
153 configs_stmt = (
154 schema.trial.select()
155 .with_only_columns(
156 schema.trial.c.trial_id,
157 schema.trial.c.config_id,
158 schema.config_param.c.param_id,
159 schema.config_param.c.param_value,
160 )
161 .where(
162 schema.trial.c.exp_id == experiment_id,
163 )
164 .join(
165 schema.config_param,
166 schema.config_param.c.config_id == schema.trial.c.config_id,
167 isouter=True,
168 )
169 .order_by(
170 schema.trial.c.trial_id,
171 schema.config_param.c.param_id,
172 )
173 )
174 if tunable_config_id is not None:
175 configs_stmt = configs_stmt.where(
176 schema.trial.c.config_id == tunable_config_id,
177 )
178 configs = conn.execute(configs_stmt)
179 configs_df = pandas.DataFrame(
180 [
181 (
182 row.trial_id,
183 row.config_id,
184 ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id,
185 row.param_value,
186 )
187 for row in configs.fetchall()
188 ],
189 columns=["trial_id", "tunable_config_id", "param", "value"],
190 ).pivot(
191 index=["trial_id", "tunable_config_id"],
192 columns="param",
193 values="value",
194 )
195 configs_df = configs_df.apply(
196 pandas.to_numeric,
197 errors="coerce",
198 ).fillna(configs_df)
200 # Get each trial's results in wide format.
201 results_stmt = (
202 schema.trial_result.select()
203 .with_only_columns(
204 schema.trial_result.c.trial_id,
205 schema.trial_result.c.metric_id,
206 schema.trial_result.c.metric_value,
207 )
208 .where(
209 schema.trial_result.c.exp_id == experiment_id,
210 )
211 .order_by(
212 schema.trial_result.c.trial_id,
213 schema.trial_result.c.metric_id,
214 )
215 )
216 if tunable_config_id is not None:
217 results_stmt = results_stmt.join(
218 schema.trial,
219 and_(
220 schema.trial.c.exp_id == schema.trial_result.c.exp_id,
221 schema.trial.c.trial_id == schema.trial_result.c.trial_id,
222 schema.trial.c.config_id == tunable_config_id,
223 ),
224 )
225 results = conn.execute(results_stmt)
226 results_df = pandas.DataFrame(
227 [
228 (
229 row.trial_id,
230 ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id,
231 row.metric_value,
232 )
233 for row in results.fetchall()
234 ],
235 columns=["trial_id", "metric", "value"],
236 ).pivot(
237 index="trial_id",
238 columns="metric",
239 values="value",
240 )
241 results_df = results_df.apply(
242 pandas.to_numeric,
243 errors="coerce",
244 ).fillna(results_df)
246 # Concat the trials, configs, and results.
247 return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left").merge(
248 results_df,
249 on="trial_id",
250 how="left",
251 )