Coverage for mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py: 100%
29 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-05 00:36 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-05 00:36 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6An interface to access the tunable config trial group data stored in SQL DB.
7"""
9from typing import Dict, Optional, TYPE_CHECKING
11import pandas
12from sqlalchemy import Engine, Integer, func
14from mlos_bench.storage.base_tunable_config_data import TunableConfigData
15from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData
16from mlos_bench.storage.sql import common
17from mlos_bench.storage.sql.schema import DbSchema
18from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData
20if TYPE_CHECKING:
21 from mlos_bench.storage.base_trial_data import TrialData
24class TunableConfigTrialGroupSqlData(TunableConfigTrialGroupData):
25 """
26 SQL interface for accessing the stored experiment benchmark tunable config
27 trial group data.
29 A (tunable) config is used to define an instance of values for a set of tunable
30 parameters for a given experiment and can be used by one or more trial instances
31 (e.g., for repeats), which we call a (tunable) config trial group.
32 """
34 def __init__(self, *,
35 engine: Engine,
36 schema: DbSchema,
37 experiment_id: str,
38 tunable_config_id: int,
39 tunable_config_trial_group_id: Optional[int] = None):
40 super().__init__(
41 experiment_id=experiment_id,
42 tunable_config_id=tunable_config_id,
43 tunable_config_trial_group_id=tunable_config_trial_group_id,
44 )
45 self._engine = engine
46 self._schema = schema
48 def _get_tunable_config_trial_group_id(self) -> int:
49 """
50 Retrieve the trial's tunable_config_trial_group_id from the storage.
51 """
52 with self._engine.connect() as conn:
53 tunable_config_trial_group = conn.execute(
54 self._schema.trial.select().with_only_columns(
55 func.min(self._schema.trial.c.trial_id).cast(Integer).label( # pylint: disable=not-callable
56 'tunable_config_trial_group_id'),
57 ).where(
58 self._schema.trial.c.exp_id == self._experiment_id,
59 self._schema.trial.c.config_id == self._tunable_config_id,
60 ).group_by(
61 self._schema.trial.c.exp_id,
62 self._schema.trial.c.config_id,
63 )
64 )
65 row = tunable_config_trial_group.fetchone()
66 assert row is not None
67 return row._tuple()[0] # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy
69 @property
70 def tunable_config(self) -> TunableConfigData:
71 return TunableConfigSqlData(
72 engine=self._engine,
73 schema=self._schema,
74 tunable_config_id=self.tunable_config_id,
75 )
77 @property
78 def trials(self) -> Dict[int, "TrialData"]:
79 """
80 Retrieve the trials' data for this (tunable) config trial group from the storage.
82 Returns
83 -------
84 trials : Dict[int, TrialData]
85 A dictionary of the trials' data, keyed by trial id.
86 """
87 return common.get_trials(self._engine, self._schema, self._experiment_id, self._tunable_config_id)
89 @property
90 def results_df(self) -> pandas.DataFrame:
91 return common.get_results_df(self._engine, self._schema, self._experiment_id, self._tunable_config_id)