Coverage for mlos_bench/mlos_bench/schedulers/base_scheduler.py: 89%

107 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Base class for the optimization loop scheduling policies. 

7""" 

8 

9import json 

10import logging 

11from datetime import datetime 

12 

13from abc import ABCMeta, abstractmethod 

14from types import TracebackType 

15from typing import Any, Dict, Optional, Tuple, Type 

16from typing_extensions import Literal 

17 

18from pytz import UTC 

19 

20from mlos_bench.environments.base_environment import Environment 

21from mlos_bench.optimizers.base_optimizer import Optimizer 

22from mlos_bench.storage.base_storage import Storage 

23from mlos_bench.tunables.tunable_groups import TunableGroups 

24from mlos_bench.util import merge_parameters 

25 

26_LOG = logging.getLogger(__name__) 

27 

28 

29class Scheduler(metaclass=ABCMeta): 

30 # pylint: disable=too-many-instance-attributes 

31 """ 

32 Base class for the optimization loop scheduling policies. 

33 """ 

34 

35 def __init__(self, *, 

36 config: Dict[str, Any], 

37 global_config: Dict[str, Any], 

38 environment: Environment, 

39 optimizer: Optimizer, 

40 storage: Storage, 

41 root_env_config: str): 

42 """ 

43 Create a new instance of the scheduler. The constructor of this 

44 and the derived classes is called by the persistence service 

45 after reading the class JSON configuration. Other objects like 

46 the Environment and Optimizer are provided by the Launcher. 

47 

48 Parameters 

49 ---------- 

50 config : dict 

51 The configuration for the scheduler. 

52 global_config : dict 

53 he global configuration for the experiment. 

54 environment : Environment 

55 The environment to benchmark/optimize. 

56 optimizer : Optimizer 

57 The optimizer to use. 

58 storage : Storage 

59 The storage to use. 

60 root_env_config : str 

61 Path to the root environment configuration. 

62 """ 

63 self.global_config = global_config 

64 config = merge_parameters(dest=config.copy(), source=global_config, 

65 required_keys=["experiment_id", "trial_id"]) 

66 

67 self._experiment_id = config["experiment_id"].strip() 

68 self._trial_id = int(config["trial_id"]) 

69 self._config_id = int(config.get("config_id", -1)) 

70 self._max_trials = int(config.get("max_trials", -1)) 

71 self._trial_count = 0 

72 

73 self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1)) 

74 if self._trial_config_repeat_count <= 0: 

75 raise ValueError(f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}") 

76 

77 self._do_teardown = bool(config.get("teardown", True)) 

78 

79 self.experiment: Optional[Storage.Experiment] = None 

80 self.environment = environment 

81 self.optimizer = optimizer 

82 self.storage = storage 

83 self._root_env_config = root_env_config 

84 self._last_trial_id = -1 

85 

86 _LOG.debug("Scheduler instantiated: %s :: %s", self, config) 

87 

88 def __repr__(self) -> str: 

89 """ 

90 Produce a human-readable version of the Scheduler (mostly for logging). 

91 

92 Returns 

93 ------- 

94 string : str 

95 A human-readable version of the Scheduler. 

96 """ 

97 return self.__class__.__name__ 

98 

99 def __enter__(self) -> 'Scheduler': 

100 """ 

101 Enter the scheduler's context. 

102 """ 

103 _LOG.debug("Scheduler START :: %s", self) 

104 assert self.experiment is None 

105 self.environment.__enter__() 

106 self.optimizer.__enter__() 

107 # Start new or resume the existing experiment. Verify that the 

108 # experiment configuration is compatible with the previous runs. 

109 # If the `merge` config parameter is present, merge in the data 

110 # from other experiments and check for compatibility. 

111 self.experiment = self.storage.experiment( 

112 experiment_id=self._experiment_id, 

113 trial_id=self._trial_id, 

114 root_env_config=self._root_env_config, 

115 description=self.environment.name, 

116 tunables=self.environment.tunable_params, 

117 opt_target=self.optimizer.target, 

118 opt_direction=self.optimizer.direction, 

119 ).__enter__() 

120 return self 

121 

122 def __exit__(self, 

123 ex_type: Optional[Type[BaseException]], 

124 ex_val: Optional[BaseException], 

125 ex_tb: Optional[TracebackType]) -> Literal[False]: 

126 """ 

127 Exit the context of the scheduler. 

128 """ 

129 if ex_val is None: 

130 _LOG.debug("Scheduler END :: %s", self) 

131 else: 

132 assert ex_type and ex_val 

133 _LOG.warning("Scheduler END :: %s", self, exc_info=(ex_type, ex_val, ex_tb)) 

134 assert self.experiment is not None 

135 self.experiment.__exit__(ex_type, ex_val, ex_tb) 

136 self.optimizer.__exit__(ex_type, ex_val, ex_tb) 

137 self.environment.__exit__(ex_type, ex_val, ex_tb) 

138 self.experiment = None 

139 return False # Do not suppress exceptions 

140 

141 @abstractmethod 

142 def start(self) -> None: 

143 """ 

144 Start the optimization loop. 

145 """ 

146 assert self.experiment is not None 

147 _LOG.info("START: Experiment: %s Env: %s Optimizer: %s", 

148 self.experiment, self.environment, self.optimizer) 

149 if _LOG.isEnabledFor(logging.INFO): 

150 _LOG.info("Root Environment:\n%s", self.environment.pprint()) 

151 

152 if self._config_id > 0: 

153 tunables = self.load_config(self._config_id) 

154 self.schedule_trial(tunables) 

155 

156 def teardown(self) -> None: 

157 """ 

158 Tear down the environment. 

159 Call it after the completion of the `.start()` in the scheduler context. 

160 """ 

161 assert self.experiment is not None 

162 if self._do_teardown: 

163 self.environment.teardown() 

164 

165 def get_best_observation(self) -> Tuple[Optional[float], Optional[TunableGroups]]: 

166 """ 

167 Get the best observation from the optimizer. 

168 """ 

169 (best_score, best_config) = self.optimizer.get_best_observation() 

170 _LOG.info("Env: %s best score: %s", self.environment, best_score) 

171 return (best_score, best_config) 

172 

173 def load_config(self, config_id: int) -> TunableGroups: 

174 """ 

175 Load the existing tunable configuration from the storage. 

176 """ 

177 assert self.experiment is not None 

178 tunable_values = self.experiment.load_tunable_config(config_id) 

179 tunables = self.environment.tunable_params.assign(tunable_values) 

180 _LOG.info("Load config from storage: %d", config_id) 

181 if _LOG.isEnabledFor(logging.DEBUG): 

182 _LOG.debug("Config %d ::\n%s", config_id, json.dumps(tunable_values, indent=2)) 

183 return tunables 

184 

185 def _schedule_new_optimizer_suggestions(self) -> bool: 

186 """ 

187 Optimizer part of the loop. Load the results of the executed trials 

188 into the optimizer, suggest new configurations, and add them to the queue. 

189 Return True if optimization is not over, False otherwise. 

190 """ 

191 assert self.experiment is not None 

192 (trial_ids, configs, scores, status) = self.experiment.load(self._last_trial_id) 

193 _LOG.info("QUEUE: Update the optimizer with trial results: %s", trial_ids) 

194 self.optimizer.bulk_register(configs, scores, status) 

195 self._last_trial_id = max(trial_ids, default=self._last_trial_id) 

196 

197 not_done = self.not_done() 

198 if not_done: 

199 tunables = self.optimizer.suggest() 

200 self.schedule_trial(tunables) 

201 

202 return not_done 

203 

204 def schedule_trial(self, tunables: TunableGroups) -> None: 

205 """ 

206 Add a configuration to the queue of trials. 

207 """ 

208 for repeat_i in range(1, self._trial_config_repeat_count + 1): 

209 self._add_trial_to_queue(tunables, config={ 

210 # Add some additional metadata to track for the trial such as the 

211 # optimizer config used. 

212 # Note: these values are unfortunately mutable at the moment. 

213 # Consider them as hints of what the config was the trial *started*. 

214 # It is possible that the experiment configs were changed 

215 # between resuming the experiment (since that is not currently 

216 # prevented). 

217 # TODO: Improve for supporting multi-objective 

218 # (e.g., opt_target_1, opt_target_2, ... and opt_direction_1, opt_direction_2, ...) 

219 "optimizer": self.optimizer.name, 

220 "opt_target": self.optimizer.target, 

221 "opt_direction": self.optimizer.direction, 

222 "repeat_i": repeat_i, 

223 "is_defaults": tunables.is_defaults, 

224 }) 

225 

226 def _add_trial_to_queue(self, tunables: TunableGroups, 

227 ts_start: Optional[datetime] = None, 

228 config: Optional[Dict[str, Any]] = None) -> None: 

229 """ 

230 Add a configuration to the queue of trials. 

231 A wrapper for the `Experiment.new_trial` method. 

232 """ 

233 assert self.experiment is not None 

234 trial = self.experiment.new_trial(tunables, ts_start, config) 

235 _LOG.info("QUEUE: Add new trial: %s", trial) 

236 

237 def _run_schedule(self, running: bool = False) -> None: 

238 """ 

239 Scheduler part of the loop. Check for pending trials in the queue and run them. 

240 """ 

241 assert self.experiment is not None 

242 for trial in self.experiment.pending_trials(datetime.now(UTC), running=running): 

243 self.run_trial(trial) 

244 

245 def not_done(self) -> bool: 

246 """ 

247 Check the stopping conditions. 

248 By default, stop when the optimizer converges or max limit of trials reached. 

249 """ 

250 return self.optimizer.not_converged() and ( 

251 self._trial_count < self._max_trials or self._max_trials <= 0 

252 ) 

253 

254 @abstractmethod 

255 def run_trial(self, trial: Storage.Trial) -> None: 

256 """ 

257 Set up and run a single trial. Save the results in the storage. 

258 """ 

259 assert self.experiment is not None 

260 self._trial_count += 1 

261 _LOG.info("QUEUE: Execute trial # %d/%d :: %s", self._trial_count, self._max_trials, trial)