Coverage for mlos_bench/mlos_bench/launcher.py: 94%
208 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 00:52 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 00:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6A helper class to load the configuration files, parse the command line parameters, and
7instantiate the main components of mlos_bench system.
9It is used in the :py:mod:`mlos_bench.run` module to run the benchmark/optimizer
10from the command line.
11"""
13import argparse
14import logging
15import sys
16from collections.abc import Iterable
17from typing import Any
19from mlos_bench.config.schemas import ConfigSchema
20from mlos_bench.dict_templater import DictTemplater
21from mlos_bench.environments.base_environment import Environment
22from mlos_bench.optimizers.base_optimizer import Optimizer
23from mlos_bench.optimizers.mock_optimizer import MockOptimizer
24from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer
25from mlos_bench.schedulers.base_scheduler import Scheduler
26from mlos_bench.schedulers.trial_runner import TrialRunner
27from mlos_bench.services.base_service import Service
28from mlos_bench.services.config_persistence import ConfigPersistenceService
29from mlos_bench.services.local.local_exec import LocalExecService
30from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
31from mlos_bench.storage.base_storage import Storage
32from mlos_bench.tunables.tunable_groups import TunableGroups
33from mlos_bench.tunables.tunable_types import TunableValue
34from mlos_bench.util import try_parse_val
36_LOG_LEVEL = logging.INFO
37_LOG_FORMAT = "%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s"
38logging.basicConfig(level=_LOG_LEVEL, format=_LOG_FORMAT)
40_LOG = logging.getLogger(__name__)
43class Launcher:
44 # pylint: disable=too-few-public-methods,too-many-instance-attributes
45 """Command line launcher for mlos_bench and mlos_core."""
47 def __init__(self, description: str, long_text: str = "", argv: list[str] | None = None):
48 # pylint: disable=too-many-statements
49 # pylint: disable=too-complex
50 # pylint: disable=too-many-locals
51 _LOG.info("Launch: %s", description)
52 epilog = """
53 Additional --key=value pairs can be specified to augment or override
54 values listed in --globals.
55 Other required_args values can also be pulled from shell environment
56 variables.
58 For additional details, please see the website or the README.md files in
59 the source tree:
60 <https://github.com/microsoft/MLOS/tree/main/mlos_bench/>
61 """
62 parser = argparse.ArgumentParser(description=f"{description} : {long_text}", epilog=epilog)
63 (args, path_args, args_rest) = self._parse_args(parser, argv)
65 # Bootstrap config loader: command line takes priority.
66 config_path = args.config_path or []
67 self._config_loader = ConfigPersistenceService({"config_path": config_path})
68 if args.config:
69 config = self._config_loader.load_config(args.config, ConfigSchema.CLI)
70 assert isinstance(config, dict)
71 # Merge the args paths for the config loader with the paths from JSON file.
72 config_path += config.get("config_path", [])
73 self._config_loader = ConfigPersistenceService({"config_path": config_path})
74 else:
75 config = {}
77 log_level = args.log_level or config.get("log_level", _LOG_LEVEL)
78 try:
79 log_level = int(log_level)
80 except ValueError:
81 # failed to parse as an int - leave it as a string and let logging
82 # module handle whether it's an appropriate log name or not
83 log_level = logging.getLevelName(log_level)
84 logging.root.setLevel(log_level)
85 log_file = args.log_file or config.get("log_file")
86 if log_file:
87 log_handler = logging.FileHandler(log_file)
88 log_handler.setLevel(log_level)
89 log_handler.setFormatter(logging.Formatter(_LOG_FORMAT))
90 logging.root.addHandler(log_handler)
92 # Prepare global_config from a combination of global config files, cli
93 # configs, and cli args.
94 args_dict = vars(args)
95 # teardown (bool) conflicts with Environment configs that use it for shell
96 # commands (list), so we exclude it from copying over
97 excluded_cli_args = path_args + ["teardown"]
98 # Include (almost) any item from the cli config file that either isn't in
99 # the cli args at all or whose cli arg is missing.
100 cli_config_args = {
101 key: val
102 for (key, val) in config.items()
103 if (args_dict.get(key) is None) and key not in excluded_cli_args
104 }
106 self.global_config = self._load_config(
107 args_globals=config.get("globals", []) + (args.globals or []),
108 config_path=(args.config_path or []) + config.get("config_path", []),
109 args_rest=args_rest,
110 global_config=cli_config_args,
111 )
112 # TODO: Can we generalize these two rules using excluded_cli_args?
113 # experiment_id is generally taken from --globals files, but we also allow
114 # overriding it on the CLI.
115 # It's useful to keep it there explicitly mostly for the --help output.
116 if args.experiment_id:
117 self.global_config["experiment_id"] = args.experiment_id
118 # trial_config_repeat_count is a scheduler property but it's convenient to
119 # set it via command line
120 if args.trial_config_repeat_count:
121 self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count
122 self.global_config.setdefault("num_trial_runners", 1)
123 if args.num_trial_runners:
124 self.global_config["num_trial_runners"] = args.num_trial_runners
125 if self.global_config["num_trial_runners"] <= 0:
126 raise ValueError(
127 f"""Invalid num_trial_runners: {self.global_config["num_trial_runners"]}"""
128 )
129 # Ensure that the trial_id is present since it gets used by some other
130 # configs but is typically controlled by the run optimize loop.
131 self.global_config.setdefault("trial_id", 1)
133 self.global_config = DictTemplater(self.global_config).expand_vars(use_os_env=True)
134 assert isinstance(self.global_config, dict)
136 # --service cli args should override the config file values.
137 service_files: list[str] = config.get("services", []) + (args.service or [])
138 # Add a LocalExecService as the parent service for all other services.
139 self._parent_service: Service = LocalExecService(parent=self._config_loader)
140 assert isinstance(self._parent_service, SupportsConfigLoading)
141 self._parent_service = self._parent_service.load_services(
142 service_files,
143 self.global_config,
144 self._parent_service,
145 )
147 self.storage = self._load_storage(
148 args.storage or config.get("storage"),
149 lazy_schema_create=False if args.create_update_storage_schema_only else None,
150 )
151 _LOG.info("Init storage: %s", self.storage)
153 if args.create_update_storage_schema_only:
154 _LOG.info("Create/update storage schema only.")
155 self.storage.update_schema()
156 sys.exit(0)
158 env_path = args.environment or config.get("environment")
159 if not env_path:
160 _LOG.error("No environment config specified.")
161 parser.error(
162 "At least the Environment config must be specified."
163 " Run `mlos_bench --help` and consult `README.md` for more info."
164 )
165 self.root_env_config = self._config_loader.resolve_path(env_path)
167 # Create the TrialRunners and their Environments and Services from the JSON files.
168 self.trial_runners = TrialRunner.create_from_json(
169 config_loader=self._config_loader,
170 global_config=self.global_config,
171 svcs_json=service_files,
172 env_json=self.root_env_config,
173 num_trial_runners=self.global_config["num_trial_runners"],
174 )
176 _LOG.info(
177 "Init %d trial runners for environments: %s",
178 len(self.trial_runners),
179 [trial_runner.environment for trial_runner in self.trial_runners],
180 )
182 # NOTE: Init tunable values *after* the Environment(s), but *before* the Optimizer
183 # TODO: should we assign the same or different tunables for all TrialRunner Environments?
184 tunable_values: list[str] | str = config.get("tunable_values", [])
185 if isinstance(tunable_values, str):
186 tunable_values = [tunable_values]
187 tunable_values += args.tunable_values or []
188 assert isinstance(tunable_values, list)
189 self.tunables = self._init_tunable_values(
190 self.trial_runners[0].environment,
191 args.random_init or config.get("random_init", False),
192 config.get("random_seed") if args.random_seed is None else args.random_seed,
193 tunable_values,
194 )
195 _LOG.info("Init tunables: %s", self.tunables)
197 self.optimizer = self._load_optimizer(args.optimizer or config.get("optimizer"))
198 _LOG.info("Init optimizer: %s", self.optimizer)
200 self.teardown: bool = (
201 bool(args.teardown)
202 if args.teardown is not None
203 else bool(config.get("teardown", True))
204 )
205 self.scheduler = self._load_scheduler(args.scheduler or config.get("scheduler"))
206 _LOG.info("Init scheduler: %s", self.scheduler)
208 @property
209 def config_loader(self) -> ConfigPersistenceService:
210 """Get the config loader service."""
211 return self._config_loader
213 @property
214 def root_environment(self) -> Environment:
215 """
216 Gets the root (prototypical) Environment from the first TrialRunner.
218 Note: All TrialRunners have the same Environment config and are made
219 unique by their use of the unique trial_runner_id assigned to each
220 TrialRunner's Environment's global_config.
222 Notes
223 -----
224 This is mostly for convenience and backwards compatibility.
225 """
226 return self.trial_runners[0].environment
228 @property
229 def service(self) -> Service:
230 """Get the parent service."""
231 return self._parent_service
233 @staticmethod
234 def _parse_args(
235 parser: argparse.ArgumentParser,
236 argv: list[str] | None,
237 ) -> tuple[argparse.Namespace, list[str], list[str]]:
238 """Parse the command line arguments."""
240 class PathArgsTracker:
241 """Simple class to help track which arguments are paths."""
243 def __init__(self, parser: argparse.ArgumentParser):
244 self._parser = parser
245 self.path_args: list[str] = []
247 def add_argument(self, *args: Any, **kwargs: Any) -> None:
248 """Add an argument to the parser and track its destination."""
249 self.path_args.append(self._parser.add_argument(*args, **kwargs).dest)
251 path_args_tracker = PathArgsTracker(parser)
253 path_args_tracker.add_argument(
254 "--config",
255 required=False,
256 help=(
257 "Main JSON5 configuration file. Its keys are the same as the "
258 "command line options and can be overridden by the latter.\n"
259 "\n"
260 "See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ "
261 "for additional config examples for this and other arguments."
262 ),
263 )
265 path_args_tracker.add_argument(
266 "--log_file",
267 "--log-file",
268 required=False,
269 help="Path to the log file. Use stdout if omitted.",
270 )
272 parser.add_argument(
273 "--log_level",
274 "--log-level",
275 required=False,
276 type=str,
277 help=(
278 f"Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}. "
279 "Set to DEBUG for debug, WARNING for warnings only."
280 ),
281 )
283 path_args_tracker.add_argument(
284 "--config_path",
285 "--config-path",
286 "--config-paths",
287 "--config_paths",
288 nargs="+",
289 action="extend",
290 required=False,
291 help="One or more locations of JSON config files.",
292 )
294 path_args_tracker.add_argument(
295 "--service",
296 "--services",
297 nargs="+",
298 action="extend",
299 required=False,
300 help=(
301 "Path to JSON file with the configuration "
302 "of the service(s) for environment(s) to use."
303 ),
304 )
306 path_args_tracker.add_argument(
307 "--environment",
308 required=False,
309 help="Path to JSON file with the configuration of the benchmarking environment(s).",
310 )
312 path_args_tracker.add_argument(
313 "--optimizer",
314 required=False,
315 help=(
316 "Path to the optimizer configuration file. If omitted, run "
317 "a single trial with default (or specified in --tunable_values)."
318 ),
319 )
321 parser.add_argument(
322 "--trial_config_repeat_count",
323 "--trial-config-repeat-count",
324 required=False,
325 type=int,
326 help=(
327 "Number of times to repeat each config. "
328 "Default is 1 trial per config, though more may be advised."
329 ),
330 )
332 parser.add_argument(
333 "--num_trial_runners",
334 "--num-trial-runners",
335 required=False,
336 type=int,
337 help=(
338 "Number of TrialRunners to use for executing benchmark Environments. "
339 "Individual TrialRunners can be identified in configs with $trial_runner_id "
340 "and optionally run in parallel."
341 ),
342 )
344 path_args_tracker.add_argument(
345 "--scheduler",
346 required=False,
347 help=(
348 "Path to the scheduler configuration file. By default, use "
349 "a single worker synchronous scheduler."
350 ),
351 )
353 path_args_tracker.add_argument(
354 "--storage",
355 required=False,
356 help=(
357 "Path to the storage configuration file. "
358 "If omitted, use the ephemeral in-memory SQL storage."
359 ),
360 )
362 parser.add_argument(
363 "--random_init",
364 "--random-init",
365 required=False,
366 default=False,
367 dest="random_init",
368 action="store_true",
369 help="Initialize tunables with random values. (Before applying --tunable_values).",
370 )
372 parser.add_argument(
373 "--random_seed",
374 "--random-seed",
375 required=False,
376 type=int,
377 help="Seed to use with --random_init",
378 )
380 path_args_tracker.add_argument(
381 "--tunable_values",
382 "--tunable-values",
383 nargs="+",
384 action="extend",
385 required=False,
386 help=(
387 "Path to one or more JSON files that contain values of the tunable "
388 "parameters. This can be used for a single trial (when no --optimizer "
389 "is specified) or as default values for the first run in optimization."
390 ),
391 )
393 path_args_tracker.add_argument(
394 "--globals",
395 nargs="+",
396 action="extend",
397 required=False,
398 help=(
399 "Path to one or more JSON files that contain additional "
400 "[private] parameters of the benchmarking environment."
401 ),
402 )
404 parser.add_argument(
405 "--no_teardown",
406 "--no-teardown",
407 required=False,
408 default=None,
409 dest="teardown",
410 action="store_false",
411 help="Disable teardown of the environment after the benchmark.",
412 )
414 parser.add_argument(
415 "--experiment_id",
416 "--experiment-id",
417 required=False,
418 default=None,
419 help="""
420 Experiment ID to use for the benchmark.
421 If omitted, the value from the --cli config or --globals is used.
423 This is used to store and reload trial results from the storage.
424 NOTE: It is **important** to change this value when incompatible
425 changes are made to config files, scripts, versions, etc.
426 This is left as a manual operation as detection of what is
427 "incompatible" is not easily automatable across systems.
428 """,
429 )
431 parser.add_argument(
432 "--create-update-storage-schema-only",
433 required=False,
434 default=False,
435 dest="create_update_storage_schema_only",
436 action="store_true",
437 help=(
438 "Makes sure that the storage schema is up to date "
439 "for the current version of mlos_bench."
440 ),
441 )
443 # By default we use the command line arguments, but allow the caller to
444 # provide some explicitly for testing purposes.
445 if argv is None:
446 argv = sys.argv[1:].copy()
447 (args, args_rest) = parser.parse_known_args(argv)
449 return (args, path_args_tracker.path_args, args_rest)
451 @staticmethod
452 def _try_parse_extra_args(cmdline: Iterable[str]) -> dict[str, TunableValue]:
453 """Helper function to parse global key/value pairs from the command line."""
454 _LOG.debug("Extra args: %s", cmdline)
456 config: dict[str, TunableValue] = {}
457 key = None
458 for elem in cmdline:
459 if elem.startswith("--"):
460 if key is not None:
461 raise ValueError("Command line argument has no value: " + key)
462 key = elem[2:]
463 kv_split = key.split("=", 1)
464 if len(kv_split) == 2:
465 config[kv_split[0].strip()] = try_parse_val(kv_split[1])
466 key = None
467 else:
468 if key is None:
469 raise ValueError("Command line argument has no key: " + elem)
470 config[key.strip()] = try_parse_val(elem)
471 key = None
473 if key is not None:
474 # Handles missing trailing elem from last --key arg.
475 raise ValueError("Command line argument has no value: " + key)
477 # Convert "max-suggestions" to "max_suggestions" for compatibility with
478 # other CLI options to use as common python/json variable replacements.
479 config = {k.replace("-", "_"): v for k, v in config.items()}
481 _LOG.debug("Parsed config: %s", config)
482 return config
484 def _load_config(
485 self,
486 *,
487 args_globals: Iterable[str],
488 config_path: Iterable[str],
489 args_rest: Iterable[str],
490 global_config: dict[str, Any],
491 ) -> dict[str, Any]:
492 """Get key/value pairs of the global configuration parameters from the specified
493 config files (if any) and command line arguments.
494 """
495 for config_file in args_globals or []:
496 conf = self._config_loader.load_config(config_file, ConfigSchema.GLOBALS)
497 assert isinstance(conf, dict)
498 global_config.update(conf)
499 global_config.update(Launcher._try_parse_extra_args(args_rest))
500 if config_path:
501 global_config["config_path"] = config_path
502 return global_config
504 def _init_tunable_values(
505 self,
506 env: Environment,
507 random_init: bool,
508 seed: int | None,
509 args_tunables: list[str] | None,
510 ) -> TunableGroups:
511 """Initialize the tunables and load key/value pairs of the tunable values from
512 given JSON files, if specified.
513 """
514 tunables = env.tunable_params
515 _LOG.debug("Init tunables: default = %s", tunables)
517 if random_init:
518 tunables = MockOptimizer(
519 tunables=tunables,
520 service=None,
521 config={"start_with_defaults": False, "seed": seed},
522 ).suggest()
523 _LOG.debug("Init tunables: random = %s", tunables)
525 if args_tunables is not None:
526 for data_file in args_tunables:
527 values = self._config_loader.load_config(data_file, ConfigSchema.TUNABLE_VALUES)
528 assert isinstance(values, dict)
529 tunables.assign(values)
530 _LOG.debug("Init tunables: load %s = %s", data_file, tunables)
532 return tunables
534 def _load_optimizer(self, args_optimizer: str | None) -> Optimizer:
535 """
536 Instantiate the Optimizer object from JSON config file, if specified in the
537 --optimizer command line option.
539 If config file not specified, create a one-shot optimizer to run a single
540 benchmark trial.
541 """
542 if args_optimizer is None:
543 # global_config may contain additional properties, so we need to
544 # strip those out before instantiating the basic oneshot optimizer.
545 config = {
546 key: val
547 for key, val in self.global_config.items()
548 if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS
549 }
550 return OneShotOptimizer(self.tunables, config=config, service=self._parent_service)
551 class_config = self._config_loader.load_config(args_optimizer, ConfigSchema.OPTIMIZER)
552 assert isinstance(class_config, dict)
553 optimizer = self._config_loader.build_optimizer(
554 tunables=self.tunables,
555 service=self._parent_service,
556 config=class_config,
557 global_config=self.global_config,
558 )
559 return optimizer
561 def _load_storage(
562 self,
563 args_storage: str | None,
564 lazy_schema_create: bool | None = None,
565 ) -> Storage:
566 """
567 Instantiate the Storage object from JSON file provided in the --storage command
568 line parameter.
570 If omitted, create an ephemeral in-memory SQL storage instead.
571 """
572 if args_storage is None:
573 # pylint: disable=import-outside-toplevel
574 from mlos_bench.storage.sql.storage import SqlStorage
576 return SqlStorage(
577 service=self._parent_service,
578 config={
579 "drivername": "sqlite",
580 "database": ":memory:",
581 "lazy_schema_create": True,
582 },
583 )
584 class_config = self._config_loader.load_config(args_storage, ConfigSchema.STORAGE)
585 assert isinstance(class_config, dict)
586 if lazy_schema_create is not None:
587 class_config["lazy_schema_create"] = lazy_schema_create
588 storage = self._config_loader.build_storage(
589 service=self._parent_service,
590 config=class_config,
591 global_config=self.global_config,
592 )
593 return storage
595 def _load_scheduler(self, args_scheduler: str | None) -> Scheduler:
596 """
597 Instantiate the Scheduler object from JSON file provided in the --scheduler
598 command line parameter.
600 Create a simple synchronous single-threaded scheduler if omitted.
601 """
602 # Set `teardown` for scheduler only to prevent conflicts with other configs.
603 global_config = self.global_config.copy()
604 global_config.setdefault("teardown", self.teardown)
605 if args_scheduler is None:
606 # pylint: disable=import-outside-toplevel
607 from mlos_bench.schedulers.sync_scheduler import SyncScheduler
609 return SyncScheduler(
610 # All config values can be overridden from global config
611 config={
612 "experiment_id": "DEFAULT_EXPERIMENT_ID",
613 "trial_id": 0,
614 "config_id": -1,
615 "trial_config_repeat_count": 1,
616 "teardown": self.teardown,
617 },
618 global_config=self.global_config,
619 trial_runners=self.trial_runners,
620 optimizer=self.optimizer,
621 storage=self.storage,
622 root_env_config=self.root_env_config,
623 )
624 class_config = self._config_loader.load_config(args_scheduler, ConfigSchema.SCHEDULER)
625 assert isinstance(class_config, dict)
626 return self._config_loader.build_scheduler(
627 config=class_config,
628 global_config=self.global_config,
629 trial_runners=self.trial_runners,
630 optimizer=self.optimizer,
631 storage=self.storage,
632 root_env_config=self.root_env_config,
633 )