Coverage for mlos_bench/mlos_bench/launcher.py: 96%

175 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-05 00:36 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6A helper class to load the configuration files, parse the command line parameters, 

7and instantiate the main components of mlos_bench system. 

8 

9It is used in `mlos_bench.run` module to run the benchmark/optimizer from the 

10command line. 

11""" 

12 

13import argparse 

14import logging 

15import sys 

16 

17from typing import Any, Dict, Iterable, List, Optional, Tuple 

18 

19from mlos_bench.config.schemas import ConfigSchema 

20from mlos_bench.dict_templater import DictTemplater 

21from mlos_bench.util import try_parse_val 

22 

23from mlos_bench.tunables.tunable import TunableValue 

24from mlos_bench.tunables.tunable_groups import TunableGroups 

25from mlos_bench.environments.base_environment import Environment 

26 

27from mlos_bench.optimizers.base_optimizer import Optimizer 

28from mlos_bench.optimizers.mock_optimizer import MockOptimizer 

29from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer 

30 

31from mlos_bench.storage.base_storage import Storage 

32 

33from mlos_bench.services.base_service import Service 

34from mlos_bench.services.local.local_exec import LocalExecService 

35from mlos_bench.services.config_persistence import ConfigPersistenceService 

36 

37from mlos_bench.schedulers.base_scheduler import Scheduler 

38 

39from mlos_bench.services.types.config_loader_type import SupportsConfigLoading 

40 

41 

42_LOG_LEVEL = logging.INFO 

43_LOG_FORMAT = '%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s' 

44logging.basicConfig(level=_LOG_LEVEL, format=_LOG_FORMAT) 

45 

46_LOG = logging.getLogger(__name__) 

47 

48 

49class Launcher: 

50 # pylint: disable=too-few-public-methods,too-many-instance-attributes 

51 """ 

52 Command line launcher for mlos_bench and mlos_core. 

53 """ 

54 

55 def __init__(self, description: str, long_text: str = "", argv: Optional[List[str]] = None): 

56 # pylint: disable=too-many-statements 

57 _LOG.info("Launch: %s", description) 

58 epilog = """ 

59 Additional --key=value pairs can be specified to augment or override values listed in --globals. 

60 Other required_args values can also be pulled from shell environment variables. 

61 

62 For additional details, please see the website or the README.md files in the source tree: 

63 <https://github.com/microsoft/MLOS/tree/main/mlos_bench/> 

64 """ 

65 parser = argparse.ArgumentParser(description=f"{description} : {long_text}", 

66 epilog=epilog) 

67 (args, args_rest) = self._parse_args(parser, argv) 

68 

69 # Bootstrap config loader: command line takes priority. 

70 config_path = args.config_path or [] 

71 self._config_loader = ConfigPersistenceService({"config_path": config_path}) 

72 if args.config: 

73 config = self._config_loader.load_config(args.config, ConfigSchema.CLI) 

74 assert isinstance(config, Dict) 

75 # Merge the args paths for the config loader with the paths from JSON file. 

76 config_path += config.get("config_path", []) 

77 self._config_loader = ConfigPersistenceService({"config_path": config_path}) 

78 else: 

79 config = {} 

80 

81 log_level = args.log_level or config.get("log_level", _LOG_LEVEL) 

82 try: 

83 log_level = int(log_level) 

84 except ValueError: 

85 # failed to parse as an int - leave it as a string and let logging 

86 # module handle whether it's an appropriate log name or not 

87 log_level = logging.getLevelName(log_level) 

88 logging.root.setLevel(log_level) 

89 log_file = args.log_file or config.get("log_file") 

90 if log_file: 

91 log_handler = logging.FileHandler(log_file) 

92 log_handler.setLevel(log_level) 

93 log_handler.setFormatter(logging.Formatter(_LOG_FORMAT)) 

94 logging.root.addHandler(log_handler) 

95 

96 self._parent_service: Service = LocalExecService(parent=self._config_loader) 

97 

98 self.global_config = self._load_config( 

99 config.get("globals", []) + (args.globals or []), 

100 (args.config_path or []) + config.get("config_path", []), 

101 args_rest, 

102 {key: val for (key, val) in config.items() if key not in vars(args)}, 

103 ) 

104 # experiment_id is generally taken from --globals files, but we also allow overriding it on the CLI. 

105 # It's useful to keep it there explicitly mostly for the --help output. 

106 if args.experiment_id: 

107 self.global_config['experiment_id'] = args.experiment_id 

108 # trial_config_repeat_count is a scheduler property but it's convenient to set it via command line 

109 if args.trial_config_repeat_count: 

110 self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count 

111 # Ensure that the trial_id is present since it gets used by some other 

112 # configs but is typically controlled by the run optimize loop. 

113 self.global_config.setdefault('trial_id', 1) 

114 

115 self.global_config = DictTemplater(self.global_config).expand_vars(use_os_env=True) 

116 assert isinstance(self.global_config, dict) 

117 

118 # --service cli args should override the config file values. 

119 service_files: List[str] = config.get("services", []) + (args.service or []) 

120 assert isinstance(self._parent_service, SupportsConfigLoading) 

121 self._parent_service = self._parent_service.load_services(service_files, self.global_config, self._parent_service) 

122 

123 env_path = args.environment or config.get("environment") 

124 if not env_path: 

125 _LOG.error("No environment config specified.") 

126 parser.error("At least the Environment config must be specified." + 

127 " Run `mlos_bench --help` and consult `README.md` for more info.") 

128 self.root_env_config = self._config_loader.resolve_path(env_path) 

129 

130 self.environment: Environment = self._config_loader.load_environment( 

131 self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service) 

132 _LOG.info("Init environment: %s", self.environment) 

133 

134 # NOTE: Init tunable values *after* the Environment, but *before* the Optimizer 

135 self.tunables = self._init_tunable_values( 

136 args.random_init or config.get("random_init", False), 

137 config.get("random_seed") if args.random_seed is None else args.random_seed, 

138 config.get("tunable_values", []) + (args.tunable_values or []) 

139 ) 

140 _LOG.info("Init tunables: %s", self.tunables) 

141 

142 self.optimizer = self._load_optimizer(args.optimizer or config.get("optimizer")) 

143 _LOG.info("Init optimizer: %s", self.optimizer) 

144 

145 self.storage = self._load_storage(args.storage or config.get("storage")) 

146 _LOG.info("Init storage: %s", self.storage) 

147 

148 self.teardown: bool = bool(args.teardown) if args.teardown is not None else bool(config.get("teardown", True)) 

149 self.scheduler = self._load_scheduler(args.scheduler or config.get("scheduler")) 

150 _LOG.info("Init scheduler: %s", self.scheduler) 

151 

152 @property 

153 def config_loader(self) -> ConfigPersistenceService: 

154 """ 

155 Get the config loader service. 

156 """ 

157 return self._config_loader 

158 

159 @property 

160 def service(self) -> Service: 

161 """ 

162 Get the parent service. 

163 """ 

164 return self._parent_service 

165 

166 @staticmethod 

167 def _parse_args(parser: argparse.ArgumentParser, argv: Optional[List[str]]) -> Tuple[argparse.Namespace, List[str]]: 

168 """ 

169 Parse the command line arguments. 

170 """ 

171 parser.add_argument( 

172 '--config', required=False, 

173 help='Main JSON5 configuration file. Its keys are the same as the' + 

174 ' command line options and can be overridden by the latter.\n' + 

175 '\n' + 

176 ' See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ ' + 

177 ' for additional config examples for this and other arguments.') 

178 

179 parser.add_argument( 

180 '--log_file', '--log-file', required=False, 

181 help='Path to the log file. Use stdout if omitted.') 

182 

183 parser.add_argument( 

184 '--log_level', '--log-level', required=False, type=str, 

185 help=f'Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}.' + 

186 ' Set to DEBUG for debug, WARNING for warnings only.') 

187 

188 parser.add_argument( 

189 '--config_path', '--config-path', '--config-paths', '--config_paths', 

190 nargs="+", action='extend', required=False, 

191 help='One or more locations of JSON config files.') 

192 

193 parser.add_argument( 

194 '--service', '--services', 

195 nargs='+', action='extend', required=False, 

196 help='Path to JSON file with the configuration of the service(s) for environment(s) to use.') 

197 

198 parser.add_argument( 

199 '--environment', required=False, 

200 help='Path to JSON file with the configuration of the benchmarking environment(s).') 

201 

202 parser.add_argument( 

203 '--optimizer', required=False, 

204 help='Path to the optimizer configuration file. If omitted, run' + 

205 ' a single trial with default (or specified in --tunable_values).') 

206 

207 parser.add_argument( 

208 '--trial_config_repeat_count', '--trial-config-repeat-count', required=False, type=int, 

209 help='Number of times to repeat each config. Default is 1 trial per config, though more may be advised.') 

210 

211 parser.add_argument( 

212 '--scheduler', required=False, 

213 help='Path to the scheduler configuration file. By default, use' + 

214 ' a single worker synchronous scheduler.') 

215 

216 parser.add_argument( 

217 '--storage', required=False, 

218 help='Path to the storage configuration file.' + 

219 ' If omitted, use the ephemeral in-memory SQL storage.') 

220 

221 parser.add_argument( 

222 '--random_init', '--random-init', required=False, default=False, 

223 dest='random_init', action='store_true', 

224 help='Initialize tunables with random values. (Before applying --tunable_values).') 

225 

226 parser.add_argument( 

227 '--random_seed', '--random-seed', required=False, type=int, 

228 help='Seed to use with --random_init') 

229 

230 parser.add_argument( 

231 '--tunable_values', '--tunable-values', nargs="+", action='extend', required=False, 

232 help='Path to one or more JSON files that contain values of the tunable' + 

233 ' parameters. This can be used for a single trial (when no --optimizer' + 

234 ' is specified) or as default values for the first run in optimization.') 

235 

236 parser.add_argument( 

237 '--globals', nargs="+", action='extend', required=False, 

238 help='Path to one or more JSON files that contain additional' + 

239 ' [private] parameters of the benchmarking environment.') 

240 

241 parser.add_argument( 

242 '--no_teardown', '--no-teardown', required=False, default=None, 

243 dest='teardown', action='store_false', 

244 help='Disable teardown of the environment after the benchmark.') 

245 

246 parser.add_argument( 

247 '--experiment_id', '--experiment-id', required=False, default=None, 

248 help=""" 

249 Experiment ID to use for the benchmark. 

250 If omitted, the value from the --cli config or --globals is used. 

251 

252 This is used to store and reload trial results from the storage. 

253 NOTE: It is **important** to change this value when incompatible 

254 changes are made to config files, scripts, versions, etc. 

255 This is left as a manual operation as detection of what is 

256 "incompatible" is not easily automatable across systems. 

257 """ 

258 ) 

259 

260 # By default we use the command line arguments, but allow the caller to 

261 # provide some explicitly for testing purposes. 

262 if argv is None: 

263 argv = sys.argv[1:].copy() 

264 (args, args_rest) = parser.parse_known_args(argv) 

265 

266 return (args, args_rest) 

267 

268 @staticmethod 

269 def _try_parse_extra_args(cmdline: Iterable[str]) -> Dict[str, TunableValue]: 

270 """ 

271 Helper function to parse global key/value pairs from the command line. 

272 """ 

273 _LOG.debug("Extra args: %s", cmdline) 

274 

275 config: Dict[str, TunableValue] = {} 

276 key = None 

277 for elem in cmdline: 

278 if elem.startswith("--"): 

279 if key is not None: 

280 raise ValueError("Command line argument has no value: " + key) 

281 key = elem[2:] 

282 kv_split = key.split("=", 1) 

283 if len(kv_split) == 2: 

284 config[kv_split[0].strip()] = try_parse_val(kv_split[1]) 

285 key = None 

286 else: 

287 if key is None: 

288 raise ValueError("Command line argument has no key: " + elem) 

289 config[key.strip()] = try_parse_val(elem) 

290 key = None 

291 

292 if key is not None: 

293 # Handles missing trailing elem from last --key arg. 

294 raise ValueError("Command line argument has no value: " + key) 

295 

296 _LOG.debug("Parsed config: %s", config) 

297 return config 

298 

299 def _load_config(self, 

300 args_globals: Iterable[str], 

301 config_path: Iterable[str], 

302 args_rest: Iterable[str], 

303 global_config: Dict[str, Any]) -> Dict[str, Any]: 

304 """ 

305 Get key/value pairs of the global configuration parameters 

306 from the specified config files (if any) and command line arguments. 

307 """ 

308 for config_file in (args_globals or []): 

309 conf = self._config_loader.load_config(config_file, ConfigSchema.GLOBALS) 

310 assert isinstance(conf, dict) 

311 global_config.update(conf) 

312 global_config.update(Launcher._try_parse_extra_args(args_rest)) 

313 if config_path: 

314 global_config["config_path"] = config_path 

315 return global_config 

316 

317 def _init_tunable_values(self, random_init: bool, seed: Optional[int], 

318 args_tunables: Optional[str]) -> TunableGroups: 

319 """ 

320 Initialize the tunables and load key/value pairs of the tunable values 

321 from given JSON files, if specified. 

322 """ 

323 tunables = self.environment.tunable_params 

324 _LOG.debug("Init tunables: default = %s", tunables) 

325 

326 if random_init: 

327 tunables = MockOptimizer( 

328 tunables=tunables, service=None, 

329 config={"start_with_defaults": False, "seed": seed}).suggest() 

330 _LOG.debug("Init tunables: random = %s", tunables) 

331 

332 if args_tunables is not None: 

333 for data_file in args_tunables: 

334 values = self._config_loader.load_config(data_file, ConfigSchema.TUNABLE_VALUES) 

335 assert isinstance(values, Dict) 

336 tunables.assign(values) 

337 _LOG.debug("Init tunables: load %s = %s", data_file, tunables) 

338 

339 return tunables 

340 

341 def _load_optimizer(self, args_optimizer: Optional[str]) -> Optimizer: 

342 """ 

343 Instantiate the Optimizer object from JSON config file, if specified 

344 in the --optimizer command line option. If config file not specified, 

345 create a one-shot optimizer to run a single benchmark trial. 

346 """ 

347 if args_optimizer is None: 

348 # global_config may contain additional properties, so we need to 

349 # strip those out before instantiating the basic oneshot optimizer. 

350 config = {key: val for key, val in self.global_config.items() if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS} 

351 return OneShotOptimizer( 

352 self.tunables, config=config, service=self._parent_service) 

353 class_config = self._config_loader.load_config(args_optimizer, ConfigSchema.OPTIMIZER) 

354 assert isinstance(class_config, Dict) 

355 optimizer = self._config_loader.build_optimizer(tunables=self.tunables, 

356 service=self._parent_service, 

357 config=class_config, 

358 global_config=self.global_config) 

359 return optimizer 

360 

361 def _load_storage(self, args_storage: Optional[str]) -> Storage: 

362 """ 

363 Instantiate the Storage object from JSON file provided in the --storage 

364 command line parameter. If omitted, create an ephemeral in-memory SQL 

365 storage instead. 

366 """ 

367 if args_storage is None: 

368 # pylint: disable=import-outside-toplevel 

369 from mlos_bench.storage.sql.storage import SqlStorage 

370 return SqlStorage(service=self._parent_service, 

371 config={ 

372 "drivername": "sqlite", 

373 "database": ":memory:", 

374 "lazy_schema_create": True, 

375 }) 

376 class_config = self._config_loader.load_config(args_storage, ConfigSchema.STORAGE) 

377 assert isinstance(class_config, Dict) 

378 storage = self._config_loader.build_storage(service=self._parent_service, 

379 config=class_config, 

380 global_config=self.global_config) 

381 return storage 

382 

383 def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler: 

384 """ 

385 Instantiate the Scheduler object from JSON file provided in the --scheduler 

386 command line parameter. 

387 Create a simple synchronous single-threaded scheduler if omitted. 

388 """ 

389 # Set `teardown` for scheduler only to prevent conflicts with other configs. 

390 global_config = self.global_config.copy() 

391 global_config.setdefault("teardown", self.teardown) 

392 if args_scheduler is None: 

393 # pylint: disable=import-outside-toplevel 

394 from mlos_bench.schedulers.sync_scheduler import SyncScheduler 

395 return SyncScheduler( 

396 # All config values can be overridden from global config 

397 config={ 

398 "experiment_id": "UNDEFINED - override from global config", 

399 "trial_id": 0, 

400 "config_id": -1, 

401 "trial_config_repeat_count": 1, 

402 "teardown": self.teardown, 

403 }, 

404 global_config=self.global_config, 

405 environment=self.environment, 

406 optimizer=self.optimizer, 

407 storage=self.storage, 

408 root_env_config=self.root_env_config, 

409 ) 

410 class_config = self._config_loader.load_config(args_scheduler, ConfigSchema.SCHEDULER) 

411 assert isinstance(class_config, Dict) 

412 return self._config_loader.build_scheduler( 

413 config=class_config, 

414 global_config=self.global_config, 

415 environment=self.environment, 

416 optimizer=self.optimizer, 

417 storage=self.storage, 

418 root_env_config=self.root_env_config, 

419 )