Coverage for mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py: 100%

71 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Tests for loading environment config examples. 

7""" 

8import logging 

9from typing import List 

10 

11import pytest 

12 

13from mlos_bench.tests.config import locate_config_examples 

14 

15from mlos_bench.config.schemas.config_schemas import ConfigSchema 

16from mlos_bench.environments.base_environment import Environment 

17from mlos_bench.environments.composite_env import CompositeEnv 

18from mlos_bench.services.config_persistence import ConfigPersistenceService 

19from mlos_bench.tunables.tunable_groups import TunableGroups 

20 

21 

22_LOG = logging.getLogger(__name__) 

23_LOG.setLevel(logging.DEBUG) 

24 

25 

26# Get the set of configs to test. 

27CONFIG_TYPE = "environments" 

28 

29 

30def filter_configs(configs_to_filter: List[str]) -> List[str]: 

31 """If necessary, filter out json files that aren't for the module we're testing.""" 

32 configs_to_filter = [config_path for config_path in configs_to_filter if not config_path.endswith("-tunables.jsonc")] 

33 return configs_to_filter 

34 

35 

36configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) 

37assert configs 

38 

39 

40@pytest.mark.parametrize("config_path", configs) 

41def test_load_environment_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: 

42 """Tests loading an environment config example.""" 

43 envs = load_environment_config_examples(config_loader_service, config_path) 

44 for env in envs: 

45 assert env is not None 

46 assert isinstance(env, Environment) 

47 

48 

49def load_environment_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> List[Environment]: 

50 """Loads an environment config example.""" 

51 # Make sure that any "required_args" are provided. 

52 global_config = config_loader_service.load_config("experiments/experiment_test_config.jsonc", ConfigSchema.GLOBALS) 

53 global_config.setdefault('trial_id', 1) # normally populated by Launcher 

54 

55 # Make sure we have the required services for the envs being used. 

56 mock_service_configs = [ 

57 "services/local/mock/mock_local_exec_service.jsonc", 

58 "services/remote/mock/mock_fileshare_service.jsonc", 

59 "services/remote/mock/mock_network_service.jsonc", 

60 "services/remote/mock/mock_vm_service.jsonc", 

61 "services/remote/mock/mock_remote_exec_service.jsonc", 

62 "services/remote/mock/mock_auth_service.jsonc", 

63 ] 

64 

65 tunable_groups = TunableGroups() # base tunable groups that all others get built on 

66 

67 for mock_service_config_path in mock_service_configs: 

68 mock_service_config = config_loader_service.load_config(mock_service_config_path, ConfigSchema.SERVICE) 

69 config_loader_service.register(config_loader_service.build_service( 

70 config=mock_service_config, parent=config_loader_service).export()) 

71 

72 envs = config_loader_service.load_environment_list( 

73 config_path, tunable_groups, global_config, service=config_loader_service) 

74 return envs 

75 

76 

77composite_configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, "environments/root/") 

78assert composite_configs 

79 

80 

81@pytest.mark.parametrize("config_path", composite_configs) 

82def test_load_composite_env_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: 

83 """Tests loading a composite env config example.""" 

84 envs = load_environment_config_examples(config_loader_service, config_path) 

85 assert len(envs) == 1 

86 assert isinstance(envs[0], CompositeEnv) 

87 composite_env: CompositeEnv = envs[0] 

88 

89 for child_env in composite_env.children: 

90 assert child_env is not None 

91 assert isinstance(child_env, Environment) 

92 assert child_env.tunable_params is not None 

93 

94 checked_child_env_groups = set() 

95 for (child_tunable, child_group) in child_env.tunable_params: 

96 # Lookup that tunable in the composite env. 

97 assert child_tunable in composite_env.tunable_params 

98 (composite_tunable, composite_group) = composite_env.tunable_params.get_tunable(child_tunable) 

99 assert child_tunable is composite_tunable # Check that the tunables are the same object. 

100 if child_group.name not in checked_child_env_groups: 

101 assert child_group is composite_group 

102 checked_child_env_groups.add(child_group.name) 

103 

104 # Check that when we change a child env, it's value is reflected in the composite env as well. 

105 # That is to say, they refer to the same objects, despite having potentially been loaded from separate configs. 

106 if child_tunable.is_categorical: 

107 old_cat_value = child_tunable.category 

108 assert child_tunable.value == old_cat_value 

109 assert child_group[child_tunable] == old_cat_value 

110 assert composite_env.tunable_params[child_tunable] == old_cat_value 

111 new_cat_value = [x for x in child_tunable.categories if x != old_cat_value][0] 

112 child_tunable.category = new_cat_value 

113 assert child_env.tunable_params[child_tunable] == new_cat_value 

114 assert composite_env.tunable_params[child_tunable] == child_tunable.category 

115 elif child_tunable.is_numerical: 

116 old_num_value = child_tunable.numerical_value 

117 assert child_tunable.value == old_num_value 

118 assert child_group[child_tunable] == old_num_value 

119 assert composite_env.tunable_params[child_tunable] == old_num_value 

120 child_tunable.numerical_value += 1 

121 assert child_env.tunable_params[child_tunable] == old_num_value + 1 

122 assert composite_env.tunable_params[child_tunable] == child_tunable.numerical_value