Coverage for mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py: 100%
37 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Tests for service schema validation.
7"""
9from os import path
10from typing import Any, Dict, List
12import pytest
14from mlos_core.tests import get_all_concrete_subclasses
16from mlos_bench.config.schemas import ConfigSchema
17from mlos_bench.services.base_service import Service
18from mlos_bench.services.config_persistence import ConfigPersistenceService
19from mlos_bench.services.local.temp_dir_context import TempDirContextService
20from mlos_bench.services.remote.azure.azure_deployment_services import AzureDeploymentService
21from mlos_bench.services.remote.ssh.ssh_service import SshService
23from mlos_bench.tests import try_resolve_class_name
24from mlos_bench.tests.config.schemas import (get_schema_test_cases,
25 check_test_case_against_schema,
26 check_test_case_config_with_extra_param)
29# General testing strategy:
30# - hand code a set of good/bad configs (useful to test editor schema checking)
31# - enumerate and try to check that we've covered all the cases
32# - for each config, load and validate against expected schema
34TEST_CASES = get_schema_test_cases(path.join(path.dirname(__file__), "test-cases"))
37# Dynamically enumerate some of the cases we want to make sure we cover.
39NON_CONFIG_SERVICE_CLASSES = {
40 ConfigPersistenceService, # configured thru the launcher cli args
41 TempDirContextService, # ABCMeta abstract class, but no good way to test that dynamically in Python.
42 AzureDeploymentService, # ABCMeta abstract base class
43 SshService, # ABCMeta abstract base class
44}
46expected_service_class_names = [subclass.__module__ + "." + subclass.__name__
47 for subclass
48 in get_all_concrete_subclasses(Service, pkg_name='mlos_bench')
49 if subclass not in NON_CONFIG_SERVICE_CLASSES]
50assert expected_service_class_names
53# Do the full cross product of all the test cases and all the Service types.
54@pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype))
55@pytest.mark.parametrize("service_class", expected_service_class_names)
56def test_case_coverage_mlos_bench_service_type(test_case_subtype: str, service_class: str) -> None:
57 """
58 Checks to see if there is a given type of test case for the given mlos_bench Service type.
59 """
60 for test_case in TEST_CASES.by_subtype[test_case_subtype].values():
61 config_list: List[Dict[str, Any]]
62 if not isinstance(test_case.config, dict):
63 continue # type: ignore[unreachable]
64 if "class" not in test_case.config:
65 config_list = test_case.config["services"]
66 else:
67 config_list = [test_case.config]
68 for config in config_list:
69 if try_resolve_class_name(config.get("class")) == service_class:
70 return
71 raise NotImplementedError(
72 f"Missing test case for subtype {test_case_subtype} for service class {service_class}")
75# Now we actually perform all of those validation tests.
77@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path))
78def test_service_configs_against_schema(test_case_name: str) -> None:
79 """
80 Checks that the service config validates against the schema.
81 """
82 check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.SERVICE)
83 check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED)
86@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"]))
87def test_service_configs_with_extra_param(test_case_name: str) -> None:
88 """
89 Checks that the service config fails to validate if extra params are present in certain places.
90 """
91 check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SERVICE)
92 check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED)