Coverage for mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py: 100%

24 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tests for schedulers schema validation.""" 

6 

7from os import path 

8 

9import pytest 

10 

11from mlos_bench.config.schemas import ConfigSchema 

12from mlos_bench.schedulers.base_scheduler import Scheduler 

13from mlos_bench.tests import try_resolve_class_name 

14from mlos_bench.tests.config.schemas import ( 

15 check_test_case_against_schema, 

16 check_test_case_config_with_extra_param, 

17 get_schema_test_cases, 

18) 

19from mlos_core.tests import get_all_concrete_subclasses 

20 

21# General testing strategy: 

22# - hand code a set of good/bad configs (useful to test editor schema checking) 

23# - enumerate and try to check that we've covered all the cases 

24# - for each config, load and validate against expected schema 

25 

26TEST_CASES = get_schema_test_cases(path.join(path.dirname(__file__), "test-cases")) 

27 

28 

29# Dynamically enumerate some of the cases we want to make sure we cover. 

30 

31expected_mlos_bench_scheduler_class_names = [ 

32 subclass.__module__ + "." + subclass.__name__ 

33 for subclass in get_all_concrete_subclasses( 

34 Scheduler, # type: ignore[type-abstract] 

35 pkg_name="mlos_bench", 

36 ) 

37] 

38assert expected_mlos_bench_scheduler_class_names 

39 

40# Do the full cross product of all the test cases and all the scheduler types. 

41 

42 

43@pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) 

44@pytest.mark.parametrize("mlos_bench_scheduler_type", expected_mlos_bench_scheduler_class_names) 

45def test_case_coverage_mlos_bench_scheduler_type( 

46 test_case_subtype: str, 

47 mlos_bench_scheduler_type: str, 

48) -> None: 

49 """Checks to see if there is a given type of test case for the given mlos_bench 

50 scheduler type. 

51 """ 

52 for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): 

53 if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_scheduler_type: 

54 return 

55 raise NotImplementedError( 

56 f"Missing test case for subtype {test_case_subtype} " 

57 f"for Scheduler class {mlos_bench_scheduler_type}" 

58 ) 

59 

60 

61# Now we actually perform all of those validation tests. 

62 

63 

64@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) 

65def test_scheduler_configs_against_schema(test_case_name: str) -> None: 

66 """Checks that the scheduler config validates against the schema.""" 

67 check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.SCHEDULER) 

68 check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) 

69 

70 

71@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) 

72def test_scheduler_configs_with_extra_param(test_case_name: str) -> None: 

73 """Checks that the scheduler config fails to validate if extra params are present in 

74 certain places. 

75 """ 

76 check_test_case_config_with_extra_param( 

77 TEST_CASES.by_type["good"][test_case_name], 

78 ConfigSchema.SCHEDULER, 

79 ) 

80 check_test_case_config_with_extra_param( 

81 TEST_CASES.by_type["good"][test_case_name], 

82 ConfigSchema.UNIFIED, 

83 ) 

84 

85 

86if __name__ == "__main__": 

87 pytest.main([__file__, "-n0"])