Coverage for mlos_core/mlos_core/tests/__init__.py: 95%
21 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Common functions for mlos_core Optimizer tests.
7"""
9import sys
11from importlib import import_module
12from pkgutil import walk_packages
13from typing import List, Optional, Set, Type, TypeVar
15# A common seed to use to avoid tracking down race conditions and intermingling
16# issues of seeds across tests that run in non-deterministic parallel orders.
17SEED = 42
19if sys.version_info >= (3, 10):
20 from typing import TypeAlias
21else:
22 from typing_extensions import TypeAlias
25T = TypeVar('T')
28def get_all_submodules(pkg: TypeAlias) -> List[str]:
29 """
30 Imports all submodules for a package and returns their names.
31 Useful for dynamically enumerating subclasses.
32 """
33 submodules = []
34 for _, submodule_name, _ in walk_packages(pkg.__path__, prefix=f"{pkg.__name__}.", onerror=lambda x: None):
35 submodules.append(submodule_name)
36 return submodules
39def _get_all_subclasses(cls: Type[T]) -> Set[Type[T]]:
40 """
41 Gets the set of all of the subclasses of the given class.
42 Useful for dynamically enumerating expected test cases.
43 """
44 return set(cls.__subclasses__()).union(
45 s for c in cls.__subclasses__() for s in _get_all_subclasses(c))
48def get_all_concrete_subclasses(cls: Type[T], pkg_name: Optional[str] = None) -> List[Type[T]]:
49 """
50 Gets a sorted list of all of the concrete subclasses of the given class.
51 Useful for dynamically enumerating expected test cases.
53 Note: For abstract types, mypy will complain at the call site.
54 Use "# type: ignore[type-abstract]" to suppress the warning.
55 See Also: https://github.com/python/mypy/issues/4717
56 """
57 if pkg_name is not None:
58 pkg = import_module(pkg_name)
59 submodules = get_all_submodules(pkg)
60 assert submodules
61 return sorted([subclass for subclass in _get_all_subclasses(cls) if not getattr(subclass, "__abstractmethods__", None)],
62 key=lambda c: (c.__module__, c.__name__))