Coverage for mlos_bench/mlos_bench/storage/sql/trial_data.py: 100%

38 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6An interface to access the benchmark trial data stored in SQL DB. 

7""" 

8from datetime import datetime 

9from typing import Optional, TYPE_CHECKING 

10 

11import pandas 

12from sqlalchemy import Engine 

13 

14from mlos_bench.storage.base_trial_data import TrialData 

15from mlos_bench.storage.base_tunable_config_data import TunableConfigData 

16from mlos_bench.environments.status import Status 

17from mlos_bench.storage.sql.schema import DbSchema 

18from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData 

19from mlos_bench.util import utcify_timestamp 

20 

21if TYPE_CHECKING: 

22 from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData 

23 

24 

25class TrialSqlData(TrialData): 

26 """ 

27 An interface to access the trial data stored in the SQL DB. 

28 """ 

29 

30 def __init__(self, *, 

31 engine: Engine, 

32 schema: DbSchema, 

33 experiment_id: str, 

34 trial_id: int, 

35 config_id: int, 

36 ts_start: datetime, 

37 ts_end: Optional[datetime], 

38 status: Status): 

39 super().__init__( 

40 experiment_id=experiment_id, 

41 trial_id=trial_id, 

42 tunable_config_id=config_id, 

43 ts_start=ts_start, 

44 ts_end=ts_end, 

45 status=status, 

46 ) 

47 self._engine = engine 

48 self._schema = schema 

49 

50 @property 

51 def tunable_config(self) -> TunableConfigData: 

52 """ 

53 Retrieve the trial's tunable configuration from the storage. 

54 

55 Note: this corresponds to the Trial object's "tunables" property. 

56 """ 

57 return TunableConfigSqlData(engine=self._engine, schema=self._schema, 

58 tunable_config_id=self._tunable_config_id) 

59 

60 @property 

61 def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData": 

62 """ 

63 Retrieve the trial's tunable config group configuration data from the storage. 

64 """ 

65 # pylint: disable=import-outside-toplevel 

66 from mlos_bench.storage.sql.tunable_config_trial_group_data import TunableConfigTrialGroupSqlData 

67 return TunableConfigTrialGroupSqlData(engine=self._engine, schema=self._schema, 

68 experiment_id=self._experiment_id, 

69 tunable_config_id=self._tunable_config_id) 

70 

71 @property 

72 def results_df(self) -> pandas.DataFrame: 

73 """ 

74 Retrieve the trials' results from the storage. 

75 """ 

76 with self._engine.connect() as conn: 

77 cur_results = conn.execute( 

78 self._schema.trial_result.select().where( 

79 self._schema.trial_result.c.exp_id == self._experiment_id, 

80 self._schema.trial_result.c.trial_id == self._trial_id 

81 ).order_by( 

82 self._schema.trial_result.c.metric_id, 

83 ) 

84 ) 

85 return pandas.DataFrame( 

86 [(row.metric_id, row.metric_value) for row in cur_results.fetchall()], 

87 columns=['metric', 'value']) 

88 

89 @property 

90 def telemetry_df(self) -> pandas.DataFrame: 

91 """ 

92 Retrieve the trials' telemetry from the storage. 

93 """ 

94 with self._engine.connect() as conn: 

95 cur_telemetry = conn.execute( 

96 self._schema.trial_telemetry.select().where( 

97 self._schema.trial_telemetry.c.exp_id == self._experiment_id, 

98 self._schema.trial_telemetry.c.trial_id == self._trial_id 

99 ).order_by( 

100 self._schema.trial_telemetry.c.ts, 

101 self._schema.trial_telemetry.c.metric_id, 

102 ) 

103 ) 

104 # Not all storage backends store the original zone info. 

105 # We try to ensure data is entered in UTC and augment it on return again here. 

106 return pandas.DataFrame( 

107 [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) for row in cur_telemetry.fetchall()], 

108 columns=['ts', 'metric', 'value']) 

109 

110 @property 

111 def metadata_df(self) -> pandas.DataFrame: 

112 """ 

113 Retrieve the trials' metadata params. 

114 

115 Note: this corresponds to the Trial object's "config" property. 

116 """ 

117 with self._engine.connect() as conn: 

118 cur_params = conn.execute( 

119 self._schema.trial_param.select().where( 

120 self._schema.trial_param.c.exp_id == self._experiment_id, 

121 self._schema.trial_param.c.trial_id == self._trial_id 

122 ).order_by( 

123 self._schema.trial_param.c.param_id, 

124 ) 

125 ) 

126 return pandas.DataFrame( 

127 [(row.param_id, row.param_value) for row in cur_params.fetchall()], 

128 columns=['parameter', 'value'])