Coverage for mlos_bench/mlos_bench/tests/__init__.py: 88%

74 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-14 00:55 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Tests for mlos_bench. 

7 

8Used to make mypy happy about multiple conftest.py modules. 

9""" 

10import filecmp 

11import json 

12import os 

13import shutil 

14import socket 

15from datetime import tzinfo 

16from logging import debug, warning 

17from subprocess import run 

18 

19import pytest 

20import pytz 

21from pytest_docker.plugin import Services as DockerServices 

22 

23from mlos_bench.util import get_class_from_name, nullable 

24 

25ZONE_NAMES = [ 

26 # Explicit time zones. 

27 "UTC", 

28 "America/Chicago", 

29 "America/Los_Angeles", 

30 # Implicit local time zone. 

31 None, 

32] 

33ZONE_INFO: list[tzinfo | None] = [nullable(pytz.timezone, zone_name) for zone_name in ZONE_NAMES] 

34 

35BUILT_IN_ENV_VAR_DEFAULTS = { 

36 "experiment_id": None, 

37 "trial_id": None, 

38 "trial_runner_id": None, 

39} 

40 

41# A decorator for tests that require docker. 

42# Use with @requires_docker above a test_...() function. 

43DOCKER = shutil.which("docker") 

44if DOCKER: 

45 cmd = run( 

46 "docker builder inspect default || docker buildx inspect default", 

47 shell=True, 

48 check=False, 

49 capture_output=True, 

50 ) 

51 stdout = cmd.stdout.decode() 

52 if cmd.returncode != 0 or not any( 

53 line for line in stdout.splitlines() if "Platform" in line and "linux" in line 

54 ): 

55 debug("Docker is available but missing support for targeting linux platform.") 

56 DOCKER = None 

57requires_docker = pytest.mark.skipif( 

58 not DOCKER, 

59 reason="Docker with Linux support is not available on this system.", 

60) 

61 

62# A decorator for tests that require ssh. 

63# Use with @requires_ssh above a test_...() function. 

64SSH = shutil.which("ssh") 

65requires_ssh = pytest.mark.skipif(not SSH, reason="ssh is not available on this system.") 

66 

67# A common seed to use to avoid tracking down race conditions and intermingling 

68# issues of seeds across tests that run in non-deterministic parallel orders. 

69SEED = 42 

70 

71# import numpy as np 

72# np.random.seed(SEED) 

73 

74 

75def try_resolve_class_name(class_name: str | None) -> str | None: 

76 """Gets the full class name from the given name or None on error.""" 

77 if class_name is None: 

78 return None 

79 try: 

80 the_class = get_class_from_name(class_name) 

81 return the_class.__module__ + "." + the_class.__name__ 

82 except (ValueError, AttributeError, ModuleNotFoundError, ImportError): 

83 return None 

84 

85 

86def check_class_name(obj: object, expected_class_name: str) -> bool: 

87 """Compares the class name of the given object with the given name.""" 

88 full_class_name = obj.__class__.__module__ + "." + obj.__class__.__name__ 

89 return full_class_name == try_resolve_class_name(expected_class_name) 

90 

91 

92def is_docker_service_healthy( 

93 compose_project_name: str, 

94 service_name: str, 

95) -> bool: 

96 """Check if a docker service is healthy.""" 

97 docker_ps_out = run( 

98 f"docker compose -p {compose_project_name} " f"ps --format json {service_name}", 

99 shell=True, 

100 check=True, 

101 capture_output=True, 

102 ) 

103 docker_ps_json = json.loads(docker_ps_out.stdout.decode().strip()) 

104 state = docker_ps_json["State"] 

105 assert isinstance(state, str) 

106 health = docker_ps_json["Health"] 

107 assert isinstance(health, str) 

108 return state == "running" and health == "healthy" 

109 

110 

111def wait_docker_service_healthy( 

112 docker_services: DockerServices, 

113 project_name: str, 

114 service_name: str, 

115 timeout: float = 30.0, 

116) -> None: 

117 """Wait until a docker service is healthy.""" 

118 docker_services.wait_until_responsive( 

119 check=lambda: is_docker_service_healthy(project_name, service_name), 

120 timeout=timeout, 

121 pause=0.5, 

122 ) 

123 

124 

125def wait_docker_service_socket(docker_services: DockerServices, hostname: str, port: int) -> None: 

126 """Wait until a docker service is ready.""" 

127 docker_services.wait_until_responsive( 

128 check=lambda: check_socket(hostname, port), 

129 timeout=30.0, 

130 pause=0.5, 

131 ) 

132 

133 

134def check_socket(host: str, port: int, timeout: float = 1.0) -> bool: 

135 """ 

136 Test to see if a socket is open. 

137 

138 Parameters 

139 ---------- 

140 host : str 

141 port : int 

142 timeout: float 

143 

144 Returns 

145 ------- 

146 bool 

147 """ 

148 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 

149 sock.settimeout(timeout) # seconds 

150 result = sock.connect_ex((host, port)) 

151 return result == 0 

152 

153 

154def resolve_host_name(host: str) -> str | None: 

155 """ 

156 Resolves the host name to an IP address. 

157 

158 Parameters 

159 ---------- 

160 host : str 

161 

162 Returns 

163 ------- 

164 str 

165 """ 

166 try: 

167 return socket.gethostbyname(host) 

168 except socket.gaierror: 

169 return None 

170 

171 

172def are_dir_trees_equal(dir1: str, dir2: str) -> bool: 

173 """ 

174 Compare two directories recursively. Files in each directory are assumed to be equal 

175 if their names and contents are equal. 

176 

177 @param dir1: First directory path @param dir2: Second directory path 

178 

179 @return: True if the directory trees are the same and there were no errors while 

180 accessing the directories or files, False otherwise. 

181 """ 

182 # See Also: https://stackoverflow.com/a/6681395 

183 dirs_cmp = filecmp.dircmp(dir1, dir2) 

184 if ( 

185 len(dirs_cmp.left_only) > 0 

186 or len(dirs_cmp.right_only) > 0 

187 or len(dirs_cmp.funny_files) > 0 

188 ): 

189 warning( 

190 f"Found differences in dir trees {dir1}, {dir2}:\n" 

191 f"{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}" 

192 ) 

193 return False 

194 (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False) 

195 if len(mismatch) > 0 or len(errors) > 0: 

196 warning(f"Found differences in files:\n{mismatch}\n{errors}") 

197 return False 

198 for common_dir in dirs_cmp.common_dirs: 

199 new_dir1 = os.path.join(dir1, common_dir) 

200 new_dir2 = os.path.join(dir2, common_dir) 

201 if not are_dir_trees_equal(new_dir1, new_dir2): 

202 return False 

203 return True