Coverage for mlos_bench/mlos_bench/storage/sql/experiment_data.py: 90%

52 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 01:18 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""An interface to access the benchmark experiment data stored in SQL DB using the 

6:py:class:`.ExperimentData` interface. 

7""" 

8import logging 

9from typing import Dict, Literal, Optional 

10 

11import pandas 

12from sqlalchemy import Integer, String, func 

13from sqlalchemy.engine import Engine 

14 

15from mlos_bench.storage.base_experiment_data import ExperimentData 

16from mlos_bench.storage.base_trial_data import TrialData 

17from mlos_bench.storage.base_tunable_config_data import TunableConfigData 

18from mlos_bench.storage.base_tunable_config_trial_group_data import ( 

19 TunableConfigTrialGroupData, 

20) 

21from mlos_bench.storage.sql import common 

22from mlos_bench.storage.sql.schema import DbSchema 

23from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData 

24from mlos_bench.storage.sql.tunable_config_trial_group_data import ( 

25 TunableConfigTrialGroupSqlData, 

26) 

27 

28_LOG = logging.getLogger(__name__) 

29 

30 

31class ExperimentSqlData(ExperimentData): 

32 """ 

33 SQL interface for accessing the stored experiment benchmark data. 

34 

35 An experiment groups together a set of trials that are run with a given set of 

36 scripts and mlos_bench configuration files. 

37 """ 

38 

39 def __init__( # pylint: disable=too-many-arguments 

40 self, 

41 *, 

42 engine: Engine, 

43 schema: DbSchema, 

44 experiment_id: str, 

45 description: str, 

46 root_env_config: str, 

47 git_repo: str, 

48 git_commit: str, 

49 ): 

50 super().__init__( 

51 experiment_id=experiment_id, 

52 description=description, 

53 root_env_config=root_env_config, 

54 git_repo=git_repo, 

55 git_commit=git_commit, 

56 ) 

57 self._engine = engine 

58 self._schema = schema 

59 

60 @property 

61 def objectives(self) -> Dict[str, Literal["min", "max"]]: 

62 with self._engine.connect() as conn: 

63 objectives_db_data = conn.execute( 

64 self._schema.objectives.select() 

65 .where( 

66 self._schema.objectives.c.exp_id == self._experiment_id, 

67 ) 

68 .order_by( 

69 self._schema.objectives.c.weight.desc(), 

70 self._schema.objectives.c.optimization_target.asc(), 

71 ) 

72 ) 

73 return { 

74 objective.optimization_target: objective.optimization_direction 

75 for objective in objectives_db_data.fetchall() 

76 } 

77 

78 # TODO: provide a way to get individual data to avoid repeated bulk fetches 

79 # where only small amounts of data is accessed. 

80 # Or else make the TrialData object lazily populate. 

81 

82 @property 

83 def trials(self) -> Dict[int, TrialData]: 

84 return common.get_trials(self._engine, self._schema, self._experiment_id) 

85 

86 @property 

87 def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]: 

88 with self._engine.connect() as conn: 

89 tunable_config_trial_groups = conn.execute( 

90 self._schema.trial.select() 

91 .with_only_columns( 

92 self._schema.trial.c.config_id, 

93 func.min(self._schema.trial.c.trial_id) 

94 .cast(Integer) 

95 .label("tunable_config_trial_group_id"), # pylint: disable=not-callable 

96 ) 

97 .where( 

98 self._schema.trial.c.exp_id == self._experiment_id, 

99 ) 

100 .group_by( 

101 self._schema.trial.c.exp_id, 

102 self._schema.trial.c.config_id, 

103 ) 

104 ) 

105 return { 

106 tunable_config_trial_group.config_id: TunableConfigTrialGroupSqlData( 

107 engine=self._engine, 

108 schema=self._schema, 

109 experiment_id=self._experiment_id, 

110 tunable_config_id=tunable_config_trial_group.config_id, 

111 tunable_config_trial_group_id=tunable_config_trial_group.tunable_config_trial_group_id, # pylint:disable=line-too-long # noqa 

112 ) 

113 for tunable_config_trial_group in tunable_config_trial_groups.fetchall() 

114 } 

115 

116 @property 

117 def tunable_configs(self) -> Dict[int, TunableConfigData]: 

118 with self._engine.connect() as conn: 

119 tunable_configs = conn.execute( 

120 self._schema.trial.select() 

121 .with_only_columns( 

122 self._schema.trial.c.config_id.cast(Integer).label("config_id"), 

123 ) 

124 .where( 

125 self._schema.trial.c.exp_id == self._experiment_id, 

126 ) 

127 .group_by( 

128 self._schema.trial.c.exp_id, 

129 self._schema.trial.c.config_id, 

130 ) 

131 ) 

132 return { 

133 tunable_config.config_id: TunableConfigSqlData( 

134 engine=self._engine, 

135 schema=self._schema, 

136 tunable_config_id=tunable_config.config_id, 

137 ) 

138 for tunable_config in tunable_configs.fetchall() 

139 } 

140 

141 @property 

142 def default_tunable_config_id(self) -> Optional[int]: 

143 """ 

144 Retrieves the (tunable) config id for the default tunable values for this 

145 experiment. 

146 

147 Note: this is by *default* the first trial executed for this experiment. 

148 However, it is currently possible that the user changed the tunables config 

149 in between resumptions of an experiment. 

150 

151 Returns 

152 ------- 

153 int 

154 """ 

155 with self._engine.connect() as conn: 

156 query_results = conn.execute( 

157 self._schema.trial.select() 

158 .with_only_columns( 

159 self._schema.trial.c.config_id.cast(Integer).label("config_id"), 

160 ) 

161 .where( 

162 self._schema.trial.c.exp_id == self._experiment_id, 

163 self._schema.trial.c.trial_id.in_( 

164 self._schema.trial_param.select() 

165 .with_only_columns( 

166 func.min(self._schema.trial_param.c.trial_id) 

167 .cast(Integer) 

168 .label("first_trial_id_with_defaults"), # pylint: disable=not-callable 

169 ) 

170 .where( 

171 self._schema.trial_param.c.exp_id == self._experiment_id, 

172 self._schema.trial_param.c.param_id == "is_defaults", 

173 func.lower(self._schema.trial_param.c.param_value, type_=String).in_( 

174 ["1", "true"] 

175 ), 

176 ) 

177 .scalar_subquery() 

178 ), 

179 ) 

180 ) 

181 min_default_trial_row = query_results.fetchone() 

182 if min_default_trial_row is not None: 

183 # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy 

184 return min_default_trial_row._tuple()[0] 

185 # fallback logic - assume minimum trial_id for experiment 

186 query_results = conn.execute( 

187 self._schema.trial.select() 

188 .with_only_columns( 

189 self._schema.trial.c.config_id.cast(Integer).label("config_id"), 

190 ) 

191 .where( 

192 self._schema.trial.c.exp_id == self._experiment_id, 

193 self._schema.trial.c.trial_id.in_( 

194 self._schema.trial.select() 

195 .with_only_columns( 

196 func.min(self._schema.trial.c.trial_id) 

197 .cast(Integer) 

198 .label("first_trial_id"), 

199 ) 

200 .where( 

201 self._schema.trial.c.exp_id == self._experiment_id, 

202 ) 

203 .scalar_subquery() 

204 ), 

205 ) 

206 ) 

207 min_trial_row = query_results.fetchone() 

208 if min_trial_row is not None: 

209 # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy 

210 return min_trial_row._tuple()[0] 

211 return None 

212 

213 @property 

214 def results_df(self) -> pandas.DataFrame: 

215 return common.get_results_df(self._engine, self._schema, self._experiment_id)