Coverage for mlos_bench/mlos_bench/launcher.py: 94%

208 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 00:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6A helper class to load the configuration files, parse the command line parameters, and 

7instantiate the main components of mlos_bench system. 

8 

9It is used in the :py:mod:`mlos_bench.run` module to run the benchmark/optimizer 

10from the command line. 

11""" 

12 

13import argparse 

14import logging 

15import sys 

16from collections.abc import Iterable 

17from typing import Any 

18 

19from mlos_bench.config.schemas import ConfigSchema 

20from mlos_bench.dict_templater import DictTemplater 

21from mlos_bench.environments.base_environment import Environment 

22from mlos_bench.optimizers.base_optimizer import Optimizer 

23from mlos_bench.optimizers.mock_optimizer import MockOptimizer 

24from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer 

25from mlos_bench.schedulers.base_scheduler import Scheduler 

26from mlos_bench.schedulers.trial_runner import TrialRunner 

27from mlos_bench.services.base_service import Service 

28from mlos_bench.services.config_persistence import ConfigPersistenceService 

29from mlos_bench.services.local.local_exec import LocalExecService 

30from mlos_bench.services.types.config_loader_type import SupportsConfigLoading 

31from mlos_bench.storage.base_storage import Storage 

32from mlos_bench.tunables.tunable_groups import TunableGroups 

33from mlos_bench.tunables.tunable_types import TunableValue 

34from mlos_bench.util import try_parse_val 

35 

36_LOG_LEVEL = logging.INFO 

37_LOG_FORMAT = "%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s" 

38logging.basicConfig(level=_LOG_LEVEL, format=_LOG_FORMAT) 

39 

40_LOG = logging.getLogger(__name__) 

41 

42 

43class Launcher: 

44 # pylint: disable=too-few-public-methods,too-many-instance-attributes 

45 """Command line launcher for mlos_bench and mlos_core.""" 

46 

47 def __init__(self, description: str, long_text: str = "", argv: list[str] | None = None): 

48 # pylint: disable=too-many-statements 

49 # pylint: disable=too-complex 

50 # pylint: disable=too-many-locals 

51 _LOG.info("Launch: %s", description) 

52 epilog = """ 

53 Additional --key=value pairs can be specified to augment or override 

54 values listed in --globals. 

55 Other required_args values can also be pulled from shell environment 

56 variables. 

57 

58 For additional details, please see the website or the README.md files in 

59 the source tree: 

60 <https://github.com/microsoft/MLOS/tree/main/mlos_bench/> 

61 """ 

62 parser = argparse.ArgumentParser(description=f"{description} : {long_text}", epilog=epilog) 

63 (args, path_args, args_rest) = self._parse_args(parser, argv) 

64 

65 # Bootstrap config loader: command line takes priority. 

66 config_path = args.config_path or [] 

67 self._config_loader = ConfigPersistenceService({"config_path": config_path}) 

68 if args.config: 

69 config = self._config_loader.load_config(args.config, ConfigSchema.CLI) 

70 assert isinstance(config, dict) 

71 # Merge the args paths for the config loader with the paths from JSON file. 

72 config_path += config.get("config_path", []) 

73 self._config_loader = ConfigPersistenceService({"config_path": config_path}) 

74 else: 

75 config = {} 

76 

77 log_level = args.log_level or config.get("log_level", _LOG_LEVEL) 

78 try: 

79 log_level = int(log_level) 

80 except ValueError: 

81 # failed to parse as an int - leave it as a string and let logging 

82 # module handle whether it's an appropriate log name or not 

83 log_level = logging.getLevelName(log_level) 

84 logging.root.setLevel(log_level) 

85 log_file = args.log_file or config.get("log_file") 

86 if log_file: 

87 log_handler = logging.FileHandler(log_file) 

88 log_handler.setLevel(log_level) 

89 log_handler.setFormatter(logging.Formatter(_LOG_FORMAT)) 

90 logging.root.addHandler(log_handler) 

91 

92 # Prepare global_config from a combination of global config files, cli 

93 # configs, and cli args. 

94 args_dict = vars(args) 

95 # teardown (bool) conflicts with Environment configs that use it for shell 

96 # commands (list), so we exclude it from copying over 

97 excluded_cli_args = path_args + ["teardown"] 

98 # Include (almost) any item from the cli config file that either isn't in 

99 # the cli args at all or whose cli arg is missing. 

100 cli_config_args = { 

101 key: val 

102 for (key, val) in config.items() 

103 if (args_dict.get(key) is None) and key not in excluded_cli_args 

104 } 

105 

106 self.global_config = self._load_config( 

107 args_globals=config.get("globals", []) + (args.globals or []), 

108 config_path=(args.config_path or []) + config.get("config_path", []), 

109 args_rest=args_rest, 

110 global_config=cli_config_args, 

111 ) 

112 # TODO: Can we generalize these two rules using excluded_cli_args? 

113 # experiment_id is generally taken from --globals files, but we also allow 

114 # overriding it on the CLI. 

115 # It's useful to keep it there explicitly mostly for the --help output. 

116 if args.experiment_id: 

117 self.global_config["experiment_id"] = args.experiment_id 

118 # trial_config_repeat_count is a scheduler property but it's convenient to 

119 # set it via command line 

120 if args.trial_config_repeat_count: 

121 self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count 

122 self.global_config.setdefault("num_trial_runners", 1) 

123 if args.num_trial_runners: 

124 self.global_config["num_trial_runners"] = args.num_trial_runners 

125 if self.global_config["num_trial_runners"] <= 0: 

126 raise ValueError( 

127 f"""Invalid num_trial_runners: {self.global_config["num_trial_runners"]}""" 

128 ) 

129 # Ensure that the trial_id is present since it gets used by some other 

130 # configs but is typically controlled by the run optimize loop. 

131 self.global_config.setdefault("trial_id", 1) 

132 

133 self.global_config = DictTemplater(self.global_config).expand_vars(use_os_env=True) 

134 assert isinstance(self.global_config, dict) 

135 

136 # --service cli args should override the config file values. 

137 service_files: list[str] = config.get("services", []) + (args.service or []) 

138 # Add a LocalExecService as the parent service for all other services. 

139 self._parent_service: Service = LocalExecService(parent=self._config_loader) 

140 assert isinstance(self._parent_service, SupportsConfigLoading) 

141 self._parent_service = self._parent_service.load_services( 

142 service_files, 

143 self.global_config, 

144 self._parent_service, 

145 ) 

146 

147 self.storage = self._load_storage( 

148 args.storage or config.get("storage"), 

149 lazy_schema_create=False if args.create_update_storage_schema_only else None, 

150 ) 

151 _LOG.info("Init storage: %s", self.storage) 

152 

153 if args.create_update_storage_schema_only: 

154 _LOG.info("Create/update storage schema only.") 

155 self.storage.update_schema() 

156 sys.exit(0) 

157 

158 env_path = args.environment or config.get("environment") 

159 if not env_path: 

160 _LOG.error("No environment config specified.") 

161 parser.error( 

162 "At least the Environment config must be specified." 

163 " Run `mlos_bench --help` and consult `README.md` for more info." 

164 ) 

165 self.root_env_config = self._config_loader.resolve_path(env_path) 

166 

167 # Create the TrialRunners and their Environments and Services from the JSON files. 

168 self.trial_runners = TrialRunner.create_from_json( 

169 config_loader=self._config_loader, 

170 global_config=self.global_config, 

171 svcs_json=service_files, 

172 env_json=self.root_env_config, 

173 num_trial_runners=self.global_config["num_trial_runners"], 

174 ) 

175 

176 _LOG.info( 

177 "Init %d trial runners for environments: %s", 

178 len(self.trial_runners), 

179 [trial_runner.environment for trial_runner in self.trial_runners], 

180 ) 

181 

182 # NOTE: Init tunable values *after* the Environment(s), but *before* the Optimizer 

183 # TODO: should we assign the same or different tunables for all TrialRunner Environments? 

184 tunable_values: list[str] | str = config.get("tunable_values", []) 

185 if isinstance(tunable_values, str): 

186 tunable_values = [tunable_values] 

187 tunable_values += args.tunable_values or [] 

188 assert isinstance(tunable_values, list) 

189 self.tunables = self._init_tunable_values( 

190 self.trial_runners[0].environment, 

191 args.random_init or config.get("random_init", False), 

192 config.get("random_seed") if args.random_seed is None else args.random_seed, 

193 tunable_values, 

194 ) 

195 _LOG.info("Init tunables: %s", self.tunables) 

196 

197 self.optimizer = self._load_optimizer(args.optimizer or config.get("optimizer")) 

198 _LOG.info("Init optimizer: %s", self.optimizer) 

199 

200 self.teardown: bool = ( 

201 bool(args.teardown) 

202 if args.teardown is not None 

203 else bool(config.get("teardown", True)) 

204 ) 

205 self.scheduler = self._load_scheduler(args.scheduler or config.get("scheduler")) 

206 _LOG.info("Init scheduler: %s", self.scheduler) 

207 

208 @property 

209 def config_loader(self) -> ConfigPersistenceService: 

210 """Get the config loader service.""" 

211 return self._config_loader 

212 

213 @property 

214 def root_environment(self) -> Environment: 

215 """ 

216 Gets the root (prototypical) Environment from the first TrialRunner. 

217 

218 Note: All TrialRunners have the same Environment config and are made 

219 unique by their use of the unique trial_runner_id assigned to each 

220 TrialRunner's Environment's global_config. 

221 

222 Notes 

223 ----- 

224 This is mostly for convenience and backwards compatibility. 

225 """ 

226 return self.trial_runners[0].environment 

227 

228 @property 

229 def service(self) -> Service: 

230 """Get the parent service.""" 

231 return self._parent_service 

232 

233 @staticmethod 

234 def _parse_args( 

235 parser: argparse.ArgumentParser, 

236 argv: list[str] | None, 

237 ) -> tuple[argparse.Namespace, list[str], list[str]]: 

238 """Parse the command line arguments.""" 

239 

240 class PathArgsTracker: 

241 """Simple class to help track which arguments are paths.""" 

242 

243 def __init__(self, parser: argparse.ArgumentParser): 

244 self._parser = parser 

245 self.path_args: list[str] = [] 

246 

247 def add_argument(self, *args: Any, **kwargs: Any) -> None: 

248 """Add an argument to the parser and track its destination.""" 

249 self.path_args.append(self._parser.add_argument(*args, **kwargs).dest) 

250 

251 path_args_tracker = PathArgsTracker(parser) 

252 

253 path_args_tracker.add_argument( 

254 "--config", 

255 required=False, 

256 help=( 

257 "Main JSON5 configuration file. Its keys are the same as the " 

258 "command line options and can be overridden by the latter.\n" 

259 "\n" 

260 "See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ " 

261 "for additional config examples for this and other arguments." 

262 ), 

263 ) 

264 

265 path_args_tracker.add_argument( 

266 "--log_file", 

267 "--log-file", 

268 required=False, 

269 help="Path to the log file. Use stdout if omitted.", 

270 ) 

271 

272 parser.add_argument( 

273 "--log_level", 

274 "--log-level", 

275 required=False, 

276 type=str, 

277 help=( 

278 f"Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}. " 

279 "Set to DEBUG for debug, WARNING for warnings only." 

280 ), 

281 ) 

282 

283 path_args_tracker.add_argument( 

284 "--config_path", 

285 "--config-path", 

286 "--config-paths", 

287 "--config_paths", 

288 nargs="+", 

289 action="extend", 

290 required=False, 

291 help="One or more locations of JSON config files.", 

292 ) 

293 

294 path_args_tracker.add_argument( 

295 "--service", 

296 "--services", 

297 nargs="+", 

298 action="extend", 

299 required=False, 

300 help=( 

301 "Path to JSON file with the configuration " 

302 "of the service(s) for environment(s) to use." 

303 ), 

304 ) 

305 

306 path_args_tracker.add_argument( 

307 "--environment", 

308 required=False, 

309 help="Path to JSON file with the configuration of the benchmarking environment(s).", 

310 ) 

311 

312 path_args_tracker.add_argument( 

313 "--optimizer", 

314 required=False, 

315 help=( 

316 "Path to the optimizer configuration file. If omitted, run " 

317 "a single trial with default (or specified in --tunable_values)." 

318 ), 

319 ) 

320 

321 parser.add_argument( 

322 "--trial_config_repeat_count", 

323 "--trial-config-repeat-count", 

324 required=False, 

325 type=int, 

326 help=( 

327 "Number of times to repeat each config. " 

328 "Default is 1 trial per config, though more may be advised." 

329 ), 

330 ) 

331 

332 parser.add_argument( 

333 "--num_trial_runners", 

334 "--num-trial-runners", 

335 required=False, 

336 type=int, 

337 help=( 

338 "Number of TrialRunners to use for executing benchmark Environments. " 

339 "Individual TrialRunners can be identified in configs with $trial_runner_id " 

340 "and optionally run in parallel." 

341 ), 

342 ) 

343 

344 path_args_tracker.add_argument( 

345 "--scheduler", 

346 required=False, 

347 help=( 

348 "Path to the scheduler configuration file. By default, use " 

349 "a single worker synchronous scheduler." 

350 ), 

351 ) 

352 

353 path_args_tracker.add_argument( 

354 "--storage", 

355 required=False, 

356 help=( 

357 "Path to the storage configuration file. " 

358 "If omitted, use the ephemeral in-memory SQL storage." 

359 ), 

360 ) 

361 

362 parser.add_argument( 

363 "--random_init", 

364 "--random-init", 

365 required=False, 

366 default=False, 

367 dest="random_init", 

368 action="store_true", 

369 help="Initialize tunables with random values. (Before applying --tunable_values).", 

370 ) 

371 

372 parser.add_argument( 

373 "--random_seed", 

374 "--random-seed", 

375 required=False, 

376 type=int, 

377 help="Seed to use with --random_init", 

378 ) 

379 

380 path_args_tracker.add_argument( 

381 "--tunable_values", 

382 "--tunable-values", 

383 nargs="+", 

384 action="extend", 

385 required=False, 

386 help=( 

387 "Path to one or more JSON files that contain values of the tunable " 

388 "parameters. This can be used for a single trial (when no --optimizer " 

389 "is specified) or as default values for the first run in optimization." 

390 ), 

391 ) 

392 

393 path_args_tracker.add_argument( 

394 "--globals", 

395 nargs="+", 

396 action="extend", 

397 required=False, 

398 help=( 

399 "Path to one or more JSON files that contain additional " 

400 "[private] parameters of the benchmarking environment." 

401 ), 

402 ) 

403 

404 parser.add_argument( 

405 "--no_teardown", 

406 "--no-teardown", 

407 required=False, 

408 default=None, 

409 dest="teardown", 

410 action="store_false", 

411 help="Disable teardown of the environment after the benchmark.", 

412 ) 

413 

414 parser.add_argument( 

415 "--experiment_id", 

416 "--experiment-id", 

417 required=False, 

418 default=None, 

419 help=""" 

420 Experiment ID to use for the benchmark. 

421 If omitted, the value from the --cli config or --globals is used. 

422 

423 This is used to store and reload trial results from the storage. 

424 NOTE: It is **important** to change this value when incompatible 

425 changes are made to config files, scripts, versions, etc. 

426 This is left as a manual operation as detection of what is 

427 "incompatible" is not easily automatable across systems. 

428 """, 

429 ) 

430 

431 parser.add_argument( 

432 "--create-update-storage-schema-only", 

433 required=False, 

434 default=False, 

435 dest="create_update_storage_schema_only", 

436 action="store_true", 

437 help=( 

438 "Makes sure that the storage schema is up to date " 

439 "for the current version of mlos_bench." 

440 ), 

441 ) 

442 

443 # By default we use the command line arguments, but allow the caller to 

444 # provide some explicitly for testing purposes. 

445 if argv is None: 

446 argv = sys.argv[1:].copy() 

447 (args, args_rest) = parser.parse_known_args(argv) 

448 

449 return (args, path_args_tracker.path_args, args_rest) 

450 

451 @staticmethod 

452 def _try_parse_extra_args(cmdline: Iterable[str]) -> dict[str, TunableValue]: 

453 """Helper function to parse global key/value pairs from the command line.""" 

454 _LOG.debug("Extra args: %s", cmdline) 

455 

456 config: dict[str, TunableValue] = {} 

457 key = None 

458 for elem in cmdline: 

459 if elem.startswith("--"): 

460 if key is not None: 

461 raise ValueError("Command line argument has no value: " + key) 

462 key = elem[2:] 

463 kv_split = key.split("=", 1) 

464 if len(kv_split) == 2: 

465 config[kv_split[0].strip()] = try_parse_val(kv_split[1]) 

466 key = None 

467 else: 

468 if key is None: 

469 raise ValueError("Command line argument has no key: " + elem) 

470 config[key.strip()] = try_parse_val(elem) 

471 key = None 

472 

473 if key is not None: 

474 # Handles missing trailing elem from last --key arg. 

475 raise ValueError("Command line argument has no value: " + key) 

476 

477 # Convert "max-suggestions" to "max_suggestions" for compatibility with 

478 # other CLI options to use as common python/json variable replacements. 

479 config = {k.replace("-", "_"): v for k, v in config.items()} 

480 

481 _LOG.debug("Parsed config: %s", config) 

482 return config 

483 

484 def _load_config( 

485 self, 

486 *, 

487 args_globals: Iterable[str], 

488 config_path: Iterable[str], 

489 args_rest: Iterable[str], 

490 global_config: dict[str, Any], 

491 ) -> dict[str, Any]: 

492 """Get key/value pairs of the global configuration parameters from the specified 

493 config files (if any) and command line arguments. 

494 """ 

495 for config_file in args_globals or []: 

496 conf = self._config_loader.load_config(config_file, ConfigSchema.GLOBALS) 

497 assert isinstance(conf, dict) 

498 global_config.update(conf) 

499 global_config.update(Launcher._try_parse_extra_args(args_rest)) 

500 if config_path: 

501 global_config["config_path"] = config_path 

502 return global_config 

503 

504 def _init_tunable_values( 

505 self, 

506 env: Environment, 

507 random_init: bool, 

508 seed: int | None, 

509 args_tunables: list[str] | None, 

510 ) -> TunableGroups: 

511 """Initialize the tunables and load key/value pairs of the tunable values from 

512 given JSON files, if specified. 

513 """ 

514 tunables = env.tunable_params 

515 _LOG.debug("Init tunables: default = %s", tunables) 

516 

517 if random_init: 

518 tunables = MockOptimizer( 

519 tunables=tunables, 

520 service=None, 

521 config={"start_with_defaults": False, "seed": seed}, 

522 ).suggest() 

523 _LOG.debug("Init tunables: random = %s", tunables) 

524 

525 if args_tunables is not None: 

526 for data_file in args_tunables: 

527 values = self._config_loader.load_config(data_file, ConfigSchema.TUNABLE_VALUES) 

528 assert isinstance(values, dict) 

529 tunables.assign(values) 

530 _LOG.debug("Init tunables: load %s = %s", data_file, tunables) 

531 

532 return tunables 

533 

534 def _load_optimizer(self, args_optimizer: str | None) -> Optimizer: 

535 """ 

536 Instantiate the Optimizer object from JSON config file, if specified in the 

537 --optimizer command line option. 

538 

539 If config file not specified, create a one-shot optimizer to run a single 

540 benchmark trial. 

541 """ 

542 if args_optimizer is None: 

543 # global_config may contain additional properties, so we need to 

544 # strip those out before instantiating the basic oneshot optimizer. 

545 config = { 

546 key: val 

547 for key, val in self.global_config.items() 

548 if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS 

549 } 

550 return OneShotOptimizer(self.tunables, config=config, service=self._parent_service) 

551 class_config = self._config_loader.load_config(args_optimizer, ConfigSchema.OPTIMIZER) 

552 assert isinstance(class_config, dict) 

553 optimizer = self._config_loader.build_optimizer( 

554 tunables=self.tunables, 

555 service=self._parent_service, 

556 config=class_config, 

557 global_config=self.global_config, 

558 ) 

559 return optimizer 

560 

561 def _load_storage( 

562 self, 

563 args_storage: str | None, 

564 lazy_schema_create: bool | None = None, 

565 ) -> Storage: 

566 """ 

567 Instantiate the Storage object from JSON file provided in the --storage command 

568 line parameter. 

569 

570 If omitted, create an ephemeral in-memory SQL storage instead. 

571 """ 

572 if args_storage is None: 

573 # pylint: disable=import-outside-toplevel 

574 from mlos_bench.storage.sql.storage import SqlStorage 

575 

576 return SqlStorage( 

577 service=self._parent_service, 

578 config={ 

579 "drivername": "sqlite", 

580 "database": ":memory:", 

581 "lazy_schema_create": True, 

582 }, 

583 ) 

584 class_config = self._config_loader.load_config(args_storage, ConfigSchema.STORAGE) 

585 assert isinstance(class_config, dict) 

586 if lazy_schema_create is not None: 

587 class_config["lazy_schema_create"] = lazy_schema_create 

588 storage = self._config_loader.build_storage( 

589 service=self._parent_service, 

590 config=class_config, 

591 global_config=self.global_config, 

592 ) 

593 return storage 

594 

595 def _load_scheduler(self, args_scheduler: str | None) -> Scheduler: 

596 """ 

597 Instantiate the Scheduler object from JSON file provided in the --scheduler 

598 command line parameter. 

599 

600 Create a simple synchronous single-threaded scheduler if omitted. 

601 """ 

602 # Set `teardown` for scheduler only to prevent conflicts with other configs. 

603 global_config = self.global_config.copy() 

604 global_config.setdefault("teardown", self.teardown) 

605 if args_scheduler is None: 

606 # pylint: disable=import-outside-toplevel 

607 from mlos_bench.schedulers.sync_scheduler import SyncScheduler 

608 

609 return SyncScheduler( 

610 # All config values can be overridden from global config 

611 config={ 

612 "experiment_id": "DEFAULT_EXPERIMENT_ID", 

613 "trial_id": 0, 

614 "config_id": -1, 

615 "trial_config_repeat_count": 1, 

616 "teardown": self.teardown, 

617 }, 

618 global_config=self.global_config, 

619 trial_runners=self.trial_runners, 

620 optimizer=self.optimizer, 

621 storage=self.storage, 

622 root_env_config=self.root_env_config, 

623 ) 

624 class_config = self._config_loader.load_config(args_scheduler, ConfigSchema.SCHEDULER) 

625 assert isinstance(class_config, dict) 

626 return self._config_loader.build_scheduler( 

627 config=class_config, 

628 global_config=self.global_config, 

629 trial_runners=self.trial_runners, 

630 optimizer=self.optimizer, 

631 storage=self.storage, 

632 root_env_config=self.root_env_config, 

633 )