Coverage for mlos_bench/mlos_bench/storage/sql/common.py: 100%
41 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Common SQL methods for accessing the stored benchmark data.
7"""
8from typing import Dict, Optional
10import pandas
11from sqlalchemy import Engine, Integer, func, and_, select
13from mlos_bench.environments.status import Status
14from mlos_bench.storage.base_experiment_data import ExperimentData
15from mlos_bench.storage.base_trial_data import TrialData
16from mlos_bench.storage.sql.schema import DbSchema
17from mlos_bench.util import utcify_timestamp, utcify_nullable_timestamp
20def get_trials(
21 engine: Engine,
22 schema: DbSchema,
23 experiment_id: str,
24 tunable_config_id: Optional[int] = None) -> Dict[int, TrialData]:
25 """
26 Gets TrialData for the given experiment_data and optionally additionally
27 restricted by tunable_config_id.
28 Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData.
29 """
30 from mlos_bench.storage.sql.trial_data import TrialSqlData # pylint: disable=import-outside-toplevel,cyclic-import
31 with engine.connect() as conn:
32 # Build up sql a statement for fetching trials.
33 stmt = schema.trial.select().where(
34 schema.trial.c.exp_id == experiment_id,
35 ).order_by(
36 schema.trial.c.exp_id.asc(),
37 schema.trial.c.trial_id.asc(),
38 )
39 # Optionally restrict to those using a particular tunable config.
40 if tunable_config_id is not None:
41 stmt = stmt.where(
42 schema.trial.c.config_id == tunable_config_id,
43 )
44 trials = conn.execute(stmt)
45 return {
46 trial.trial_id: TrialSqlData(
47 engine=engine,
48 schema=schema,
49 experiment_id=experiment_id,
50 trial_id=trial.trial_id,
51 config_id=trial.config_id,
52 ts_start=utcify_timestamp(trial.ts_start, origin="utc"),
53 ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"),
54 status=Status[trial.status],
55 )
56 for trial in trials.fetchall()
57 }
60def get_results_df(
61 engine: Engine,
62 schema: DbSchema,
63 experiment_id: str,
64 tunable_config_id: Optional[int] = None) -> pandas.DataFrame:
65 """
66 Gets TrialData for the given experiment_data and optionally additionally
67 restricted by tunable_config_id.
68 Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData.
69 """
70 # pylint: disable=too-many-locals
71 with engine.connect() as conn:
72 # Compose a subquery to fetch the tunable_config_trial_group_id for each tunable config.
73 tunable_config_group_id_stmt = schema.trial.select().with_only_columns(
74 schema.trial.c.exp_id,
75 schema.trial.c.config_id,
76 func.min(schema.trial.c.trial_id).cast(Integer).label('tunable_config_trial_group_id'),
77 ).where(
78 schema.trial.c.exp_id == experiment_id,
79 ).group_by(
80 schema.trial.c.exp_id,
81 schema.trial.c.config_id,
82 )
83 # Optionally restrict to those using a particular tunable config.
84 if tunable_config_id is not None:
85 tunable_config_group_id_stmt = tunable_config_group_id_stmt.where(
86 schema.trial.c.config_id == tunable_config_id,
87 )
88 tunable_config_trial_group_id_subquery = tunable_config_group_id_stmt.subquery()
90 # Get each trial's metadata.
91 cur_trials_stmt = select(
92 schema.trial,
93 tunable_config_trial_group_id_subquery,
94 ).where(
95 schema.trial.c.exp_id == experiment_id,
96 and_(
97 tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id,
98 tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id,
99 ),
100 ).order_by(
101 schema.trial.c.exp_id.asc(),
102 schema.trial.c.trial_id.asc(),
103 )
104 # Optionally restrict to those using a particular tunable config.
105 if tunable_config_id is not None:
106 cur_trials_stmt = cur_trials_stmt.where(
107 schema.trial.c.config_id == tunable_config_id,
108 )
109 cur_trials = conn.execute(cur_trials_stmt)
110 trials_df = pandas.DataFrame(
111 [(
112 row.trial_id,
113 utcify_timestamp(row.ts_start, origin="utc"),
114 utcify_nullable_timestamp(row.ts_end, origin="utc"),
115 row.config_id,
116 row.tunable_config_trial_group_id,
117 row.status,
118 ) for row in cur_trials.fetchall()],
119 columns=[
120 'trial_id',
121 'ts_start',
122 'ts_end',
123 'tunable_config_id',
124 'tunable_config_trial_group_id',
125 'status',
126 ]
127 )
129 # Get each trial's config in wide format.
130 configs_stmt = schema.trial.select().with_only_columns(
131 schema.trial.c.trial_id,
132 schema.trial.c.config_id,
133 schema.config_param.c.param_id,
134 schema.config_param.c.param_value,
135 ).where(
136 schema.trial.c.exp_id == experiment_id,
137 ).join(
138 schema.config_param,
139 schema.config_param.c.config_id == schema.trial.c.config_id,
140 isouter=True
141 ).order_by(
142 schema.trial.c.trial_id,
143 schema.config_param.c.param_id,
144 )
145 if tunable_config_id is not None:
146 configs_stmt = configs_stmt.where(
147 schema.trial.c.config_id == tunable_config_id,
148 )
149 configs = conn.execute(configs_stmt)
150 configs_df = pandas.DataFrame(
151 [(row.trial_id, row.config_id, ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id, row.param_value)
152 for row in configs.fetchall()],
153 columns=['trial_id', 'tunable_config_id', 'param', 'value']
154 ).pivot(
155 index=["trial_id", "tunable_config_id"], columns="param", values="value",
156 )
157 configs_df = configs_df.apply(pandas.to_numeric, errors='coerce').fillna(configs_df) # type: ignore[assignment] # (fp)
159 # Get each trial's results in wide format.
160 results_stmt = schema.trial_result.select().with_only_columns(
161 schema.trial_result.c.trial_id,
162 schema.trial_result.c.metric_id,
163 schema.trial_result.c.metric_value,
164 ).where(
165 schema.trial_result.c.exp_id == experiment_id,
166 ).order_by(
167 schema.trial_result.c.trial_id,
168 schema.trial_result.c.metric_id,
169 )
170 if tunable_config_id is not None:
171 results_stmt = results_stmt.join(schema.trial, and_(
172 schema.trial.c.exp_id == schema.trial_result.c.exp_id,
173 schema.trial.c.trial_id == schema.trial_result.c.trial_id,
174 schema.trial.c.config_id == tunable_config_id,
175 ))
176 results = conn.execute(results_stmt)
177 results_df = pandas.DataFrame(
178 [(row.trial_id, ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id, row.metric_value)
179 for row in results.fetchall()],
180 columns=['trial_id', 'metric', 'value']
181 ).pivot(
182 index="trial_id", columns="metric", values="value",
183 )
184 results_df = results_df.apply(pandas.to_numeric, errors='coerce').fillna(results_df) # type: ignore[assignment] # (fp)
186 # Concat the trials, configs, and results.
187 return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left") \
188 .merge(results_df, on="trial_id", how="left")