Coverage for mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py: 100%

53 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tests for optimizer schema validation.""" 

6 

7from os import path 

8from typing import Optional 

9 

10import pytest 

11 

12from mlos_bench.config.schemas import ConfigSchema 

13from mlos_bench.optimizers.base_optimizer import Optimizer 

14from mlos_bench.tests import try_resolve_class_name 

15from mlos_bench.tests.config.schemas import ( 

16 check_test_case_against_schema, 

17 check_test_case_config_with_extra_param, 

18 get_schema_test_cases, 

19) 

20from mlos_core.optimizers import OptimizerType 

21from mlos_core.spaces.adapters import SpaceAdapterType 

22from mlos_core.tests import get_all_concrete_subclasses 

23 

24# General testing strategy: 

25# - hand code a set of good/bad configs (useful to test editor schema checking) 

26# - enumerate and try to check that we've covered all the cases 

27# - for each config, load and validate against expected schema 

28 

29TEST_CASES = get_schema_test_cases(path.join(path.dirname(__file__), "test-cases")) 

30 

31 

32# Dynamically enumerate some of the cases we want to make sure we cover. 

33 

34expected_mlos_bench_optimizer_class_names = [ 

35 subclass.__module__ + "." + subclass.__name__ 

36 for subclass in get_all_concrete_subclasses( 

37 Optimizer, # type: ignore[type-abstract] 

38 pkg_name="mlos_bench", 

39 ) 

40] 

41assert expected_mlos_bench_optimizer_class_names 

42 

43# Also make sure that we check for configs where the optimizer_type or 

44# space_adapter_type are left unspecified (None). 

45 

46expected_mlos_core_optimizer_types = list(OptimizerType) + [None] 

47assert expected_mlos_core_optimizer_types 

48 

49expected_mlos_core_space_adapter_types = list(SpaceAdapterType) + [None] 

50assert expected_mlos_core_space_adapter_types 

51 

52 

53# Do the full cross product of all the test cases and all the optimizer types. 

54@pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) 

55@pytest.mark.parametrize("mlos_bench_optimizer_type", expected_mlos_bench_optimizer_class_names) 

56def test_case_coverage_mlos_bench_optimizer_type( 

57 test_case_subtype: str, 

58 mlos_bench_optimizer_type: str, 

59) -> None: 

60 """Checks to see if there is a given type of test case for the given mlos_bench 

61 optimizer type. 

62 """ 

63 for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): 

64 if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_optimizer_type: 

65 return 

66 raise NotImplementedError( 

67 f"Missing test case for subtype {test_case_subtype} " 

68 f"for Optimizer class {mlos_bench_optimizer_type}" 

69 ) 

70 

71 

72# Being a little lazy for the moment and relaxing the requirement that we have 

73# a subtype test case for each optimizer and space adapter combo. 

74 

75 

76@pytest.mark.parametrize("test_case_type", sorted(TEST_CASES.by_type)) 

77# @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) 

78@pytest.mark.parametrize("mlos_core_optimizer_type", expected_mlos_core_optimizer_types) 

79def test_case_coverage_mlos_core_optimizer_type( 

80 test_case_type: str, 

81 mlos_core_optimizer_type: Optional[OptimizerType], 

82) -> None: 

83 """Checks to see if there is a given type of test case for the given mlos_core 

84 optimizer type. 

85 """ 

86 optimizer_name = None if mlos_core_optimizer_type is None else mlos_core_optimizer_type.name 

87 for test_case in TEST_CASES.by_type[test_case_type].values(): 

88 if ( 

89 try_resolve_class_name(test_case.config.get("class")) 

90 == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer" 

91 ): 

92 optimizer_type = None 

93 if test_case.config.get("config"): 

94 optimizer_type = test_case.config["config"].get("optimizer_type", None) 

95 if optimizer_type == optimizer_name: 

96 return 

97 raise NotImplementedError( 

98 f"Missing test case for type {test_case_type} " 

99 f"for MlosCore Optimizer type {mlos_core_optimizer_type}" 

100 ) 

101 

102 

103@pytest.mark.parametrize("test_case_type", sorted(TEST_CASES.by_type)) 

104# @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) 

105@pytest.mark.parametrize("mlos_core_space_adapter_type", expected_mlos_core_space_adapter_types) 

106def test_case_coverage_mlos_core_space_adapter_type( 

107 test_case_type: str, 

108 mlos_core_space_adapter_type: Optional[SpaceAdapterType], 

109) -> None: 

110 """Checks to see if there is a given type of test case for the given mlos_core space 

111 adapter type. 

112 """ 

113 space_adapter_name = ( 

114 None if mlos_core_space_adapter_type is None else mlos_core_space_adapter_type.name 

115 ) 

116 for test_case in TEST_CASES.by_type[test_case_type].values(): 

117 if ( 

118 try_resolve_class_name(test_case.config.get("class")) 

119 == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer" 

120 ): 

121 space_adapter_type = None 

122 if test_case.config.get("config"): 

123 space_adapter_type = test_case.config["config"].get("space_adapter_type", None) 

124 if space_adapter_type == space_adapter_name: 

125 return 

126 raise NotImplementedError( 

127 f"Missing test case for type {test_case_type} " 

128 f"for SpaceAdapter type {mlos_core_space_adapter_type}" 

129 ) 

130 

131 

132# Now we actually perform all of those validation tests. 

133 

134 

135@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) 

136def test_optimizer_configs_against_schema(test_case_name: str) -> None: 

137 """Checks that the optimizer config validates against the schema.""" 

138 check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.OPTIMIZER) 

139 check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) 

140 

141 

142@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) 

143def test_optimizer_configs_with_extra_param(test_case_name: str) -> None: 

144 """Checks that the optimizer config fails to validate if extra params are present in 

145 certain places. 

146 """ 

147 check_test_case_config_with_extra_param( 

148 TEST_CASES.by_type["good"][test_case_name], 

149 ConfigSchema.OPTIMIZER, 

150 ) 

151 check_test_case_config_with_extra_param( 

152 TEST_CASES.by_type["good"][test_case_name], 

153 ConfigSchema.UNIFIED, 

154 )