Coverage for mlos_bench/mlos_bench/tests/__init__.py: 84%

107 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-30 00:51 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Tests for mlos_bench. 

7 

8Used to make mypy happy about multiple conftest.py modules. 

9""" 

10import filecmp 

11import json 

12import os 

13import shutil 

14import socket 

15import stat 

16import sys 

17from datetime import tzinfo 

18from subprocess import run 

19from warnings import warn 

20 

21import pytest 

22import pytz 

23from pytest_docker.plugin import Services as DockerServices 

24 

25from mlos_bench.util import get_class_from_name, nullable 

26 

27ZONE_NAMES = [ 

28 # Explicit time zones. 

29 "UTC", 

30 "America/Chicago", 

31 "America/Los_Angeles", 

32 # Implicit local time zone. 

33 None, 

34] 

35ZONE_INFO: list[tzinfo | None] = [nullable(pytz.timezone, zone_name) for zone_name in ZONE_NAMES] 

36 

37BUILT_IN_ENV_VAR_DEFAULTS = { 

38 "experiment_id": None, 

39 "trial_id": None, 

40 "trial_runner_id": None, 

41} 

42 

43 

44DOCKER = shutil.which("docker") 

45if DOCKER: 

46 # Gathering info about Github CI docker.sock permissions for debugging purposes. 

47 DOCKER_SOCK_PATH: str 

48 if sys.platform == "win32": 

49 DOCKER_SOCK_PATH = "//./pipe/docker_engine" 

50 else: 

51 DOCKER_SOCK_PATH = "/var/run/docker.sock" 

52 

53 mode: str | None = None 

54 uid: int | None = None 

55 gid: int | None = None 

56 current_uid: int | None = None 

57 current_gid: int | None = None 

58 gids: list[int] | None = None 

59 try: 

60 st = os.stat(DOCKER_SOCK_PATH) 

61 mode = stat.filemode(st.st_mode) 

62 uid = st.st_uid 

63 gid = st.st_gid 

64 except Exception as e: # pylint: disable=broad-except 

65 warn(f"Could not stat {DOCKER_SOCK_PATH}: {e}", UserWarning) 

66 try: 

67 if sys.platform != "win32": 

68 current_uid = os.getuid() 

69 current_gid = os.getgid() 

70 gids = os.getgroups() 

71 if not os.access(DOCKER_SOCK_PATH, os.W_OK): 

72 warn(f"Docker socket {DOCKER_SOCK_PATH} is not writable.", UserWarning) 

73 except Exception as e: # pylint: disable=broad-except 

74 warn(f"Could not get current user info: {e}", UserWarning) 

75 

76 cmd = run( 

77 "docker builder inspect default || docker buildx inspect default", 

78 shell=True, 

79 check=False, 

80 capture_output=True, 

81 ) 

82 stdout = cmd.stdout.decode() 

83 stderr = cmd.stderr.decode() 

84 if cmd.returncode != 0 or not any( 

85 line for line in stdout.splitlines() if "Platform" in line and "linux" in line 

86 ): 

87 DOCKER = None 

88 warn( 

89 "Docker is available but missing buildx support for targeting linux platform:\n" 

90 + f"stdout:\n{stdout}\n" 

91 + f"stderr:\n{stderr}\n" 

92 + f"sock_path: {DOCKER_SOCK_PATH} sock mode: {mode} sock uid: {uid} gid: {gid}\n" 

93 + f"current_uid: {current_uid} groups: {gids}\n", 

94 UserWarning, 

95 ) 

96 

97if not DOCKER: 

98 warn("Docker is not available on this system. Some tests will be skipped.", UserWarning) 

99 

100# A decorator for tests that require docker. 

101# Use with @requires_docker above a test_...() function. 

102requires_docker = pytest.mark.skipif( 

103 not DOCKER, 

104 reason="Docker with Linux support is not available on this system.", 

105) 

106 

107# A decorator for tests that require ssh. 

108# Use with @requires_ssh above a test_...() function. 

109SSH = shutil.which("ssh") 

110if not SSH: 

111 warn("ssh is not available on this system. Some tests will be skipped.", UserWarning) 

112requires_ssh = pytest.mark.skipif(not SSH, reason="ssh is not available on this system.") 

113 

114# A common seed to use to avoid tracking down race conditions and intermingling 

115# issues of seeds across tests that run in non-deterministic parallel orders. 

116SEED = 42 

117 

118# import numpy as np 

119# np.random.seed(SEED) 

120 

121 

122def try_resolve_class_name(class_name: str | None) -> str | None: 

123 """Gets the full class name from the given name or None on error.""" 

124 if class_name is None: 

125 return None 

126 try: 

127 the_class = get_class_from_name(class_name) 

128 return the_class.__module__ + "." + the_class.__name__ 

129 except (ValueError, AttributeError, ModuleNotFoundError, ImportError): 

130 return None 

131 

132 

133def check_class_name(obj: object, expected_class_name: str) -> bool: 

134 """Compares the class name of the given object with the given name.""" 

135 full_class_name = obj.__class__.__module__ + "." + obj.__class__.__name__ 

136 return full_class_name == try_resolve_class_name(expected_class_name) 

137 

138 

139def is_docker_service_healthy( 

140 compose_project_name: str, 

141 service_name: str, 

142) -> bool: 

143 """Check if a docker service is healthy.""" 

144 docker_ps_out = run( 

145 f"docker compose -p {compose_project_name} " f"ps --format json {service_name}", 

146 shell=True, 

147 check=True, 

148 capture_output=True, 

149 ) 

150 docker_ps_json = json.loads(docker_ps_out.stdout.decode().strip()) 

151 state = docker_ps_json["State"] 

152 assert isinstance(state, str) 

153 health = docker_ps_json["Health"] 

154 assert isinstance(health, str) 

155 return state == "running" and health == "healthy" 

156 

157 

158def wait_docker_service_healthy( 

159 docker_services: DockerServices, 

160 project_name: str, 

161 service_name: str, 

162 timeout: float = 60.0, 

163) -> None: 

164 """Wait until a docker service is healthy.""" 

165 docker_services.wait_until_responsive( 

166 check=lambda: is_docker_service_healthy(project_name, service_name), 

167 timeout=timeout, 

168 pause=0.5, 

169 ) 

170 

171 

172def wait_docker_service_socket(docker_services: DockerServices, hostname: str, port: int) -> None: 

173 """Wait until a docker service is ready.""" 

174 docker_services.wait_until_responsive( 

175 check=lambda: check_socket(hostname, port), 

176 timeout=60.0, 

177 pause=0.5, 

178 ) 

179 

180 

181def check_socket(host: str, port: int, timeout: float = 1.0) -> bool: 

182 """ 

183 Test to see if a socket is open. 

184 

185 Parameters 

186 ---------- 

187 host : str 

188 port : int 

189 timeout: float 

190 

191 Returns 

192 ------- 

193 bool 

194 """ 

195 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 

196 sock.settimeout(timeout) # seconds 

197 result = sock.connect_ex((host, port)) 

198 return result == 0 

199 

200 

201def resolve_host_name(host: str) -> str | None: 

202 """ 

203 Resolves the host name to an IP address. 

204 

205 Parameters 

206 ---------- 

207 host : str 

208 

209 Returns 

210 ------- 

211 str 

212 """ 

213 try: 

214 return socket.gethostbyname(host) 

215 except socket.gaierror: 

216 return None 

217 

218 

219def are_dir_trees_equal(dir1: str, dir2: str) -> bool: 

220 """ 

221 Compare two directories recursively. Files in each directory are assumed to be equal 

222 if their names and contents are equal. 

223 

224 @param dir1: First directory path @param dir2: Second directory path 

225 

226 @return: True if the directory trees are the same and there were no errors while 

227 accessing the directories or files, False otherwise. 

228 """ 

229 # See Also: https://stackoverflow.com/a/6681395 

230 dirs_cmp = filecmp.dircmp(dir1, dir2) 

231 if ( 

232 len(dirs_cmp.left_only) > 0 

233 or len(dirs_cmp.right_only) > 0 

234 or len(dirs_cmp.funny_files) > 0 

235 ): 

236 warn( 

237 UserWarning( 

238 f"Found differences in dir trees {dir1}, {dir2}:\n" 

239 f"{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}" 

240 ) 

241 ) 

242 return False 

243 (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False) 

244 if len(mismatch) > 0 or len(errors) > 0: 

245 warn(f"Found differences in files:\n{mismatch}\n{errors}", UserWarning) 

246 return False 

247 for common_dir in dirs_cmp.common_dirs: 

248 new_dir1 = os.path.join(dir1, common_dir) 

249 new_dir2 = os.path.join(dir2, common_dir) 

250 if not are_dir_trees_equal(new_dir1, new_dir2): 

251 return False 

252 return True