Coverage for mlos_bench/mlos_bench/launcher.py: 96%

185 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6A helper class to load the configuration files, parse the command line parameters, and 

7instantiate the main components of mlos_bench system. 

8 

9It is used in the :py:mod:`mlos_bench.run` module to run the benchmark/optimizer 

10from the command line. 

11""" 

12 

13import argparse 

14import logging 

15import sys 

16from typing import Any, Dict, Iterable, List, Optional, Tuple 

17 

18from mlos_bench.config.schemas import ConfigSchema 

19from mlos_bench.dict_templater import DictTemplater 

20from mlos_bench.environments.base_environment import Environment 

21from mlos_bench.optimizers.base_optimizer import Optimizer 

22from mlos_bench.optimizers.mock_optimizer import MockOptimizer 

23from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer 

24from mlos_bench.schedulers.base_scheduler import Scheduler 

25from mlos_bench.services.base_service import Service 

26from mlos_bench.services.config_persistence import ConfigPersistenceService 

27from mlos_bench.services.local.local_exec import LocalExecService 

28from mlos_bench.services.types.config_loader_type import SupportsConfigLoading 

29from mlos_bench.storage.base_storage import Storage 

30from mlos_bench.tunables.tunable import TunableValue 

31from mlos_bench.tunables.tunable_groups import TunableGroups 

32from mlos_bench.util import try_parse_val 

33 

34_LOG_LEVEL = logging.INFO 

35_LOG_FORMAT = "%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s" 

36logging.basicConfig(level=_LOG_LEVEL, format=_LOG_FORMAT) 

37 

38_LOG = logging.getLogger(__name__) 

39 

40 

41class Launcher: 

42 # pylint: disable=too-few-public-methods,too-many-instance-attributes 

43 """Command line launcher for mlos_bench and mlos_core.""" 

44 

45 def __init__(self, description: str, long_text: str = "", argv: Optional[List[str]] = None): 

46 # pylint: disable=too-many-statements 

47 # pylint: disable=too-many-locals 

48 _LOG.info("Launch: %s", description) 

49 epilog = """ 

50 Additional --key=value pairs can be specified to augment or override 

51 values listed in --globals. 

52 Other required_args values can also be pulled from shell environment 

53 variables. 

54 

55 For additional details, please see the website or the README.md files in 

56 the source tree: 

57 <https://github.com/microsoft/MLOS/tree/main/mlos_bench/> 

58 """ 

59 parser = argparse.ArgumentParser(description=f"{description} : {long_text}", epilog=epilog) 

60 (args, path_args, args_rest) = self._parse_args(parser, argv) 

61 

62 # Bootstrap config loader: command line takes priority. 

63 config_path = args.config_path or [] 

64 self._config_loader = ConfigPersistenceService({"config_path": config_path}) 

65 if args.config: 

66 config = self._config_loader.load_config(args.config, ConfigSchema.CLI) 

67 assert isinstance(config, Dict) 

68 # Merge the args paths for the config loader with the paths from JSON file. 

69 config_path += config.get("config_path", []) 

70 self._config_loader = ConfigPersistenceService({"config_path": config_path}) 

71 else: 

72 config = {} 

73 

74 log_level = args.log_level or config.get("log_level", _LOG_LEVEL) 

75 try: 

76 log_level = int(log_level) 

77 except ValueError: 

78 # failed to parse as an int - leave it as a string and let logging 

79 # module handle whether it's an appropriate log name or not 

80 log_level = logging.getLevelName(log_level) 

81 logging.root.setLevel(log_level) 

82 log_file = args.log_file or config.get("log_file") 

83 if log_file: 

84 log_handler = logging.FileHandler(log_file) 

85 log_handler.setLevel(log_level) 

86 log_handler.setFormatter(logging.Formatter(_LOG_FORMAT)) 

87 logging.root.addHandler(log_handler) 

88 

89 self._parent_service: Service = LocalExecService(parent=self._config_loader) 

90 

91 # Prepare global_config from a combination of global config files, cli 

92 # configs, and cli args. 

93 args_dict = vars(args) 

94 # teardown (bool) conflicts with Environment configs that use it for shell 

95 # commands (list), so we exclude it from copying over 

96 excluded_cli_args = path_args + ["teardown"] 

97 # Include (almost) any item from the cli config file that either isn't in 

98 # the cli args at all or whose cli arg is missing. 

99 cli_config_args = { 

100 key: val 

101 for (key, val) in config.items() 

102 if (args_dict.get(key) is None) and key not in excluded_cli_args 

103 } 

104 

105 self.global_config = self._load_config( 

106 args_globals=config.get("globals", []) + (args.globals or []), 

107 config_path=(args.config_path or []) + config.get("config_path", []), 

108 args_rest=args_rest, 

109 global_config=cli_config_args, 

110 ) 

111 # experiment_id is generally taken from --globals files, but we also allow 

112 # overriding it on the CLI. 

113 # It's useful to keep it there explicitly mostly for the --help output. 

114 if args.experiment_id: 

115 self.global_config["experiment_id"] = args.experiment_id 

116 # trial_config_repeat_count is a scheduler property but it's convenient to 

117 # set it via command line 

118 if args.trial_config_repeat_count: 

119 self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count 

120 # Ensure that the trial_id is present since it gets used by some other 

121 # configs but is typically controlled by the run optimize loop. 

122 self.global_config.setdefault("trial_id", 1) 

123 

124 self.global_config = DictTemplater(self.global_config).expand_vars(use_os_env=True) 

125 assert isinstance(self.global_config, dict) 

126 

127 # --service cli args should override the config file values. 

128 service_files: List[str] = config.get("services", []) + (args.service or []) 

129 assert isinstance(self._parent_service, SupportsConfigLoading) 

130 self._parent_service = self._parent_service.load_services( 

131 service_files, 

132 self.global_config, 

133 self._parent_service, 

134 ) 

135 

136 env_path = args.environment or config.get("environment") 

137 if not env_path: 

138 _LOG.error("No environment config specified.") 

139 parser.error( 

140 "At least the Environment config must be specified." 

141 + " Run `mlos_bench --help` and consult `README.md` for more info." 

142 ) 

143 self.root_env_config = self._config_loader.resolve_path(env_path) 

144 

145 self.environment: Environment = self._config_loader.load_environment( 

146 self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service 

147 ) 

148 _LOG.info("Init environment: %s", self.environment) 

149 

150 # NOTE: Init tunable values *after* the Environment, but *before* the Optimizer 

151 self.tunables = self._init_tunable_values( 

152 args.random_init or config.get("random_init", False), 

153 config.get("random_seed") if args.random_seed is None else args.random_seed, 

154 config.get("tunable_values", []) + (args.tunable_values or []), 

155 ) 

156 _LOG.info("Init tunables: %s", self.tunables) 

157 

158 self.optimizer = self._load_optimizer(args.optimizer or config.get("optimizer")) 

159 _LOG.info("Init optimizer: %s", self.optimizer) 

160 

161 self.storage = self._load_storage(args.storage or config.get("storage")) 

162 _LOG.info("Init storage: %s", self.storage) 

163 

164 self.teardown: bool = ( 

165 bool(args.teardown) 

166 if args.teardown is not None 

167 else bool(config.get("teardown", True)) 

168 ) 

169 self.scheduler = self._load_scheduler(args.scheduler or config.get("scheduler")) 

170 _LOG.info("Init scheduler: %s", self.scheduler) 

171 

172 @property 

173 def config_loader(self) -> ConfigPersistenceService: 

174 """Get the config loader service.""" 

175 return self._config_loader 

176 

177 @property 

178 def service(self) -> Service: 

179 """Get the parent service.""" 

180 return self._parent_service 

181 

182 @staticmethod 

183 def _parse_args( 

184 parser: argparse.ArgumentParser, 

185 argv: Optional[List[str]], 

186 ) -> Tuple[argparse.Namespace, List[str], List[str]]: 

187 """Parse the command line arguments.""" 

188 

189 class PathArgsTracker: 

190 """Simple class to help track which arguments are paths.""" 

191 

192 def __init__(self, parser: argparse.ArgumentParser): 

193 self._parser = parser 

194 self.path_args: List[str] = [] 

195 

196 def add_argument(self, *args: Any, **kwargs: Any) -> None: 

197 """Add an argument to the parser and track its destination.""" 

198 self.path_args.append(self._parser.add_argument(*args, **kwargs).dest) 

199 

200 path_args_tracker = PathArgsTracker(parser) 

201 

202 path_args_tracker.add_argument( 

203 "--config", 

204 required=False, 

205 help=( 

206 "Main JSON5 configuration file. Its keys are the same as the " 

207 "command line options and can be overridden by the latter.\n" 

208 "\n" 

209 "See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ " 

210 "for additional config examples for this and other arguments." 

211 ), 

212 ) 

213 

214 path_args_tracker.add_argument( 

215 "--log_file", 

216 "--log-file", 

217 required=False, 

218 help="Path to the log file. Use stdout if omitted.", 

219 ) 

220 

221 parser.add_argument( 

222 "--log_level", 

223 "--log-level", 

224 required=False, 

225 type=str, 

226 help=( 

227 f"Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}. " 

228 "Set to DEBUG for debug, WARNING for warnings only." 

229 ), 

230 ) 

231 

232 path_args_tracker.add_argument( 

233 "--config_path", 

234 "--config-path", 

235 "--config-paths", 

236 "--config_paths", 

237 nargs="+", 

238 action="extend", 

239 required=False, 

240 help="One or more locations of JSON config files.", 

241 ) 

242 

243 path_args_tracker.add_argument( 

244 "--service", 

245 "--services", 

246 nargs="+", 

247 action="extend", 

248 required=False, 

249 help=( 

250 "Path to JSON file with the configuration " 

251 "of the service(s) for environment(s) to use." 

252 ), 

253 ) 

254 

255 path_args_tracker.add_argument( 

256 "--environment", 

257 required=False, 

258 help="Path to JSON file with the configuration of the benchmarking environment(s).", 

259 ) 

260 

261 path_args_tracker.add_argument( 

262 "--optimizer", 

263 required=False, 

264 help=( 

265 "Path to the optimizer configuration file. If omitted, run " 

266 "a single trial with default (or specified in --tunable_values)." 

267 ), 

268 ) 

269 

270 parser.add_argument( 

271 "--trial_config_repeat_count", 

272 "--trial-config-repeat-count", 

273 required=False, 

274 type=int, 

275 help=( 

276 "Number of times to repeat each config. " 

277 "Default is 1 trial per config, though more may be advised." 

278 ), 

279 ) 

280 

281 path_args_tracker.add_argument( 

282 "--scheduler", 

283 required=False, 

284 help=( 

285 "Path to the scheduler configuration file. By default, use " 

286 "a single worker synchronous scheduler." 

287 ), 

288 ) 

289 

290 path_args_tracker.add_argument( 

291 "--storage", 

292 required=False, 

293 help=( 

294 "Path to the storage configuration file. " 

295 "If omitted, use the ephemeral in-memory SQL storage." 

296 ), 

297 ) 

298 

299 parser.add_argument( 

300 "--random_init", 

301 "--random-init", 

302 required=False, 

303 default=False, 

304 dest="random_init", 

305 action="store_true", 

306 help="Initialize tunables with random values. (Before applying --tunable_values).", 

307 ) 

308 

309 parser.add_argument( 

310 "--random_seed", 

311 "--random-seed", 

312 required=False, 

313 type=int, 

314 help="Seed to use with --random_init", 

315 ) 

316 

317 path_args_tracker.add_argument( 

318 "--tunable_values", 

319 "--tunable-values", 

320 nargs="+", 

321 action="extend", 

322 required=False, 

323 help=( 

324 "Path to one or more JSON files that contain values of the tunable " 

325 "parameters. This can be used for a single trial (when no --optimizer " 

326 "is specified) or as default values for the first run in optimization." 

327 ), 

328 ) 

329 

330 path_args_tracker.add_argument( 

331 "--globals", 

332 nargs="+", 

333 action="extend", 

334 required=False, 

335 help=( 

336 "Path to one or more JSON files that contain additional " 

337 "[private] parameters of the benchmarking environment." 

338 ), 

339 ) 

340 

341 parser.add_argument( 

342 "--no_teardown", 

343 "--no-teardown", 

344 required=False, 

345 default=None, 

346 dest="teardown", 

347 action="store_false", 

348 help="Disable teardown of the environment after the benchmark.", 

349 ) 

350 

351 parser.add_argument( 

352 "--experiment_id", 

353 "--experiment-id", 

354 required=False, 

355 default=None, 

356 help=""" 

357 Experiment ID to use for the benchmark. 

358 If omitted, the value from the --cli config or --globals is used. 

359 

360 This is used to store and reload trial results from the storage. 

361 NOTE: It is **important** to change this value when incompatible 

362 changes are made to config files, scripts, versions, etc. 

363 This is left as a manual operation as detection of what is 

364 "incompatible" is not easily automatable across systems. 

365 """, 

366 ) 

367 

368 # By default we use the command line arguments, but allow the caller to 

369 # provide some explicitly for testing purposes. 

370 if argv is None: 

371 argv = sys.argv[1:].copy() 

372 (args, args_rest) = parser.parse_known_args(argv) 

373 

374 return (args, path_args_tracker.path_args, args_rest) 

375 

376 @staticmethod 

377 def _try_parse_extra_args(cmdline: Iterable[str]) -> Dict[str, TunableValue]: 

378 """Helper function to parse global key/value pairs from the command line.""" 

379 _LOG.debug("Extra args: %s", cmdline) 

380 

381 config: Dict[str, TunableValue] = {} 

382 key = None 

383 for elem in cmdline: 

384 if elem.startswith("--"): 

385 if key is not None: 

386 raise ValueError("Command line argument has no value: " + key) 

387 key = elem[2:] 

388 kv_split = key.split("=", 1) 

389 if len(kv_split) == 2: 

390 config[kv_split[0].strip()] = try_parse_val(kv_split[1]) 

391 key = None 

392 else: 

393 if key is None: 

394 raise ValueError("Command line argument has no key: " + elem) 

395 config[key.strip()] = try_parse_val(elem) 

396 key = None 

397 

398 if key is not None: 

399 # Handles missing trailing elem from last --key arg. 

400 raise ValueError("Command line argument has no value: " + key) 

401 

402 # Convert "max-suggestions" to "max_suggestions" for compatibility with 

403 # other CLI options to use as common python/json variable replacements. 

404 config = {k.replace("-", "_"): v for k, v in config.items()} 

405 

406 _LOG.debug("Parsed config: %s", config) 

407 return config 

408 

409 def _load_config( 

410 self, 

411 *, 

412 args_globals: Iterable[str], 

413 config_path: Iterable[str], 

414 args_rest: Iterable[str], 

415 global_config: Dict[str, Any], 

416 ) -> Dict[str, Any]: 

417 """Get key/value pairs of the global configuration parameters from the specified 

418 config files (if any) and command line arguments. 

419 """ 

420 for config_file in args_globals or []: 

421 conf = self._config_loader.load_config(config_file, ConfigSchema.GLOBALS) 

422 assert isinstance(conf, dict) 

423 global_config.update(conf) 

424 global_config.update(Launcher._try_parse_extra_args(args_rest)) 

425 if config_path: 

426 global_config["config_path"] = config_path 

427 return global_config 

428 

429 def _init_tunable_values( 

430 self, 

431 random_init: bool, 

432 seed: Optional[int], 

433 args_tunables: Optional[str], 

434 ) -> TunableGroups: 

435 """Initialize the tunables and load key/value pairs of the tunable values from 

436 given JSON files, if specified. 

437 """ 

438 tunables = self.environment.tunable_params 

439 _LOG.debug("Init tunables: default = %s", tunables) 

440 

441 if random_init: 

442 tunables = MockOptimizer( 

443 tunables=tunables, 

444 service=None, 

445 config={"start_with_defaults": False, "seed": seed}, 

446 ).suggest() 

447 _LOG.debug("Init tunables: random = %s", tunables) 

448 

449 if args_tunables is not None: 

450 for data_file in args_tunables: 

451 values = self._config_loader.load_config(data_file, ConfigSchema.TUNABLE_VALUES) 

452 assert isinstance(values, Dict) 

453 tunables.assign(values) 

454 _LOG.debug("Init tunables: load %s = %s", data_file, tunables) 

455 

456 return tunables 

457 

458 def _load_optimizer(self, args_optimizer: Optional[str]) -> Optimizer: 

459 """ 

460 Instantiate the Optimizer object from JSON config file, if specified in the 

461 --optimizer command line option. 

462 

463 If config file not specified, create a one-shot optimizer to run a single 

464 benchmark trial. 

465 """ 

466 if args_optimizer is None: 

467 # global_config may contain additional properties, so we need to 

468 # strip those out before instantiating the basic oneshot optimizer. 

469 config = { 

470 key: val 

471 for key, val in self.global_config.items() 

472 if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS 

473 } 

474 return OneShotOptimizer(self.tunables, config=config, service=self._parent_service) 

475 class_config = self._config_loader.load_config(args_optimizer, ConfigSchema.OPTIMIZER) 

476 assert isinstance(class_config, Dict) 

477 optimizer = self._config_loader.build_optimizer( 

478 tunables=self.tunables, 

479 service=self._parent_service, 

480 config=class_config, 

481 global_config=self.global_config, 

482 ) 

483 return optimizer 

484 

485 def _load_storage(self, args_storage: Optional[str]) -> Storage: 

486 """ 

487 Instantiate the Storage object from JSON file provided in the --storage command 

488 line parameter. 

489 

490 If omitted, create an ephemeral in-memory SQL storage instead. 

491 """ 

492 if args_storage is None: 

493 # pylint: disable=import-outside-toplevel 

494 from mlos_bench.storage.sql.storage import SqlStorage 

495 

496 return SqlStorage( 

497 service=self._parent_service, 

498 config={ 

499 "drivername": "sqlite", 

500 "database": ":memory:", 

501 "lazy_schema_create": True, 

502 }, 

503 ) 

504 class_config = self._config_loader.load_config(args_storage, ConfigSchema.STORAGE) 

505 assert isinstance(class_config, Dict) 

506 storage = self._config_loader.build_storage( 

507 service=self._parent_service, 

508 config=class_config, 

509 global_config=self.global_config, 

510 ) 

511 return storage 

512 

513 def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler: 

514 """ 

515 Instantiate the Scheduler object from JSON file provided in the --scheduler 

516 command line parameter. 

517 

518 Create a simple synchronous single-threaded scheduler if omitted. 

519 """ 

520 # Set `teardown` for scheduler only to prevent conflicts with other configs. 

521 global_config = self.global_config.copy() 

522 global_config.setdefault("teardown", self.teardown) 

523 if args_scheduler is None: 

524 # pylint: disable=import-outside-toplevel 

525 from mlos_bench.schedulers.sync_scheduler import SyncScheduler 

526 

527 return SyncScheduler( 

528 # All config values can be overridden from global config 

529 config={ 

530 "experiment_id": "UNDEFINED - override from global config", 

531 "trial_id": 0, 

532 "config_id": -1, 

533 "trial_config_repeat_count": 1, 

534 "teardown": self.teardown, 

535 }, 

536 global_config=self.global_config, 

537 environment=self.environment, 

538 optimizer=self.optimizer, 

539 storage=self.storage, 

540 root_env_config=self.root_env_config, 

541 ) 

542 class_config = self._config_loader.load_config(args_scheduler, ConfigSchema.SCHEDULER) 

543 assert isinstance(class_config, Dict) 

544 return self._config_loader.build_scheduler( 

545 config=class_config, 

546 global_config=self.global_config, 

547 environment=self.environment, 

548 optimizer=self.optimizer, 

549 storage=self.storage, 

550 root_env_config=self.root_env_config, 

551 )