Coverage for mlos_bench/mlos_bench/tests/config/schemas/__init__.py: 94%

78 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Common tests for config schemas and their validation and test cases. 

7""" 

8 

9from copy import deepcopy 

10from dataclasses import dataclass 

11from typing import Any, Dict, Set 

12 

13import os 

14 

15import json5 

16import jsonschema 

17import pytest 

18 

19from mlos_bench.config.schemas.config_schemas import ConfigSchema 

20from mlos_bench.tests.config import locate_config_examples 

21 

22 

23# A dataclass to make pylint happy. 

24@dataclass 

25class SchemaTestType: 

26 """ 

27 The different type of schema test cases we expect to have. 

28 """ 

29 

30 test_case_type: str 

31 test_case_subtypes: Set[str] 

32 

33 def __hash__(self) -> int: 

34 return hash(self.test_case_type) 

35 

36 

37# The different type of schema test cases we expect to have. 

38_SCHEMA_TEST_TYPES = {x.test_case_type: x for x in ( 

39 SchemaTestType(test_case_type='good', test_case_subtypes={'full', 'partial'}), 

40 SchemaTestType(test_case_type='bad', test_case_subtypes={'invalid', 'unhandled'}), 

41)} 

42 

43 

44@dataclass 

45class SchemaTestCaseInfo(): 

46 """ 

47 Some basic info about a schema test case. 

48 """ 

49 

50 config: Dict[str, Any] 

51 test_case_file: str 

52 test_case_type: str 

53 test_case_subtype: str 

54 

55 def __hash__(self) -> int: 

56 return hash(self.test_case_file) 

57 

58 

59def check_schema_dir_layout(test_cases_root: str) -> None: 

60 """ 

61 Makes sure the directory layout matches what we expect so we aren't missing 

62 any extra configs or test cases. 

63 """ 

64 for test_case_dir in os.listdir(test_cases_root): 

65 if test_case_dir == 'README.md': 

66 continue 

67 if test_case_dir not in _SCHEMA_TEST_TYPES: 

68 raise NotImplementedError(f"Unhandled test case type: {test_case_dir}") 

69 for test_case_subdir in os.listdir(os.path.join(test_cases_root, test_case_dir)): 

70 if test_case_subdir == 'README.md': 

71 continue 

72 if test_case_subdir not in _SCHEMA_TEST_TYPES[test_case_dir].test_case_subtypes: 

73 raise NotImplementedError(f"Unhandled test case subtype {test_case_subdir} for test case type {test_case_dir}") 

74 

75 

76@dataclass 

77class TestCases: 

78 """ 

79 A container for test cases by type. 

80 """ 

81 

82 by_path: Dict[str, SchemaTestCaseInfo] 

83 by_type: Dict[str, Dict[str, SchemaTestCaseInfo]] 

84 by_subtype: Dict[str, Dict[str, SchemaTestCaseInfo]] 

85 

86 

87def get_schema_test_cases(test_cases_root: str) -> TestCases: 

88 """ 

89 Gets a dict of schema test cases from the given root. 

90 """ 

91 test_cases = TestCases(by_path={}, 

92 by_type={x: {} for x in _SCHEMA_TEST_TYPES}, 

93 by_subtype={y: {} for x in _SCHEMA_TEST_TYPES for y in _SCHEMA_TEST_TYPES[x].test_case_subtypes}) 

94 check_schema_dir_layout(test_cases_root) 

95 # Note: we sort the test cases so that we can deterministically test them in parallel. 

96 for (test_case_type, schema_test_type) in _SCHEMA_TEST_TYPES.items(): 

97 for test_case_subtype in schema_test_type.test_case_subtypes: 

98 for test_case_file in locate_config_examples(test_cases_root, os.path.join(test_case_type, test_case_subtype)): 

99 with open(test_case_file, mode='r', encoding='utf-8') as test_case_fh: 

100 try: 

101 test_case_info = SchemaTestCaseInfo( 

102 config=json5.load(test_case_fh), 

103 test_case_file=test_case_file, 

104 test_case_type=test_case_type, 

105 test_case_subtype=test_case_subtype, 

106 ) 

107 test_cases.by_path[test_case_info.test_case_file] = test_case_info 

108 test_cases.by_type[test_case_info.test_case_type][test_case_info.test_case_file] = test_case_info 

109 test_cases.by_subtype[test_case_info.test_case_subtype][test_case_info.test_case_file] = test_case_info 

110 except Exception as ex: 

111 raise RuntimeError("Failed to load test case: " + test_case_file) from ex 

112 assert test_cases 

113 

114 assert len(test_cases.by_type["good"]) > 0 

115 assert len(test_cases.by_type["bad"]) > 0 

116 assert len(test_cases.by_subtype) > 2 

117 

118 return test_cases 

119 

120 

121def check_test_case_against_schema(test_case: SchemaTestCaseInfo, schema_type: ConfigSchema) -> None: 

122 """ 

123 Checks the given test case against the given schema. 

124 

125 Parameters 

126 ---------- 

127 test_case : SchemaTestCaseInfo 

128 Schema test case to check. 

129 schema_type : ConfigSchema 

130 Schema to check against, e.g., ENVIRONMENT or SERVICE. 

131 

132 Raises 

133 ------ 

134 NotImplementedError 

135 If test case is not known. 

136 """ 

137 if test_case.test_case_type == "good": 

138 schema_type.validate(test_case.config) 

139 elif test_case.test_case_type == "bad": 

140 with pytest.raises(jsonschema.ValidationError): 

141 schema_type.validate(test_case.config) 

142 else: 

143 raise NotImplementedError(f"Unknown test case type: {test_case.test_case_type}") 

144 

145 

146def check_test_case_config_with_extra_param(test_case: SchemaTestCaseInfo, schema_type: ConfigSchema) -> None: 

147 """ 

148 Checks that the config fails to validate if extra params are present in certain places. 

149 """ 

150 config = deepcopy(test_case.config) 

151 schema_type.validate(config) 

152 extra_outer_attr = "extra_outer_attr" 

153 config[extra_outer_attr] = "should not be here" 

154 with pytest.raises(jsonschema.ValidationError): 

155 schema_type.validate(config) 

156 del config[extra_outer_attr] 

157 if not config.get("config"): 

158 config["config"] = {} 

159 extra_config_attr = "extra_config_attr" 

160 config["config"][extra_config_attr] = "should not be here" 

161 with pytest.raises(jsonschema.ValidationError): 

162 schema_type.validate(config)