Coverage for mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py: 100%

37 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Tests for service schema validation. 

7""" 

8 

9from os import path 

10from typing import Any, Dict, List 

11 

12import pytest 

13 

14from mlos_core.tests import get_all_concrete_subclasses 

15 

16from mlos_bench.config.schemas import ConfigSchema 

17from mlos_bench.services.base_service import Service 

18from mlos_bench.services.config_persistence import ConfigPersistenceService 

19from mlos_bench.services.local.temp_dir_context import TempDirContextService 

20from mlos_bench.services.remote.azure.azure_deployment_services import AzureDeploymentService 

21from mlos_bench.services.remote.ssh.ssh_service import SshService 

22 

23from mlos_bench.tests import try_resolve_class_name 

24from mlos_bench.tests.config.schemas import (get_schema_test_cases, 

25 check_test_case_against_schema, 

26 check_test_case_config_with_extra_param) 

27 

28 

29# General testing strategy: 

30# - hand code a set of good/bad configs (useful to test editor schema checking) 

31# - enumerate and try to check that we've covered all the cases 

32# - for each config, load and validate against expected schema 

33 

34TEST_CASES = get_schema_test_cases(path.join(path.dirname(__file__), "test-cases")) 

35 

36 

37# Dynamically enumerate some of the cases we want to make sure we cover. 

38 

39NON_CONFIG_SERVICE_CLASSES = { 

40 ConfigPersistenceService, # configured thru the launcher cli args 

41 TempDirContextService, # ABCMeta abstract class, but no good way to test that dynamically in Python. 

42 AzureDeploymentService, # ABCMeta abstract base class 

43 SshService, # ABCMeta abstract base class 

44} 

45 

46expected_service_class_names = [subclass.__module__ + "." + subclass.__name__ 

47 for subclass 

48 in get_all_concrete_subclasses(Service, pkg_name='mlos_bench') 

49 if subclass not in NON_CONFIG_SERVICE_CLASSES] 

50assert expected_service_class_names 

51 

52 

53# Do the full cross product of all the test cases and all the Service types. 

54@pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) 

55@pytest.mark.parametrize("service_class", expected_service_class_names) 

56def test_case_coverage_mlos_bench_service_type(test_case_subtype: str, service_class: str) -> None: 

57 """ 

58 Checks to see if there is a given type of test case for the given mlos_bench Service type. 

59 """ 

60 for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): 

61 config_list: List[Dict[str, Any]] 

62 if not isinstance(test_case.config, dict): 

63 continue # type: ignore[unreachable] 

64 if "class" not in test_case.config: 

65 config_list = test_case.config["services"] 

66 else: 

67 config_list = [test_case.config] 

68 for config in config_list: 

69 if try_resolve_class_name(config.get("class")) == service_class: 

70 return 

71 raise NotImplementedError( 

72 f"Missing test case for subtype {test_case_subtype} for service class {service_class}") 

73 

74 

75# Now we actually perform all of those validation tests. 

76 

77@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) 

78def test_service_configs_against_schema(test_case_name: str) -> None: 

79 """ 

80 Checks that the service config validates against the schema. 

81 """ 

82 check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.SERVICE) 

83 check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) 

84 

85 

86@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) 

87def test_service_configs_with_extra_param(test_case_name: str) -> None: 

88 """ 

89 Checks that the service config fails to validate if extra params are present in certain places. 

90 """ 

91 check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SERVICE) 

92 check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED)