Coverage for mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py: 100%
26 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Tests for loading optimizer config examples."""
6import logging
7from typing import List
9import pytest
11from mlos_bench.config.schemas import ConfigSchema
12from mlos_bench.optimizers.base_optimizer import Optimizer
13from mlos_bench.services.config_persistence import ConfigPersistenceService
14from mlos_bench.tests.config import locate_config_examples
15from mlos_bench.tunables.tunable_groups import TunableGroups
16from mlos_bench.util import get_class_from_name
18_LOG = logging.getLogger(__name__)
19_LOG.setLevel(logging.DEBUG)
22# Get the set of configs to test.
23CONFIG_TYPE = "optimizers"
26def filter_configs(configs_to_filter: List[str]) -> List[str]:
27 """If necessary, filter out json files that aren't for the module we're testing."""
28 return configs_to_filter
31configs = locate_config_examples(
32 ConfigPersistenceService.BUILTIN_CONFIG_PATH,
33 CONFIG_TYPE,
34 filter_configs,
35)
36assert configs
39@pytest.mark.parametrize("config_path", configs)
40def test_load_optimizer_config_examples(
41 config_loader_service: ConfigPersistenceService,
42 config_path: str,
43) -> None:
44 """Tests loading a config example."""
45 config = config_loader_service.load_config(config_path, ConfigSchema.OPTIMIZER)
46 assert isinstance(config, dict)
47 cls = get_class_from_name(config["class"])
48 assert issubclass(cls, Optimizer)
49 # Make an instance of the class based on the config.
50 tunable_groups = TunableGroups()
51 optimizer_inst = config_loader_service.build_optimizer(
52 tunables=tunable_groups,
53 config=config,
54 service=config_loader_service,
55 )
56 assert optimizer_inst is not None
57 assert isinstance(optimizer_inst, cls)