Coverage for mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py: 100%
30 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-05 00:36 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-05 00:36 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Tests for environment schema validation.
7"""
9from os import path
11import pytest
13from mlos_core.tests import get_all_concrete_subclasses
15from mlos_bench.config.schemas import ConfigSchema
16from mlos_bench.environments.base_environment import Environment
17from mlos_bench.environments.composite_env import CompositeEnv
18from mlos_bench.environments.script_env import ScriptEnv
20from mlos_bench.tests import try_resolve_class_name
21from mlos_bench.tests.config.schemas import (get_schema_test_cases,
22 check_test_case_against_schema,
23 check_test_case_config_with_extra_param)
26# General testing strategy:
27# - hand code a set of good/bad configs (useful to test editor schema checking)
28# - enumerate and try to check that we've covered all the cases
29# - for each config, load and validate against expected schema
31TEST_CASES = get_schema_test_cases(path.join(path.dirname(__file__), "test-cases"))
34# Dynamically enumerate some of the cases we want to make sure we cover.
36NON_CONFIG_ENV_CLASSES = {
37 ScriptEnv # ScriptEnv is ABCMeta abstract, but there's no good way to test that dynamically in Python.
38}
39expected_environment_class_names = [subclass.__module__ + "." + subclass.__name__
40 for subclass
41 in get_all_concrete_subclasses(Environment, pkg_name='mlos_bench')
42 if subclass not in NON_CONFIG_ENV_CLASSES]
43assert expected_environment_class_names
45COMPOSITE_ENV_CLASS_NAME = CompositeEnv.__module__ + "." + CompositeEnv.__name__
46expected_leaf_environment_class_names = [subclass_name for subclass_name in expected_environment_class_names
47 if subclass_name != COMPOSITE_ENV_CLASS_NAME]
50# Do the full cross product of all the test cases and all the Environment types.
51@pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype))
52@pytest.mark.parametrize("env_class", expected_environment_class_names)
53def test_case_coverage_mlos_bench_environment_type(test_case_subtype: str, env_class: str) -> None:
54 """
55 Checks to see if there is a given type of test case for the given mlos_bench Environment type.
56 """
57 for test_case in TEST_CASES.by_subtype[test_case_subtype].values():
58 if try_resolve_class_name(test_case.config.get("class")) == env_class:
59 return
60 raise NotImplementedError(
61 f"Missing test case for subtype {test_case_subtype} for Environment class {env_class}")
64# Now we actually perform all of those validation tests.
66@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path))
67def test_environment_configs_against_schema(test_case_name: str) -> None:
68 """
69 Checks that the environment config validates against the schema.
70 """
71 check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.ENVIRONMENT)
72 check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED)
75@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"]))
76def test_environment_configs_with_extra_param(test_case_name: str) -> None:
77 """
78 Checks that the environment config fails to validate if extra params are present in certain places.
79 """
80 check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.ENVIRONMENT)
81 check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED)