Coverage for mlos_bench/mlos_bench/tests/__init__.py: 85%
61 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Tests for mlos_bench.
7Used to make mypy happy about multiple conftest.py modules.
8"""
9from datetime import tzinfo
10from logging import debug, warning
11from subprocess import run
12from typing import List, Optional
14import filecmp
15import os
16import socket
17import shutil
19import pytz
20import pytest
22from mlos_bench.util import get_class_from_name, nullable
25ZONE_NAMES = [
26 # Explicit time zones.
27 "UTC",
28 "America/Chicago",
29 "America/Los_Angeles",
30 # Implicit local time zone.
31 None,
32]
33ZONE_INFO: List[Optional[tzinfo]] = [
34 nullable(pytz.timezone, zone_name)
35 for zone_name in ZONE_NAMES
36]
39# A decorator for tests that require docker.
40# Use with @requires_docker above a test_...() function.
41DOCKER = shutil.which('docker')
42if DOCKER:
43 cmd = run("docker builder inspect default || docker buildx inspect default", shell=True, check=False, capture_output=True)
44 stdout = cmd.stdout.decode()
45 if cmd.returncode != 0 or not any(line for line in stdout.splitlines() if 'Platform' in line and 'linux' in line):
46 debug("Docker is available but missing support for targeting linux platform.")
47 DOCKER = None
48requires_docker = pytest.mark.skipif(not DOCKER, reason='Docker with Linux support is not available on this system.')
50# A decorator for tests that require ssh.
51# Use with @requires_ssh above a test_...() function.
52SSH = shutil.which('ssh')
53requires_ssh = pytest.mark.skipif(not SSH, reason='ssh is not available on this system.')
55# A common seed to use to avoid tracking down race conditions and intermingling
56# issues of seeds across tests that run in non-deterministic parallel orders.
57SEED = 42
59# import numpy as np
60# np.random.seed(SEED)
63def try_resolve_class_name(class_name: Optional[str]) -> Optional[str]:
64 """
65 Gets the full class name from the given name or None on error.
66 """
67 if class_name is None:
68 return None
69 try:
70 the_class = get_class_from_name(class_name)
71 return the_class.__module__ + "." + the_class.__name__
72 except (ValueError, AttributeError, ModuleNotFoundError, ImportError):
73 return None
76def check_class_name(obj: object, expected_class_name: str) -> bool:
77 """
78 Compares the class name of the given object with the given name.
79 """
80 full_class_name = obj.__class__.__module__ + "." + obj.__class__.__name__
81 return full_class_name == try_resolve_class_name(expected_class_name)
84def check_socket(host: str, port: int, timeout: float = 1.0) -> bool:
85 """
86 Test to see if a socket is open.
88 Parameters
89 ----------
90 host : str
91 port : int
92 timeout: float
94 Returns
95 -------
96 bool
97 """
98 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
99 sock.settimeout(timeout) # seconds
100 result = sock.connect_ex((host, port))
101 return result == 0
104def resolve_host_name(host: str) -> Optional[str]:
105 """
106 Resolves the host name to an IP address.
108 Parameters
109 ----------
110 host : str
112 Returns
113 -------
114 str
115 """
116 try:
117 return socket.gethostbyname(host)
118 except socket.gaierror:
119 return None
122def are_dir_trees_equal(dir1: str, dir2: str) -> bool:
123 """
124 Compare two directories recursively. Files in each directory are
125 assumed to be equal if their names and contents are equal.
127 @param dir1: First directory path
128 @param dir2: Second directory path
130 @return: True if the directory trees are the same and
131 there were no errors while accessing the directories or files,
132 False otherwise.
133 """
134 # See Also: https://stackoverflow.com/a/6681395
135 dirs_cmp = filecmp.dircmp(dir1, dir2)
136 if len(dirs_cmp.left_only) > 0 or len(dirs_cmp.right_only) > 0 or len(dirs_cmp.funny_files) > 0:
137 warning(f"Found differences in dir trees {dir1}, {dir2}:\n{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}")
138 return False
139 (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False)
140 if len(mismatch) > 0 or len(errors) > 0:
141 warning(f"Found differences in files:\n{mismatch}\n{errors}")
142 return False
143 for common_dir in dirs_cmp.common_dirs:
144 new_dir1 = os.path.join(dir1, common_dir)
145 new_dir2 = os.path.join(dir2, common_dir)
146 if not are_dir_trees_equal(new_dir1, new_dir2):
147 return False
148 return True