Coverage for mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 01:18 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""An interface to access the tunable config trial group data stored in a SQL DB using 

6the :py:class:`.TunableConfigTrialGroupData` interface. 

7""" 

8 

9from typing import TYPE_CHECKING, Dict, Optional 

10 

11import pandas 

12from sqlalchemy import Integer, func 

13from sqlalchemy.engine import Engine 

14 

15from mlos_bench.storage.base_tunable_config_data import TunableConfigData 

16from mlos_bench.storage.base_tunable_config_trial_group_data import ( 

17 TunableConfigTrialGroupData, 

18) 

19from mlos_bench.storage.sql import common 

20from mlos_bench.storage.sql.schema import DbSchema 

21from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData 

22 

23if TYPE_CHECKING: 

24 from mlos_bench.storage.base_trial_data import TrialData 

25 

26 

27class TunableConfigTrialGroupSqlData(TunableConfigTrialGroupData): 

28 """ 

29 SQL interface for accessing the stored experiment benchmark tunable config trial 

30 group data. 

31 

32 A (tunable) config is used to define an instance of values for a set of tunable 

33 parameters for a given experiment and can be used by one or more trial instances 

34 (e.g., for repeats), which we call a (tunable) config trial group. 

35 """ 

36 

37 def __init__( # pylint: disable=too-many-arguments 

38 self, 

39 *, 

40 engine: Engine, 

41 schema: DbSchema, 

42 experiment_id: str, 

43 tunable_config_id: int, 

44 tunable_config_trial_group_id: Optional[int] = None, 

45 ): 

46 super().__init__( 

47 experiment_id=experiment_id, 

48 tunable_config_id=tunable_config_id, 

49 tunable_config_trial_group_id=tunable_config_trial_group_id, 

50 ) 

51 self._engine = engine 

52 self._schema = schema 

53 

54 def _get_tunable_config_trial_group_id(self) -> int: 

55 """Retrieve the trial's tunable_config_trial_group_id from the storage.""" 

56 with self._engine.connect() as conn: 

57 tunable_config_trial_group = conn.execute( 

58 self._schema.trial.select() 

59 .with_only_columns( 

60 func.min(self._schema.trial.c.trial_id) 

61 .cast(Integer) 

62 .label("tunable_config_trial_group_id"), # pylint: disable=not-callable 

63 ) 

64 .where( 

65 self._schema.trial.c.exp_id == self._experiment_id, 

66 self._schema.trial.c.config_id == self._tunable_config_id, 

67 ) 

68 .group_by( 

69 self._schema.trial.c.exp_id, 

70 self._schema.trial.c.config_id, 

71 ) 

72 ) 

73 row = tunable_config_trial_group.fetchone() 

74 assert row is not None 

75 # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy 

76 return row._tuple()[0] 

77 

78 @property 

79 def tunable_config(self) -> TunableConfigData: 

80 return TunableConfigSqlData( 

81 engine=self._engine, 

82 schema=self._schema, 

83 tunable_config_id=self.tunable_config_id, 

84 ) 

85 

86 @property 

87 def trials(self) -> Dict[int, "TrialData"]: 

88 """ 

89 Retrieve the trials' data for this (tunable) config trial group from the 

90 storage. 

91 

92 Returns 

93 ------- 

94 trials : Dict[int, TrialData] 

95 A dictionary of the trials' data, keyed by trial id. 

96 """ 

97 return common.get_trials( 

98 self._engine, 

99 self._schema, 

100 self._experiment_id, 

101 self._tunable_config_id, 

102 ) 

103 

104 @property 

105 def results_df(self) -> pandas.DataFrame: 

106 return common.get_results_df( 

107 self._engine, 

108 self._schema, 

109 self._experiment_id, 

110 self._tunable_config_id, 

111 )