import os
import sys
import unittest
import warnings
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from typing import Any, Callable, List, Optional, Tuple, Union
import numpy
from numpy.testing import assert_allclose
def is_azure() -> bool:
"Tells if the job is running on Azure DevOps."
return os.environ.get("AZURE_HTTP_USER_AGENT", "undefined") != "undefined"
def is_windows() -> bool:
return sys.platform == "win32"
def is_apple() -> bool:
return sys.platform == "darwin"
def skipif_ci_windows(msg) -> Callable:
Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.
if is_windows() and is_azure():
msg = f"Test does not work on azure pipeline (Windows). {msg}"
return unittest.skip(msg)
return lambda x: x
def skipif_ci_apple(msg) -> Callable:
Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.
if is_apple() and is_azure():
msg = f"Test does not work on azure pipeline (Apple). {msg}"
return unittest.skip(msg)
return lambda x: x
def with_path_append(path_to_add: Union[str, List[str]]) -> Callable:
Adds a path to sys.path to check.
def wraps(f, path_to_add=path_to_add):
def wrapped(self, path_to_add=path_to_add):
cpy = sys.path.copy()
if path_to_add is not None:
if isinstance(path_to_add, str):
path_to_add = [path_to_add]
sys.path = cpy
return wrapped
return wraps
def unit_test_going():
Enables a flag telling the script is running while testing it.
Avoids unit tests to be very long if used.
going = int(os.environ.get("UNITTEST_GOING", 0))
return going == 1
def ignore_warnings(warns: List[Warning]) -> Callable:
Catches warnings.
:param warns: warnings to ignore
def wrapper(fct):
if warns is None:
raise AssertionError(f"warns cannot be None for '{fct}'.")
def call_f(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", warns)
return fct(self)
return call_f
return wrapper
class ExtTestCase(unittest.TestCase):
_warns: List[Tuple[str, int, Warning]] = []
def assertExists(self, name):
if not os.path.exists(name):
raise AssertionError(f"File or folder {name!r} does not exists.")
def assertGreaterOrEqual(self, a, b):
if a < b:
return AssertionError(f"{a} < {b}, a not greater or equal than b.")
def assertEqualArray(
expected: numpy.ndarray,
value: numpy.ndarray,
atol: float = 0,
rtol: float = 0,
msg: Optional[str] = None,
self.assertEqual(expected.dtype, value.dtype)
self.assertEqual(expected.shape, value.shape)
assert_allclose(expected, value, atol=atol, rtol=rtol)
except AssertionError as e:
raise AssertionError(msg) from e
def assertAlmostEqual(
expected: numpy.ndarray,
value: numpy.ndarray,
atol: float = 0,
rtol: float = 0,
if not isinstance(expected, numpy.ndarray):
expected = numpy.array(expected)
if not isinstance(value, numpy.ndarray):
value = numpy.array(value).astype(expected.dtype)
self.assertEqualArray(expected, value, atol=atol, rtol=rtol)
def assertRaise(self, fct: Callable, exc_type: type[Exception]):
except exc_type as e:
if not isinstance(e, exc_type):
raise AssertionError(f"Unexpected exception {type(e)!r}.")
raise AssertionError("No exception was raised.")
def assertEmpty(self, value: Any):
if value is None:
if not value:
raise AssertionError(f"value is not empty: {value!r}.")
def assertNotEmpty(self, value: Any):
if value is None:
raise AssertionError(f"value is empty: {value!r}.")
if isinstance(value, (list, dict, tuple, set)):
if not value:
raise AssertionError(f"value is empty: {value!r}.")
def assertStartsWith(self, prefix: str, full: str):
if not full.startswith(prefix):
raise AssertionError(f"prefix={prefix!r} does not start string {full!r}.")
def tearDownClass(cls):
for name, line, w in cls._warns:
warnings.warn(f"\n{name}:{line}: {type(w)}\n {str(w)}")
def capture(self, fct: Callable):
Runs a function and capture standard output and error.
:param fct: function to run
:return: result of *fct*, output, error
sout = StringIO()
serr = StringIO()
with redirect_stdout(sout):
with redirect_stderr(serr):
res = fct()
except Exception as e:
raise AssertionError(
f"function {fct} failed, stdout="
) from e
return res, sout.getvalue(), serr.getvalue()
def tryCall(
self, fct: Callable, msg: Optional[str] = None, none_if: Optional[str] = None
) -> Optional[Any]:
Calls the function, catch any error.
:param fct: function to call
:param msg: error message to display if failing
:param none_if: returns None if this substring is found in the error message
:return: output of *fct*
return fct()
except Exception as e:
if none_if is not None and none_if in str(e):
return None
if msg is None:
raise e
raise AssertionError(msg) from e
def get_figure(ax: Any) -> Any:
Returns the figure of a matplotlib figure.
if hasattr(ax, "get_figure"):
return ax.get_figure()
if len(ax.shape) == 0:
return ax.get_figure()
if len(ax.shape) == 1:
return ax[0].get_figure()
if len(ax.shape) == 2:
return ax[0, 0].get_figure()
raise RuntimeError(f"Unexpected shape {ax.shape} for axis.")