Coverage for mlos_bench/mlos_bench/storage/sql/experiment.py: 89%

101 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Saving and restoring the benchmark data using SQLAlchemy. 

7""" 

8 

9import logging 

10import hashlib 

11from datetime import datetime 

12from typing import Optional, Tuple, List, Dict, Iterator, Any 

13 

14from pytz import UTC 

15 

16from sqlalchemy import Engine, Connection, CursorResult, Table, column, func, select 

17 

18from mlos_bench.environments.status import Status 

19from mlos_bench.tunables.tunable_groups import TunableGroups 

20from mlos_bench.storage.base_storage import Storage 

21from mlos_bench.storage.sql.schema import DbSchema 

22from mlos_bench.storage.sql.trial import Trial 

23from mlos_bench.util import nullable, utcify_timestamp 

24 

25_LOG = logging.getLogger(__name__) 

26 

27 

28class Experiment(Storage.Experiment): 

29 """ 

30 Logic for retrieving and storing the results of a single experiment. 

31 """ 

32 

33 def __init__(self, *, 

34 engine: Engine, 

35 schema: DbSchema, 

36 tunables: TunableGroups, 

37 experiment_id: str, 

38 trial_id: int, 

39 root_env_config: str, 

40 description: str, 

41 opt_target: str, 

42 opt_direction: Optional[str]): 

43 super().__init__( 

44 tunables=tunables, 

45 experiment_id=experiment_id, 

46 trial_id=trial_id, 

47 root_env_config=root_env_config, 

48 description=description, 

49 opt_target=opt_target, 

50 opt_direction=opt_direction) 

51 self._engine = engine 

52 self._schema = schema 

53 

54 def _setup(self) -> None: 

55 super()._setup() 

56 with self._engine.begin() as conn: 

57 # Get git info and the last trial ID for the experiment. 

58 # pylint: disable=not-callable 

59 exp_info = conn.execute( 

60 self._schema.experiment.select().with_only_columns( 

61 self._schema.experiment.c.git_repo, 

62 self._schema.experiment.c.git_commit, 

63 self._schema.experiment.c.root_env_config, 

64 func.max(self._schema.trial.c.trial_id).label("trial_id"), 

65 ).join( 

66 self._schema.trial, 

67 self._schema.trial.c.exp_id == self._schema.experiment.c.exp_id, 

68 isouter=True 

69 ).where( 

70 self._schema.experiment.c.exp_id == self._experiment_id, 

71 ).group_by( 

72 self._schema.experiment.c.git_repo, 

73 self._schema.experiment.c.git_commit, 

74 self._schema.experiment.c.root_env_config, 

75 ) 

76 ).fetchone() 

77 if exp_info is None: 

78 _LOG.info("Start new experiment: %s", self._experiment_id) 

79 # It's a new experiment: create a record for it in the database. 

80 conn.execute(self._schema.experiment.insert().values( 

81 exp_id=self._experiment_id, 

82 description=self._description, 

83 git_repo=self._git_repo, 

84 git_commit=self._git_commit, 

85 root_env_config=self._root_env_config, 

86 )) 

87 # TODO: Expand for multiple objectives. 

88 conn.execute(self._schema.objectives.insert().values( 

89 exp_id=self._experiment_id, 

90 optimization_target=self._opt_target, 

91 optimization_direction=self._opt_direction, 

92 )) 

93 else: 

94 if exp_info.trial_id is not None: 

95 self._trial_id = exp_info.trial_id + 1 

96 _LOG.info("Continue experiment: %s last trial: %s resume from: %d", 

97 self._experiment_id, exp_info.trial_id, self._trial_id) 

98 # TODO: Sanity check that certain critical configs (e.g., 

99 # objectives) haven't changed to be incompatible such that a new 

100 # experiment should be started (possibly by prewarming with the 

101 # previous one). 

102 if exp_info.git_commit != self._git_commit: 

103 _LOG.warning("Experiment %s git expected: %s %s", 

104 self, exp_info.git_repo, exp_info.git_commit) 

105 

106 def merge(self, experiment_ids: List[str]) -> None: 

107 _LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids) 

108 raise NotImplementedError("TODO") 

109 

110 def load_tunable_config(self, config_id: int) -> Dict[str, Any]: 

111 with self._engine.connect() as conn: 

112 return self._get_key_val(conn, self._schema.config_param, "param", config_id=config_id) 

113 

114 def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]: 

115 with self._engine.connect() as conn: 

116 cur_telemetry = conn.execute( 

117 self._schema.trial_telemetry.select().where( 

118 self._schema.trial_telemetry.c.exp_id == self._experiment_id, 

119 self._schema.trial_telemetry.c.trial_id == trial_id 

120 ).order_by( 

121 self._schema.trial_telemetry.c.ts, 

122 self._schema.trial_telemetry.c.metric_id, 

123 ) 

124 ) 

125 # Not all storage backends store the original zone info. 

126 # We try to ensure data is entered in UTC and augment it on return again here. 

127 return [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) 

128 for row in cur_telemetry.fetchall()] 

129 

130 def load(self, last_trial_id: int = -1, 

131 ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: 

132 

133 with self._engine.connect() as conn: 

134 cur_trials = conn.execute( 

135 self._schema.trial.select().with_only_columns( 

136 self._schema.trial.c.trial_id, 

137 self._schema.trial.c.config_id, 

138 self._schema.trial.c.status, 

139 ).where( 

140 self._schema.trial.c.exp_id == self._experiment_id, 

141 self._schema.trial.c.trial_id > last_trial_id, 

142 self._schema.trial.c.status.in_(['SUCCEEDED', 'FAILED', 'TIMED_OUT']), 

143 ).order_by( 

144 self._schema.trial.c.trial_id.asc(), 

145 ) 

146 ) 

147 

148 trial_ids: List[int] = [] 

149 configs: List[Dict[str, Any]] = [] 

150 scores: List[Optional[Dict[str, Any]]] = [] 

151 status: List[Status] = [] 

152 

153 for trial in cur_trials.fetchall(): 

154 stat = Status[trial.status] 

155 status.append(stat) 

156 trial_ids.append(trial.trial_id) 

157 configs.append(self._get_key_val( 

158 conn, self._schema.config_param, "param", config_id=trial.config_id)) 

159 if stat.is_succeeded(): 

160 scores.append(self._get_key_val( 

161 conn, self._schema.trial_result, "metric", 

162 exp_id=self._experiment_id, trial_id=trial.trial_id)) 

163 else: 

164 scores.append(None) 

165 

166 return (trial_ids, configs, scores, status) 

167 

168 @staticmethod 

169 def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> Dict[str, Any]: 

170 """ 

171 Helper method to retrieve key-value pairs from the database. 

172 (E.g., configurations, results, and telemetry). 

173 """ 

174 cur_result: CursorResult[Tuple[str, Any]] = conn.execute( 

175 select( 

176 column(f"{field}_id"), 

177 column(f"{field}_value"), 

178 ).select_from(table).where( 

179 *[column(key) == val for (key, val) in kwargs.items()] 

180 ) 

181 ) 

182 # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to avoid naming conflicts. 

183 return dict(row._tuple() for row in cur_result.fetchall()) # pylint: disable=protected-access 

184 

185 @staticmethod 

186 def _save_params(conn: Connection, table: Table, 

187 params: Dict[str, Any], **kwargs: Any) -> None: 

188 if not params: 

189 return 

190 conn.execute(table.insert(), [ 

191 { 

192 **kwargs, 

193 "param_id": key, 

194 "param_value": nullable(str, val) 

195 } 

196 for (key, val) in params.items() 

197 ]) 

198 

199 def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]: 

200 timestamp = utcify_timestamp(timestamp, origin="local") 

201 _LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp) 

202 if running: 

203 pending_status = ['PENDING', 'READY', 'RUNNING'] 

204 else: 

205 pending_status = ['PENDING'] 

206 with self._engine.connect() as conn: 

207 cur_trials = conn.execute(self._schema.trial.select().where( 

208 self._schema.trial.c.exp_id == self._experiment_id, 

209 (self._schema.trial.c.ts_start.is_(None) | 

210 (self._schema.trial.c.ts_start <= timestamp)), 

211 self._schema.trial.c.ts_end.is_(None), 

212 self._schema.trial.c.status.in_(pending_status), 

213 )) 

214 for trial in cur_trials.fetchall(): 

215 tunables = self._get_key_val( 

216 conn, self._schema.config_param, "param", 

217 config_id=trial.config_id) 

218 config = self._get_key_val( 

219 conn, self._schema.trial_param, "param", 

220 exp_id=self._experiment_id, trial_id=trial.trial_id) 

221 yield Trial( 

222 engine=self._engine, 

223 schema=self._schema, 

224 # Reset .is_updated flag after the assignment: 

225 tunables=self._tunables.copy().assign(tunables).reset(), 

226 experiment_id=self._experiment_id, 

227 trial_id=trial.trial_id, 

228 config_id=trial.config_id, 

229 opt_target=self._opt_target, 

230 opt_direction=self._opt_direction, 

231 config=config, 

232 ) 

233 

234 def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int: 

235 """ 

236 Get the config ID for the given tunables. If the config does not exist, 

237 create a new record for it. 

238 """ 

239 config_hash = hashlib.sha256(str(tunables).encode('utf-8')).hexdigest() 

240 cur_config = conn.execute(self._schema.config.select().where( 

241 self._schema.config.c.config_hash == config_hash 

242 )).fetchone() 

243 if cur_config is not None: 

244 return int(cur_config.config_id) # mypy doesn't know it's always int 

245 # Config not found, create a new one: 

246 config_id: int = conn.execute(self._schema.config.insert().values( 

247 config_hash=config_hash)).inserted_primary_key[0] 

248 self._save_params( 

249 conn, self._schema.config_param, 

250 {tunable.name: tunable.value for (tunable, _group) in tunables}, 

251 config_id=config_id) 

252 return config_id 

253 

254 def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None, 

255 config: Optional[Dict[str, Any]] = None) -> Storage.Trial: 

256 ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local") 

257 _LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start) 

258 with self._engine.begin() as conn: 

259 try: 

260 config_id = self._get_config_id(conn, tunables) 

261 conn.execute(self._schema.trial.insert().values( 

262 exp_id=self._experiment_id, 

263 trial_id=self._trial_id, 

264 config_id=config_id, 

265 ts_start=ts_start, 

266 status='PENDING', 

267 )) 

268 

269 # Note: config here is the framework config, not the target 

270 # environment config (i.e., tunables). 

271 if config is not None: 

272 self._save_params( 

273 conn, self._schema.trial_param, config, 

274 exp_id=self._experiment_id, trial_id=self._trial_id) 

275 

276 trial = Trial( 

277 engine=self._engine, 

278 schema=self._schema, 

279 tunables=tunables, 

280 experiment_id=self._experiment_id, 

281 trial_id=self._trial_id, 

282 config_id=config_id, 

283 opt_target=self._opt_target, 

284 opt_direction=self._opt_direction, 

285 config=config, 

286 ) 

287 self._trial_id += 1 

288 return trial 

289 except Exception: 

290 conn.rollback() 

291 raise