Coverage for mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py: 100%
71 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Tests for loading environment config examples.
7"""
8import logging
9from typing import List
11import pytest
13from mlos_bench.tests.config import locate_config_examples
15from mlos_bench.config.schemas.config_schemas import ConfigSchema
16from mlos_bench.environments.base_environment import Environment
17from mlos_bench.environments.composite_env import CompositeEnv
18from mlos_bench.services.config_persistence import ConfigPersistenceService
19from mlos_bench.tunables.tunable_groups import TunableGroups
22_LOG = logging.getLogger(__name__)
23_LOG.setLevel(logging.DEBUG)
26# Get the set of configs to test.
27CONFIG_TYPE = "environments"
30def filter_configs(configs_to_filter: List[str]) -> List[str]:
31 """If necessary, filter out json files that aren't for the module we're testing."""
32 configs_to_filter = [config_path for config_path in configs_to_filter if not config_path.endswith("-tunables.jsonc")]
33 return configs_to_filter
36configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs)
37assert configs
40@pytest.mark.parametrize("config_path", configs)
41def test_load_environment_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None:
42 """Tests loading an environment config example."""
43 envs = load_environment_config_examples(config_loader_service, config_path)
44 for env in envs:
45 assert env is not None
46 assert isinstance(env, Environment)
49def load_environment_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> List[Environment]:
50 """Loads an environment config example."""
51 # Make sure that any "required_args" are provided.
52 global_config = config_loader_service.load_config("experiments/experiment_test_config.jsonc", ConfigSchema.GLOBALS)
53 global_config.setdefault('trial_id', 1) # normally populated by Launcher
55 # Make sure we have the required services for the envs being used.
56 mock_service_configs = [
57 "services/local/mock/mock_local_exec_service.jsonc",
58 "services/remote/mock/mock_fileshare_service.jsonc",
59 "services/remote/mock/mock_network_service.jsonc",
60 "services/remote/mock/mock_vm_service.jsonc",
61 "services/remote/mock/mock_remote_exec_service.jsonc",
62 "services/remote/mock/mock_auth_service.jsonc",
63 ]
65 tunable_groups = TunableGroups() # base tunable groups that all others get built on
67 for mock_service_config_path in mock_service_configs:
68 mock_service_config = config_loader_service.load_config(mock_service_config_path, ConfigSchema.SERVICE)
69 config_loader_service.register(config_loader_service.build_service(
70 config=mock_service_config, parent=config_loader_service).export())
72 envs = config_loader_service.load_environment_list(
73 config_path, tunable_groups, global_config, service=config_loader_service)
74 return envs
77composite_configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, "environments/root/")
78assert composite_configs
81@pytest.mark.parametrize("config_path", composite_configs)
82def test_load_composite_env_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None:
83 """Tests loading a composite env config example."""
84 envs = load_environment_config_examples(config_loader_service, config_path)
85 assert len(envs) == 1
86 assert isinstance(envs[0], CompositeEnv)
87 composite_env: CompositeEnv = envs[0]
89 for child_env in composite_env.children:
90 assert child_env is not None
91 assert isinstance(child_env, Environment)
92 assert child_env.tunable_params is not None
94 checked_child_env_groups = set()
95 for (child_tunable, child_group) in child_env.tunable_params:
96 # Lookup that tunable in the composite env.
97 assert child_tunable in composite_env.tunable_params
98 (composite_tunable, composite_group) = composite_env.tunable_params.get_tunable(child_tunable)
99 assert child_tunable is composite_tunable # Check that the tunables are the same object.
100 if child_group.name not in checked_child_env_groups:
101 assert child_group is composite_group
102 checked_child_env_groups.add(child_group.name)
104 # Check that when we change a child env, it's value is reflected in the composite env as well.
105 # That is to say, they refer to the same objects, despite having potentially been loaded from separate configs.
106 if child_tunable.is_categorical:
107 old_cat_value = child_tunable.category
108 assert child_tunable.value == old_cat_value
109 assert child_group[child_tunable] == old_cat_value
110 assert composite_env.tunable_params[child_tunable] == old_cat_value
111 new_cat_value = [x for x in child_tunable.categories if x != old_cat_value][0]
112 child_tunable.category = new_cat_value
113 assert child_env.tunable_params[child_tunable] == new_cat_value
114 assert composite_env.tunable_params[child_tunable] == child_tunable.category
115 elif child_tunable.is_numerical:
116 old_num_value = child_tunable.numerical_value
117 assert child_tunable.value == old_num_value
118 assert child_group[child_tunable] == old_num_value
119 assert composite_env.tunable_params[child_tunable] == old_num_value
120 child_tunable.numerical_value += 1
121 assert child_env.tunable_params[child_tunable] == old_num_value + 1
122 assert composite_env.tunable_params[child_tunable] == child_tunable.numerical_value