Coverage for mlos_bench/mlos_bench/tests/config/schemas/__init__.py: 94%

77 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Common tests for config schemas and their validation and test cases.""" 

6 

7import os 

8from copy import deepcopy 

9from dataclasses import dataclass 

10from typing import Any, Dict, Set 

11 

12import json5 

13import jsonschema 

14import pytest 

15 

16from mlos_bench.config.schemas.config_schemas import ConfigSchema 

17from mlos_bench.tests.config import locate_config_examples 

18 

19 

20# A dataclass to make pylint happy. 

21@dataclass 

22class SchemaTestType: 

23 """The different type of schema test cases we expect to have.""" 

24 

25 test_case_type: str 

26 test_case_subtypes: Set[str] 

27 

28 def __hash__(self) -> int: 

29 return hash(self.test_case_type) 

30 

31 

32# The different type of schema test cases we expect to have. 

33_SCHEMA_TEST_TYPES = { 

34 x.test_case_type: x 

35 for x in ( 

36 SchemaTestType(test_case_type="good", test_case_subtypes={"full", "partial"}), 

37 SchemaTestType(test_case_type="bad", test_case_subtypes={"invalid", "unhandled"}), 

38 ) 

39} 

40 

41 

42@dataclass 

43class SchemaTestCaseInfo: 

44 """Some basic info about a schema test case.""" 

45 

46 config: Dict[str, Any] 

47 test_case_file: str 

48 test_case_type: str 

49 test_case_subtype: str 

50 

51 def __hash__(self) -> int: 

52 return hash(self.test_case_file) 

53 

54 

55def check_schema_dir_layout(test_cases_root: str) -> None: 

56 """Makes sure the directory layout matches what we expect so we aren't missing any 

57 extra configs or test cases. 

58 """ 

59 for test_case_dir in os.listdir(test_cases_root): 

60 if test_case_dir == "README.md": 

61 continue 

62 if test_case_dir not in _SCHEMA_TEST_TYPES: 

63 raise NotImplementedError(f"Unhandled test case type: {test_case_dir}") 

64 for test_case_subdir in os.listdir(os.path.join(test_cases_root, test_case_dir)): 

65 if test_case_subdir == "README.md": 

66 continue 

67 if test_case_subdir not in _SCHEMA_TEST_TYPES[test_case_dir].test_case_subtypes: 

68 raise NotImplementedError( 

69 f"Unhandled test case subtype {test_case_subdir} " 

70 f"for test case type {test_case_dir}" 

71 ) 

72 

73 

74@dataclass 

75class TestCases: 

76 """A container for test cases by type.""" 

77 

78 by_path: Dict[str, SchemaTestCaseInfo] 

79 by_type: Dict[str, Dict[str, SchemaTestCaseInfo]] 

80 by_subtype: Dict[str, Dict[str, SchemaTestCaseInfo]] 

81 

82 

83def get_schema_test_cases(test_cases_root: str) -> TestCases: 

84 """Gets a dict of schema test cases from the given root.""" 

85 test_cases = TestCases( 

86 by_path={}, 

87 by_type={x: {} for x in _SCHEMA_TEST_TYPES}, 

88 by_subtype={ 

89 y: {} for x in _SCHEMA_TEST_TYPES for y in _SCHEMA_TEST_TYPES[x].test_case_subtypes 

90 }, 

91 ) 

92 check_schema_dir_layout(test_cases_root) 

93 # Note: we sort the test cases so that we can deterministically test them in parallel. 

94 for test_case_type, schema_test_type in _SCHEMA_TEST_TYPES.items(): 

95 for test_case_subtype in schema_test_type.test_case_subtypes: 

96 for test_case_file in locate_config_examples( 

97 test_cases_root, os.path.join(test_case_type, test_case_subtype) 

98 ): 

99 with open(test_case_file, mode="r", encoding="utf-8") as test_case_fh: 

100 try: 

101 test_case_info = SchemaTestCaseInfo( 

102 config=json5.load(test_case_fh), 

103 test_case_file=test_case_file, 

104 test_case_type=test_case_type, 

105 test_case_subtype=test_case_subtype, 

106 ) 

107 test_cases.by_path[test_case_info.test_case_file] = test_case_info 

108 test_cases.by_type[test_case_info.test_case_type][ 

109 test_case_info.test_case_file 

110 ] = test_case_info 

111 test_cases.by_subtype[test_case_info.test_case_subtype][ 

112 test_case_info.test_case_file 

113 ] = test_case_info 

114 except Exception as ex: 

115 raise RuntimeError("Failed to load test case: " + test_case_file) from ex 

116 assert test_cases 

117 

118 assert len(test_cases.by_type["good"]) > 0 

119 assert len(test_cases.by_type["bad"]) > 0 

120 assert len(test_cases.by_subtype) > 2 

121 

122 return test_cases 

123 

124 

125def check_test_case_against_schema( 

126 test_case: SchemaTestCaseInfo, 

127 schema_type: ConfigSchema, 

128) -> None: 

129 """ 

130 Checks the given test case against the given schema. 

131 

132 Parameters 

133 ---------- 

134 test_case : SchemaTestCaseInfo 

135 Schema test case to check. 

136 schema_type : ConfigSchema 

137 Schema to check against, e.g., ENVIRONMENT or SERVICE. 

138 

139 Raises 

140 ------ 

141 NotImplementedError 

142 If test case is not known. 

143 """ 

144 if test_case.test_case_type == "good": 

145 schema_type.validate(test_case.config) 

146 elif test_case.test_case_type == "bad": 

147 with pytest.raises(jsonschema.ValidationError): 

148 schema_type.validate(test_case.config) 

149 else: 

150 raise NotImplementedError(f"Unknown test case type: {test_case.test_case_type}") 

151 

152 

153def check_test_case_config_with_extra_param( 

154 test_case: SchemaTestCaseInfo, 

155 schema_type: ConfigSchema, 

156) -> None: 

157 """Checks that the config fails to validate if extra params are present in certain 

158 places. 

159 """ 

160 config = deepcopy(test_case.config) 

161 schema_type.validate(config) 

162 extra_outer_attr = "extra_outer_attr" 

163 config[extra_outer_attr] = "should not be here" 

164 with pytest.raises(jsonschema.ValidationError): 

165 schema_type.validate(config) 

166 del config[extra_outer_attr] 

167 if not config.get("config"): 

168 config["config"] = {} 

169 extra_config_attr = "extra_config_attr" 

170 config["config"][extra_config_attr] = "should not be here" 

171 with pytest.raises(jsonschema.ValidationError): 

172 schema_type.validate(config)