Coverage for mlos_bench/mlos_bench/storage/sql/experiment_data.py: 93%
75 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6An interface to access the experiment benchmark data stored in SQL DB.
7"""
8from typing import Dict, Literal, Optional
10import logging
12import pandas
13from sqlalchemy import Engine, Integer, String, func
15from mlos_bench.storage.base_experiment_data import ExperimentData
16from mlos_bench.storage.base_trial_data import TrialData
17from mlos_bench.storage.base_tunable_config_data import TunableConfigData
18from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData
19from mlos_bench.storage.sql import common
20from mlos_bench.storage.sql.schema import DbSchema
21from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData
22from mlos_bench.storage.sql.tunable_config_trial_group_data import TunableConfigTrialGroupSqlData
24_LOG = logging.getLogger(__name__)
27class ExperimentSqlData(ExperimentData):
28 """
29 SQL interface for accessing the stored experiment benchmark data.
31 An experiment groups together a set of trials that are run with a given set of
32 scripts and mlos_bench configuration files.
33 """
35 def __init__(self, *,
36 engine: Engine,
37 schema: DbSchema,
38 experiment_id: str,
39 description: str,
40 root_env_config: str,
41 git_repo: str,
42 git_commit: str):
43 super().__init__(
44 experiment_id=experiment_id,
45 description=description,
46 root_env_config=root_env_config,
47 git_repo=git_repo,
48 git_commit=git_commit,
49 )
50 self._engine = engine
51 self._schema = schema
53 @property
54 def objectives(self) -> Dict[str, Literal["min", "max"]]:
55 objectives: Dict[str, Literal["min", "max"]] = {}
56 # First try to lookup the objectives from the experiment metadata in the storage layer.
57 if hasattr(self._schema, "objectives"):
58 with self._engine.connect() as conn:
59 objectives_db_data = conn.execute(
60 self._schema.objectives.select().where(
61 self._schema.objectives.c.exp_id == self._experiment_id,
62 ).order_by(
63 # TODO: return weight as well
64 self._schema.objectives.c.weight.desc(),
65 self._schema.objectives.c.optimization_target.asc(),
66 )
67 )
68 objectives = {
69 objective.optimization_target: objective.optimization_direction
70 for objective in objectives_db_data.fetchall()
71 }
72 # Backwards compatibility: try and obtain the objectives from the TrialData and merge them in.
73 # NOTE: The original format of storing opt_target/opt_direction in the Trial
74 # metadata did not support multi-objectives.
75 # Nor does it make it easy to detect when a config change caused a switch in
76 # opt_direction for a given opt_target between run.py executions of an
77 # Experiment.
78 # For now, we simply issue a warning about potentially inconsistent data.
79 for trial in self.trials.values():
80 trial_objs_df = trial.metadata_df[
81 trial.metadata_df["parameter"].isin(("opt_target", "opt_direction"))
82 ][["parameter", "value"]]
83 try:
84 opt_targets = trial_objs_df[trial_objs_df["parameter"] == "opt_target"]
85 assert len(opt_targets) == 1, \
86 "Should only be a single opt_target in the metadata params."
87 opt_target = opt_targets["value"].iloc[0]
88 except KeyError:
89 continue
90 try:
91 opt_directions = trial_objs_df[trial_objs_df["parameter"] == "opt_direction"]
92 assert len(opt_directions) <= 1, \
93 "Should only be a single opt_direction in the metadata params."
94 opt_direction = opt_directions["value"].iloc[0]
95 except (KeyError, IndexError):
96 opt_direction = None
97 if opt_target not in objectives:
98 objectives[opt_target] = opt_direction
99 elif opt_direction != objectives[opt_target]:
100 _LOG.warning("Experiment %s has multiple trial optimization directions for optimization_target %s=%s",
101 self, opt_target, objectives[opt_target])
102 for opt_tgt, opt_dir in objectives.items():
103 assert opt_dir in {None, "min", "max"}, f"Unexpected opt_dir {opt_dir} for opt_tgt {opt_tgt}."
104 return objectives
106 # TODO: provide a way to get individual data to avoid repeated bulk fetches where only small amounts of data is accessed.
107 # Or else make the TrialData object lazily populate.
109 @property
110 def trials(self) -> Dict[int, TrialData]:
111 return common.get_trials(self._engine, self._schema, self._experiment_id)
113 @property
114 def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]:
115 with self._engine.connect() as conn:
116 tunable_config_trial_groups = conn.execute(
117 self._schema.trial.select().with_only_columns(
118 self._schema.trial.c.config_id,
119 func.min(self._schema.trial.c.trial_id).cast(Integer).label( # pylint: disable=not-callable
120 'tunable_config_trial_group_id'),
121 ).where(
122 self._schema.trial.c.exp_id == self._experiment_id,
123 ).group_by(
124 self._schema.trial.c.exp_id,
125 self._schema.trial.c.config_id,
126 )
127 )
128 return {
129 tunable_config_trial_group.config_id: TunableConfigTrialGroupSqlData(
130 engine=self._engine,
131 schema=self._schema,
132 experiment_id=self._experiment_id,
133 tunable_config_id=tunable_config_trial_group.config_id,
134 tunable_config_trial_group_id=tunable_config_trial_group.tunable_config_trial_group_id,
135 )
136 for tunable_config_trial_group in tunable_config_trial_groups.fetchall()
137 }
139 @property
140 def tunable_configs(self) -> Dict[int, TunableConfigData]:
141 with self._engine.connect() as conn:
142 tunable_configs = conn.execute(
143 self._schema.trial.select().with_only_columns(
144 self._schema.trial.c.config_id.cast(Integer).label('config_id'),
145 ).where(
146 self._schema.trial.c.exp_id == self._experiment_id,
147 ).group_by(
148 self._schema.trial.c.exp_id,
149 self._schema.trial.c.config_id,
150 )
151 )
152 return {
153 tunable_config.config_id: TunableConfigSqlData(
154 engine=self._engine,
155 schema=self._schema,
156 tunable_config_id=tunable_config.config_id,
157 )
158 for tunable_config in tunable_configs.fetchall()
159 }
161 @property
162 def default_tunable_config_id(self) -> Optional[int]:
163 """
164 Retrieves the (tunable) config id for the default tunable values for this experiment.
166 Note: this is by *default* the first trial executed for this experiment.
167 However, it is currently possible that the user changed the tunables config
168 in between resumptions of an experiment.
170 Returns
171 -------
172 int
173 """
174 with self._engine.connect() as conn:
175 query_results = conn.execute(
176 self._schema.trial.select().with_only_columns(
177 self._schema.trial.c.config_id.cast(Integer).label('config_id'),
178 ).where(
179 self._schema.trial.c.exp_id == self._experiment_id,
180 self._schema.trial.c.trial_id.in_(
181 self._schema.trial_param.select().with_only_columns(
182 func.min(self._schema.trial_param.c.trial_id).cast(Integer).label( # pylint: disable=not-callable
183 "first_trial_id_with_defaults"),
184 ).where(
185 self._schema.trial_param.c.exp_id == self._experiment_id,
186 self._schema.trial_param.c.param_id == "is_defaults",
187 func.lower(self._schema.trial_param.c.param_value, type_=String).in_(["1", "true"]),
188 ).scalar_subquery()
189 )
190 )
191 )
192 min_default_trial_row = query_results.fetchone()
193 if min_default_trial_row is not None:
194 # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy
195 return min_default_trial_row._tuple()[0]
196 # fallback logic - assume minimum trial_id for experiment
197 query_results = conn.execute(
198 self._schema.trial.select().with_only_columns(
199 self._schema.trial.c.config_id.cast(Integer).label('config_id'),
200 ).where(
201 self._schema.trial.c.exp_id == self._experiment_id,
202 self._schema.trial.c.trial_id.in_(
203 self._schema.trial.select().with_only_columns(
204 func.min(self._schema.trial.c.trial_id).cast(Integer).label("first_trial_id"),
205 ).where(
206 self._schema.trial.c.exp_id == self._experiment_id,
207 ).scalar_subquery()
208 )
209 )
210 )
211 min_trial_row = query_results.fetchone()
212 if min_trial_row is not None:
213 # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy
214 return min_trial_row._tuple()[0]
215 return None
217 @property
218 def results_df(self) -> pandas.DataFrame:
219 return common.get_results_df(self._engine, self._schema, self._experiment_id)