Coverage for mlos_bench/mlos_bench/util.py: 92%
161 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-30 00:51 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-30 00:51 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Various helper functions for mlos_bench."""
7# NOTE: This has to be placed in the top-level mlos_bench package to avoid circular imports.
9import importlib
10import json
11import logging
12import os
13import subprocess
14from collections.abc import Callable, Iterable, Mapping
15from datetime import datetime
16from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union
18import pandas
19import pytz
21_LOG = logging.getLogger(__name__)
23if TYPE_CHECKING:
24 from mlos_bench.environments.base_environment import Environment
25 from mlos_bench.optimizers.base_optimizer import Optimizer
26 from mlos_bench.schedulers.base_scheduler import Scheduler
27 from mlos_bench.services.base_service import Service
28 from mlos_bench.storage.base_storage import Storage
30BaseTypeVar = TypeVar("BaseTypeVar", "Environment", "Optimizer", "Scheduler", "Service", "Storage")
31"""BaseTypeVar is a generic with a constraint of the main base classes (e.g.,
32:py:class:`~mlos_bench.environments.base_environment.Environment`,
33:py:class:`~mlos_bench.optimizers.base_optimizer.Optimizer`,
34:py:class:`~mlos_bench.schedulers.base_scheduler.Scheduler`,
35:py:class:`~mlos_bench.services.base_service.Service`,
36:py:class:`~mlos_bench.storage.base_storage.Storage`, etc.).
37"""
39BaseTypes = Union[ # pylint: disable=consider-alternative-union-syntax
40 "Environment", "Optimizer", "Scheduler", "Service", "Storage"
41]
42"""Similar to :py:data:`.BaseTypeVar`, BaseTypes is a Union of the main base classes."""
45# Adjusted from https://github.com/python/cpython/blob/v3.11.10/Lib/distutils/util.py#L308
46# See Also: https://github.com/microsoft/MLOS/issues/865
47def strtobool(val: str) -> bool:
48 """
49 Convert a string representation of truth to true (1) or false (0).
51 Parameters
52 ----------
53 val : str
54 True values are 'y', 'yes', 't', 'true', 'on', and '1';
55 False values are 'n', 'no', 'f', 'false', 'off', and '0'.
57 Raises
58 ------
59 ValueError
60 If 'val' is anything else.
61 """
62 val = val.lower()
63 if val in {"y", "yes", "t", "true", "on", "1"}:
64 return True
65 elif val in {"n", "no", "f", "false", "off", "0"}:
66 return False
67 else:
68 raise ValueError(f"Invalid Boolean value: '{val}'")
71def preprocess_dynamic_configs(*, dest: dict, source: dict | None = None) -> dict:
72 """
73 Replaces all ``$name`` values in the destination config with the corresponding value
74 from the source config.
76 Parameters
77 ----------
78 dest : dict
79 Destination config.
80 source : dict | None
81 Source config.
83 Returns
84 -------
85 dest : dict
86 A reference to the destination config after the preprocessing.
87 """
88 if source is None:
89 source = {}
90 for key, val in dest.items():
91 if isinstance(val, str) and val.startswith("$") and val[1:] in source:
92 dest[key] = source[val[1:]]
93 return dest
96def merge_parameters(
97 *,
98 dest: dict,
99 source: dict | None = None,
100 required_keys: Iterable[str] | None = None,
101) -> dict:
102 """
103 Merge the source config dict into the destination config. Pick from the source
104 configs *ONLY* the keys that are already present in the destination config.
106 Parameters
107 ----------
108 dest : dict
109 Destination config.
110 source : dict | None
111 Source config.
112 required_keys : Optional[Iterable[str]]
113 An optional list of keys that must be present in the destination config.
115 Returns
116 -------
117 dest : dict
118 A reference to the destination config after the merge.
119 """
120 if source is None:
121 source = {}
123 for key in set(dest).intersection(source):
124 dest[key] = source[key]
126 for key in required_keys or []:
127 if key in dest:
128 continue
129 if key in source:
130 dest[key] = source[key]
131 else:
132 raise ValueError("Missing required parameter: " + key)
134 return dest
137def path_join(*args: str, abs_path: bool = False) -> str:
138 """
139 Joins the path components and normalizes the path.
141 Parameters
142 ----------
143 args : str
144 Path components.
146 abs_path : bool
147 If True, the path is converted to be absolute.
149 Returns
150 -------
151 str
152 Joined path.
153 """
154 path = os.path.join(*args)
155 if abs_path:
156 path = os.path.realpath(path)
157 return os.path.normpath(path).replace("\\", "/")
160def prepare_class_load(
161 config: dict,
162 global_config: dict[str, Any] | None = None,
163) -> tuple[str, dict[str, Any]]:
164 """
165 Extract the class instantiation parameters from the configuration.
167 Parameters
168 ----------
169 config : dict
170 Configuration of the optimizer.
171 global_config : dict
172 Global configuration parameters (optional).
174 Returns
175 -------
176 (class_name, class_config) : (str, dict)
177 Name of the class to instantiate and its configuration.
178 """
179 class_name = config["class"]
180 class_config = config.setdefault("config", {})
182 merge_parameters(dest=class_config, source=global_config)
184 if _LOG.isEnabledFor(logging.DEBUG):
185 _LOG.debug(
186 "Instantiating: %s with config:\n%s", class_name, json.dumps(class_config, indent=2)
187 )
189 return (class_name, class_config)
192def get_class_from_name(class_name: str) -> type:
193 """
194 Gets the class from the fully qualified name.
196 Parameters
197 ----------
198 class_name : str
199 Fully qualified class name.
201 Returns
202 -------
203 type
204 Class object.
205 """
206 # We need to import mlos_bench to make the factory methods work.
207 class_name_split = class_name.split(".")
208 module_name = ".".join(class_name_split[:-1])
209 class_id = class_name_split[-1]
211 module = importlib.import_module(module_name)
212 cls = getattr(module, class_id)
213 assert isinstance(cls, type)
214 return cls
217# FIXME: Technically, this should return a type "class_name" derived from "base_class".
218def instantiate_from_config(
219 base_class: type[BaseTypeVar],
220 class_name: str,
221 *args: Any,
222 **kwargs: Any,
223) -> BaseTypeVar:
224 """
225 Factory method for a new class instantiated from config.
227 Parameters
228 ----------
229 base_class : type
230 Base type of the class to instantiate.
231 Currently it's one of {Environment, Service, Optimizer}.
232 class_name : str
233 FQN of a Python class to instantiate, e.g.,
234 "mlos_bench.environments.remote.HostEnv".
235 Must be derived from the `base_class`.
236 args : list
237 Positional arguments to pass to the constructor.
238 kwargs : dict
239 Keyword arguments to pass to the constructor.
241 Returns
242 -------
243 inst : Union[Environment, Service, Optimizer, Storage]
244 An instance of the `class_name` class.
245 """
246 impl = get_class_from_name(class_name)
247 _LOG.info("Instantiating: %s :: %s", class_name, impl)
249 assert issubclass(impl, base_class)
250 ret: BaseTypeVar = impl(*args, **kwargs)
251 assert isinstance(ret, base_class)
252 return ret
255def check_required_params(config: Mapping[str, Any], required_params: Iterable[str]) -> None:
256 """
257 Check if all required parameters are present in the configuration. Raise ValueError
258 if any of the parameters are missing.
260 Parameters
261 ----------
262 config : dict
263 Free-format dictionary with the configuration
264 of the service or benchmarking environment.
265 required_params : Iterable[str]
266 A collection of identifiers of the parameters that must be present
267 in the configuration.
268 """
269 missing_params = set(required_params).difference(config)
270 if missing_params:
271 raise ValueError(
272 "The following parameters must be provided in the configuration"
273 + f" or as command line arguments: {missing_params}"
274 )
277def get_git_root(path: str = __file__) -> str:
278 """
279 Get the root dir of the git repository.
281 Parameters
282 ----------
283 path : Optional[str]
284 Path to the file in git repository.
286 Raises
287 ------
288 subprocess.CalledProcessError
289 If the path is not a git repository or the command fails.
291 Returns
292 -------
293 str
294 The absolute path to the root directory of the git repository.
295 """
296 abspath = path_join(path, abs_path=True)
297 if not os.path.exists(abspath) or not os.path.isdir(abspath):
298 dirname = os.path.dirname(abspath)
299 else:
300 dirname = abspath
301 git_root = subprocess.check_output(
302 ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True
303 ).strip()
304 return path_join(git_root, abs_path=True)
307def get_git_remote_info(path: str, remote: str) -> str:
308 """
309 Gets the remote URL for the given remote name in the git repository.
311 Parameters
312 ----------
313 path : str
314 Path to the file in git repository.
315 remote : str
316 The name of the remote (e.g., "origin").
318 Raises
319 ------
320 subprocess.CalledProcessError
321 If the command fails or the remote does not exist.
323 Returns
324 -------
325 str
326 The URL of the remote repository.
327 """
328 return subprocess.check_output(
329 ["git", "-C", path, "remote", "get-url", remote], text=True
330 ).strip()
333def get_git_repo_info(path: str) -> str:
334 """
335 Get the git repository URL for the given git repo.
337 Tries to get the upstream branch URL, falling back to the "origin" remote
338 if the upstream branch is not set or does not exist. If that also fails,
339 it returns a file URL pointing to the local path.
341 Parameters
342 ----------
343 path : str
344 Path to the git repository.
346 Raises
347 ------
348 subprocess.CalledProcessError
349 If the command fails or the git repository does not exist.
351 Returns
352 -------
353 str
354 The upstream URL of the git repository.
355 """
356 # In case "origin" remote is not set, or this branch has a different
357 # upstream, we should handle it gracefully.
358 # (e.g., fallback to the first one we find?)
359 path = path_join(path, abs_path=True)
360 cmd = ["git", "-C", path, "rev-parse", "--abbrev-ref", "--symbolic-full-name", "HEAD@{u}"]
361 try:
362 git_remote = subprocess.check_output(cmd, text=True).strip()
363 git_remote = git_remote.split("/", 1)[0]
364 git_repo = get_git_remote_info(path, git_remote)
365 except subprocess.CalledProcessError:
366 git_remote = "origin"
367 _LOG.warning(
368 "Failed to get the upstream branch for %s. Falling back to '%s' remote.",
369 path,
370 git_remote,
371 )
372 try:
373 git_repo = get_git_remote_info(path, git_remote)
374 except subprocess.CalledProcessError:
375 git_repo = "file://" + path
376 _LOG.warning(
377 "Failed to get the upstream branch for %s. Falling back to '%s'.",
378 path,
379 git_repo,
380 )
381 return git_repo
384def get_git_info(path: str = __file__) -> tuple[str, str, str, str]:
385 """
386 Get the git repository, commit hash, and local path of the given file.
388 Parameters
389 ----------
390 path : str
391 Path to the file in git repository.
393 Raises
394 ------
395 subprocess.CalledProcessError
396 If the path is not a git repository or the command fails.
398 Returns
399 -------
400 (git_repo, git_commit, rel_path, abs_path) : tuple[str, str, str, str]
401 Git repository URL, last commit hash, and relative file path and current
402 absolute path.
403 """
404 abspath = path_join(path, abs_path=True)
405 if os.path.exists(abspath) and os.path.isdir(abspath):
406 dirname = abspath
407 else:
408 dirname = os.path.dirname(abspath)
409 git_root = get_git_root(path=abspath)
410 git_repo = get_git_repo_info(git_root)
411 git_commit = subprocess.check_output(
412 ["git", "-C", dirname, "rev-parse", "HEAD"], text=True
413 ).strip()
414 _LOG.debug("Current git branch for %s: %s %s", git_root, git_repo, git_commit)
415 rel_path = os.path.relpath(abspath, os.path.abspath(git_root))
416 # TODO: return the branch too?
417 return (git_repo, git_commit, rel_path.replace("\\", "/"), abspath)
420# TODO: Add support for checking out the branch locally.
423# Note: to avoid circular imports, we don't specify TunableValue here.
424def try_parse_val(val: str | None) -> int | float | str | None:
425 """
426 Try to parse the value as an int or float, otherwise return the string.
428 This can help with config schema validation to make sure early on that
429 the args we're expecting are the right type.
431 Parameters
432 ----------
433 val : str
434 The initial cmd line arg value.
436 Returns
437 -------
438 TunableValue
439 The parsed value.
440 """
441 if val is None:
442 return val
443 try:
444 val_float = float(val)
445 try:
446 val_int = int(val)
447 return val_int if val_int == val_float else val_float
448 except (ValueError, OverflowError):
449 return val_float
450 except ValueError:
451 return str(val)
454NullableT = TypeVar("NullableT")
455"""A generic type variable for :py:func:`nullable` return types."""
458def nullable(func: Callable[..., NullableT], value: Any | None) -> NullableT | None:
459 """
460 Poor man's Maybe monad: apply the function to the value if it's not None.
462 Parameters
463 ----------
464 func : Callable
465 Function to apply to the value.
466 value : Any | None
467 Value to apply the function to.
469 Returns
470 -------
471 value : NullableT | None
472 The result of the function application or None if the value is None.
474 Examples
475 --------
476 >>> nullable(int, "1")
477 1
478 >>> nullable(int, None)
479 ...
480 >>> nullable(str, 1)
481 '1'
482 """
483 return None if value is None else func(value)
486def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> datetime:
487 """
488 Augment a timestamp with zoneinfo if missing and convert it to UTC.
490 Parameters
491 ----------
492 timestamp : datetime.datetime
493 A timestamp to convert to UTC.
494 Note: The original datetime may or may not have tzinfo associated with it.
496 origin : Literal["utc", "local"]
497 Whether the source timestamp is considered to be in UTC or local time.
498 In the case of loading data from storage, where we intentionally convert all
499 timestamps to UTC, this can help us retrieve the original timezone when the
500 storage backend doesn't explicitly store it.
501 In the case of receiving data from a client or other source, this can help us
502 convert the timestamp to UTC if it's not already.
504 Returns
505 -------
506 datetime.datetime
507 A datetime with zoneinfo in UTC.
508 """
509 if timestamp.tzinfo is not None or origin == "local":
510 # A timestamp with no zoneinfo is interpretted as "local" time
511 # (e.g., according to the TZ environment variable).
512 # That could be UTC or some other timezone, but either way we convert it to
513 # be explicitly UTC with zone info.
514 return timestamp.astimezone(pytz.UTC)
515 elif origin == "utc":
516 # If the timestamp is already in UTC, we just add the zoneinfo without conversion.
517 # Converting with astimezone() when the local time is *not* UTC would cause
518 # a timestamp conversion which we don't want.
519 return timestamp.replace(tzinfo=pytz.UTC)
520 else:
521 raise ValueError(f"Invalid origin: {origin}")
524def utcify_nullable_timestamp(
525 timestamp: datetime | None,
526 *,
527 origin: Literal["utc", "local"],
528) -> datetime | None:
529 """A nullable version of utcify_timestamp."""
530 return utcify_timestamp(timestamp, origin=origin) if timestamp is not None else None
533# All timestamps in the telemetry data must be greater than this date
534# (a very rough approximation for the start of this feature).
535_MIN_TS = datetime(2024, 1, 1, 0, 0, 0, tzinfo=pytz.UTC)
538def datetime_parser(
539 datetime_col: pandas.Series,
540 *,
541 origin: Literal["utc", "local"],
542) -> pandas.Series:
543 """
544 Attempt to convert a pandas column to a datetime format.
546 Parameters
547 ----------
548 datetime_col : pandas.Series
549 The column to convert.
551 origin : Literal["utc", "local"]
552 Whether to interpret naive timestamps as originating from UTC or local time.
554 Returns
555 -------
556 pandas.Series
557 The converted datetime column.
559 Raises
560 ------
561 ValueError
562 On parse errors.
563 """
564 new_datetime_col = pandas.to_datetime(datetime_col, utc=False)
565 # If timezone data is missing, assume the provided origin timezone.
566 if new_datetime_col.dt.tz is None:
567 if origin == "local":
568 tzinfo = datetime.now().astimezone().tzinfo
569 elif origin == "utc":
570 tzinfo = pytz.UTC
571 else:
572 raise ValueError(f"Invalid timezone origin: {origin}")
573 new_datetime_col = new_datetime_col.dt.tz_localize(tzinfo)
574 assert new_datetime_col.dt.tz is not None
575 # And convert it to UTC.
576 new_datetime_col = new_datetime_col.dt.tz_convert("UTC")
577 if new_datetime_col.isna().any():
578 raise ValueError(f"Invalid date format in the data: {datetime_col}")
579 if new_datetime_col.le(_MIN_TS).any():
580 raise ValueError(f"Invalid date range in the data: {datetime_col}")
581 return new_datetime_col
584def sanitize_config(config: dict[str, Any]) -> dict[str, Any]:
585 """
586 Sanitize a configuration dictionary by obfuscating potentially sensitive keys.
588 Parameters
589 ----------
590 config : dict
591 Configuration dictionary to sanitize.
593 Returns
594 -------
595 dict
596 Sanitized configuration dictionary.
597 """
598 sanitize_keys = {"password", "secret", "token", "api_key"}
600 def recursive_sanitize(conf: dict[str, Any]) -> dict[str, Any]:
601 """Recursively sanitize a dictionary."""
602 sanitized = {}
603 for k, v in conf.items():
604 if k in sanitize_keys:
605 sanitized[k] = "[REDACTED]"
606 elif isinstance(v, dict):
607 sanitized[k] = recursive_sanitize(v) # type: ignore[assignment]
608 else:
609 sanitized[k] = v
610 return sanitized
612 return recursive_sanitize(config)