Coverage for mlos_bench/mlos_bench/tests/config/schedulers/test_load_scheduler_config_examples.py: 100%

33 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-14 00:55 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tests for loading scheduler config examples.""" 

6import logging 

7 

8import pytest 

9 

10import mlos_bench.tests.optimizers.fixtures 

11import mlos_bench.tests.storage.sql.fixtures 

12from mlos_bench.config.schemas.config_schemas import ConfigSchema 

13from mlos_bench.optimizers.mock_optimizer import MockOptimizer 

14from mlos_bench.schedulers.base_scheduler import Scheduler 

15from mlos_bench.schedulers.trial_runner import TrialRunner 

16from mlos_bench.services.config_persistence import ConfigPersistenceService 

17from mlos_bench.storage.base_storage import Storage 

18from mlos_bench.tests.config import BUILTIN_TEST_CONFIG_PATH, locate_config_examples 

19from mlos_bench.util import get_class_from_name 

20 

21mock_opt = mlos_bench.tests.optimizers.fixtures.mock_opt 

22storage = mlos_bench.tests.storage.sql.fixtures.storage 

23 

24 

25_LOG = logging.getLogger(__name__) 

26_LOG.setLevel(logging.DEBUG) 

27 

28# pylint: disable=redefined-outer-name 

29 

30# Get the set of configs to test. 

31CONFIG_TYPE = "schedulers" 

32 

33 

34def filter_configs(configs_to_filter: list[str]) -> list[str]: 

35 """If necessary, filter out json files that aren't for the module we're testing.""" 

36 return configs_to_filter 

37 

38 

39configs = locate_config_examples( 

40 ConfigPersistenceService.BUILTIN_CONFIG_PATH, 

41 CONFIG_TYPE, 

42 filter_configs, 

43) 

44assert configs 

45 

46test_configs = locate_config_examples( 

47 BUILTIN_TEST_CONFIG_PATH, 

48 CONFIG_TYPE, 

49 filter_configs, 

50) 

51# assert test_configs 

52configs.extend(test_configs) 

53 

54 

55@pytest.mark.parametrize("config_path", configs) 

56def test_load_scheduler_config_examples( 

57 config_loader_service: ConfigPersistenceService, 

58 config_path: str, 

59 mock_env_config_path: str, 

60 trial_runners: list[TrialRunner], 

61 storage: Storage, 

62 mock_opt: MockOptimizer, 

63) -> None: 

64 """Tests loading a config example.""" 

65 # pylint: disable=too-many-arguments,too-many-positional-arguments 

66 config = config_loader_service.load_config(config_path, ConfigSchema.SCHEDULER) 

67 assert isinstance(config, dict) 

68 cls = get_class_from_name(config["class"]) 

69 assert issubclass(cls, Scheduler) 

70 global_config = { 

71 # Required configs generally provided by the Launcher. 

72 "experiment_id": f"test_experiment_{__name__}", 

73 "trial_id": 1, 

74 } 

75 # Make an instance of the class based on the config. 

76 scheduler_inst = config_loader_service.build_scheduler( 

77 config=config, 

78 global_config=global_config, 

79 trial_runners=trial_runners, 

80 optimizer=mock_opt, 

81 storage=storage, 

82 root_env_config=mock_env_config_path, 

83 ) 

84 assert scheduler_inst is not None 

85 assert isinstance(scheduler_inst, cls)