Coverage for mlos_bench/mlos_bench/schedulers/base_scheduler.py: 91%

129 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 01:18 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Base class for the optimization loop scheduling policies.""" 

6 

7import json 

8import logging 

9from abc import ABCMeta, abstractmethod 

10from datetime import datetime 

11from types import TracebackType 

12from typing import Any, Dict, List, Literal, Optional, Tuple, Type 

13 

14from pytz import UTC 

15 

16from mlos_bench.config.schemas import ConfigSchema 

17from mlos_bench.environments.base_environment import Environment 

18from mlos_bench.optimizers.base_optimizer import Optimizer 

19from mlos_bench.storage.base_storage import Storage 

20from mlos_bench.tunables.tunable_groups import TunableGroups 

21from mlos_bench.util import merge_parameters 

22 

23_LOG = logging.getLogger(__name__) 

24 

25 

26class Scheduler(metaclass=ABCMeta): 

27 # pylint: disable=too-many-instance-attributes 

28 """Base class for the optimization loop scheduling policies.""" 

29 

30 def __init__( # pylint: disable=too-many-arguments 

31 self, 

32 *, 

33 config: Dict[str, Any], 

34 global_config: Dict[str, Any], 

35 environment: Environment, 

36 optimizer: Optimizer, 

37 storage: Storage, 

38 root_env_config: str, 

39 ): 

40 """ 

41 Create a new instance of the scheduler. The constructor of this and the derived 

42 classes is called by the persistence service after reading the class JSON 

43 configuration. Other objects like the Environment and Optimizer are provided by 

44 the Launcher. 

45 

46 Parameters 

47 ---------- 

48 config : dict 

49 The configuration for the scheduler. 

50 global_config : dict 

51 he global configuration for the experiment. 

52 environment : Environment 

53 The environment to benchmark/optimize. 

54 optimizer : Optimizer 

55 The optimizer to use. 

56 storage : Storage 

57 The storage to use. 

58 root_env_config : str 

59 Path to the root environment configuration. 

60 """ 

61 self.global_config = global_config 

62 config = merge_parameters( 

63 dest=config.copy(), 

64 source=global_config, 

65 required_keys=["experiment_id", "trial_id"], 

66 ) 

67 self._validate_json_config(config) 

68 

69 self._experiment_id = config["experiment_id"].strip() 

70 self._trial_id = int(config["trial_id"]) 

71 self._config_id = int(config.get("config_id", -1)) 

72 self._max_trials = int(config.get("max_trials", -1)) 

73 self._trial_count = 0 

74 

75 self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1)) 

76 if self._trial_config_repeat_count <= 0: 

77 raise ValueError( 

78 f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}" 

79 ) 

80 

81 self._do_teardown = bool(config.get("teardown", True)) 

82 

83 self.experiment: Optional[Storage.Experiment] = None 

84 self.environment = environment 

85 self.optimizer = optimizer 

86 self.storage = storage 

87 self._root_env_config = root_env_config 

88 self._last_trial_id = -1 

89 self._ran_trials: List[Storage.Trial] = [] 

90 

91 _LOG.debug("Scheduler instantiated: %s :: %s", self, config) 

92 

93 def _validate_json_config(self, config: dict) -> None: 

94 """Reconstructs a basic json config that this class might have been instantiated 

95 from in order to validate configs provided outside the file loading 

96 mechanism. 

97 """ 

98 json_config: dict = { 

99 "class": self.__class__.__module__ + "." + self.__class__.__name__, 

100 } 

101 if config: 

102 json_config["config"] = config.copy() 

103 # The json schema does not allow for -1 as a valid value for config_id. 

104 # As it is just a default placeholder value, and not required, we can 

105 # remove it from the config copy prior to validation safely. 

106 config_id = json_config["config"].get("config_id") 

107 if config_id is not None and isinstance(config_id, int) and config_id < 0: 

108 json_config["config"].pop("config_id") 

109 ConfigSchema.SCHEDULER.validate(json_config) 

110 

111 @property 

112 def trial_config_repeat_count(self) -> int: 

113 """Gets the number of trials to run for a given config.""" 

114 return self._trial_config_repeat_count 

115 

116 @property 

117 def trial_count(self) -> int: 

118 """Gets the current number of trials run for the experiment.""" 

119 return self._trial_count 

120 

121 @property 

122 def max_trials(self) -> int: 

123 """Gets the maximum number of trials to run for a given experiment, or -1 for no 

124 limit. 

125 """ 

126 return self._max_trials 

127 

128 def __repr__(self) -> str: 

129 """ 

130 Produce a human-readable version of the Scheduler (mostly for logging). 

131 

132 Returns 

133 ------- 

134 string : str 

135 A human-readable version of the Scheduler. 

136 """ 

137 return self.__class__.__name__ 

138 

139 def __enter__(self) -> "Scheduler": 

140 """Enter the scheduler's context.""" 

141 _LOG.debug("Scheduler START :: %s", self) 

142 assert self.experiment is None 

143 self.environment.__enter__() 

144 self.optimizer.__enter__() 

145 # Start new or resume the existing experiment. Verify that the 

146 # experiment configuration is compatible with the previous runs. 

147 # If the `merge` config parameter is present, merge in the data 

148 # from other experiments and check for compatibility. 

149 self.experiment = self.storage.experiment( 

150 experiment_id=self._experiment_id, 

151 trial_id=self._trial_id, 

152 root_env_config=self._root_env_config, 

153 description=self.environment.name, 

154 tunables=self.environment.tunable_params, 

155 opt_targets=self.optimizer.targets, 

156 ).__enter__() 

157 return self 

158 

159 def __exit__( 

160 self, 

161 ex_type: Optional[Type[BaseException]], 

162 ex_val: Optional[BaseException], 

163 ex_tb: Optional[TracebackType], 

164 ) -> Literal[False]: 

165 """Exit the context of the scheduler.""" 

166 if ex_val is None: 

167 _LOG.debug("Scheduler END :: %s", self) 

168 else: 

169 assert ex_type and ex_val 

170 _LOG.warning("Scheduler END :: %s", self, exc_info=(ex_type, ex_val, ex_tb)) 

171 assert self.experiment is not None 

172 self.experiment.__exit__(ex_type, ex_val, ex_tb) 

173 self.optimizer.__exit__(ex_type, ex_val, ex_tb) 

174 self.environment.__exit__(ex_type, ex_val, ex_tb) 

175 self.experiment = None 

176 return False # Do not suppress exceptions 

177 

178 @abstractmethod 

179 def start(self) -> None: 

180 """Start the optimization loop.""" 

181 assert self.experiment is not None 

182 _LOG.info( 

183 "START: Experiment: %s Env: %s Optimizer: %s", 

184 self.experiment, 

185 self.environment, 

186 self.optimizer, 

187 ) 

188 if _LOG.isEnabledFor(logging.INFO): 

189 _LOG.info("Root Environment:\n%s", self.environment.pprint()) 

190 

191 if self._config_id > 0: 

192 tunables = self.load_config(self._config_id) 

193 self.schedule_trial(tunables) 

194 

195 def teardown(self) -> None: 

196 """ 

197 Tear down the environment. 

198 

199 Call it after the completion of the `.start()` in the scheduler context. 

200 """ 

201 assert self.experiment is not None 

202 if self._do_teardown: 

203 self.environment.teardown() 

204 

205 def get_best_observation(self) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]: 

206 """Get the best observation from the optimizer.""" 

207 (best_score, best_config) = self.optimizer.get_best_observation() 

208 _LOG.info("Env: %s best score: %s", self.environment, best_score) 

209 return (best_score, best_config) 

210 

211 def load_config(self, config_id: int) -> TunableGroups: 

212 """Load the existing tunable configuration from the storage.""" 

213 assert self.experiment is not None 

214 tunable_values = self.experiment.load_tunable_config(config_id) 

215 tunables = self.environment.tunable_params.assign(tunable_values) 

216 _LOG.info("Load config from storage: %d", config_id) 

217 if _LOG.isEnabledFor(logging.DEBUG): 

218 _LOG.debug("Config %d ::\n%s", config_id, json.dumps(tunable_values, indent=2)) 

219 return tunables 

220 

221 def _schedule_new_optimizer_suggestions(self) -> bool: 

222 """ 

223 Optimizer part of the loop. 

224 

225 Load the results of the executed trials into the optimizer, suggest new 

226 configurations, and add them to the queue. Return True if optimization is not 

227 over, False otherwise. 

228 """ 

229 assert self.experiment is not None 

230 (trial_ids, configs, scores, status) = self.experiment.load(self._last_trial_id) 

231 _LOG.info("QUEUE: Update the optimizer with trial results: %s", trial_ids) 

232 self.optimizer.bulk_register(configs, scores, status) 

233 self._last_trial_id = max(trial_ids, default=self._last_trial_id) 

234 

235 not_done = self.not_done() 

236 if not_done: 

237 tunables = self.optimizer.suggest() 

238 self.schedule_trial(tunables) 

239 

240 return not_done 

241 

242 def schedule_trial(self, tunables: TunableGroups) -> None: 

243 """Add a configuration to the queue of trials.""" 

244 for repeat_i in range(1, self._trial_config_repeat_count + 1): 

245 self._add_trial_to_queue( 

246 tunables, 

247 config={ 

248 # Add some additional metadata to track for the trial such as the 

249 # optimizer config used. 

250 # Note: these values are unfortunately mutable at the moment. 

251 # Consider them as hints of what the config was the trial *started*. 

252 # It is possible that the experiment configs were changed 

253 # between resuming the experiment (since that is not currently 

254 # prevented). 

255 "optimizer": self.optimizer.name, 

256 "repeat_i": repeat_i, 

257 "is_defaults": tunables.is_defaults(), 

258 **{ 

259 f"opt_{key}_{i}": val 

260 for (i, opt_target) in enumerate(self.optimizer.targets.items()) 

261 for (key, val) in zip(["target", "direction"], opt_target) 

262 }, 

263 }, 

264 ) 

265 

266 def _add_trial_to_queue( 

267 self, 

268 tunables: TunableGroups, 

269 ts_start: Optional[datetime] = None, 

270 config: Optional[Dict[str, Any]] = None, 

271 ) -> None: 

272 """ 

273 Add a configuration to the queue of trials. 

274 

275 A wrapper for the `Experiment.new_trial` method. 

276 """ 

277 assert self.experiment is not None 

278 trial = self.experiment.new_trial(tunables, ts_start, config) 

279 _LOG.info("QUEUE: Add new trial: %s", trial) 

280 

281 def _run_schedule(self, running: bool = False) -> None: 

282 """ 

283 Scheduler part of the loop. 

284 

285 Check for pending trials in the queue and run them. 

286 """ 

287 assert self.experiment is not None 

288 for trial in self.experiment.pending_trials(datetime.now(UTC), running=running): 

289 self.run_trial(trial) 

290 

291 def not_done(self) -> bool: 

292 """ 

293 Check the stopping conditions. 

294 

295 By default, stop when the optimizer converges or max limit of trials reached. 

296 """ 

297 return self.optimizer.not_converged() and ( 

298 self._trial_count < self._max_trials or self._max_trials <= 0 

299 ) 

300 

301 @abstractmethod 

302 def run_trial(self, trial: Storage.Trial) -> None: 

303 """ 

304 Set up and run a single trial. 

305 

306 Save the results in the storage. 

307 """ 

308 assert self.experiment is not None 

309 self._trial_count += 1 

310 self._ran_trials.append(trial) 

311 _LOG.info("QUEUE: Execute trial # %d/%d :: %s", self._trial_count, self._max_trials, trial) 

312 

313 @property 

314 def ran_trials(self) -> List[Storage.Trial]: 

315 """Get the list of trials that were run.""" 

316 return self._ran_trials