Coverage for mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py: 100%

72 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-14 00:55 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tests for loading environment config examples.""" 

6import logging 

7 

8import pytest 

9 

10from mlos_bench.config.schemas.config_schemas import ConfigSchema 

11from mlos_bench.environments.base_environment import Environment 

12from mlos_bench.environments.composite_env import CompositeEnv 

13from mlos_bench.services.config_persistence import ConfigPersistenceService 

14from mlos_bench.tests.config import BUILTIN_TEST_CONFIG_PATH, locate_config_examples 

15from mlos_bench.tunables.tunable_groups import TunableGroups 

16 

17_LOG = logging.getLogger(__name__) 

18_LOG.setLevel(logging.DEBUG) 

19 

20 

21# Get the set of configs to test. 

22CONFIG_TYPE = "environments" 

23 

24 

25def filter_configs(configs_to_filter: list[str]) -> list[str]: 

26 """If necessary, filter out json files that aren't for the module we're testing.""" 

27 configs_to_filter = [ 

28 config_path 

29 for config_path in configs_to_filter 

30 if not config_path.endswith("-tunables.jsonc") 

31 ] 

32 return configs_to_filter 

33 

34 

35configs = locate_config_examples( 

36 ConfigPersistenceService.BUILTIN_CONFIG_PATH, 

37 CONFIG_TYPE, 

38 filter_configs, 

39) 

40assert configs 

41 

42test_configs = locate_config_examples( 

43 BUILTIN_TEST_CONFIG_PATH, 

44 CONFIG_TYPE, 

45 filter_configs, 

46) 

47assert test_configs 

48configs.extend(test_configs) 

49 

50 

51@pytest.mark.parametrize("config_path", configs) 

52def test_load_environment_config_examples( 

53 config_loader_service: ConfigPersistenceService, 

54 config_path: str, 

55) -> None: 

56 """Tests loading an environment config example.""" 

57 envs = load_environment_config_examples(config_loader_service, config_path) 

58 for env in envs: 

59 assert env is not None 

60 assert isinstance(env, Environment) 

61 

62 

63def load_environment_config_examples( 

64 config_loader_service: ConfigPersistenceService, 

65 config_path: str, 

66) -> list[Environment]: 

67 """Loads an environment config example.""" 

68 # Make sure that any "required_args" are provided. 

69 global_config = config_loader_service.load_config( 

70 "experiments/experiment_test_config.jsonc", 

71 ConfigSchema.GLOBALS, 

72 ) 

73 global_config.setdefault("trial_id", 1) # normally populated by Launcher 

74 

75 # Make sure we have the required services for the envs being used. 

76 mock_service_configs = [ 

77 "services/local/mock/mock_local_exec_service.jsonc", 

78 "services/remote/mock/mock_fileshare_service.jsonc", 

79 "services/remote/mock/mock_network_service.jsonc", 

80 "services/remote/mock/mock_vm_service.jsonc", 

81 "services/remote/mock/mock_remote_exec_service.jsonc", 

82 "services/remote/mock/mock_auth_service.jsonc", 

83 ] 

84 

85 tunable_groups = TunableGroups() # base tunable groups that all others get built on 

86 

87 for mock_service_config_path in mock_service_configs: 

88 mock_service_config = config_loader_service.load_config( 

89 mock_service_config_path, 

90 ConfigSchema.SERVICE, 

91 ) 

92 config_loader_service.register( 

93 config_loader_service.build_service( 

94 config=mock_service_config, 

95 parent=config_loader_service, 

96 ).export() 

97 ) 

98 

99 envs = config_loader_service.load_environment_list( 

100 config_path, 

101 tunable_groups, 

102 global_config, 

103 service=config_loader_service, 

104 ) 

105 return envs 

106 

107 

108composite_configs = locate_config_examples( 

109 ConfigPersistenceService.BUILTIN_CONFIG_PATH, 

110 "environments/root/", 

111) 

112assert composite_configs 

113 

114 

115@pytest.mark.parametrize("config_path", composite_configs) 

116def test_load_composite_env_config_examples( 

117 config_loader_service: ConfigPersistenceService, 

118 config_path: str, 

119) -> None: 

120 """Tests loading a composite env config example.""" 

121 envs = load_environment_config_examples(config_loader_service, config_path) 

122 assert len(envs) == 1 

123 assert isinstance(envs[0], CompositeEnv) 

124 composite_env: CompositeEnv = envs[0] 

125 

126 for child_env in composite_env.children: 

127 assert child_env is not None 

128 assert isinstance(child_env, Environment) 

129 assert child_env.tunable_params is not None 

130 

131 checked_child_env_groups = set() 

132 for child_tunable, child_group in child_env.tunable_params: 

133 # Lookup that tunable in the composite env. 

134 assert child_tunable in composite_env.tunable_params 

135 (composite_tunable, composite_group) = composite_env.tunable_params.get_tunable( 

136 child_tunable 

137 ) 

138 # Check that the tunables are the same object. 

139 assert child_tunable is composite_tunable 

140 if child_group.name not in checked_child_env_groups: 

141 assert child_group is composite_group 

142 checked_child_env_groups.add(child_group.name) 

143 

144 # Check that when we change a child env, it's value is reflected in the 

145 # composite env as well. 

146 # That is to say, they refer to the same objects, despite having 

147 # potentially been loaded from separate configs. 

148 if child_tunable.is_categorical: 

149 old_cat_value = child_tunable.category 

150 assert child_tunable.value == old_cat_value 

151 assert child_group[child_tunable] == old_cat_value 

152 assert composite_env.tunable_params[child_tunable] == old_cat_value 

153 new_cat_value = [x for x in child_tunable.categories if x != old_cat_value][0] 

154 child_tunable.category = new_cat_value 

155 assert child_env.tunable_params[child_tunable] == new_cat_value 

156 assert composite_env.tunable_params[child_tunable] == child_tunable.category 

157 elif child_tunable.is_numerical: 

158 old_num_value = child_tunable.numerical_value 

159 assert child_tunable.value == old_num_value 

160 assert child_group[child_tunable] == old_num_value 

161 assert composite_env.tunable_params[child_tunable] == old_num_value 

162 child_tunable.numerical_value += 1 

163 assert child_env.tunable_params[child_tunable] == old_num_value + 1 

164 assert composite_env.tunable_params[child_tunable] == child_tunable.numerical_value