Coverage for mlos_bench/mlos_bench/tests/__init__.py: 84%
107 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-30 00:51 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-30 00:51 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Tests for mlos_bench.
8Used to make mypy happy about multiple conftest.py modules.
9"""
10import filecmp
11import json
12import os
13import shutil
14import socket
15import stat
16import sys
17from datetime import tzinfo
18from subprocess import run
19from warnings import warn
21import pytest
22import pytz
23from pytest_docker.plugin import Services as DockerServices
25from mlos_bench.util import get_class_from_name, nullable
27ZONE_NAMES = [
28 # Explicit time zones.
29 "UTC",
30 "America/Chicago",
31 "America/Los_Angeles",
32 # Implicit local time zone.
33 None,
34]
35ZONE_INFO: list[tzinfo | None] = [nullable(pytz.timezone, zone_name) for zone_name in ZONE_NAMES]
37BUILT_IN_ENV_VAR_DEFAULTS = {
38 "experiment_id": None,
39 "trial_id": None,
40 "trial_runner_id": None,
41}
44DOCKER = shutil.which("docker")
45if DOCKER:
46 # Gathering info about Github CI docker.sock permissions for debugging purposes.
47 DOCKER_SOCK_PATH: str
48 if sys.platform == "win32":
49 DOCKER_SOCK_PATH = "//./pipe/docker_engine"
50 else:
51 DOCKER_SOCK_PATH = "/var/run/docker.sock"
53 mode: str | None = None
54 uid: int | None = None
55 gid: int | None = None
56 current_uid: int | None = None
57 current_gid: int | None = None
58 gids: list[int] | None = None
59 try:
60 st = os.stat(DOCKER_SOCK_PATH)
61 mode = stat.filemode(st.st_mode)
62 uid = st.st_uid
63 gid = st.st_gid
64 except Exception as e: # pylint: disable=broad-except
65 warn(f"Could not stat {DOCKER_SOCK_PATH}: {e}", UserWarning)
66 try:
67 if sys.platform != "win32":
68 current_uid = os.getuid()
69 current_gid = os.getgid()
70 gids = os.getgroups()
71 if not os.access(DOCKER_SOCK_PATH, os.W_OK):
72 warn(f"Docker socket {DOCKER_SOCK_PATH} is not writable.", UserWarning)
73 except Exception as e: # pylint: disable=broad-except
74 warn(f"Could not get current user info: {e}", UserWarning)
76 cmd = run(
77 "docker builder inspect default || docker buildx inspect default",
78 shell=True,
79 check=False,
80 capture_output=True,
81 )
82 stdout = cmd.stdout.decode()
83 stderr = cmd.stderr.decode()
84 if cmd.returncode != 0 or not any(
85 line for line in stdout.splitlines() if "Platform" in line and "linux" in line
86 ):
87 DOCKER = None
88 warn(
89 "Docker is available but missing buildx support for targeting linux platform:\n"
90 + f"stdout:\n{stdout}\n"
91 + f"stderr:\n{stderr}\n"
92 + f"sock_path: {DOCKER_SOCK_PATH} sock mode: {mode} sock uid: {uid} gid: {gid}\n"
93 + f"current_uid: {current_uid} groups: {gids}\n",
94 UserWarning,
95 )
97if not DOCKER:
98 warn("Docker is not available on this system. Some tests will be skipped.", UserWarning)
100# A decorator for tests that require docker.
101# Use with @requires_docker above a test_...() function.
102requires_docker = pytest.mark.skipif(
103 not DOCKER,
104 reason="Docker with Linux support is not available on this system.",
105)
107# A decorator for tests that require ssh.
108# Use with @requires_ssh above a test_...() function.
109SSH = shutil.which("ssh")
110if not SSH:
111 warn("ssh is not available on this system. Some tests will be skipped.", UserWarning)
112requires_ssh = pytest.mark.skipif(not SSH, reason="ssh is not available on this system.")
114# A common seed to use to avoid tracking down race conditions and intermingling
115# issues of seeds across tests that run in non-deterministic parallel orders.
116SEED = 42
118# import numpy as np
119# np.random.seed(SEED)
122def try_resolve_class_name(class_name: str | None) -> str | None:
123 """Gets the full class name from the given name or None on error."""
124 if class_name is None:
125 return None
126 try:
127 the_class = get_class_from_name(class_name)
128 return the_class.__module__ + "." + the_class.__name__
129 except (ValueError, AttributeError, ModuleNotFoundError, ImportError):
130 return None
133def check_class_name(obj: object, expected_class_name: str) -> bool:
134 """Compares the class name of the given object with the given name."""
135 full_class_name = obj.__class__.__module__ + "." + obj.__class__.__name__
136 return full_class_name == try_resolve_class_name(expected_class_name)
139def is_docker_service_healthy(
140 compose_project_name: str,
141 service_name: str,
142) -> bool:
143 """Check if a docker service is healthy."""
144 docker_ps_out = run(
145 f"docker compose -p {compose_project_name} " f"ps --format json {service_name}",
146 shell=True,
147 check=True,
148 capture_output=True,
149 )
150 docker_ps_json = json.loads(docker_ps_out.stdout.decode().strip())
151 state = docker_ps_json["State"]
152 assert isinstance(state, str)
153 health = docker_ps_json["Health"]
154 assert isinstance(health, str)
155 return state == "running" and health == "healthy"
158def wait_docker_service_healthy(
159 docker_services: DockerServices,
160 project_name: str,
161 service_name: str,
162 timeout: float = 60.0,
163) -> None:
164 """Wait until a docker service is healthy."""
165 docker_services.wait_until_responsive(
166 check=lambda: is_docker_service_healthy(project_name, service_name),
167 timeout=timeout,
168 pause=0.5,
169 )
172def wait_docker_service_socket(docker_services: DockerServices, hostname: str, port: int) -> None:
173 """Wait until a docker service is ready."""
174 docker_services.wait_until_responsive(
175 check=lambda: check_socket(hostname, port),
176 timeout=60.0,
177 pause=0.5,
178 )
181def check_socket(host: str, port: int, timeout: float = 1.0) -> bool:
182 """
183 Test to see if a socket is open.
185 Parameters
186 ----------
187 host : str
188 port : int
189 timeout: float
191 Returns
192 -------
193 bool
194 """
195 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
196 sock.settimeout(timeout) # seconds
197 result = sock.connect_ex((host, port))
198 return result == 0
201def resolve_host_name(host: str) -> str | None:
202 """
203 Resolves the host name to an IP address.
205 Parameters
206 ----------
207 host : str
209 Returns
210 -------
211 str
212 """
213 try:
214 return socket.gethostbyname(host)
215 except socket.gaierror:
216 return None
219def are_dir_trees_equal(dir1: str, dir2: str) -> bool:
220 """
221 Compare two directories recursively. Files in each directory are assumed to be equal
222 if their names and contents are equal.
224 @param dir1: First directory path @param dir2: Second directory path
226 @return: True if the directory trees are the same and there were no errors while
227 accessing the directories or files, False otherwise.
228 """
229 # See Also: https://stackoverflow.com/a/6681395
230 dirs_cmp = filecmp.dircmp(dir1, dir2)
231 if (
232 len(dirs_cmp.left_only) > 0
233 or len(dirs_cmp.right_only) > 0
234 or len(dirs_cmp.funny_files) > 0
235 ):
236 warn(
237 UserWarning(
238 f"Found differences in dir trees {dir1}, {dir2}:\n"
239 f"{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}"
240 )
241 )
242 return False
243 (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False)
244 if len(mismatch) > 0 or len(errors) > 0:
245 warn(f"Found differences in files:\n{mismatch}\n{errors}", UserWarning)
246 return False
247 for common_dir in dirs_cmp.common_dirs:
248 new_dir1 = os.path.join(dir1, common_dir)
249 new_dir2 = os.path.join(dir2, common_dir)
250 if not are_dir_trees_equal(new_dir1, new_dir2):
251 return False
252 return True