Coverage for mlos_core/mlos_core/tests/__init__.py: 95%

21 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Common functions for mlos_core Optimizer tests. 

7""" 

8 

9import sys 

10 

11from importlib import import_module 

12from pkgutil import walk_packages 

13from typing import List, Optional, Set, Type, TypeVar 

14 

15# A common seed to use to avoid tracking down race conditions and intermingling 

16# issues of seeds across tests that run in non-deterministic parallel orders. 

17SEED = 42 

18 

19if sys.version_info >= (3, 10): 

20 from typing import TypeAlias 

21else: 

22 from typing_extensions import TypeAlias 

23 

24 

25T = TypeVar('T') 

26 

27 

28def get_all_submodules(pkg: TypeAlias) -> List[str]: 

29 """ 

30 Imports all submodules for a package and returns their names. 

31 Useful for dynamically enumerating subclasses. 

32 """ 

33 submodules = [] 

34 for _, submodule_name, _ in walk_packages(pkg.__path__, prefix=f"{pkg.__name__}.", onerror=lambda x: None): 

35 submodules.append(submodule_name) 

36 return submodules 

37 

38 

39def _get_all_subclasses(cls: Type[T]) -> Set[Type[T]]: 

40 """ 

41 Gets the set of all of the subclasses of the given class. 

42 Useful for dynamically enumerating expected test cases. 

43 """ 

44 return set(cls.__subclasses__()).union( 

45 s for c in cls.__subclasses__() for s in _get_all_subclasses(c)) 

46 

47 

48def get_all_concrete_subclasses(cls: Type[T], pkg_name: Optional[str] = None) -> List[Type[T]]: 

49 """ 

50 Gets a sorted list of all of the concrete subclasses of the given class. 

51 Useful for dynamically enumerating expected test cases. 

52 

53 Note: For abstract types, mypy will complain at the call site. 

54 Use "# type: ignore[type-abstract]" to suppress the warning. 

55 See Also: https://github.com/python/mypy/issues/4717 

56 """ 

57 if pkg_name is not None: 

58 pkg = import_module(pkg_name) 

59 submodules = get_all_submodules(pkg) 

60 assert submodules 

61 return sorted([subclass for subclass in _get_all_subclasses(cls) if not getattr(subclass, "__abstractmethods__", None)], 

62 key=lambda c: (c.__module__, c.__name__))