Coverage for mlos_bench/mlos_bench/storage/sql/experiment_data.py: 93%

75 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6An interface to access the experiment benchmark data stored in SQL DB. 

7""" 

8from typing import Dict, Literal, Optional 

9 

10import logging 

11 

12import pandas 

13from sqlalchemy import Engine, Integer, String, func 

14 

15from mlos_bench.storage.base_experiment_data import ExperimentData 

16from mlos_bench.storage.base_trial_data import TrialData 

17from mlos_bench.storage.base_tunable_config_data import TunableConfigData 

18from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData 

19from mlos_bench.storage.sql import common 

20from mlos_bench.storage.sql.schema import DbSchema 

21from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData 

22from mlos_bench.storage.sql.tunable_config_trial_group_data import TunableConfigTrialGroupSqlData 

23 

24_LOG = logging.getLogger(__name__) 

25 

26 

27class ExperimentSqlData(ExperimentData): 

28 """ 

29 SQL interface for accessing the stored experiment benchmark data. 

30 

31 An experiment groups together a set of trials that are run with a given set of 

32 scripts and mlos_bench configuration files. 

33 """ 

34 

35 def __init__(self, *, 

36 engine: Engine, 

37 schema: DbSchema, 

38 experiment_id: str, 

39 description: str, 

40 root_env_config: str, 

41 git_repo: str, 

42 git_commit: str): 

43 super().__init__( 

44 experiment_id=experiment_id, 

45 description=description, 

46 root_env_config=root_env_config, 

47 git_repo=git_repo, 

48 git_commit=git_commit, 

49 ) 

50 self._engine = engine 

51 self._schema = schema 

52 

53 @property 

54 def objectives(self) -> Dict[str, Literal["min", "max"]]: 

55 objectives: Dict[str, Literal["min", "max"]] = {} 

56 # First try to lookup the objectives from the experiment metadata in the storage layer. 

57 if hasattr(self._schema, "objectives"): 

58 with self._engine.connect() as conn: 

59 objectives_db_data = conn.execute( 

60 self._schema.objectives.select().where( 

61 self._schema.objectives.c.exp_id == self._experiment_id, 

62 ).order_by( 

63 # TODO: return weight as well 

64 self._schema.objectives.c.weight.desc(), 

65 self._schema.objectives.c.optimization_target.asc(), 

66 ) 

67 ) 

68 objectives = { 

69 objective.optimization_target: objective.optimization_direction 

70 for objective in objectives_db_data.fetchall() 

71 } 

72 # Backwards compatibility: try and obtain the objectives from the TrialData and merge them in. 

73 # NOTE: The original format of storing opt_target/opt_direction in the Trial 

74 # metadata did not support multi-objectives. 

75 # Nor does it make it easy to detect when a config change caused a switch in 

76 # opt_direction for a given opt_target between run.py executions of an 

77 # Experiment. 

78 # For now, we simply issue a warning about potentially inconsistent data. 

79 for trial in self.trials.values(): 

80 trial_objs_df = trial.metadata_df[ 

81 trial.metadata_df["parameter"].isin(("opt_target", "opt_direction")) 

82 ][["parameter", "value"]] 

83 try: 

84 opt_targets = trial_objs_df[trial_objs_df["parameter"] == "opt_target"] 

85 assert len(opt_targets) == 1, \ 

86 "Should only be a single opt_target in the metadata params." 

87 opt_target = opt_targets["value"].iloc[0] 

88 except KeyError: 

89 continue 

90 try: 

91 opt_directions = trial_objs_df[trial_objs_df["parameter"] == "opt_direction"] 

92 assert len(opt_directions) <= 1, \ 

93 "Should only be a single opt_direction in the metadata params." 

94 opt_direction = opt_directions["value"].iloc[0] 

95 except (KeyError, IndexError): 

96 opt_direction = None 

97 if opt_target not in objectives: 

98 objectives[opt_target] = opt_direction 

99 elif opt_direction != objectives[opt_target]: 

100 _LOG.warning("Experiment %s has multiple trial optimization directions for optimization_target %s=%s", 

101 self, opt_target, objectives[opt_target]) 

102 for opt_tgt, opt_dir in objectives.items(): 

103 assert opt_dir in {None, "min", "max"}, f"Unexpected opt_dir {opt_dir} for opt_tgt {opt_tgt}." 

104 return objectives 

105 

106 # TODO: provide a way to get individual data to avoid repeated bulk fetches where only small amounts of data is accessed. 

107 # Or else make the TrialData object lazily populate. 

108 

109 @property 

110 def trials(self) -> Dict[int, TrialData]: 

111 return common.get_trials(self._engine, self._schema, self._experiment_id) 

112 

113 @property 

114 def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]: 

115 with self._engine.connect() as conn: 

116 tunable_config_trial_groups = conn.execute( 

117 self._schema.trial.select().with_only_columns( 

118 self._schema.trial.c.config_id, 

119 func.min(self._schema.trial.c.trial_id).cast(Integer).label( # pylint: disable=not-callable 

120 'tunable_config_trial_group_id'), 

121 ).where( 

122 self._schema.trial.c.exp_id == self._experiment_id, 

123 ).group_by( 

124 self._schema.trial.c.exp_id, 

125 self._schema.trial.c.config_id, 

126 ) 

127 ) 

128 return { 

129 tunable_config_trial_group.config_id: TunableConfigTrialGroupSqlData( 

130 engine=self._engine, 

131 schema=self._schema, 

132 experiment_id=self._experiment_id, 

133 tunable_config_id=tunable_config_trial_group.config_id, 

134 tunable_config_trial_group_id=tunable_config_trial_group.tunable_config_trial_group_id, 

135 ) 

136 for tunable_config_trial_group in tunable_config_trial_groups.fetchall() 

137 } 

138 

139 @property 

140 def tunable_configs(self) -> Dict[int, TunableConfigData]: 

141 with self._engine.connect() as conn: 

142 tunable_configs = conn.execute( 

143 self._schema.trial.select().with_only_columns( 

144 self._schema.trial.c.config_id.cast(Integer).label('config_id'), 

145 ).where( 

146 self._schema.trial.c.exp_id == self._experiment_id, 

147 ).group_by( 

148 self._schema.trial.c.exp_id, 

149 self._schema.trial.c.config_id, 

150 ) 

151 ) 

152 return { 

153 tunable_config.config_id: TunableConfigSqlData( 

154 engine=self._engine, 

155 schema=self._schema, 

156 tunable_config_id=tunable_config.config_id, 

157 ) 

158 for tunable_config in tunable_configs.fetchall() 

159 } 

160 

161 @property 

162 def default_tunable_config_id(self) -> Optional[int]: 

163 """ 

164 Retrieves the (tunable) config id for the default tunable values for this experiment. 

165 

166 Note: this is by *default* the first trial executed for this experiment. 

167 However, it is currently possible that the user changed the tunables config 

168 in between resumptions of an experiment. 

169 

170 Returns 

171 ------- 

172 int 

173 """ 

174 with self._engine.connect() as conn: 

175 query_results = conn.execute( 

176 self._schema.trial.select().with_only_columns( 

177 self._schema.trial.c.config_id.cast(Integer).label('config_id'), 

178 ).where( 

179 self._schema.trial.c.exp_id == self._experiment_id, 

180 self._schema.trial.c.trial_id.in_( 

181 self._schema.trial_param.select().with_only_columns( 

182 func.min(self._schema.trial_param.c.trial_id).cast(Integer).label( # pylint: disable=not-callable 

183 "first_trial_id_with_defaults"), 

184 ).where( 

185 self._schema.trial_param.c.exp_id == self._experiment_id, 

186 self._schema.trial_param.c.param_id == "is_defaults", 

187 func.lower(self._schema.trial_param.c.param_value, type_=String).in_(["1", "true"]), 

188 ).scalar_subquery() 

189 ) 

190 ) 

191 ) 

192 min_default_trial_row = query_results.fetchone() 

193 if min_default_trial_row is not None: 

194 # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy 

195 return min_default_trial_row._tuple()[0] 

196 # fallback logic - assume minimum trial_id for experiment 

197 query_results = conn.execute( 

198 self._schema.trial.select().with_only_columns( 

199 self._schema.trial.c.config_id.cast(Integer).label('config_id'), 

200 ).where( 

201 self._schema.trial.c.exp_id == self._experiment_id, 

202 self._schema.trial.c.trial_id.in_( 

203 self._schema.trial.select().with_only_columns( 

204 func.min(self._schema.trial.c.trial_id).cast(Integer).label("first_trial_id"), 

205 ).where( 

206 self._schema.trial.c.exp_id == self._experiment_id, 

207 ).scalar_subquery() 

208 ) 

209 ) 

210 ) 

211 min_trial_row = query_results.fetchone() 

212 if min_trial_row is not None: 

213 # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy 

214 return min_trial_row._tuple()[0] 

215 return None 

216 

217 @property 

218 def results_df(self) -> pandas.DataFrame: 

219 return common.get_results_df(self._engine, self._schema, self._experiment_id)