Coverage for mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py: 100%
70 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Tests for loading environment config examples."""
6import logging
7from typing import List
9import pytest
11from mlos_bench.config.schemas.config_schemas import ConfigSchema
12from mlos_bench.environments.base_environment import Environment
13from mlos_bench.environments.composite_env import CompositeEnv
14from mlos_bench.services.config_persistence import ConfigPersistenceService
15from mlos_bench.tests.config import locate_config_examples
16from mlos_bench.tunables.tunable_groups import TunableGroups
18_LOG = logging.getLogger(__name__)
19_LOG.setLevel(logging.DEBUG)
22# Get the set of configs to test.
23CONFIG_TYPE = "environments"
26def filter_configs(configs_to_filter: List[str]) -> List[str]:
27 """If necessary, filter out json files that aren't for the module we're testing."""
28 configs_to_filter = [
29 config_path
30 for config_path in configs_to_filter
31 if not config_path.endswith("-tunables.jsonc")
32 ]
33 return configs_to_filter
36configs = locate_config_examples(
37 ConfigPersistenceService.BUILTIN_CONFIG_PATH,
38 CONFIG_TYPE,
39 filter_configs,
40)
41assert configs
44@pytest.mark.parametrize("config_path", configs)
45def test_load_environment_config_examples(
46 config_loader_service: ConfigPersistenceService,
47 config_path: str,
48) -> None:
49 """Tests loading an environment config example."""
50 envs = load_environment_config_examples(config_loader_service, config_path)
51 for env in envs:
52 assert env is not None
53 assert isinstance(env, Environment)
56def load_environment_config_examples(
57 config_loader_service: ConfigPersistenceService,
58 config_path: str,
59) -> List[Environment]:
60 """Loads an environment config example."""
61 # Make sure that any "required_args" are provided.
62 global_config = config_loader_service.load_config(
63 "experiments/experiment_test_config.jsonc",
64 ConfigSchema.GLOBALS,
65 )
66 global_config.setdefault("trial_id", 1) # normally populated by Launcher
68 # Make sure we have the required services for the envs being used.
69 mock_service_configs = [
70 "services/local/mock/mock_local_exec_service.jsonc",
71 "services/remote/mock/mock_fileshare_service.jsonc",
72 "services/remote/mock/mock_network_service.jsonc",
73 "services/remote/mock/mock_vm_service.jsonc",
74 "services/remote/mock/mock_remote_exec_service.jsonc",
75 "services/remote/mock/mock_auth_service.jsonc",
76 ]
78 tunable_groups = TunableGroups() # base tunable groups that all others get built on
80 for mock_service_config_path in mock_service_configs:
81 mock_service_config = config_loader_service.load_config(
82 mock_service_config_path,
83 ConfigSchema.SERVICE,
84 )
85 config_loader_service.register(
86 config_loader_service.build_service(
87 config=mock_service_config,
88 parent=config_loader_service,
89 ).export()
90 )
92 envs = config_loader_service.load_environment_list(
93 config_path,
94 tunable_groups,
95 global_config,
96 service=config_loader_service,
97 )
98 return envs
101composite_configs = locate_config_examples(
102 ConfigPersistenceService.BUILTIN_CONFIG_PATH,
103 "environments/root/",
104)
105assert composite_configs
108@pytest.mark.parametrize("config_path", composite_configs)
109def test_load_composite_env_config_examples(
110 config_loader_service: ConfigPersistenceService,
111 config_path: str,
112) -> None:
113 """Tests loading a composite env config example."""
114 envs = load_environment_config_examples(config_loader_service, config_path)
115 assert len(envs) == 1
116 assert isinstance(envs[0], CompositeEnv)
117 composite_env: CompositeEnv = envs[0]
119 for child_env in composite_env.children:
120 assert child_env is not None
121 assert isinstance(child_env, Environment)
122 assert child_env.tunable_params is not None
124 checked_child_env_groups = set()
125 for child_tunable, child_group in child_env.tunable_params:
126 # Lookup that tunable in the composite env.
127 assert child_tunable in composite_env.tunable_params
128 (composite_tunable, composite_group) = composite_env.tunable_params.get_tunable(
129 child_tunable
130 )
131 # Check that the tunables are the same object.
132 assert child_tunable is composite_tunable
133 if child_group.name not in checked_child_env_groups:
134 assert child_group is composite_group
135 checked_child_env_groups.add(child_group.name)
137 # Check that when we change a child env, it's value is reflected in the
138 # composite env as well.
139 # That is to say, they refer to the same objects, despite having
140 # potentially been loaded from separate configs.
141 if child_tunable.is_categorical:
142 old_cat_value = child_tunable.category
143 assert child_tunable.value == old_cat_value
144 assert child_group[child_tunable] == old_cat_value
145 assert composite_env.tunable_params[child_tunable] == old_cat_value
146 new_cat_value = [x for x in child_tunable.categories if x != old_cat_value][0]
147 child_tunable.category = new_cat_value
148 assert child_env.tunable_params[child_tunable] == new_cat_value
149 assert composite_env.tunable_params[child_tunable] == child_tunable.category
150 elif child_tunable.is_numerical:
151 old_num_value = child_tunable.numerical_value
152 assert child_tunable.value == old_num_value
153 assert child_group[child_tunable] == old_num_value
154 assert composite_env.tunable_params[child_tunable] == old_num_value
155 child_tunable.numerical_value += 1
156 assert child_env.tunable_params[child_tunable] == old_num_value + 1
157 assert composite_env.tunable_params[child_tunable] == child_tunable.numerical_value