Coverage for mlos_bench/mlos_bench/tests/storage/exp_data_test.py: 100%

54 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Unit tests for loading the experiment metadata. 

7""" 

8 

9from mlos_bench.storage.base_storage import Storage 

10from mlos_bench.storage.base_experiment_data import ExperimentData 

11from mlos_bench.tunables.tunable_groups import TunableGroups 

12 

13from mlos_bench.tests.storage import CONFIG_COUNT, CONFIG_TRIAL_REPEAT_COUNT 

14 

15 

16def test_load_empty_exp_data(storage: Storage, exp_storage: Storage.Experiment) -> None: 

17 """ 

18 Try to retrieve old experimental data from the empty storage. 

19 """ 

20 exp = storage.experiments[exp_storage.experiment_id] 

21 assert exp.experiment_id == exp_storage.experiment_id 

22 assert exp.description == exp_storage.description 

23 # Only support single objective for now. 

24 assert exp.objectives == {exp_storage.opt_target: exp_storage.opt_direction} 

25 

26 

27def test_exp_data_root_env_config(exp_storage: Storage.Experiment, exp_data: ExperimentData) -> None: 

28 """Tests the root_env_config property of ExperimentData""" 

29 # pylint: disable=protected-access 

30 assert exp_data.root_env_config == (exp_storage._root_env_config, exp_storage._git_repo, exp_storage._git_commit) 

31 

32 

33def test_exp_trial_data_objectives(storage: Storage, 

34 exp_storage: Storage.Experiment, 

35 tunable_groups: TunableGroups) -> None: 

36 """ 

37 Start a new trial and check the storage for the trial data. 

38 """ 

39 

40 trial_opt_new = exp_storage.new_trial(tunable_groups, config={ 

41 "opt_target": "some-other-target", 

42 "opt_direction": "max", 

43 }) 

44 assert trial_opt_new.config() == { 

45 "experiment_id": exp_storage.experiment_id, 

46 "trial_id": trial_opt_new.trial_id, 

47 "opt_target": "some-other-target", 

48 "opt_direction": "max", 

49 } 

50 

51 trial_opt_old = exp_storage.new_trial(tunable_groups, config={ 

52 "opt_target": "back-compat", 

53 # "opt_direction": "max", # missing 

54 }) 

55 assert trial_opt_old.config() == { 

56 "experiment_id": exp_storage.experiment_id, 

57 "trial_id": trial_opt_old.trial_id, 

58 "opt_target": "back-compat", 

59 } 

60 

61 exp = storage.experiments[exp_storage.experiment_id] 

62 # objectives should be the combination of both the trial objectives and the experiment objectives 

63 assert exp.objectives == { 

64 "back-compat": None, 

65 "some-other-target": "max", 

66 exp_storage.opt_target: exp_storage.opt_direction, 

67 } 

68 

69 trial_data_opt_new = exp.trials[trial_opt_new.trial_id] 

70 assert trial_data_opt_new.metadata_dict == { 

71 "opt_target": "some-other-target", 

72 "opt_direction": "max", 

73 } 

74 

75 

76def test_exp_data_results_df(exp_data: ExperimentData, tunable_groups: TunableGroups) -> None: 

77 """Tests the results_df property of ExperimentData""" 

78 results_df = exp_data.results_df 

79 expected_trials_count = CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT 

80 assert len(results_df) == expected_trials_count 

81 assert len(results_df["tunable_config_id"].unique()) == CONFIG_COUNT 

82 assert len(results_df["trial_id"].unique()) == expected_trials_count 

83 obj_target = next(iter(exp_data.objectives)) 

84 assert len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_trials_count 

85 (tunable, _covariant_group) = next(iter(tunable_groups)) 

86 assert len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) == expected_trials_count 

87 

88 

89def test_exp_data_tunable_config_trial_group_id_in_results_df(exp_data: ExperimentData) -> None: 

90 """ 

91 Tests the tunable_config_trial_group_id property of ExperimentData.results_df 

92 

93 See Also: test_exp_trial_data_tunable_config_trial_group_id() 

94 """ 

95 results_df = exp_data.results_df 

96 

97 # First three trials should use the same config. 

98 trial_1_df = results_df.loc[(results_df["trial_id"] == 1)] 

99 assert len(trial_1_df) == 1 

100 assert trial_1_df["tunable_config_id"].iloc[0] == 1 

101 assert trial_1_df["tunable_config_trial_group_id"].iloc[0] == 1 

102 

103 trial_2_df = results_df.loc[(results_df["trial_id"] == 2)] 

104 assert len(trial_2_df) == 1 

105 assert trial_2_df["tunable_config_id"].iloc[0] == 1 

106 assert trial_2_df["tunable_config_trial_group_id"].iloc[0] == 1 

107 

108 # The fourth, should be a new config. 

109 trial_4_df = results_df.loc[(results_df["trial_id"] == 4)] 

110 assert len(trial_4_df) == 1 

111 assert trial_4_df["tunable_config_id"].iloc[0] == 2 

112 assert trial_4_df["tunable_config_trial_group_id"].iloc[0] == 4 

113 

114 # And so on ... 

115 

116 

117def test_exp_data_tunable_config_trial_groups(exp_data: ExperimentData) -> None: 

118 """ 

119 Tests the tunable_config_trial_groups property of ExperimentData 

120 

121 This tests bulk loading of the tunable_config_trial_groups. 

122 """ 

123 # Should be keyed by config_id. 

124 assert list(exp_data.tunable_config_trial_groups.keys()) == list(range(1, CONFIG_COUNT + 1)) 

125 # Which should match the objects. 

126 assert [config_trial_group.tunable_config_id 

127 for config_trial_group in exp_data.tunable_config_trial_groups.values() 

128 ] == list(range(1, CONFIG_COUNT + 1)) 

129 # And the tunable_config_trial_group_id should also match the minimum trial_id. 

130 assert [config_trial_group.tunable_config_trial_group_id 

131 for config_trial_group in exp_data.tunable_config_trial_groups.values() 

132 ] == list(range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT, CONFIG_TRIAL_REPEAT_COUNT)) 

133 

134 

135def test_exp_data_tunable_configs(exp_data: ExperimentData) -> None: 

136 """Tests the tunable_configs property of ExperimentData""" 

137 # Should be keyed by config_id. 

138 assert list(exp_data.tunable_configs.keys()) == list(range(1, CONFIG_COUNT + 1)) 

139 # Which should match the objects. 

140 assert [config.tunable_config_id 

141 for config in exp_data.tunable_configs.values() 

142 ] == list(range(1, CONFIG_COUNT + 1)) 

143 

144 

145def test_exp_data_default_config_id(exp_data: ExperimentData) -> None: 

146 """Tests the default_tunable_config_id property of ExperimentData""" 

147 assert exp_data.default_tunable_config_id == 1