Coverage for mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py: 100%

30 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-05 00:36 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Tests for environment schema validation. 

7""" 

8 

9from os import path 

10 

11import pytest 

12 

13from mlos_core.tests import get_all_concrete_subclasses 

14 

15from mlos_bench.config.schemas import ConfigSchema 

16from mlos_bench.environments.base_environment import Environment 

17from mlos_bench.environments.composite_env import CompositeEnv 

18from mlos_bench.environments.script_env import ScriptEnv 

19 

20from mlos_bench.tests import try_resolve_class_name 

21from mlos_bench.tests.config.schemas import (get_schema_test_cases, 

22 check_test_case_against_schema, 

23 check_test_case_config_with_extra_param) 

24 

25 

26# General testing strategy: 

27# - hand code a set of good/bad configs (useful to test editor schema checking) 

28# - enumerate and try to check that we've covered all the cases 

29# - for each config, load and validate against expected schema 

30 

31TEST_CASES = get_schema_test_cases(path.join(path.dirname(__file__), "test-cases")) 

32 

33 

34# Dynamically enumerate some of the cases we want to make sure we cover. 

35 

36NON_CONFIG_ENV_CLASSES = { 

37 ScriptEnv # ScriptEnv is ABCMeta abstract, but there's no good way to test that dynamically in Python. 

38} 

39expected_environment_class_names = [subclass.__module__ + "." + subclass.__name__ 

40 for subclass 

41 in get_all_concrete_subclasses(Environment, pkg_name='mlos_bench') 

42 if subclass not in NON_CONFIG_ENV_CLASSES] 

43assert expected_environment_class_names 

44 

45COMPOSITE_ENV_CLASS_NAME = CompositeEnv.__module__ + "." + CompositeEnv.__name__ 

46expected_leaf_environment_class_names = [subclass_name for subclass_name in expected_environment_class_names 

47 if subclass_name != COMPOSITE_ENV_CLASS_NAME] 

48 

49 

50# Do the full cross product of all the test cases and all the Environment types. 

51@pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) 

52@pytest.mark.parametrize("env_class", expected_environment_class_names) 

53def test_case_coverage_mlos_bench_environment_type(test_case_subtype: str, env_class: str) -> None: 

54 """ 

55 Checks to see if there is a given type of test case for the given mlos_bench Environment type. 

56 """ 

57 for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): 

58 if try_resolve_class_name(test_case.config.get("class")) == env_class: 

59 return 

60 raise NotImplementedError( 

61 f"Missing test case for subtype {test_case_subtype} for Environment class {env_class}") 

62 

63 

64# Now we actually perform all of those validation tests. 

65 

66@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) 

67def test_environment_configs_against_schema(test_case_name: str) -> None: 

68 """ 

69 Checks that the environment config validates against the schema. 

70 """ 

71 check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.ENVIRONMENT) 

72 check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) 

73 

74 

75@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) 

76def test_environment_configs_with_extra_param(test_case_name: str) -> None: 

77 """ 

78 Checks that the environment config fails to validate if extra params are present in certain places. 

79 """ 

80 check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.ENVIRONMENT) 

81 check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED)