Coverage for mlos_bench/mlos_bench/launcher.py: 96%
185 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6A helper class to load the configuration files, parse the command line parameters, and
7instantiate the main components of mlos_bench system.
9It is used in the :py:mod:`mlos_bench.run` module to run the benchmark/optimizer
10from the command line.
11"""
13import argparse
14import logging
15import sys
16from typing import Any, Dict, Iterable, List, Optional, Tuple
18from mlos_bench.config.schemas import ConfigSchema
19from mlos_bench.dict_templater import DictTemplater
20from mlos_bench.environments.base_environment import Environment
21from mlos_bench.optimizers.base_optimizer import Optimizer
22from mlos_bench.optimizers.mock_optimizer import MockOptimizer
23from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer
24from mlos_bench.schedulers.base_scheduler import Scheduler
25from mlos_bench.services.base_service import Service
26from mlos_bench.services.config_persistence import ConfigPersistenceService
27from mlos_bench.services.local.local_exec import LocalExecService
28from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
29from mlos_bench.storage.base_storage import Storage
30from mlos_bench.tunables.tunable import TunableValue
31from mlos_bench.tunables.tunable_groups import TunableGroups
32from mlos_bench.util import try_parse_val
34_LOG_LEVEL = logging.INFO
35_LOG_FORMAT = "%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s"
36logging.basicConfig(level=_LOG_LEVEL, format=_LOG_FORMAT)
38_LOG = logging.getLogger(__name__)
41class Launcher:
42 # pylint: disable=too-few-public-methods,too-many-instance-attributes
43 """Command line launcher for mlos_bench and mlos_core."""
45 def __init__(self, description: str, long_text: str = "", argv: Optional[List[str]] = None):
46 # pylint: disable=too-many-statements
47 # pylint: disable=too-many-locals
48 _LOG.info("Launch: %s", description)
49 epilog = """
50 Additional --key=value pairs can be specified to augment or override
51 values listed in --globals.
52 Other required_args values can also be pulled from shell environment
53 variables.
55 For additional details, please see the website or the README.md files in
56 the source tree:
57 <https://github.com/microsoft/MLOS/tree/main/mlos_bench/>
58 """
59 parser = argparse.ArgumentParser(description=f"{description} : {long_text}", epilog=epilog)
60 (args, path_args, args_rest) = self._parse_args(parser, argv)
62 # Bootstrap config loader: command line takes priority.
63 config_path = args.config_path or []
64 self._config_loader = ConfigPersistenceService({"config_path": config_path})
65 if args.config:
66 config = self._config_loader.load_config(args.config, ConfigSchema.CLI)
67 assert isinstance(config, Dict)
68 # Merge the args paths for the config loader with the paths from JSON file.
69 config_path += config.get("config_path", [])
70 self._config_loader = ConfigPersistenceService({"config_path": config_path})
71 else:
72 config = {}
74 log_level = args.log_level or config.get("log_level", _LOG_LEVEL)
75 try:
76 log_level = int(log_level)
77 except ValueError:
78 # failed to parse as an int - leave it as a string and let logging
79 # module handle whether it's an appropriate log name or not
80 log_level = logging.getLevelName(log_level)
81 logging.root.setLevel(log_level)
82 log_file = args.log_file or config.get("log_file")
83 if log_file:
84 log_handler = logging.FileHandler(log_file)
85 log_handler.setLevel(log_level)
86 log_handler.setFormatter(logging.Formatter(_LOG_FORMAT))
87 logging.root.addHandler(log_handler)
89 self._parent_service: Service = LocalExecService(parent=self._config_loader)
91 # Prepare global_config from a combination of global config files, cli
92 # configs, and cli args.
93 args_dict = vars(args)
94 # teardown (bool) conflicts with Environment configs that use it for shell
95 # commands (list), so we exclude it from copying over
96 excluded_cli_args = path_args + ["teardown"]
97 # Include (almost) any item from the cli config file that either isn't in
98 # the cli args at all or whose cli arg is missing.
99 cli_config_args = {
100 key: val
101 for (key, val) in config.items()
102 if (args_dict.get(key) is None) and key not in excluded_cli_args
103 }
105 self.global_config = self._load_config(
106 args_globals=config.get("globals", []) + (args.globals or []),
107 config_path=(args.config_path or []) + config.get("config_path", []),
108 args_rest=args_rest,
109 global_config=cli_config_args,
110 )
111 # experiment_id is generally taken from --globals files, but we also allow
112 # overriding it on the CLI.
113 # It's useful to keep it there explicitly mostly for the --help output.
114 if args.experiment_id:
115 self.global_config["experiment_id"] = args.experiment_id
116 # trial_config_repeat_count is a scheduler property but it's convenient to
117 # set it via command line
118 if args.trial_config_repeat_count:
119 self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count
120 # Ensure that the trial_id is present since it gets used by some other
121 # configs but is typically controlled by the run optimize loop.
122 self.global_config.setdefault("trial_id", 1)
124 self.global_config = DictTemplater(self.global_config).expand_vars(use_os_env=True)
125 assert isinstance(self.global_config, dict)
127 # --service cli args should override the config file values.
128 service_files: List[str] = config.get("services", []) + (args.service or [])
129 assert isinstance(self._parent_service, SupportsConfigLoading)
130 self._parent_service = self._parent_service.load_services(
131 service_files,
132 self.global_config,
133 self._parent_service,
134 )
136 env_path = args.environment or config.get("environment")
137 if not env_path:
138 _LOG.error("No environment config specified.")
139 parser.error(
140 "At least the Environment config must be specified."
141 + " Run `mlos_bench --help` and consult `README.md` for more info."
142 )
143 self.root_env_config = self._config_loader.resolve_path(env_path)
145 self.environment: Environment = self._config_loader.load_environment(
146 self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service
147 )
148 _LOG.info("Init environment: %s", self.environment)
150 # NOTE: Init tunable values *after* the Environment, but *before* the Optimizer
151 self.tunables = self._init_tunable_values(
152 args.random_init or config.get("random_init", False),
153 config.get("random_seed") if args.random_seed is None else args.random_seed,
154 config.get("tunable_values", []) + (args.tunable_values or []),
155 )
156 _LOG.info("Init tunables: %s", self.tunables)
158 self.optimizer = self._load_optimizer(args.optimizer or config.get("optimizer"))
159 _LOG.info("Init optimizer: %s", self.optimizer)
161 self.storage = self._load_storage(args.storage or config.get("storage"))
162 _LOG.info("Init storage: %s", self.storage)
164 self.teardown: bool = (
165 bool(args.teardown)
166 if args.teardown is not None
167 else bool(config.get("teardown", True))
168 )
169 self.scheduler = self._load_scheduler(args.scheduler or config.get("scheduler"))
170 _LOG.info("Init scheduler: %s", self.scheduler)
172 @property
173 def config_loader(self) -> ConfigPersistenceService:
174 """Get the config loader service."""
175 return self._config_loader
177 @property
178 def service(self) -> Service:
179 """Get the parent service."""
180 return self._parent_service
182 @staticmethod
183 def _parse_args(
184 parser: argparse.ArgumentParser,
185 argv: Optional[List[str]],
186 ) -> Tuple[argparse.Namespace, List[str], List[str]]:
187 """Parse the command line arguments."""
189 class PathArgsTracker:
190 """Simple class to help track which arguments are paths."""
192 def __init__(self, parser: argparse.ArgumentParser):
193 self._parser = parser
194 self.path_args: List[str] = []
196 def add_argument(self, *args: Any, **kwargs: Any) -> None:
197 """Add an argument to the parser and track its destination."""
198 self.path_args.append(self._parser.add_argument(*args, **kwargs).dest)
200 path_args_tracker = PathArgsTracker(parser)
202 path_args_tracker.add_argument(
203 "--config",
204 required=False,
205 help=(
206 "Main JSON5 configuration file. Its keys are the same as the "
207 "command line options and can be overridden by the latter.\n"
208 "\n"
209 "See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ "
210 "for additional config examples for this and other arguments."
211 ),
212 )
214 path_args_tracker.add_argument(
215 "--log_file",
216 "--log-file",
217 required=False,
218 help="Path to the log file. Use stdout if omitted.",
219 )
221 parser.add_argument(
222 "--log_level",
223 "--log-level",
224 required=False,
225 type=str,
226 help=(
227 f"Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}. "
228 "Set to DEBUG for debug, WARNING for warnings only."
229 ),
230 )
232 path_args_tracker.add_argument(
233 "--config_path",
234 "--config-path",
235 "--config-paths",
236 "--config_paths",
237 nargs="+",
238 action="extend",
239 required=False,
240 help="One or more locations of JSON config files.",
241 )
243 path_args_tracker.add_argument(
244 "--service",
245 "--services",
246 nargs="+",
247 action="extend",
248 required=False,
249 help=(
250 "Path to JSON file with the configuration "
251 "of the service(s) for environment(s) to use."
252 ),
253 )
255 path_args_tracker.add_argument(
256 "--environment",
257 required=False,
258 help="Path to JSON file with the configuration of the benchmarking environment(s).",
259 )
261 path_args_tracker.add_argument(
262 "--optimizer",
263 required=False,
264 help=(
265 "Path to the optimizer configuration file. If omitted, run "
266 "a single trial with default (or specified in --tunable_values)."
267 ),
268 )
270 parser.add_argument(
271 "--trial_config_repeat_count",
272 "--trial-config-repeat-count",
273 required=False,
274 type=int,
275 help=(
276 "Number of times to repeat each config. "
277 "Default is 1 trial per config, though more may be advised."
278 ),
279 )
281 path_args_tracker.add_argument(
282 "--scheduler",
283 required=False,
284 help=(
285 "Path to the scheduler configuration file. By default, use "
286 "a single worker synchronous scheduler."
287 ),
288 )
290 path_args_tracker.add_argument(
291 "--storage",
292 required=False,
293 help=(
294 "Path to the storage configuration file. "
295 "If omitted, use the ephemeral in-memory SQL storage."
296 ),
297 )
299 parser.add_argument(
300 "--random_init",
301 "--random-init",
302 required=False,
303 default=False,
304 dest="random_init",
305 action="store_true",
306 help="Initialize tunables with random values. (Before applying --tunable_values).",
307 )
309 parser.add_argument(
310 "--random_seed",
311 "--random-seed",
312 required=False,
313 type=int,
314 help="Seed to use with --random_init",
315 )
317 path_args_tracker.add_argument(
318 "--tunable_values",
319 "--tunable-values",
320 nargs="+",
321 action="extend",
322 required=False,
323 help=(
324 "Path to one or more JSON files that contain values of the tunable "
325 "parameters. This can be used for a single trial (when no --optimizer "
326 "is specified) or as default values for the first run in optimization."
327 ),
328 )
330 path_args_tracker.add_argument(
331 "--globals",
332 nargs="+",
333 action="extend",
334 required=False,
335 help=(
336 "Path to one or more JSON files that contain additional "
337 "[private] parameters of the benchmarking environment."
338 ),
339 )
341 parser.add_argument(
342 "--no_teardown",
343 "--no-teardown",
344 required=False,
345 default=None,
346 dest="teardown",
347 action="store_false",
348 help="Disable teardown of the environment after the benchmark.",
349 )
351 parser.add_argument(
352 "--experiment_id",
353 "--experiment-id",
354 required=False,
355 default=None,
356 help="""
357 Experiment ID to use for the benchmark.
358 If omitted, the value from the --cli config or --globals is used.
360 This is used to store and reload trial results from the storage.
361 NOTE: It is **important** to change this value when incompatible
362 changes are made to config files, scripts, versions, etc.
363 This is left as a manual operation as detection of what is
364 "incompatible" is not easily automatable across systems.
365 """,
366 )
368 # By default we use the command line arguments, but allow the caller to
369 # provide some explicitly for testing purposes.
370 if argv is None:
371 argv = sys.argv[1:].copy()
372 (args, args_rest) = parser.parse_known_args(argv)
374 return (args, path_args_tracker.path_args, args_rest)
376 @staticmethod
377 def _try_parse_extra_args(cmdline: Iterable[str]) -> Dict[str, TunableValue]:
378 """Helper function to parse global key/value pairs from the command line."""
379 _LOG.debug("Extra args: %s", cmdline)
381 config: Dict[str, TunableValue] = {}
382 key = None
383 for elem in cmdline:
384 if elem.startswith("--"):
385 if key is not None:
386 raise ValueError("Command line argument has no value: " + key)
387 key = elem[2:]
388 kv_split = key.split("=", 1)
389 if len(kv_split) == 2:
390 config[kv_split[0].strip()] = try_parse_val(kv_split[1])
391 key = None
392 else:
393 if key is None:
394 raise ValueError("Command line argument has no key: " + elem)
395 config[key.strip()] = try_parse_val(elem)
396 key = None
398 if key is not None:
399 # Handles missing trailing elem from last --key arg.
400 raise ValueError("Command line argument has no value: " + key)
402 # Convert "max-suggestions" to "max_suggestions" for compatibility with
403 # other CLI options to use as common python/json variable replacements.
404 config = {k.replace("-", "_"): v for k, v in config.items()}
406 _LOG.debug("Parsed config: %s", config)
407 return config
409 def _load_config(
410 self,
411 *,
412 args_globals: Iterable[str],
413 config_path: Iterable[str],
414 args_rest: Iterable[str],
415 global_config: Dict[str, Any],
416 ) -> Dict[str, Any]:
417 """Get key/value pairs of the global configuration parameters from the specified
418 config files (if any) and command line arguments.
419 """
420 for config_file in args_globals or []:
421 conf = self._config_loader.load_config(config_file, ConfigSchema.GLOBALS)
422 assert isinstance(conf, dict)
423 global_config.update(conf)
424 global_config.update(Launcher._try_parse_extra_args(args_rest))
425 if config_path:
426 global_config["config_path"] = config_path
427 return global_config
429 def _init_tunable_values(
430 self,
431 random_init: bool,
432 seed: Optional[int],
433 args_tunables: Optional[str],
434 ) -> TunableGroups:
435 """Initialize the tunables and load key/value pairs of the tunable values from
436 given JSON files, if specified.
437 """
438 tunables = self.environment.tunable_params
439 _LOG.debug("Init tunables: default = %s", tunables)
441 if random_init:
442 tunables = MockOptimizer(
443 tunables=tunables,
444 service=None,
445 config={"start_with_defaults": False, "seed": seed},
446 ).suggest()
447 _LOG.debug("Init tunables: random = %s", tunables)
449 if args_tunables is not None:
450 for data_file in args_tunables:
451 values = self._config_loader.load_config(data_file, ConfigSchema.TUNABLE_VALUES)
452 assert isinstance(values, Dict)
453 tunables.assign(values)
454 _LOG.debug("Init tunables: load %s = %s", data_file, tunables)
456 return tunables
458 def _load_optimizer(self, args_optimizer: Optional[str]) -> Optimizer:
459 """
460 Instantiate the Optimizer object from JSON config file, if specified in the
461 --optimizer command line option.
463 If config file not specified, create a one-shot optimizer to run a single
464 benchmark trial.
465 """
466 if args_optimizer is None:
467 # global_config may contain additional properties, so we need to
468 # strip those out before instantiating the basic oneshot optimizer.
469 config = {
470 key: val
471 for key, val in self.global_config.items()
472 if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS
473 }
474 return OneShotOptimizer(self.tunables, config=config, service=self._parent_service)
475 class_config = self._config_loader.load_config(args_optimizer, ConfigSchema.OPTIMIZER)
476 assert isinstance(class_config, Dict)
477 optimizer = self._config_loader.build_optimizer(
478 tunables=self.tunables,
479 service=self._parent_service,
480 config=class_config,
481 global_config=self.global_config,
482 )
483 return optimizer
485 def _load_storage(self, args_storage: Optional[str]) -> Storage:
486 """
487 Instantiate the Storage object from JSON file provided in the --storage command
488 line parameter.
490 If omitted, create an ephemeral in-memory SQL storage instead.
491 """
492 if args_storage is None:
493 # pylint: disable=import-outside-toplevel
494 from mlos_bench.storage.sql.storage import SqlStorage
496 return SqlStorage(
497 service=self._parent_service,
498 config={
499 "drivername": "sqlite",
500 "database": ":memory:",
501 "lazy_schema_create": True,
502 },
503 )
504 class_config = self._config_loader.load_config(args_storage, ConfigSchema.STORAGE)
505 assert isinstance(class_config, Dict)
506 storage = self._config_loader.build_storage(
507 service=self._parent_service,
508 config=class_config,
509 global_config=self.global_config,
510 )
511 return storage
513 def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler:
514 """
515 Instantiate the Scheduler object from JSON file provided in the --scheduler
516 command line parameter.
518 Create a simple synchronous single-threaded scheduler if omitted.
519 """
520 # Set `teardown` for scheduler only to prevent conflicts with other configs.
521 global_config = self.global_config.copy()
522 global_config.setdefault("teardown", self.teardown)
523 if args_scheduler is None:
524 # pylint: disable=import-outside-toplevel
525 from mlos_bench.schedulers.sync_scheduler import SyncScheduler
527 return SyncScheduler(
528 # All config values can be overridden from global config
529 config={
530 "experiment_id": "UNDEFINED - override from global config",
531 "trial_id": 0,
532 "config_id": -1,
533 "trial_config_repeat_count": 1,
534 "teardown": self.teardown,
535 },
536 global_config=self.global_config,
537 environment=self.environment,
538 optimizer=self.optimizer,
539 storage=self.storage,
540 root_env_config=self.root_env_config,
541 )
542 class_config = self._config_loader.load_config(args_scheduler, ConfigSchema.SCHEDULER)
543 assert isinstance(class_config, Dict)
544 return self._config_loader.build_scheduler(
545 config=class_config,
546 global_config=self.global_config,
547 environment=self.environment,
548 optimizer=self.optimizer,
549 storage=self.storage,
550 root_env_config=self.root_env_config,
551 )