Coverage for mlos_bench/mlos_bench/tests/config/schemas/__init__.py: 94%
78 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Common tests for config schemas and their validation and test cases.
7"""
9from copy import deepcopy
10from dataclasses import dataclass
11from typing import Any, Dict, Set
13import os
15import json5
16import jsonschema
17import pytest
19from mlos_bench.config.schemas.config_schemas import ConfigSchema
20from mlos_bench.tests.config import locate_config_examples
23# A dataclass to make pylint happy.
24@dataclass
25class SchemaTestType:
26 """
27 The different type of schema test cases we expect to have.
28 """
30 test_case_type: str
31 test_case_subtypes: Set[str]
33 def __hash__(self) -> int:
34 return hash(self.test_case_type)
37# The different type of schema test cases we expect to have.
38_SCHEMA_TEST_TYPES = {x.test_case_type: x for x in (
39 SchemaTestType(test_case_type='good', test_case_subtypes={'full', 'partial'}),
40 SchemaTestType(test_case_type='bad', test_case_subtypes={'invalid', 'unhandled'}),
41)}
44@dataclass
45class SchemaTestCaseInfo():
46 """
47 Some basic info about a schema test case.
48 """
50 config: Dict[str, Any]
51 test_case_file: str
52 test_case_type: str
53 test_case_subtype: str
55 def __hash__(self) -> int:
56 return hash(self.test_case_file)
59def check_schema_dir_layout(test_cases_root: str) -> None:
60 """
61 Makes sure the directory layout matches what we expect so we aren't missing
62 any extra configs or test cases.
63 """
64 for test_case_dir in os.listdir(test_cases_root):
65 if test_case_dir == 'README.md':
66 continue
67 if test_case_dir not in _SCHEMA_TEST_TYPES:
68 raise NotImplementedError(f"Unhandled test case type: {test_case_dir}")
69 for test_case_subdir in os.listdir(os.path.join(test_cases_root, test_case_dir)):
70 if test_case_subdir == 'README.md':
71 continue
72 if test_case_subdir not in _SCHEMA_TEST_TYPES[test_case_dir].test_case_subtypes:
73 raise NotImplementedError(f"Unhandled test case subtype {test_case_subdir} for test case type {test_case_dir}")
76@dataclass
77class TestCases:
78 """
79 A container for test cases by type.
80 """
82 by_path: Dict[str, SchemaTestCaseInfo]
83 by_type: Dict[str, Dict[str, SchemaTestCaseInfo]]
84 by_subtype: Dict[str, Dict[str, SchemaTestCaseInfo]]
87def get_schema_test_cases(test_cases_root: str) -> TestCases:
88 """
89 Gets a dict of schema test cases from the given root.
90 """
91 test_cases = TestCases(by_path={},
92 by_type={x: {} for x in _SCHEMA_TEST_TYPES},
93 by_subtype={y: {} for x in _SCHEMA_TEST_TYPES for y in _SCHEMA_TEST_TYPES[x].test_case_subtypes})
94 check_schema_dir_layout(test_cases_root)
95 # Note: we sort the test cases so that we can deterministically test them in parallel.
96 for (test_case_type, schema_test_type) in _SCHEMA_TEST_TYPES.items():
97 for test_case_subtype in schema_test_type.test_case_subtypes:
98 for test_case_file in locate_config_examples(test_cases_root, os.path.join(test_case_type, test_case_subtype)):
99 with open(test_case_file, mode='r', encoding='utf-8') as test_case_fh:
100 try:
101 test_case_info = SchemaTestCaseInfo(
102 config=json5.load(test_case_fh),
103 test_case_file=test_case_file,
104 test_case_type=test_case_type,
105 test_case_subtype=test_case_subtype,
106 )
107 test_cases.by_path[test_case_info.test_case_file] = test_case_info
108 test_cases.by_type[test_case_info.test_case_type][test_case_info.test_case_file] = test_case_info
109 test_cases.by_subtype[test_case_info.test_case_subtype][test_case_info.test_case_file] = test_case_info
110 except Exception as ex:
111 raise RuntimeError("Failed to load test case: " + test_case_file) from ex
112 assert test_cases
114 assert len(test_cases.by_type["good"]) > 0
115 assert len(test_cases.by_type["bad"]) > 0
116 assert len(test_cases.by_subtype) > 2
118 return test_cases
121def check_test_case_against_schema(test_case: SchemaTestCaseInfo, schema_type: ConfigSchema) -> None:
122 """
123 Checks the given test case against the given schema.
125 Parameters
126 ----------
127 test_case : SchemaTestCaseInfo
128 Schema test case to check.
129 schema_type : ConfigSchema
130 Schema to check against, e.g., ENVIRONMENT or SERVICE.
132 Raises
133 ------
134 NotImplementedError
135 If test case is not known.
136 """
137 if test_case.test_case_type == "good":
138 schema_type.validate(test_case.config)
139 elif test_case.test_case_type == "bad":
140 with pytest.raises(jsonschema.ValidationError):
141 schema_type.validate(test_case.config)
142 else:
143 raise NotImplementedError(f"Unknown test case type: {test_case.test_case_type}")
146def check_test_case_config_with_extra_param(test_case: SchemaTestCaseInfo, schema_type: ConfigSchema) -> None:
147 """
148 Checks that the config fails to validate if extra params are present in certain places.
149 """
150 config = deepcopy(test_case.config)
151 schema_type.validate(config)
152 extra_outer_attr = "extra_outer_attr"
153 config[extra_outer_attr] = "should not be here"
154 with pytest.raises(jsonschema.ValidationError):
155 schema_type.validate(config)
156 del config[extra_outer_attr]
157 if not config.get("config"):
158 config["config"] = {}
159 extra_config_attr = "extra_config_attr"
160 config["config"][extra_config_attr] = "should not be here"
161 with pytest.raises(jsonschema.ValidationError):
162 schema_type.validate(config)