Coverage for mlos_bench/mlos_bench/storage/base_storage.py: 95%

130 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-05 00:36 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Base interface for saving and restoring the benchmark data. 

7""" 

8 

9import logging 

10from abc import ABCMeta, abstractmethod 

11from datetime import datetime 

12from types import TracebackType 

13from typing import Optional, Union, List, Tuple, Dict, Iterator, Type, Any 

14from typing_extensions import Literal 

15 

16from mlos_bench.config.schemas import ConfigSchema 

17from mlos_bench.environments.status import Status 

18from mlos_bench.services.base_service import Service 

19from mlos_bench.storage.base_experiment_data import ExperimentData 

20from mlos_bench.tunables.tunable_groups import TunableGroups 

21from mlos_bench.util import get_git_info 

22 

23_LOG = logging.getLogger(__name__) 

24 

25 

26class Storage(metaclass=ABCMeta): 

27 """ 

28 An abstract interface between the benchmarking framework 

29 and storage systems (e.g., SQLite or MLFLow). 

30 """ 

31 

32 def __init__(self, 

33 config: Dict[str, Any], 

34 global_config: Optional[dict] = None, 

35 service: Optional[Service] = None): 

36 """ 

37 Create a new storage object. 

38 

39 Parameters 

40 ---------- 

41 config : dict 

42 Free-format key/value pairs of configuration parameters. 

43 """ 

44 _LOG.debug("Storage config: %s", config) 

45 self._validate_json_config(config) 

46 self._service = service 

47 self._config = config.copy() 

48 self._global_config = global_config or {} 

49 

50 def _validate_json_config(self, config: dict) -> None: 

51 """ 

52 Reconstructs a basic json config that this class might have been 

53 instantiated from in order to validate configs provided outside the 

54 file loading mechanism. 

55 """ 

56 json_config: dict = { 

57 "class": self.__class__.__module__ + "." + self.__class__.__name__, 

58 } 

59 if config: 

60 json_config["config"] = config 

61 ConfigSchema.STORAGE.validate(json_config) 

62 

63 @property 

64 @abstractmethod 

65 def experiments(self) -> Dict[str, ExperimentData]: 

66 """ 

67 Retrieve the experiments' data from the storage. 

68 

69 Returns 

70 ------- 

71 experiments : Dict[str, ExperimentData] 

72 A dictionary of the experiments' data, keyed by experiment id. 

73 """ 

74 

75 @abstractmethod 

76 def experiment(self, *, 

77 experiment_id: str, 

78 trial_id: int, 

79 root_env_config: str, 

80 description: str, 

81 tunables: TunableGroups, 

82 opt_target: str, 

83 opt_direction: Optional[str]) -> 'Storage.Experiment': 

84 """ 

85 Create a new experiment in the storage. 

86 

87 We need the `opt_target` parameter here to know what metric to retrieve 

88 when we load the data from previous trials. Later we will replace it with 

89 full metadata about the optimization direction, multiple objectives, etc. 

90 

91 Parameters 

92 ---------- 

93 experiment_id : str 

94 Unique identifier of the experiment. 

95 trial_id : int 

96 Starting number of the trial. 

97 root_env_config : str 

98 A path to the root JSON configuration file of the benchmarking environment. 

99 description : str 

100 Human-readable description of the experiment. 

101 tunables : TunableGroups 

102 opt_target : str 

103 Name of metric we're optimizing for. 

104 opt_direction: Optional[str] 

105 Direction to optimize the metric (e.g., min or max) 

106 

107 Returns 

108 ------- 

109 experiment : Storage.Experiment 

110 An object that allows to update the storage with 

111 the results of the experiment and related data. 

112 """ 

113 

114 class Experiment(metaclass=ABCMeta): 

115 # pylint: disable=too-many-instance-attributes 

116 """ 

117 Base interface for storing the results of the experiment. 

118 This class is instantiated in the `Storage.experiment()` method. 

119 """ 

120 

121 def __init__(self, 

122 *, 

123 tunables: TunableGroups, 

124 experiment_id: str, 

125 trial_id: int, 

126 root_env_config: str, 

127 description: str, 

128 opt_target: str, 

129 opt_direction: Optional[str]): 

130 self._tunables = tunables.copy() 

131 self._trial_id = trial_id 

132 self._experiment_id = experiment_id 

133 (self._git_repo, self._git_commit, self._root_env_config) = get_git_info(root_env_config) 

134 self._description = description 

135 self._opt_target = opt_target 

136 assert opt_direction in {None, "min", "max"} 

137 self._opt_direction = opt_direction 

138 self._in_context = False 

139 

140 def __enter__(self) -> 'Storage.Experiment': 

141 """ 

142 Enter the context of the experiment. 

143 

144 Override the `_setup` method to add custom context initialization. 

145 """ 

146 _LOG.debug("Starting experiment: %s", self) 

147 assert not self._in_context 

148 self._setup() 

149 self._in_context = True 

150 return self 

151 

152 def __exit__(self, exc_type: Optional[Type[BaseException]], 

153 exc_val: Optional[BaseException], 

154 exc_tb: Optional[TracebackType]) -> Literal[False]: 

155 """ 

156 End the context of the experiment. 

157 

158 Override the `_teardown` method to add custom context teardown logic. 

159 """ 

160 is_ok = exc_val is None 

161 if is_ok: 

162 _LOG.debug("Finishing experiment: %s", self) 

163 else: 

164 assert exc_type and exc_val 

165 _LOG.warning("Finishing experiment: %s", self, 

166 exc_info=(exc_type, exc_val, exc_tb)) 

167 assert self._in_context 

168 self._teardown(is_ok) 

169 self._in_context = False 

170 return False # Do not suppress exceptions 

171 

172 def __repr__(self) -> str: 

173 return self._experiment_id 

174 

175 def _setup(self) -> None: 

176 """ 

177 Create a record of the new experiment or find an existing one in the storage. 

178 

179 This method is called by `Storage.Experiment.__enter__()`. 

180 """ 

181 

182 def _teardown(self, is_ok: bool) -> None: 

183 """ 

184 Finalize the experiment in the storage. 

185 

186 This method is called by `Storage.Experiment.__exit__()`. 

187 

188 Parameters 

189 ---------- 

190 is_ok : bool 

191 True if there were no exceptions during the experiment, False otherwise. 

192 """ 

193 

194 @property 

195 def experiment_id(self) -> str: 

196 """Get the Experiment's ID""" 

197 return self._experiment_id 

198 

199 @property 

200 def trial_id(self) -> int: 

201 """Get the current Trial ID""" 

202 return self._trial_id 

203 

204 @property 

205 def description(self) -> str: 

206 """Get the Experiment's description""" 

207 return self._description 

208 

209 @property 

210 def tunables(self) -> TunableGroups: 

211 """Get the Experiment's tunables""" 

212 return self._tunables 

213 

214 @property 

215 def opt_target(self) -> str: 

216 """Get the Experiment's optimization target""" 

217 return self._opt_target 

218 

219 @property 

220 def opt_direction(self) -> Optional[str]: 

221 """Get the Experiment's optimization target""" 

222 return self._opt_direction 

223 

224 @abstractmethod 

225 def merge(self, experiment_ids: List[str]) -> None: 

226 """ 

227 Merge in the results of other (compatible) experiments trials. 

228 Used to help warm up the optimizer for this experiment. 

229 

230 Parameters 

231 ---------- 

232 experiment_ids : List[str] 

233 List of IDs of the experiments to merge in. 

234 """ 

235 

236 @abstractmethod 

237 def load_tunable_config(self, config_id: int) -> Dict[str, Any]: 

238 """ 

239 Load tunable values for a given config ID. 

240 """ 

241 

242 @abstractmethod 

243 def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]: 

244 """ 

245 Retrieve the telemetry data for a given trial. 

246 

247 Parameters 

248 ---------- 

249 trial_id : int 

250 Trial ID. 

251 

252 Returns 

253 ------- 

254 metrics : List[Tuple[datetime, str, Any]] 

255 Telemetry data. 

256 """ 

257 

258 @abstractmethod 

259 def load(self, last_trial_id: int = -1, 

260 ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: 

261 """ 

262 Load (tunable values, benchmark scores, status) to warm-up the optimizer. 

263 

264 If `last_trial_id` is present, load only the data from the (completed) trials 

265 that were scheduled *after* the given trial ID. Otherwise, return data from ALL 

266 merged-in experiments and attempt to impute the missing tunable values. 

267 

268 Parameters 

269 ---------- 

270 last_trial_id : int 

271 (Optional) Trial ID to start from. 

272 

273 Returns 

274 ------- 

275 (trial_ids, configs, scores, status) : ([int], [dict], [Optional[dict]], [Status]) 

276 Trial ids, Tunable values, benchmark scores, and status of the trials. 

277 """ 

278 

279 @abstractmethod 

280 def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator['Storage.Trial']: 

281 """ 

282 Return an iterator over the pending trials that are scheduled to run 

283 on or before the specified timestamp. 

284 

285 Parameters 

286 ---------- 

287 timestamp : datetime 

288 The time in UTC to check for scheduled trials. 

289 running : bool 

290 If True, include the trials that are already running. 

291 Otherwise, return only the scheduled trials. 

292 

293 Returns 

294 ------- 

295 trials : Iterator[Storage.Trial] 

296 An iterator over the scheduled (and maybe running) trials. 

297 """ 

298 

299 @abstractmethod 

300 def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None, 

301 config: Optional[Dict[str, Any]] = None) -> 'Storage.Trial': 

302 """ 

303 Create a new experiment run in the storage. 

304 

305 Parameters 

306 ---------- 

307 tunables : TunableGroups 

308 Tunable parameters to use for the trial. 

309 ts_start : Optional[datetime] 

310 Timestamp of the trial start (can be in the future). 

311 config : dict 

312 Key/value pairs of additional non-tunable parameters of the trial. 

313 

314 Returns 

315 ------- 

316 trial : Storage.Trial 

317 An object that allows to update the storage with 

318 the results of the experiment trial run. 

319 """ 

320 

321 class Trial(metaclass=ABCMeta): 

322 # pylint: disable=too-many-instance-attributes 

323 """ 

324 Base interface for storing the results of a single run of the experiment. 

325 This class is instantiated in the `Storage.Experiment.trial()` method. 

326 """ 

327 

328 def __init__(self, *, 

329 tunables: TunableGroups, experiment_id: str, trial_id: int, 

330 tunable_config_id: int, opt_target: str, opt_direction: Optional[str], 

331 config: Optional[Dict[str, Any]] = None): 

332 self._tunables = tunables 

333 self._experiment_id = experiment_id 

334 self._trial_id = trial_id 

335 self._tunable_config_id = tunable_config_id 

336 self._opt_target = opt_target 

337 assert opt_direction in {None, "min", "max"} 

338 self._opt_direction = opt_direction 

339 self._config = config or {} 

340 

341 def __repr__(self) -> str: 

342 return f"{self._experiment_id}:{self._trial_id}:{self._tunable_config_id}" 

343 

344 @property 

345 def trial_id(self) -> int: 

346 """ 

347 ID of the current trial. 

348 """ 

349 return self._trial_id 

350 

351 @property 

352 def tunable_config_id(self) -> int: 

353 """ 

354 ID of the current trial (tunable) configuration. 

355 """ 

356 return self._tunable_config_id 

357 

358 @property 

359 def opt_target(self) -> str: 

360 """ 

361 Get the Trial's optimization target. 

362 """ 

363 return self._opt_target 

364 

365 @property 

366 def opt_direction(self) -> Optional[str]: 

367 """ 

368 Get the Trial's optimization direction (e.g., min or max) 

369 """ 

370 return self._opt_direction 

371 

372 @property 

373 def tunables(self) -> TunableGroups: 

374 """ 

375 Tunable parameters of the current trial 

376 

377 (e.g., application Environment's "config") 

378 """ 

379 return self._tunables 

380 

381 def config(self, global_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: 

382 """ 

383 Produce a copy of the global configuration updated 

384 with the parameters of the current trial. 

385 

386 Note: this is not the target Environment's "config" (i.e., tunable 

387 params), but rather the internal "config" which consists of a 

388 combination of somewhat more static variables defined in the json config 

389 files. 

390 """ 

391 config = self._config.copy() 

392 config.update(global_config or {}) 

393 config["experiment_id"] = self._experiment_id 

394 config["trial_id"] = self._trial_id 

395 return config 

396 

397 @abstractmethod 

398 def update(self, status: Status, timestamp: datetime, 

399 metrics: Optional[Union[Dict[str, Any], float]] = None 

400 ) -> Optional[Dict[str, Any]]: 

401 """ 

402 Update the storage with the results of the experiment. 

403 

404 Parameters 

405 ---------- 

406 status : Status 

407 Status of the experiment run. 

408 timestamp: datetime 

409 Timestamp of the status and metrics. 

410 metrics : Optional[Union[Dict[str, Any], float]] 

411 One or several metrics of the experiment run. 

412 Must contain the (float) optimization target if the status is SUCCEEDED. 

413 

414 Returns 

415 ------- 

416 metrics : Optional[Dict[str, Any]] 

417 Same as `metrics`, but always in the dict format. 

418 """ 

419 _LOG.info("Store trial: %s :: %s %s", self, status, metrics) 

420 if isinstance(metrics, dict) and self._opt_target not in metrics: 

421 _LOG.warning("Trial %s :: opt.target missing: %s", self, self._opt_target) 

422 # raise ValueError( 

423 # f"Optimization target '{self._opt_target}' is missing from {metrics}") 

424 return {self._opt_target: metrics} if isinstance(metrics, (float, int)) else metrics 

425 

426 @abstractmethod 

427 def update_telemetry(self, status: Status, timestamp: datetime, 

428 metrics: List[Tuple[datetime, str, Any]]) -> None: 

429 """ 

430 Save the experiment's telemetry data and intermediate status. 

431 

432 Parameters 

433 ---------- 

434 status : Status 

435 Current status of the trial. 

436 timestamp: datetime 

437 Timestamp of the status (but not the metrics). 

438 metrics : List[Tuple[datetime, str, Any]] 

439 Telemetry data. 

440 """ 

441 _LOG.info("Store telemetry: %s :: %s %d records", self, status, len(metrics))