Coverage for mlos_core/mlos_core/tests/__init__.py: 95%
20 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Common functions for mlos_core Optimizer tests."""
7import sys
8from importlib import import_module
9from pkgutil import walk_packages
10from typing import List, Optional, Set, Type, TypeVar
12# A common seed to use to avoid tracking down race conditions and intermingling
13# issues of seeds across tests that run in non-deterministic parallel orders.
14SEED = 42
16if sys.version_info >= (3, 10):
17 from typing import TypeAlias
18else:
19 from typing_extensions import TypeAlias
22T = TypeVar("T")
25def get_all_submodules(pkg: TypeAlias) -> List[str]:
26 """
27 Imports all submodules for a package and returns their names.
29 Useful for dynamically enumerating subclasses.
30 """
31 submodules = []
32 for _, submodule_name, _ in walk_packages(
33 pkg.__path__, prefix=f"{pkg.__name__}.", onerror=lambda x: None
34 ):
35 submodules.append(submodule_name)
36 return submodules
39def _get_all_subclasses(cls: Type[T]) -> Set[Type[T]]:
40 """
41 Gets the set of all of the subclasses of the given class.
43 Useful for dynamically enumerating expected test cases.
44 """
45 return set(cls.__subclasses__()).union(
46 s for c in cls.__subclasses__() for s in _get_all_subclasses(c)
47 )
50def get_all_concrete_subclasses(cls: Type[T], pkg_name: Optional[str] = None) -> List[Type[T]]:
51 """
52 Gets a sorted list of all of the concrete subclasses of the given class. Useful for
53 dynamically enumerating expected test cases.
55 Note: For abstract types, mypy will complain at the call site.
56 Use "# type: ignore[type-abstract]" to suppress the warning.
57 See Also: https://github.com/python/mypy/issues/4717
58 """
59 if pkg_name is not None:
60 pkg = import_module(pkg_name)
61 submodules = get_all_submodules(pkg)
62 assert submodules
63 return sorted(
64 [
65 subclass
66 for subclass in _get_all_subclasses(cls)
67 if not getattr(subclass, "__abstractmethods__", None)
68 ],
69 key=lambda c: (c.__module__, c.__name__),
70 )