Coverage for mlos_bench/mlos_bench/storage/sql/trial_data.py: 100%
37 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""An interface to access the benchmark trial data stored in SQL DB using the
6:py:class:`.TrialData` interface.
7"""
8from datetime import datetime
9from typing import TYPE_CHECKING, Optional
11import pandas
12from sqlalchemy.engine import Engine
14from mlos_bench.environments.status import Status
15from mlos_bench.storage.base_trial_data import TrialData
16from mlos_bench.storage.base_tunable_config_data import TunableConfigData
17from mlos_bench.storage.sql.schema import DbSchema
18from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData
19from mlos_bench.util import utcify_timestamp
21if TYPE_CHECKING:
22 from mlos_bench.storage.base_tunable_config_trial_group_data import (
23 TunableConfigTrialGroupData,
24 )
27class TrialSqlData(TrialData):
28 """An interface to access the trial data stored in the SQL DB."""
30 def __init__( # pylint: disable=too-many-arguments
31 self,
32 *,
33 engine: Engine,
34 schema: DbSchema,
35 experiment_id: str,
36 trial_id: int,
37 config_id: int,
38 ts_start: datetime,
39 ts_end: Optional[datetime],
40 status: Status,
41 ):
42 super().__init__(
43 experiment_id=experiment_id,
44 trial_id=trial_id,
45 tunable_config_id=config_id,
46 ts_start=ts_start,
47 ts_end=ts_end,
48 status=status,
49 )
50 self._engine = engine
51 self._schema = schema
53 @property
54 def tunable_config(self) -> TunableConfigData:
55 """
56 Retrieve the trial's tunable configuration from the storage.
58 Note: this corresponds to the Trial object's "tunables" property.
59 """
60 return TunableConfigSqlData(
61 engine=self._engine,
62 schema=self._schema,
63 tunable_config_id=self._tunable_config_id,
64 )
66 @property
67 def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData":
68 """Retrieve the trial's tunable config group configuration data from the
69 storage.
70 """
71 # pylint: disable=import-outside-toplevel
72 from mlos_bench.storage.sql.tunable_config_trial_group_data import (
73 TunableConfigTrialGroupSqlData,
74 )
76 return TunableConfigTrialGroupSqlData(
77 engine=self._engine,
78 schema=self._schema,
79 experiment_id=self._experiment_id,
80 tunable_config_id=self._tunable_config_id,
81 )
83 @property
84 def results_df(self) -> pandas.DataFrame:
85 """Retrieve the trials' results from the storage."""
86 with self._engine.connect() as conn:
87 cur_results = conn.execute(
88 self._schema.trial_result.select()
89 .where(
90 self._schema.trial_result.c.exp_id == self._experiment_id,
91 self._schema.trial_result.c.trial_id == self._trial_id,
92 )
93 .order_by(
94 self._schema.trial_result.c.metric_id,
95 )
96 )
97 return pandas.DataFrame(
98 [(row.metric_id, row.metric_value) for row in cur_results.fetchall()],
99 columns=["metric", "value"],
100 )
102 @property
103 def telemetry_df(self) -> pandas.DataFrame:
104 """Retrieve the trials' telemetry from the storage."""
105 with self._engine.connect() as conn:
106 cur_telemetry = conn.execute(
107 self._schema.trial_telemetry.select()
108 .where(
109 self._schema.trial_telemetry.c.exp_id == self._experiment_id,
110 self._schema.trial_telemetry.c.trial_id == self._trial_id,
111 )
112 .order_by(
113 self._schema.trial_telemetry.c.ts,
114 self._schema.trial_telemetry.c.metric_id,
115 )
116 )
117 # Not all storage backends store the original zone info.
118 # We try to ensure data is entered in UTC and augment it on return again here.
119 return pandas.DataFrame(
120 [
121 (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value)
122 for row in cur_telemetry.fetchall()
123 ],
124 columns=["ts", "metric", "value"],
125 )
127 @property
128 def metadata_df(self) -> pandas.DataFrame:
129 """
130 Retrieve the trials' metadata params.
132 Note: this corresponds to the Trial object's "config" property.
133 """
134 with self._engine.connect() as conn:
135 cur_params = conn.execute(
136 self._schema.trial_param.select()
137 .where(
138 self._schema.trial_param.c.exp_id == self._experiment_id,
139 self._schema.trial_param.c.trial_id == self._trial_id,
140 )
141 .order_by(
142 self._schema.trial_param.c.param_id,
143 )
144 )
145 return pandas.DataFrame(
146 [(row.param_id, row.param_value) for row in cur_params.fetchall()],
147 columns=["parameter", "value"],
148 )