Coverage for mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py: 100%

36 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tests for service schema validation.""" 

6 

7from os import path 

8from typing import Any, Dict, List 

9 

10import pytest 

11 

12from mlos_bench.config.schemas import ConfigSchema 

13from mlos_bench.services.base_service import Service 

14from mlos_bench.services.config_persistence import ConfigPersistenceService 

15from mlos_bench.services.local.temp_dir_context import TempDirContextService 

16from mlos_bench.services.remote.azure.azure_deployment_services import ( 

17 AzureDeploymentService, 

18) 

19from mlos_bench.services.remote.ssh.ssh_service import SshService 

20from mlos_bench.tests import try_resolve_class_name 

21from mlos_bench.tests.config.schemas import ( 

22 check_test_case_against_schema, 

23 check_test_case_config_with_extra_param, 

24 get_schema_test_cases, 

25) 

26from mlos_core.tests import get_all_concrete_subclasses 

27 

28# General testing strategy: 

29# - hand code a set of good/bad configs (useful to test editor schema checking) 

30# - enumerate and try to check that we've covered all the cases 

31# - for each config, load and validate against expected schema 

32 

33TEST_CASES = get_schema_test_cases(path.join(path.dirname(__file__), "test-cases")) 

34 

35 

36# Dynamically enumerate some of the cases we want to make sure we cover. 

37 

38NON_CONFIG_SERVICE_CLASSES = { 

39 # configured thru the launcher cli args 

40 ConfigPersistenceService, 

41 # ABCMeta abstract class, but no good way to test that dynamically in Python. 

42 TempDirContextService, 

43 # ABCMeta abstract base class 

44 AzureDeploymentService, 

45 # ABCMeta abstract base class 

46 SshService, 

47} 

48 

49expected_service_class_names = [ 

50 subclass.__module__ + "." + subclass.__name__ 

51 for subclass in get_all_concrete_subclasses(Service, pkg_name="mlos_bench") 

52 if subclass not in NON_CONFIG_SERVICE_CLASSES 

53] 

54assert expected_service_class_names 

55 

56 

57# Do the full cross product of all the test cases and all the Service types. 

58@pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) 

59@pytest.mark.parametrize("service_class", expected_service_class_names) 

60def test_case_coverage_mlos_bench_service_type(test_case_subtype: str, service_class: str) -> None: 

61 """Checks to see if there is a given type of test case for the given mlos_bench 

62 Service type. 

63 """ 

64 for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): 

65 config_list: List[Dict[str, Any]] 

66 if not isinstance(test_case.config, dict): 

67 continue # type: ignore[unreachable] 

68 if "class" not in test_case.config: 

69 config_list = test_case.config["services"] 

70 else: 

71 config_list = [test_case.config] 

72 for config in config_list: 

73 if try_resolve_class_name(config.get("class")) == service_class: 

74 return 

75 raise NotImplementedError( 

76 f"Missing test case for subtype {test_case_subtype} for service class {service_class}" 

77 ) 

78 

79 

80# Now we actually perform all of those validation tests. 

81 

82 

83@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) 

84def test_service_configs_against_schema(test_case_name: str) -> None: 

85 """Checks that the service config validates against the schema.""" 

86 check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.SERVICE) 

87 check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) 

88 

89 

90@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) 

91def test_service_configs_with_extra_param(test_case_name: str) -> None: 

92 """Checks that the service config fails to validate if extra params are present in 

93 certain places. 

94 """ 

95 check_test_case_config_with_extra_param( 

96 TEST_CASES.by_type["good"][test_case_name], 

97 ConfigSchema.SERVICE, 

98 ) 

99 check_test_case_config_with_extra_param( 

100 TEST_CASES.by_type["good"][test_case_name], 

101 ConfigSchema.UNIFIED, 

102 )