Coverage for mlos_bench/mlos_bench/tests/__init__.py: 85%

61 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Tests for mlos_bench. 

7Used to make mypy happy about multiple conftest.py modules. 

8""" 

9from datetime import tzinfo 

10from logging import debug, warning 

11from subprocess import run 

12from typing import List, Optional 

13 

14import filecmp 

15import os 

16import socket 

17import shutil 

18 

19import pytz 

20import pytest 

21 

22from mlos_bench.util import get_class_from_name, nullable 

23 

24 

25ZONE_NAMES = [ 

26 # Explicit time zones. 

27 "UTC", 

28 "America/Chicago", 

29 "America/Los_Angeles", 

30 # Implicit local time zone. 

31 None, 

32] 

33ZONE_INFO: List[Optional[tzinfo]] = [ 

34 nullable(pytz.timezone, zone_name) 

35 for zone_name in ZONE_NAMES 

36] 

37 

38 

39# A decorator for tests that require docker. 

40# Use with @requires_docker above a test_...() function. 

41DOCKER = shutil.which('docker') 

42if DOCKER: 

43 cmd = run("docker builder inspect default || docker buildx inspect default", shell=True, check=False, capture_output=True) 

44 stdout = cmd.stdout.decode() 

45 if cmd.returncode != 0 or not any(line for line in stdout.splitlines() if 'Platform' in line and 'linux' in line): 

46 debug("Docker is available but missing support for targeting linux platform.") 

47 DOCKER = None 

48requires_docker = pytest.mark.skipif(not DOCKER, reason='Docker with Linux support is not available on this system.') 

49 

50# A decorator for tests that require ssh. 

51# Use with @requires_ssh above a test_...() function. 

52SSH = shutil.which('ssh') 

53requires_ssh = pytest.mark.skipif(not SSH, reason='ssh is not available on this system.') 

54 

55# A common seed to use to avoid tracking down race conditions and intermingling 

56# issues of seeds across tests that run in non-deterministic parallel orders. 

57SEED = 42 

58 

59# import numpy as np 

60# np.random.seed(SEED) 

61 

62 

63def try_resolve_class_name(class_name: Optional[str]) -> Optional[str]: 

64 """ 

65 Gets the full class name from the given name or None on error. 

66 """ 

67 if class_name is None: 

68 return None 

69 try: 

70 the_class = get_class_from_name(class_name) 

71 return the_class.__module__ + "." + the_class.__name__ 

72 except (ValueError, AttributeError, ModuleNotFoundError, ImportError): 

73 return None 

74 

75 

76def check_class_name(obj: object, expected_class_name: str) -> bool: 

77 """ 

78 Compares the class name of the given object with the given name. 

79 """ 

80 full_class_name = obj.__class__.__module__ + "." + obj.__class__.__name__ 

81 return full_class_name == try_resolve_class_name(expected_class_name) 

82 

83 

84def check_socket(host: str, port: int, timeout: float = 1.0) -> bool: 

85 """ 

86 Test to see if a socket is open. 

87 

88 Parameters 

89 ---------- 

90 host : str 

91 port : int 

92 timeout: float 

93 

94 Returns 

95 ------- 

96 bool 

97 """ 

98 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 

99 sock.settimeout(timeout) # seconds 

100 result = sock.connect_ex((host, port)) 

101 return result == 0 

102 

103 

104def resolve_host_name(host: str) -> Optional[str]: 

105 """ 

106 Resolves the host name to an IP address. 

107 

108 Parameters 

109 ---------- 

110 host : str 

111 

112 Returns 

113 ------- 

114 str 

115 """ 

116 try: 

117 return socket.gethostbyname(host) 

118 except socket.gaierror: 

119 return None 

120 

121 

122def are_dir_trees_equal(dir1: str, dir2: str) -> bool: 

123 """ 

124 Compare two directories recursively. Files in each directory are 

125 assumed to be equal if their names and contents are equal. 

126 

127 @param dir1: First directory path 

128 @param dir2: Second directory path 

129 

130 @return: True if the directory trees are the same and 

131 there were no errors while accessing the directories or files, 

132 False otherwise. 

133 """ 

134 # See Also: https://stackoverflow.com/a/6681395 

135 dirs_cmp = filecmp.dircmp(dir1, dir2) 

136 if len(dirs_cmp.left_only) > 0 or len(dirs_cmp.right_only) > 0 or len(dirs_cmp.funny_files) > 0: 

137 warning(f"Found differences in dir trees {dir1}, {dir2}:\n{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}") 

138 return False 

139 (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False) 

140 if len(mismatch) > 0 or len(errors) > 0: 

141 warning(f"Found differences in files:\n{mismatch}\n{errors}") 

142 return False 

143 for common_dir in dirs_cmp.common_dirs: 

144 new_dir1 = os.path.join(dir1, common_dir) 

145 new_dir2 = os.path.join(dir2, common_dir) 

146 if not are_dir_trees_equal(new_dir1, new_dir2): 

147 return False 

148 return True