Coverage for mlos_bench/mlos_bench/util.py: 92%

161 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-30 00:51 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Various helper functions for mlos_bench.""" 

6 

7# NOTE: This has to be placed in the top-level mlos_bench package to avoid circular imports. 

8 

9import importlib 

10import json 

11import logging 

12import os 

13import subprocess 

14from collections.abc import Callable, Iterable, Mapping 

15from datetime import datetime 

16from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union 

17 

18import pandas 

19import pytz 

20 

21_LOG = logging.getLogger(__name__) 

22 

23if TYPE_CHECKING: 

24 from mlos_bench.environments.base_environment import Environment 

25 from mlos_bench.optimizers.base_optimizer import Optimizer 

26 from mlos_bench.schedulers.base_scheduler import Scheduler 

27 from mlos_bench.services.base_service import Service 

28 from mlos_bench.storage.base_storage import Storage 

29 

30BaseTypeVar = TypeVar("BaseTypeVar", "Environment", "Optimizer", "Scheduler", "Service", "Storage") 

31"""BaseTypeVar is a generic with a constraint of the main base classes (e.g., 

32:py:class:`~mlos_bench.environments.base_environment.Environment`, 

33:py:class:`~mlos_bench.optimizers.base_optimizer.Optimizer`, 

34:py:class:`~mlos_bench.schedulers.base_scheduler.Scheduler`, 

35:py:class:`~mlos_bench.services.base_service.Service`, 

36:py:class:`~mlos_bench.storage.base_storage.Storage`, etc.). 

37""" 

38 

39BaseTypes = Union[ # pylint: disable=consider-alternative-union-syntax 

40 "Environment", "Optimizer", "Scheduler", "Service", "Storage" 

41] 

42"""Similar to :py:data:`.BaseTypeVar`, BaseTypes is a Union of the main base classes.""" 

43 

44 

45# Adjusted from https://github.com/python/cpython/blob/v3.11.10/Lib/distutils/util.py#L308 

46# See Also: https://github.com/microsoft/MLOS/issues/865 

47def strtobool(val: str) -> bool: 

48 """ 

49 Convert a string representation of truth to true (1) or false (0). 

50 

51 Parameters 

52 ---------- 

53 val : str 

54 True values are 'y', 'yes', 't', 'true', 'on', and '1'; 

55 False values are 'n', 'no', 'f', 'false', 'off', and '0'. 

56 

57 Raises 

58 ------ 

59 ValueError 

60 If 'val' is anything else. 

61 """ 

62 val = val.lower() 

63 if val in {"y", "yes", "t", "true", "on", "1"}: 

64 return True 

65 elif val in {"n", "no", "f", "false", "off", "0"}: 

66 return False 

67 else: 

68 raise ValueError(f"Invalid Boolean value: '{val}'") 

69 

70 

71def preprocess_dynamic_configs(*, dest: dict, source: dict | None = None) -> dict: 

72 """ 

73 Replaces all ``$name`` values in the destination config with the corresponding value 

74 from the source config. 

75 

76 Parameters 

77 ---------- 

78 dest : dict 

79 Destination config. 

80 source : dict | None 

81 Source config. 

82 

83 Returns 

84 ------- 

85 dest : dict 

86 A reference to the destination config after the preprocessing. 

87 """ 

88 if source is None: 

89 source = {} 

90 for key, val in dest.items(): 

91 if isinstance(val, str) and val.startswith("$") and val[1:] in source: 

92 dest[key] = source[val[1:]] 

93 return dest 

94 

95 

96def merge_parameters( 

97 *, 

98 dest: dict, 

99 source: dict | None = None, 

100 required_keys: Iterable[str] | None = None, 

101) -> dict: 

102 """ 

103 Merge the source config dict into the destination config. Pick from the source 

104 configs *ONLY* the keys that are already present in the destination config. 

105 

106 Parameters 

107 ---------- 

108 dest : dict 

109 Destination config. 

110 source : dict | None 

111 Source config. 

112 required_keys : Optional[Iterable[str]] 

113 An optional list of keys that must be present in the destination config. 

114 

115 Returns 

116 ------- 

117 dest : dict 

118 A reference to the destination config after the merge. 

119 """ 

120 if source is None: 

121 source = {} 

122 

123 for key in set(dest).intersection(source): 

124 dest[key] = source[key] 

125 

126 for key in required_keys or []: 

127 if key in dest: 

128 continue 

129 if key in source: 

130 dest[key] = source[key] 

131 else: 

132 raise ValueError("Missing required parameter: " + key) 

133 

134 return dest 

135 

136 

137def path_join(*args: str, abs_path: bool = False) -> str: 

138 """ 

139 Joins the path components and normalizes the path. 

140 

141 Parameters 

142 ---------- 

143 args : str 

144 Path components. 

145 

146 abs_path : bool 

147 If True, the path is converted to be absolute. 

148 

149 Returns 

150 ------- 

151 str 

152 Joined path. 

153 """ 

154 path = os.path.join(*args) 

155 if abs_path: 

156 path = os.path.realpath(path) 

157 return os.path.normpath(path).replace("\\", "/") 

158 

159 

160def prepare_class_load( 

161 config: dict, 

162 global_config: dict[str, Any] | None = None, 

163) -> tuple[str, dict[str, Any]]: 

164 """ 

165 Extract the class instantiation parameters from the configuration. 

166 

167 Parameters 

168 ---------- 

169 config : dict 

170 Configuration of the optimizer. 

171 global_config : dict 

172 Global configuration parameters (optional). 

173 

174 Returns 

175 ------- 

176 (class_name, class_config) : (str, dict) 

177 Name of the class to instantiate and its configuration. 

178 """ 

179 class_name = config["class"] 

180 class_config = config.setdefault("config", {}) 

181 

182 merge_parameters(dest=class_config, source=global_config) 

183 

184 if _LOG.isEnabledFor(logging.DEBUG): 

185 _LOG.debug( 

186 "Instantiating: %s with config:\n%s", class_name, json.dumps(class_config, indent=2) 

187 ) 

188 

189 return (class_name, class_config) 

190 

191 

192def get_class_from_name(class_name: str) -> type: 

193 """ 

194 Gets the class from the fully qualified name. 

195 

196 Parameters 

197 ---------- 

198 class_name : str 

199 Fully qualified class name. 

200 

201 Returns 

202 ------- 

203 type 

204 Class object. 

205 """ 

206 # We need to import mlos_bench to make the factory methods work. 

207 class_name_split = class_name.split(".") 

208 module_name = ".".join(class_name_split[:-1]) 

209 class_id = class_name_split[-1] 

210 

211 module = importlib.import_module(module_name) 

212 cls = getattr(module, class_id) 

213 assert isinstance(cls, type) 

214 return cls 

215 

216 

217# FIXME: Technically, this should return a type "class_name" derived from "base_class". 

218def instantiate_from_config( 

219 base_class: type[BaseTypeVar], 

220 class_name: str, 

221 *args: Any, 

222 **kwargs: Any, 

223) -> BaseTypeVar: 

224 """ 

225 Factory method for a new class instantiated from config. 

226 

227 Parameters 

228 ---------- 

229 base_class : type 

230 Base type of the class to instantiate. 

231 Currently it's one of {Environment, Service, Optimizer}. 

232 class_name : str 

233 FQN of a Python class to instantiate, e.g., 

234 "mlos_bench.environments.remote.HostEnv". 

235 Must be derived from the `base_class`. 

236 args : list 

237 Positional arguments to pass to the constructor. 

238 kwargs : dict 

239 Keyword arguments to pass to the constructor. 

240 

241 Returns 

242 ------- 

243 inst : Union[Environment, Service, Optimizer, Storage] 

244 An instance of the `class_name` class. 

245 """ 

246 impl = get_class_from_name(class_name) 

247 _LOG.info("Instantiating: %s :: %s", class_name, impl) 

248 

249 assert issubclass(impl, base_class) 

250 ret: BaseTypeVar = impl(*args, **kwargs) 

251 assert isinstance(ret, base_class) 

252 return ret 

253 

254 

255def check_required_params(config: Mapping[str, Any], required_params: Iterable[str]) -> None: 

256 """ 

257 Check if all required parameters are present in the configuration. Raise ValueError 

258 if any of the parameters are missing. 

259 

260 Parameters 

261 ---------- 

262 config : dict 

263 Free-format dictionary with the configuration 

264 of the service or benchmarking environment. 

265 required_params : Iterable[str] 

266 A collection of identifiers of the parameters that must be present 

267 in the configuration. 

268 """ 

269 missing_params = set(required_params).difference(config) 

270 if missing_params: 

271 raise ValueError( 

272 "The following parameters must be provided in the configuration" 

273 + f" or as command line arguments: {missing_params}" 

274 ) 

275 

276 

277def get_git_root(path: str = __file__) -> str: 

278 """ 

279 Get the root dir of the git repository. 

280 

281 Parameters 

282 ---------- 

283 path : Optional[str] 

284 Path to the file in git repository. 

285 

286 Raises 

287 ------ 

288 subprocess.CalledProcessError 

289 If the path is not a git repository or the command fails. 

290 

291 Returns 

292 ------- 

293 str 

294 The absolute path to the root directory of the git repository. 

295 """ 

296 abspath = path_join(path, abs_path=True) 

297 if not os.path.exists(abspath) or not os.path.isdir(abspath): 

298 dirname = os.path.dirname(abspath) 

299 else: 

300 dirname = abspath 

301 git_root = subprocess.check_output( 

302 ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True 

303 ).strip() 

304 return path_join(git_root, abs_path=True) 

305 

306 

307def get_git_remote_info(path: str, remote: str) -> str: 

308 """ 

309 Gets the remote URL for the given remote name in the git repository. 

310 

311 Parameters 

312 ---------- 

313 path : str 

314 Path to the file in git repository. 

315 remote : str 

316 The name of the remote (e.g., "origin"). 

317 

318 Raises 

319 ------ 

320 subprocess.CalledProcessError 

321 If the command fails or the remote does not exist. 

322 

323 Returns 

324 ------- 

325 str 

326 The URL of the remote repository. 

327 """ 

328 return subprocess.check_output( 

329 ["git", "-C", path, "remote", "get-url", remote], text=True 

330 ).strip() 

331 

332 

333def get_git_repo_info(path: str) -> str: 

334 """ 

335 Get the git repository URL for the given git repo. 

336 

337 Tries to get the upstream branch URL, falling back to the "origin" remote 

338 if the upstream branch is not set or does not exist. If that also fails, 

339 it returns a file URL pointing to the local path. 

340 

341 Parameters 

342 ---------- 

343 path : str 

344 Path to the git repository. 

345 

346 Raises 

347 ------ 

348 subprocess.CalledProcessError 

349 If the command fails or the git repository does not exist. 

350 

351 Returns 

352 ------- 

353 str 

354 The upstream URL of the git repository. 

355 """ 

356 # In case "origin" remote is not set, or this branch has a different 

357 # upstream, we should handle it gracefully. 

358 # (e.g., fallback to the first one we find?) 

359 path = path_join(path, abs_path=True) 

360 cmd = ["git", "-C", path, "rev-parse", "--abbrev-ref", "--symbolic-full-name", "HEAD@{u}"] 

361 try: 

362 git_remote = subprocess.check_output(cmd, text=True).strip() 

363 git_remote = git_remote.split("/", 1)[0] 

364 git_repo = get_git_remote_info(path, git_remote) 

365 except subprocess.CalledProcessError: 

366 git_remote = "origin" 

367 _LOG.warning( 

368 "Failed to get the upstream branch for %s. Falling back to '%s' remote.", 

369 path, 

370 git_remote, 

371 ) 

372 try: 

373 git_repo = get_git_remote_info(path, git_remote) 

374 except subprocess.CalledProcessError: 

375 git_repo = "file://" + path 

376 _LOG.warning( 

377 "Failed to get the upstream branch for %s. Falling back to '%s'.", 

378 path, 

379 git_repo, 

380 ) 

381 return git_repo 

382 

383 

384def get_git_info(path: str = __file__) -> tuple[str, str, str, str]: 

385 """ 

386 Get the git repository, commit hash, and local path of the given file. 

387 

388 Parameters 

389 ---------- 

390 path : str 

391 Path to the file in git repository. 

392 

393 Raises 

394 ------ 

395 subprocess.CalledProcessError 

396 If the path is not a git repository or the command fails. 

397 

398 Returns 

399 ------- 

400 (git_repo, git_commit, rel_path, abs_path) : tuple[str, str, str, str] 

401 Git repository URL, last commit hash, and relative file path and current 

402 absolute path. 

403 """ 

404 abspath = path_join(path, abs_path=True) 

405 if os.path.exists(abspath) and os.path.isdir(abspath): 

406 dirname = abspath 

407 else: 

408 dirname = os.path.dirname(abspath) 

409 git_root = get_git_root(path=abspath) 

410 git_repo = get_git_repo_info(git_root) 

411 git_commit = subprocess.check_output( 

412 ["git", "-C", dirname, "rev-parse", "HEAD"], text=True 

413 ).strip() 

414 _LOG.debug("Current git branch for %s: %s %s", git_root, git_repo, git_commit) 

415 rel_path = os.path.relpath(abspath, os.path.abspath(git_root)) 

416 # TODO: return the branch too? 

417 return (git_repo, git_commit, rel_path.replace("\\", "/"), abspath) 

418 

419 

420# TODO: Add support for checking out the branch locally. 

421 

422 

423# Note: to avoid circular imports, we don't specify TunableValue here. 

424def try_parse_val(val: str | None) -> int | float | str | None: 

425 """ 

426 Try to parse the value as an int or float, otherwise return the string. 

427 

428 This can help with config schema validation to make sure early on that 

429 the args we're expecting are the right type. 

430 

431 Parameters 

432 ---------- 

433 val : str 

434 The initial cmd line arg value. 

435 

436 Returns 

437 ------- 

438 TunableValue 

439 The parsed value. 

440 """ 

441 if val is None: 

442 return val 

443 try: 

444 val_float = float(val) 

445 try: 

446 val_int = int(val) 

447 return val_int if val_int == val_float else val_float 

448 except (ValueError, OverflowError): 

449 return val_float 

450 except ValueError: 

451 return str(val) 

452 

453 

454NullableT = TypeVar("NullableT") 

455"""A generic type variable for :py:func:`nullable` return types.""" 

456 

457 

458def nullable(func: Callable[..., NullableT], value: Any | None) -> NullableT | None: 

459 """ 

460 Poor man's Maybe monad: apply the function to the value if it's not None. 

461 

462 Parameters 

463 ---------- 

464 func : Callable 

465 Function to apply to the value. 

466 value : Any | None 

467 Value to apply the function to. 

468 

469 Returns 

470 ------- 

471 value : NullableT | None 

472 The result of the function application or None if the value is None. 

473 

474 Examples 

475 -------- 

476 >>> nullable(int, "1") 

477 1 

478 >>> nullable(int, None) 

479 ... 

480 >>> nullable(str, 1) 

481 '1' 

482 """ 

483 return None if value is None else func(value) 

484 

485 

486def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> datetime: 

487 """ 

488 Augment a timestamp with zoneinfo if missing and convert it to UTC. 

489 

490 Parameters 

491 ---------- 

492 timestamp : datetime.datetime 

493 A timestamp to convert to UTC. 

494 Note: The original datetime may or may not have tzinfo associated with it. 

495 

496 origin : Literal["utc", "local"] 

497 Whether the source timestamp is considered to be in UTC or local time. 

498 In the case of loading data from storage, where we intentionally convert all 

499 timestamps to UTC, this can help us retrieve the original timezone when the 

500 storage backend doesn't explicitly store it. 

501 In the case of receiving data from a client or other source, this can help us 

502 convert the timestamp to UTC if it's not already. 

503 

504 Returns 

505 ------- 

506 datetime.datetime 

507 A datetime with zoneinfo in UTC. 

508 """ 

509 if timestamp.tzinfo is not None or origin == "local": 

510 # A timestamp with no zoneinfo is interpretted as "local" time 

511 # (e.g., according to the TZ environment variable). 

512 # That could be UTC or some other timezone, but either way we convert it to 

513 # be explicitly UTC with zone info. 

514 return timestamp.astimezone(pytz.UTC) 

515 elif origin == "utc": 

516 # If the timestamp is already in UTC, we just add the zoneinfo without conversion. 

517 # Converting with astimezone() when the local time is *not* UTC would cause 

518 # a timestamp conversion which we don't want. 

519 return timestamp.replace(tzinfo=pytz.UTC) 

520 else: 

521 raise ValueError(f"Invalid origin: {origin}") 

522 

523 

524def utcify_nullable_timestamp( 

525 timestamp: datetime | None, 

526 *, 

527 origin: Literal["utc", "local"], 

528) -> datetime | None: 

529 """A nullable version of utcify_timestamp.""" 

530 return utcify_timestamp(timestamp, origin=origin) if timestamp is not None else None 

531 

532 

533# All timestamps in the telemetry data must be greater than this date 

534# (a very rough approximation for the start of this feature). 

535_MIN_TS = datetime(2024, 1, 1, 0, 0, 0, tzinfo=pytz.UTC) 

536 

537 

538def datetime_parser( 

539 datetime_col: pandas.Series, 

540 *, 

541 origin: Literal["utc", "local"], 

542) -> pandas.Series: 

543 """ 

544 Attempt to convert a pandas column to a datetime format. 

545 

546 Parameters 

547 ---------- 

548 datetime_col : pandas.Series 

549 The column to convert. 

550 

551 origin : Literal["utc", "local"] 

552 Whether to interpret naive timestamps as originating from UTC or local time. 

553 

554 Returns 

555 ------- 

556 pandas.Series 

557 The converted datetime column. 

558 

559 Raises 

560 ------ 

561 ValueError 

562 On parse errors. 

563 """ 

564 new_datetime_col = pandas.to_datetime(datetime_col, utc=False) 

565 # If timezone data is missing, assume the provided origin timezone. 

566 if new_datetime_col.dt.tz is None: 

567 if origin == "local": 

568 tzinfo = datetime.now().astimezone().tzinfo 

569 elif origin == "utc": 

570 tzinfo = pytz.UTC 

571 else: 

572 raise ValueError(f"Invalid timezone origin: {origin}") 

573 new_datetime_col = new_datetime_col.dt.tz_localize(tzinfo) 

574 assert new_datetime_col.dt.tz is not None 

575 # And convert it to UTC. 

576 new_datetime_col = new_datetime_col.dt.tz_convert("UTC") 

577 if new_datetime_col.isna().any(): 

578 raise ValueError(f"Invalid date format in the data: {datetime_col}") 

579 if new_datetime_col.le(_MIN_TS).any(): 

580 raise ValueError(f"Invalid date range in the data: {datetime_col}") 

581 return new_datetime_col 

582 

583 

584def sanitize_config(config: dict[str, Any]) -> dict[str, Any]: 

585 """ 

586 Sanitize a configuration dictionary by obfuscating potentially sensitive keys. 

587 

588 Parameters 

589 ---------- 

590 config : dict 

591 Configuration dictionary to sanitize. 

592 

593 Returns 

594 ------- 

595 dict 

596 Sanitized configuration dictionary. 

597 """ 

598 sanitize_keys = {"password", "secret", "token", "api_key"} 

599 

600 def recursive_sanitize(conf: dict[str, Any]) -> dict[str, Any]: 

601 """Recursively sanitize a dictionary.""" 

602 sanitized = {} 

603 for k, v in conf.items(): 

604 if k in sanitize_keys: 

605 sanitized[k] = "[REDACTED]" 

606 elif isinstance(v, dict): 

607 sanitized[k] = recursive_sanitize(v) # type: ignore[assignment] 

608 else: 

609 sanitized[k] = v 

610 return sanitized 

611 

612 return recursive_sanitize(config)