Source code for onnxrt_backend_dev.ext_test_case

import os
import sys
import unittest
import warnings
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from typing import Any, Callable, List, Optional, Tuple, Union

import numpy
from numpy.testing import assert_allclose


def is_azure() -> bool:
    "Tells if the job is running on Azure DevOps."
    return os.environ.get("AZURE_HTTP_USER_AGENT", "undefined") != "undefined"


def is_windows() -> bool:
    return sys.platform == "win32"


def is_apple() -> bool:
    return sys.platform == "darwin"


def skipif_ci_windows(msg) -> Callable:
    """
    Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.
    """
    if is_windows() and is_azure():
        msg = f"Test does not work on azure pipeline (Windows). {msg}"
        return unittest.skip(msg)
    return lambda x: x


def skipif_ci_apple(msg) -> Callable:
    """
    Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.
    """
    if is_apple() and is_azure():
        msg = f"Test does not work on azure pipeline (Apple). {msg}"
        return unittest.skip(msg)
    return lambda x: x


def with_path_append(path_to_add: Union[str, List[str]]) -> Callable:
    """
    Adds a path to sys.path to check.
    """

    def wraps(f, path_to_add=path_to_add):
        def wrapped(self, path_to_add=path_to_add):
            cpy = sys.path.copy()
            if path_to_add is not None:
                if isinstance(path_to_add, str):
                    path_to_add = [path_to_add]
                sys.path.extend(path_to_add)
            f(self)
            sys.path = cpy

        return wrapped

    return wraps


def unit_test_going():
    """
    Enables a flag telling the script is running while testing it.
    Avoids unit tests to be very long if used.
    """
    going = int(os.environ.get("UNITTEST_GOING", 0))
    return going == 1


def ignore_warnings(warns: List[Warning]) -> Callable:
    """
    Catches warnings.

    :param warns:   warnings to ignore
    """

    def wrapper(fct):
        if warns is None:
            raise AssertionError(f"warns cannot be None for '{fct}'.")

        def call_f(self):
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", warns)
                return fct(self)

        return call_f

    return wrapper


[docs] class ExtTestCase(unittest.TestCase): _warns: List[Tuple[str, int, Warning]] = [] def assertExists(self, name): if not os.path.exists(name): raise AssertionError(f"File or folder {name!r} does not exists.") def assertGreaterOrEqual(self, a, b): if a < b: return AssertionError(f"{a} < {b}, a not greater or equal than b.") def assertEqualArray( self, expected: numpy.ndarray, value: numpy.ndarray, atol: float = 0, rtol: float = 0, msg: Optional[str] = None, ): self.assertEqual(expected.dtype, value.dtype) self.assertEqual(expected.shape, value.shape) try: assert_allclose(expected, value, atol=atol, rtol=rtol) except AssertionError as e: raise AssertionError(msg) from e
[docs] def assertAlmostEqual( self, expected: numpy.ndarray, value: numpy.ndarray, atol: float = 0, rtol: float = 0, ): if not isinstance(expected, numpy.ndarray): expected = numpy.array(expected) if not isinstance(value, numpy.ndarray): value = numpy.array(value).astype(expected.dtype) self.assertEqualArray(expected, value, atol=atol, rtol=rtol)
def assertRaise(self, fct: Callable, exc_type: type[Exception]): try: fct() except exc_type as e: if not isinstance(e, exc_type): raise AssertionError(f"Unexpected exception {type(e)!r}.") return raise AssertionError("No exception was raised.") def assertEmpty(self, value: Any): if value is None: return if not value: return raise AssertionError(f"value is not empty: {value!r}.") def assertNotEmpty(self, value: Any): if value is None: raise AssertionError(f"value is empty: {value!r}.") if isinstance(value, (list, dict, tuple, set)): if not value: raise AssertionError(f"value is empty: {value!r}.") def assertStartsWith(self, prefix: str, full: str): if not full.startswith(prefix): raise AssertionError(f"prefix={prefix!r} does not start string {full!r}.")
[docs] @classmethod def tearDownClass(cls): for name, line, w in cls._warns: warnings.warn(f"\n{name}:{line}: {type(w)}\n {str(w)}")
[docs] def capture(self, fct: Callable): """ Runs a function and capture standard output and error. :param fct: function to run :return: result of *fct*, output, error """ sout = StringIO() serr = StringIO() with redirect_stdout(sout): with redirect_stderr(serr): try: res = fct() except Exception as e: raise AssertionError( f"function {fct} failed, stdout=" f"\n{sout.getvalue()}\n---\nstderr=\n{serr.getvalue()}" ) from e return res, sout.getvalue(), serr.getvalue()
[docs] def tryCall( self, fct: Callable, msg: Optional[str] = None, none_if: Optional[str] = None ) -> Optional[Any]: """ Calls the function, catch any error. :param fct: function to call :param msg: error message to display if failing :param none_if: returns None if this substring is found in the error message :return: output of *fct* """ try: return fct() except Exception as e: if none_if is not None and none_if in str(e): return None if msg is None: raise e raise AssertionError(msg) from e
def get_figure(ax: Any) -> Any: """ Returns the figure of a matplotlib figure. """ if hasattr(ax, "get_figure"): return ax.get_figure() if len(ax.shape) == 0: return ax.get_figure() if len(ax.shape) == 1: return ax[0].get_figure() if len(ax.shape) == 2: return ax[0, 0].get_figure() raise RuntimeError(f"Unexpected shape {ax.shape} for axis.")