Coverage for mlos_bench/mlos_bench/tests/__init__.py: 88%
74 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Tests for mlos_bench.
8Used to make mypy happy about multiple conftest.py modules.
9"""
10import filecmp
11import json
12import os
13import shutil
14import socket
15from datetime import tzinfo
16from logging import debug, warning
17from subprocess import run
19import pytest
20import pytz
21from pytest_docker.plugin import Services as DockerServices
23from mlos_bench.util import get_class_from_name, nullable
25ZONE_NAMES = [
26 # Explicit time zones.
27 "UTC",
28 "America/Chicago",
29 "America/Los_Angeles",
30 # Implicit local time zone.
31 None,
32]
33ZONE_INFO: list[tzinfo | None] = [nullable(pytz.timezone, zone_name) for zone_name in ZONE_NAMES]
35BUILT_IN_ENV_VAR_DEFAULTS = {
36 "experiment_id": None,
37 "trial_id": None,
38 "trial_runner_id": None,
39}
41# A decorator for tests that require docker.
42# Use with @requires_docker above a test_...() function.
43DOCKER = shutil.which("docker")
44if DOCKER:
45 cmd = run(
46 "docker builder inspect default || docker buildx inspect default",
47 shell=True,
48 check=False,
49 capture_output=True,
50 )
51 stdout = cmd.stdout.decode()
52 if cmd.returncode != 0 or not any(
53 line for line in stdout.splitlines() if "Platform" in line and "linux" in line
54 ):
55 debug("Docker is available but missing support for targeting linux platform.")
56 DOCKER = None
57requires_docker = pytest.mark.skipif(
58 not DOCKER,
59 reason="Docker with Linux support is not available on this system.",
60)
62# A decorator for tests that require ssh.
63# Use with @requires_ssh above a test_...() function.
64SSH = shutil.which("ssh")
65requires_ssh = pytest.mark.skipif(not SSH, reason="ssh is not available on this system.")
67# A common seed to use to avoid tracking down race conditions and intermingling
68# issues of seeds across tests that run in non-deterministic parallel orders.
69SEED = 42
71# import numpy as np
72# np.random.seed(SEED)
75def try_resolve_class_name(class_name: str | None) -> str | None:
76 """Gets the full class name from the given name or None on error."""
77 if class_name is None:
78 return None
79 try:
80 the_class = get_class_from_name(class_name)
81 return the_class.__module__ + "." + the_class.__name__
82 except (ValueError, AttributeError, ModuleNotFoundError, ImportError):
83 return None
86def check_class_name(obj: object, expected_class_name: str) -> bool:
87 """Compares the class name of the given object with the given name."""
88 full_class_name = obj.__class__.__module__ + "." + obj.__class__.__name__
89 return full_class_name == try_resolve_class_name(expected_class_name)
92def is_docker_service_healthy(
93 compose_project_name: str,
94 service_name: str,
95) -> bool:
96 """Check if a docker service is healthy."""
97 docker_ps_out = run(
98 f"docker compose -p {compose_project_name} " f"ps --format json {service_name}",
99 shell=True,
100 check=True,
101 capture_output=True,
102 )
103 docker_ps_json = json.loads(docker_ps_out.stdout.decode().strip())
104 state = docker_ps_json["State"]
105 assert isinstance(state, str)
106 health = docker_ps_json["Health"]
107 assert isinstance(health, str)
108 return state == "running" and health == "healthy"
111def wait_docker_service_healthy(
112 docker_services: DockerServices,
113 project_name: str,
114 service_name: str,
115 timeout: float = 30.0,
116) -> None:
117 """Wait until a docker service is healthy."""
118 docker_services.wait_until_responsive(
119 check=lambda: is_docker_service_healthy(project_name, service_name),
120 timeout=timeout,
121 pause=0.5,
122 )
125def wait_docker_service_socket(docker_services: DockerServices, hostname: str, port: int) -> None:
126 """Wait until a docker service is ready."""
127 docker_services.wait_until_responsive(
128 check=lambda: check_socket(hostname, port),
129 timeout=30.0,
130 pause=0.5,
131 )
134def check_socket(host: str, port: int, timeout: float = 1.0) -> bool:
135 """
136 Test to see if a socket is open.
138 Parameters
139 ----------
140 host : str
141 port : int
142 timeout: float
144 Returns
145 -------
146 bool
147 """
148 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
149 sock.settimeout(timeout) # seconds
150 result = sock.connect_ex((host, port))
151 return result == 0
154def resolve_host_name(host: str) -> str | None:
155 """
156 Resolves the host name to an IP address.
158 Parameters
159 ----------
160 host : str
162 Returns
163 -------
164 str
165 """
166 try:
167 return socket.gethostbyname(host)
168 except socket.gaierror:
169 return None
172def are_dir_trees_equal(dir1: str, dir2: str) -> bool:
173 """
174 Compare two directories recursively. Files in each directory are assumed to be equal
175 if their names and contents are equal.
177 @param dir1: First directory path @param dir2: Second directory path
179 @return: True if the directory trees are the same and there were no errors while
180 accessing the directories or files, False otherwise.
181 """
182 # See Also: https://stackoverflow.com/a/6681395
183 dirs_cmp = filecmp.dircmp(dir1, dir2)
184 if (
185 len(dirs_cmp.left_only) > 0
186 or len(dirs_cmp.right_only) > 0
187 or len(dirs_cmp.funny_files) > 0
188 ):
189 warning(
190 f"Found differences in dir trees {dir1}, {dir2}:\n"
191 f"{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}"
192 )
193 return False
194 (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False)
195 if len(mismatch) > 0 or len(errors) > 0:
196 warning(f"Found differences in files:\n{mismatch}\n{errors}")
197 return False
198 for common_dir in dirs_cmp.common_dirs:
199 new_dir1 = os.path.join(dir1, common_dir)
200 new_dir2 = os.path.join(dir2, common_dir)
201 if not are_dir_trees_equal(new_dir1, new_dir2):
202 return False
203 return True