Coverage for mlos_bench/mlos_bench/storage/sql/experiment.py: 92%

101 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 01:18 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""":py:class:`.Storage.Experiment` interface implementation for saving and restoring 

6the benchmark experiment data using `SQLAlchemy <https://sqlalchemy.org>`_ backend. 

7""" 

8 

9import hashlib 

10import logging 

11from datetime import datetime 

12from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple 

13 

14from pytz import UTC 

15from sqlalchemy import Connection, CursorResult, Table, column, func, select 

16from sqlalchemy.engine import Engine 

17 

18from mlos_bench.environments.status import Status 

19from mlos_bench.storage.base_storage import Storage 

20from mlos_bench.storage.sql.schema import DbSchema 

21from mlos_bench.storage.sql.trial import Trial 

22from mlos_bench.tunables.tunable_groups import TunableGroups 

23from mlos_bench.util import nullable, utcify_timestamp 

24 

25_LOG = logging.getLogger(__name__) 

26 

27 

28class Experiment(Storage.Experiment): 

29 """Logic for retrieving and storing the results of a single experiment.""" 

30 

31 def __init__( # pylint: disable=too-many-arguments 

32 self, 

33 *, 

34 engine: Engine, 

35 schema: DbSchema, 

36 tunables: TunableGroups, 

37 experiment_id: str, 

38 trial_id: int, 

39 root_env_config: str, 

40 description: str, 

41 opt_targets: Dict[str, Literal["min", "max"]], 

42 ): 

43 super().__init__( 

44 tunables=tunables, 

45 experiment_id=experiment_id, 

46 trial_id=trial_id, 

47 root_env_config=root_env_config, 

48 description=description, 

49 opt_targets=opt_targets, 

50 ) 

51 self._engine = engine 

52 self._schema = schema 

53 

54 def _setup(self) -> None: 

55 super()._setup() 

56 with self._engine.begin() as conn: 

57 # Get git info and the last trial ID for the experiment. 

58 # pylint: disable=not-callable 

59 exp_info = conn.execute( 

60 self._schema.experiment.select() 

61 .with_only_columns( 

62 self._schema.experiment.c.git_repo, 

63 self._schema.experiment.c.git_commit, 

64 self._schema.experiment.c.root_env_config, 

65 func.max(self._schema.trial.c.trial_id).label("trial_id"), 

66 ) 

67 .join( 

68 self._schema.trial, 

69 self._schema.trial.c.exp_id == self._schema.experiment.c.exp_id, 

70 isouter=True, 

71 ) 

72 .where( 

73 self._schema.experiment.c.exp_id == self._experiment_id, 

74 ) 

75 .group_by( 

76 self._schema.experiment.c.git_repo, 

77 self._schema.experiment.c.git_commit, 

78 self._schema.experiment.c.root_env_config, 

79 ) 

80 ).fetchone() 

81 if exp_info is None: 

82 _LOG.info("Start new experiment: %s", self._experiment_id) 

83 # It's a new experiment: create a record for it in the database. 

84 conn.execute( 

85 self._schema.experiment.insert().values( 

86 exp_id=self._experiment_id, 

87 description=self._description, 

88 git_repo=self._git_repo, 

89 git_commit=self._git_commit, 

90 root_env_config=self._root_env_config, 

91 ) 

92 ) 

93 conn.execute( 

94 self._schema.objectives.insert().values( 

95 [ 

96 { 

97 "exp_id": self._experiment_id, 

98 "optimization_target": opt_target, 

99 "optimization_direction": opt_dir, 

100 } 

101 for (opt_target, opt_dir) in self.opt_targets.items() 

102 ] 

103 ) 

104 ) 

105 else: 

106 if exp_info.trial_id is not None: 

107 self._trial_id = exp_info.trial_id + 1 

108 _LOG.info( 

109 "Continue experiment: %s last trial: %s resume from: %d", 

110 self._experiment_id, 

111 exp_info.trial_id, 

112 self._trial_id, 

113 ) 

114 # TODO: Sanity check that certain critical configs (e.g., 

115 # objectives) haven't changed to be incompatible such that a new 

116 # experiment should be started (possibly by prewarming with the 

117 # previous one). 

118 if exp_info.git_commit != self._git_commit: 

119 _LOG.warning( 

120 "Experiment %s git expected: %s %s", 

121 self, 

122 exp_info.git_repo, 

123 exp_info.git_commit, 

124 ) 

125 

126 def merge(self, experiment_ids: List[str]) -> None: 

127 _LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids) 

128 raise NotImplementedError("TODO") 

129 

130 def load_tunable_config(self, config_id: int) -> Dict[str, Any]: 

131 with self._engine.connect() as conn: 

132 return self._get_key_val(conn, self._schema.config_param, "param", config_id=config_id) 

133 

134 def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]: 

135 with self._engine.connect() as conn: 

136 cur_telemetry = conn.execute( 

137 self._schema.trial_telemetry.select() 

138 .where( 

139 self._schema.trial_telemetry.c.exp_id == self._experiment_id, 

140 self._schema.trial_telemetry.c.trial_id == trial_id, 

141 ) 

142 .order_by( 

143 self._schema.trial_telemetry.c.ts, 

144 self._schema.trial_telemetry.c.metric_id, 

145 ) 

146 ) 

147 # Not all storage backends store the original zone info. 

148 # We try to ensure data is entered in UTC and augment it on return again here. 

149 return [ 

150 (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) 

151 for row in cur_telemetry.fetchall() 

152 ] 

153 

154 def load( 

155 self, 

156 last_trial_id: int = -1, 

157 ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: 

158 

159 with self._engine.connect() as conn: 

160 cur_trials = conn.execute( 

161 self._schema.trial.select() 

162 .with_only_columns( 

163 self._schema.trial.c.trial_id, 

164 self._schema.trial.c.config_id, 

165 self._schema.trial.c.status, 

166 ) 

167 .where( 

168 self._schema.trial.c.exp_id == self._experiment_id, 

169 self._schema.trial.c.trial_id > last_trial_id, 

170 self._schema.trial.c.status.in_(["SUCCEEDED", "FAILED", "TIMED_OUT"]), 

171 ) 

172 .order_by( 

173 self._schema.trial.c.trial_id.asc(), 

174 ) 

175 ) 

176 

177 trial_ids: List[int] = [] 

178 configs: List[Dict[str, Any]] = [] 

179 scores: List[Optional[Dict[str, Any]]] = [] 

180 status: List[Status] = [] 

181 

182 for trial in cur_trials.fetchall(): 

183 stat = Status[trial.status] 

184 status.append(stat) 

185 trial_ids.append(trial.trial_id) 

186 configs.append( 

187 self._get_key_val( 

188 conn, 

189 self._schema.config_param, 

190 "param", 

191 config_id=trial.config_id, 

192 ) 

193 ) 

194 if stat.is_succeeded(): 

195 scores.append( 

196 self._get_key_val( 

197 conn, 

198 self._schema.trial_result, 

199 "metric", 

200 exp_id=self._experiment_id, 

201 trial_id=trial.trial_id, 

202 ) 

203 ) 

204 else: 

205 scores.append(None) 

206 

207 return (trial_ids, configs, scores, status) 

208 

209 @staticmethod 

210 def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> Dict[str, Any]: 

211 """ 

212 Helper method to retrieve key-value pairs from the database. 

213 

214 (E.g., configurations, results, and telemetry). 

215 """ 

216 cur_result: CursorResult[Tuple[str, Any]] = conn.execute( 

217 select( 

218 column(f"{field}_id"), 

219 column(f"{field}_value"), 

220 ) 

221 .select_from(table) 

222 .where(*[column(key) == val for (key, val) in kwargs.items()]) 

223 ) 

224 # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to 

225 # avoid naming conflicts. 

226 return dict( 

227 row._tuple() for row in cur_result.fetchall() # pylint: disable=protected-access 

228 ) 

229 

230 @staticmethod 

231 def _save_params( 

232 conn: Connection, 

233 table: Table, 

234 params: Dict[str, Any], 

235 **kwargs: Any, 

236 ) -> None: 

237 if not params: 

238 return 

239 conn.execute( 

240 table.insert(), 

241 [ 

242 {**kwargs, "param_id": key, "param_value": nullable(str, val)} 

243 for (key, val) in params.items() 

244 ], 

245 ) 

246 

247 def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]: 

248 timestamp = utcify_timestamp(timestamp, origin="local") 

249 _LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp) 

250 if running: 

251 pending_status = ["PENDING", "READY", "RUNNING"] 

252 else: 

253 pending_status = ["PENDING"] 

254 with self._engine.connect() as conn: 

255 cur_trials = conn.execute( 

256 self._schema.trial.select().where( 

257 self._schema.trial.c.exp_id == self._experiment_id, 

258 ( 

259 self._schema.trial.c.ts_start.is_(None) 

260 | (self._schema.trial.c.ts_start <= timestamp) 

261 ), 

262 self._schema.trial.c.ts_end.is_(None), 

263 self._schema.trial.c.status.in_(pending_status), 

264 ) 

265 ) 

266 for trial in cur_trials.fetchall(): 

267 tunables = self._get_key_val( 

268 conn, 

269 self._schema.config_param, 

270 "param", 

271 config_id=trial.config_id, 

272 ) 

273 config = self._get_key_val( 

274 conn, 

275 self._schema.trial_param, 

276 "param", 

277 exp_id=self._experiment_id, 

278 trial_id=trial.trial_id, 

279 ) 

280 yield Trial( 

281 engine=self._engine, 

282 schema=self._schema, 

283 # Reset .is_updated flag after the assignment: 

284 tunables=self._tunables.copy().assign(tunables).reset(), 

285 experiment_id=self._experiment_id, 

286 trial_id=trial.trial_id, 

287 config_id=trial.config_id, 

288 opt_targets=self._opt_targets, 

289 config=config, 

290 ) 

291 

292 def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int: 

293 """ 

294 Get the config ID for the given tunables. 

295 

296 If the config does not exist, create a new record for it. 

297 """ 

298 config_hash = hashlib.sha256(str(tunables).encode("utf-8")).hexdigest() 

299 cur_config = conn.execute( 

300 self._schema.config.select().where(self._schema.config.c.config_hash == config_hash) 

301 ).fetchone() 

302 if cur_config is not None: 

303 return int(cur_config.config_id) # mypy doesn't know it's always int 

304 # Config not found, create a new one: 

305 config_id: int = conn.execute( 

306 self._schema.config.insert().values(config_hash=config_hash) 

307 ).inserted_primary_key[0] 

308 self._save_params( 

309 conn, 

310 self._schema.config_param, 

311 {tunable.name: tunable.value for (tunable, _group) in tunables}, 

312 config_id=config_id, 

313 ) 

314 return config_id 

315 

316 def _new_trial( 

317 self, 

318 tunables: TunableGroups, 

319 ts_start: Optional[datetime] = None, 

320 config: Optional[Dict[str, Any]] = None, 

321 ) -> Storage.Trial: 

322 # MySQL can round microseconds into the future causing scheduler to skip trials. 

323 # Truncate microseconds to avoid this issue. 

324 ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local").replace( 

325 microsecond=0 

326 ) 

327 _LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start) 

328 with self._engine.begin() as conn: 

329 try: 

330 config_id = self._get_config_id(conn, tunables) 

331 conn.execute( 

332 self._schema.trial.insert().values( 

333 exp_id=self._experiment_id, 

334 trial_id=self._trial_id, 

335 config_id=config_id, 

336 ts_start=ts_start, 

337 status="PENDING", 

338 ) 

339 ) 

340 

341 # Note: config here is the framework config, not the target 

342 # environment config (i.e., tunables). 

343 if config is not None: 

344 self._save_params( 

345 conn, 

346 self._schema.trial_param, 

347 config, 

348 exp_id=self._experiment_id, 

349 trial_id=self._trial_id, 

350 ) 

351 

352 trial = Trial( 

353 engine=self._engine, 

354 schema=self._schema, 

355 tunables=tunables, 

356 experiment_id=self._experiment_id, 

357 trial_id=self._trial_id, 

358 config_id=config_id, 

359 opt_targets=self._opt_targets, 

360 config=config, 

361 ) 

362 self._trial_id += 1 

363 return trial 

364 except Exception: 

365 conn.rollback() 

366 raise