Coverage for mlos_bench/mlos_bench/util.py: 89%
110 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Various helper functions for mlos_bench.
7"""
9# NOTE: This has to be placed in the top-level mlos_bench package to avoid circular imports.
11from datetime import datetime
12import os
13import json
14import logging
15import importlib
16import subprocess
18from typing import (
19 Any, Callable, Dict, Iterable, Literal, Mapping, Optional,
20 Tuple, Type, TypeVar, TYPE_CHECKING, Union,
21)
23import pandas
24import pytz
27_LOG = logging.getLogger(__name__)
29if TYPE_CHECKING:
30 from mlos_bench.environments.base_environment import Environment
31 from mlos_bench.optimizers.base_optimizer import Optimizer
32 from mlos_bench.schedulers.base_scheduler import Scheduler
33 from mlos_bench.services.base_service import Service
34 from mlos_bench.storage.base_storage import Storage
36# BaseTypeVar is a generic with a constraint of the three base classes.
37BaseTypeVar = TypeVar("BaseTypeVar", "Environment", "Optimizer", "Scheduler", "Service", "Storage")
38BaseTypes = Union["Environment", "Optimizer", "Scheduler", "Service", "Storage"]
41def preprocess_dynamic_configs(*, dest: dict, source: Optional[dict] = None) -> dict:
42 """
43 Replaces all $name values in the destination config with the corresponding
44 value from the source config.
46 Parameters
47 ----------
48 dest : dict
49 Destination config.
50 source : Optional[dict]
51 Source config.
53 Returns
54 -------
55 dest : dict
56 A reference to the destination config after the preprocessing.
57 """
58 if source is None:
59 source = {}
60 for key, val in dest.items():
61 if isinstance(val, str) and val.startswith("$") and val[1:] in source:
62 dest[key] = source[val[1:]]
63 return dest
66def merge_parameters(*, dest: dict, source: Optional[dict] = None,
67 required_keys: Optional[Iterable[str]] = None) -> dict:
68 """
69 Merge the source config dict into the destination config.
70 Pick from the source configs *ONLY* the keys that are already present
71 in the destination config.
73 Parameters
74 ----------
75 dest : dict
76 Destination config.
77 source : Optional[dict]
78 Source config.
79 required_keys : Optional[Iterable[str]]
80 An optional list of keys that must be present in the destination config.
82 Returns
83 -------
84 dest : dict
85 A reference to the destination config after the merge.
86 """
87 if source is None:
88 source = {}
90 for key in set(dest).intersection(source):
91 dest[key] = source[key]
93 for key in required_keys or []:
94 if key in dest:
95 continue
96 if key in source:
97 dest[key] = source[key]
98 else:
99 raise ValueError("Missing required parameter: " + key)
101 return dest
104def path_join(*args: str, abs_path: bool = False) -> str:
105 """
106 Joins the path components and normalizes the path.
108 Parameters
109 ----------
110 args : str
111 Path components.
113 abs_path : bool
114 If True, the path is converted to be absolute.
116 Returns
117 -------
118 str
119 Joined path.
120 """
121 path = os.path.join(*args)
122 if abs_path:
123 path = os.path.abspath(path)
124 return os.path.normpath(path).replace("\\", "/")
127def prepare_class_load(config: dict,
128 global_config: Optional[Dict[str, Any]] = None) -> Tuple[str, Dict[str, Any]]:
129 """
130 Extract the class instantiation parameters from the configuration.
132 Parameters
133 ----------
134 config : dict
135 Configuration of the optimizer.
136 global_config : dict
137 Global configuration parameters (optional).
139 Returns
140 -------
141 (class_name, class_config) : (str, dict)
142 Name of the class to instantiate and its configuration.
143 """
144 class_name = config["class"]
145 class_config = config.setdefault("config", {})
147 merge_parameters(dest=class_config, source=global_config)
149 if _LOG.isEnabledFor(logging.DEBUG):
150 _LOG.debug("Instantiating: %s with config:\n%s",
151 class_name, json.dumps(class_config, indent=2))
153 return (class_name, class_config)
156def get_class_from_name(class_name: str) -> type:
157 """
158 Gets the class from the fully qualified name.
160 Parameters
161 ----------
162 class_name : str
163 Fully qualified class name.
165 Returns
166 -------
167 type
168 Class object.
169 """
170 # We need to import mlos_bench to make the factory methods work.
171 class_name_split = class_name.split(".")
172 module_name = ".".join(class_name_split[:-1])
173 class_id = class_name_split[-1]
175 module = importlib.import_module(module_name)
176 cls = getattr(module, class_id)
177 assert isinstance(cls, type)
178 return cls
181# FIXME: Technically, this should return a type "class_name" derived from "base_class".
182def instantiate_from_config(base_class: Type[BaseTypeVar], class_name: str,
183 *args: Any, **kwargs: Any) -> BaseTypeVar:
184 """
185 Factory method for a new class instantiated from config.
187 Parameters
188 ----------
189 base_class : type
190 Base type of the class to instantiate.
191 Currently it's one of {Environment, Service, Optimizer}.
192 class_name : str
193 FQN of a Python class to instantiate, e.g.,
194 "mlos_bench.environments.remote.HostEnv".
195 Must be derived from the `base_class`.
196 args : list
197 Positional arguments to pass to the constructor.
198 kwargs : dict
199 Keyword arguments to pass to the constructor.
201 Returns
202 -------
203 inst : Union[Environment, Service, Optimizer, Storage]
204 An instance of the `class_name` class.
205 """
206 impl = get_class_from_name(class_name)
207 _LOG.info("Instantiating: %s :: %s", class_name, impl)
209 assert issubclass(impl, base_class)
210 ret: BaseTypeVar = impl(*args, **kwargs)
211 assert isinstance(ret, base_class)
212 return ret
215def check_required_params(config: Mapping[str, Any], required_params: Iterable[str]) -> None:
216 """
217 Check if all required parameters are present in the configuration.
218 Raise ValueError if any of the parameters are missing.
220 Parameters
221 ----------
222 config : dict
223 Free-format dictionary with the configuration
224 of the service or benchmarking environment.
225 required_params : Iterable[str]
226 A collection of identifiers of the parameters that must be present
227 in the configuration.
228 """
229 missing_params = set(required_params).difference(config)
230 if missing_params:
231 raise ValueError(
232 "The following parameters must be provided in the configuration"
233 + f" or as command line arguments: {missing_params}")
236def get_git_info(path: str = __file__) -> Tuple[str, str, str]:
237 """
238 Get the git repository, commit hash, and local path of the given file.
240 Parameters
241 ----------
242 path : str
243 Path to the file in git repository.
245 Returns
246 -------
247 (git_repo, git_commit, git_path) : Tuple[str, str, str]
248 Git repository URL, last commit hash, and relative file path.
249 """
250 dirname = os.path.dirname(path)
251 git_repo = subprocess.check_output(
252 ["git", "-C", dirname, "remote", "get-url", "origin"], text=True).strip()
253 git_commit = subprocess.check_output(
254 ["git", "-C", dirname, "rev-parse", "HEAD"], text=True).strip()
255 git_root = subprocess.check_output(
256 ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True).strip()
257 _LOG.debug("Current git branch: %s %s", git_repo, git_commit)
258 rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root))
259 return (git_repo, git_commit, rel_path.replace("\\", "/"))
262# Note: to avoid circular imports, we don't specify TunableValue here.
263def try_parse_val(val: Optional[str]) -> Optional[Union[int, float, str]]:
264 """
265 Try to parse the value as an int or float, otherwise return the string.
267 This can help with config schema validation to make sure early on that
268 the args we're expecting are the right type.
270 Parameters
271 ----------
272 val : str
273 The initial cmd line arg value.
275 Returns
276 -------
277 TunableValue
278 The parsed value.
279 """
280 if val is None:
281 return val
282 try:
283 val_float = float(val)
284 try:
285 val_int = int(val)
286 return val_int if val_int == val_float else val_float
287 except (ValueError, OverflowError):
288 return val_float
289 except ValueError:
290 return str(val)
293def nullable(func: Callable, value: Optional[Any]) -> Optional[Any]:
294 """
295 Poor man's Maybe monad: apply the function to the value if it's not None.
297 Parameters
298 ----------
299 func : Callable
300 Function to apply to the value.
301 value : Optional[Any]
302 Value to apply the function to.
304 Returns
305 -------
306 value : Optional[Any]
307 The result of the function application or None if the value is None.
308 """
309 return None if value is None else func(value)
312def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> datetime:
313 """
314 Augment a timestamp with zoneinfo if missing and convert it to UTC.
316 Parameters
317 ----------
318 timestamp : datetime
319 A timestamp to convert to UTC.
320 Note: The original datetime may or may not have tzinfo associated with it.
322 origin : Literal["utc", "local"]
323 Whether the source timestamp is considered to be in UTC or local time.
324 In the case of loading data from storage, where we intentionally convert all
325 timestamps to UTC, this can help us retrieve the original timezone when the
326 storage backend doesn't explicitly store it.
327 In the case of receiving data from a client or other source, this can help us
328 convert the timestamp to UTC if it's not already.
330 Returns
331 -------
332 datetime
333 A datetime with zoneinfo in UTC.
334 """
335 if timestamp.tzinfo is not None or origin == "local":
336 # A timestamp with no zoneinfo is interpretted as "local" time
337 # (e.g., according to the TZ environment variable).
338 # That could be UTC or some other timezone, but either way we convert it to
339 # be explicitly UTC with zone info.
340 return timestamp.astimezone(pytz.UTC)
341 elif origin == "utc":
342 # If the timestamp is already in UTC, we just add the zoneinfo without conversion.
343 # Converting with astimezone() when the local time is *not* UTC would cause
344 # a timestamp conversion which we don't want.
345 return timestamp.replace(tzinfo=pytz.UTC)
346 else:
347 raise ValueError(f"Invalid origin: {origin}")
350def utcify_nullable_timestamp(timestamp: Optional[datetime], *, origin: Literal["utc", "local"]) -> Optional[datetime]:
351 """
352 A nullable version of utcify_timestamp.
353 """
354 return utcify_timestamp(timestamp, origin=origin) if timestamp is not None else None
357# All timestamps in the telemetry data must be greater than this date
358# (a very rough approximation for the start of this feature).
359_MIN_TS = datetime(2024, 1, 1, 0, 0, 0, tzinfo=pytz.UTC)
362def datetime_parser(datetime_col: pandas.Series, *, origin: Literal["utc", "local"]) -> pandas.Series:
363 """
364 Attempt to convert a pandas column to a datetime format.
366 Parameters
367 ----------
368 datetime_col : pandas.Series
369 The column to convert.
371 origin : Literal["utc", "local"]
372 Whether to interpret naive timestamps as originating from UTC or local time.
374 Returns
375 -------
376 pandas.Series
377 The converted datetime column.
379 Raises
380 ------
381 ValueError
382 On parse errors.
383 """
384 new_datetime_col = pandas.to_datetime(datetime_col, utc=False)
385 # If timezone data is missing, assume the provided origin timezone.
386 if new_datetime_col.dt.tz is None:
387 if origin == "local":
388 tzinfo = datetime.now().astimezone().tzinfo
389 elif origin == "utc":
390 tzinfo = pytz.UTC
391 else:
392 raise ValueError(f"Invalid timezone origin: {origin}")
393 new_datetime_col = new_datetime_col.dt.tz_localize(tzinfo)
394 assert new_datetime_col.dt.tz is not None
395 # And convert it to UTC.
396 new_datetime_col = new_datetime_col.dt.tz_convert('UTC')
397 if new_datetime_col.isna().any():
398 raise ValueError(f"Invalid date format in the data: {datetime_col}")
399 if new_datetime_col.le(_MIN_TS).any():
400 raise ValueError(f"Invalid date range in the data: {datetime_col}")
401 return new_datetime_col