Coverage for mlos_core/mlos_core/tests/__init__.py: 95%

20 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Common functions for mlos_core Optimizer tests.""" 

6 

7import sys 

8from importlib import import_module 

9from pkgutil import walk_packages 

10from typing import List, Optional, Set, Type, TypeVar 

11 

12# A common seed to use to avoid tracking down race conditions and intermingling 

13# issues of seeds across tests that run in non-deterministic parallel orders. 

14SEED = 42 

15 

16if sys.version_info >= (3, 10): 

17 from typing import TypeAlias 

18else: 

19 from typing_extensions import TypeAlias 

20 

21 

22T = TypeVar("T") 

23 

24 

25def get_all_submodules(pkg: TypeAlias) -> List[str]: 

26 """ 

27 Imports all submodules for a package and returns their names. 

28 

29 Useful for dynamically enumerating subclasses. 

30 """ 

31 submodules = [] 

32 for _, submodule_name, _ in walk_packages( 

33 pkg.__path__, prefix=f"{pkg.__name__}.", onerror=lambda x: None 

34 ): 

35 submodules.append(submodule_name) 

36 return submodules 

37 

38 

39def _get_all_subclasses(cls: Type[T]) -> Set[Type[T]]: 

40 """ 

41 Gets the set of all of the subclasses of the given class. 

42 

43 Useful for dynamically enumerating expected test cases. 

44 """ 

45 return set(cls.__subclasses__()).union( 

46 s for c in cls.__subclasses__() for s in _get_all_subclasses(c) 

47 ) 

48 

49 

50def get_all_concrete_subclasses(cls: Type[T], pkg_name: Optional[str] = None) -> List[Type[T]]: 

51 """ 

52 Gets a sorted list of all of the concrete subclasses of the given class. Useful for 

53 dynamically enumerating expected test cases. 

54 

55 Note: For abstract types, mypy will complain at the call site. 

56 Use "# type: ignore[type-abstract]" to suppress the warning. 

57 See Also: https://github.com/python/mypy/issues/4717 

58 """ 

59 if pkg_name is not None: 

60 pkg = import_module(pkg_name) 

61 submodules = get_all_submodules(pkg) 

62 assert submodules 

63 return sorted( 

64 [ 

65 subclass 

66 for subclass in _get_all_subclasses(cls) 

67 if not getattr(subclass, "__abstractmethods__", None) 

68 ], 

69 key=lambda c: (c.__module__, c.__name__), 

70 )