Coverage for mlos_bench/mlos_bench/storage/sql/common.py: 100%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Common SQL methods for accessing the stored benchmark data."""
6from typing import Dict, Optional
8import pandas
9from sqlalchemy import Integer, and_, func, select
10from sqlalchemy.engine import Engine
12from mlos_bench.environments.status import Status
13from mlos_bench.storage.base_experiment_data import ExperimentData
14from mlos_bench.storage.base_trial_data import TrialData
15from mlos_bench.storage.sql.schema import DbSchema
16from mlos_bench.util import utcify_nullable_timestamp, utcify_timestamp
19def get_trials(
20 engine: Engine,
21 schema: DbSchema,
22 experiment_id: str,
23 tunable_config_id: Optional[int] = None,
24) -> Dict[int, TrialData]:
25 """
26 Gets :py:class:`~.TrialData` for the given ``experiment_id`` and optionally
27 additionally restricted by ``tunable_config_id``.
29 See Also
30 --------
31 :py:class:`~mlos_bench.storage.sql.tunable_config_trial_group_data.TunableConfigTrialGroupSqlData`
32 :py:class:`~mlos_bench.storage.sql.experiment_data.ExperimentSqlData`
33 """ # pylint: disable=line-too-long # noqa: E501
34 # pylint: disable=import-outside-toplevel,cyclic-import
35 from mlos_bench.storage.sql.trial_data import TrialSqlData
37 with engine.connect() as conn:
38 # Build up sql a statement for fetching trials.
39 stmt = (
40 schema.trial.select()
41 .where(
42 schema.trial.c.exp_id == experiment_id,
43 )
44 .order_by(
45 schema.trial.c.exp_id.asc(),
46 schema.trial.c.trial_id.asc(),
47 )
48 )
49 # Optionally restrict to those using a particular tunable config.
50 if tunable_config_id is not None:
51 stmt = stmt.where(
52 schema.trial.c.config_id == tunable_config_id,
53 )
54 trials = conn.execute(stmt)
55 return {
56 trial.trial_id: TrialSqlData(
57 engine=engine,
58 schema=schema,
59 experiment_id=experiment_id,
60 trial_id=trial.trial_id,
61 config_id=trial.config_id,
62 ts_start=utcify_timestamp(trial.ts_start, origin="utc"),
63 ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"),
64 status=Status[trial.status],
65 )
66 for trial in trials.fetchall()
67 }
70def get_results_df(
71 engine: Engine,
72 schema: DbSchema,
73 experiment_id: str,
74 tunable_config_id: Optional[int] = None,
75) -> pandas.DataFrame:
76 """
77 Gets TrialData for the given experiment_id and optionally additionally restricted by
78 tunable_config_id.
80 The returned DataFrame includes each trial's metadata, config, and results in
81 wide format, with config parameters prefixed with
82 :py:attr:`.ExperimentData.CONFIG_COLUMN_PREFIX` and results prefixed with
83 :py:attr:`.ExperimentData.RESULT_COLUMN_PREFIX`.
85 See Also
86 --------
87 :py:class:`~mlos_bench.storage.sql.tunable_config_trial_group_data.TunableConfigTrialGroupSqlData`
88 :py:class:`~mlos_bench.storage.sql.experiment_data.ExperimentSqlData`
89 """ # pylint: disable=line-too-long # noqa: E501
90 # pylint: disable=too-many-locals
91 with engine.connect() as conn:
92 # Compose a subquery to fetch the tunable_config_trial_group_id for each tunable config.
93 tunable_config_group_id_stmt = (
94 schema.trial.select()
95 .with_only_columns(
96 schema.trial.c.exp_id,
97 schema.trial.c.config_id,
98 func.min(schema.trial.c.trial_id)
99 .cast(Integer)
100 .label("tunable_config_trial_group_id"),
101 )
102 .where(
103 schema.trial.c.exp_id == experiment_id,
104 )
105 .group_by(
106 schema.trial.c.exp_id,
107 schema.trial.c.config_id,
108 )
109 )
110 # Optionally restrict to those using a particular tunable config.
111 if tunable_config_id is not None:
112 tunable_config_group_id_stmt = tunable_config_group_id_stmt.where(
113 schema.trial.c.config_id == tunable_config_id,
114 )
115 tunable_config_trial_group_id_subquery = tunable_config_group_id_stmt.subquery()
117 # Get each trial's metadata.
118 cur_trials_stmt = (
119 select(
120 schema.trial,
121 tunable_config_trial_group_id_subquery,
122 )
123 .where(
124 schema.trial.c.exp_id == experiment_id,
125 and_(
126 tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id,
127 tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id,
128 ),
129 )
130 .order_by(
131 schema.trial.c.exp_id.asc(),
132 schema.trial.c.trial_id.asc(),
133 )
134 )
135 # Optionally restrict to those using a particular tunable config.
136 if tunable_config_id is not None:
137 cur_trials_stmt = cur_trials_stmt.where(
138 schema.trial.c.config_id == tunable_config_id,
139 )
140 cur_trials = conn.execute(cur_trials_stmt)
141 trials_df = pandas.DataFrame(
142 [
143 (
144 row.trial_id,
145 utcify_timestamp(row.ts_start, origin="utc"),
146 utcify_nullable_timestamp(row.ts_end, origin="utc"),
147 row.config_id,
148 row.tunable_config_trial_group_id,
149 row.status,
150 )
151 for row in cur_trials.fetchall()
152 ],
153 columns=[
154 "trial_id",
155 "ts_start",
156 "ts_end",
157 "tunable_config_id",
158 "tunable_config_trial_group_id",
159 "status",
160 ],
161 )
163 # Get each trial's config in wide format.
164 configs_stmt = (
165 schema.trial.select()
166 .with_only_columns(
167 schema.trial.c.trial_id,
168 schema.trial.c.config_id,
169 schema.config_param.c.param_id,
170 schema.config_param.c.param_value,
171 )
172 .where(
173 schema.trial.c.exp_id == experiment_id,
174 )
175 .join(
176 schema.config_param,
177 schema.config_param.c.config_id == schema.trial.c.config_id,
178 )
179 .order_by(
180 schema.trial.c.trial_id,
181 schema.config_param.c.param_id,
182 )
183 )
184 if tunable_config_id is not None:
185 configs_stmt = configs_stmt.where(
186 schema.trial.c.config_id == tunable_config_id,
187 )
188 configs = conn.execute(configs_stmt)
189 configs_df = pandas.DataFrame(
190 [
191 (
192 row.trial_id,
193 row.config_id,
194 ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id,
195 row.param_value,
196 )
197 for row in configs.fetchall()
198 ],
199 columns=["trial_id", "tunable_config_id", "param", "value"],
200 ).pivot(
201 index=["trial_id", "tunable_config_id"],
202 columns="param",
203 values="value",
204 )
205 configs_df = configs_df.apply(
206 pandas.to_numeric,
207 errors="coerce",
208 ).fillna(configs_df)
210 # Get each trial's results in wide format.
211 results_stmt = (
212 schema.trial_result.select()
213 .with_only_columns(
214 schema.trial_result.c.trial_id,
215 schema.trial_result.c.metric_id,
216 schema.trial_result.c.metric_value,
217 )
218 .where(
219 schema.trial_result.c.exp_id == experiment_id,
220 )
221 .order_by(
222 schema.trial_result.c.trial_id,
223 schema.trial_result.c.metric_id,
224 )
225 )
226 if tunable_config_id is not None:
227 results_stmt = results_stmt.join(
228 schema.trial,
229 and_(
230 schema.trial.c.exp_id == schema.trial_result.c.exp_id,
231 schema.trial.c.trial_id == schema.trial_result.c.trial_id,
232 schema.trial.c.config_id == tunable_config_id,
233 ),
234 )
235 results = conn.execute(results_stmt)
236 results_df = pandas.DataFrame(
237 [
238 (
239 row.trial_id,
240 ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id,
241 row.metric_value,
242 )
243 for row in results.fetchall()
244 ],
245 columns=["trial_id", "metric", "value"],
246 ).pivot(
247 index="trial_id",
248 columns="metric",
249 values="value",
250 )
251 results_df = results_df.apply(
252 pandas.to_numeric,
253 errors="coerce",
254 ).fillna(results_df)
256 # Concat the trials, configs, and results.
257 return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left").merge(
258 results_df,
259 on="trial_id",
260 how="left",
261 )