Coverage for mlos_bench/mlos_bench/tests/storage/exp_data_test.py: 100%

53 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 01:18 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for loading the experiment metadata.""" 

6 

7from mlos_bench.storage.base_experiment_data import ExperimentData 

8from mlos_bench.storage.base_storage import Storage 

9from mlos_bench.tests.storage import CONFIG_COUNT, CONFIG_TRIAL_REPEAT_COUNT 

10from mlos_bench.tunables.tunable_groups import TunableGroups 

11 

12 

13def test_load_empty_exp_data(storage: Storage, exp_storage: Storage.Experiment) -> None: 

14 """Try to retrieve old experimental data from the empty storage.""" 

15 exp = storage.experiments[exp_storage.experiment_id] 

16 assert exp.experiment_id == exp_storage.experiment_id 

17 assert exp.description == exp_storage.description 

18 assert exp.objectives == exp_storage.opt_targets 

19 

20 

21def test_exp_data_root_env_config( 

22 exp_storage: Storage.Experiment, 

23 exp_data: ExperimentData, 

24) -> None: 

25 """Tests the root_env_config property of ExperimentData.""" 

26 # pylint: disable=protected-access 

27 assert exp_data.root_env_config == ( 

28 exp_storage._root_env_config, 

29 exp_storage._git_repo, 

30 exp_storage._git_commit, 

31 ) 

32 

33 

34def test_exp_trial_data_objectives( 

35 storage: Storage, 

36 exp_storage: Storage.Experiment, 

37 tunable_groups: TunableGroups, 

38) -> None: 

39 """Start a new trial and check the storage for the trial data.""" 

40 

41 trial_opt_new = exp_storage.new_trial( 

42 tunable_groups, 

43 config={ 

44 "opt_target": "some-other-target", 

45 "opt_direction": "max", 

46 }, 

47 ) 

48 assert trial_opt_new.config() == { 

49 "experiment_id": exp_storage.experiment_id, 

50 "trial_id": trial_opt_new.trial_id, 

51 "opt_target": "some-other-target", 

52 "opt_direction": "max", 

53 } 

54 

55 trial_opt_old = exp_storage.new_trial( 

56 tunable_groups, 

57 config={ 

58 "opt_target": "back-compat", 

59 # "opt_direction": "max", # missing 

60 }, 

61 ) 

62 assert trial_opt_old.config() == { 

63 "experiment_id": exp_storage.experiment_id, 

64 "trial_id": trial_opt_old.trial_id, 

65 "opt_target": "back-compat", 

66 } 

67 

68 exp = storage.experiments[exp_storage.experiment_id] 

69 assert exp.objectives == exp_storage.opt_targets 

70 

71 trial_data_opt_new = exp.trials[trial_opt_new.trial_id] 

72 assert trial_data_opt_new.metadata_dict == { 

73 "opt_target": "some-other-target", 

74 "opt_direction": "max", 

75 } 

76 

77 

78def test_exp_data_results_df(exp_data: ExperimentData, tunable_groups: TunableGroups) -> None: 

79 """Tests the results_df property of ExperimentData.""" 

80 results_df = exp_data.results_df 

81 expected_trials_count = CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT 

82 assert len(results_df) == expected_trials_count 

83 assert len(results_df["tunable_config_id"].unique()) == CONFIG_COUNT 

84 assert len(results_df["trial_id"].unique()) == expected_trials_count 

85 obj_target = next(iter(exp_data.objectives)) 

86 assert ( 

87 len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_trials_count 

88 ) 

89 (tunable, _covariant_group) = next(iter(tunable_groups)) 

90 assert ( 

91 len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) 

92 == expected_trials_count 

93 ) 

94 

95 

96def test_exp_data_tunable_config_trial_group_id_in_results_df(exp_data: ExperimentData) -> None: 

97 """ 

98 Tests the tunable_config_trial_group_id property of ExperimentData.results_df. 

99 

100 See Also: test_exp_trial_data_tunable_config_trial_group_id() 

101 """ 

102 results_df = exp_data.results_df 

103 

104 # First three trials should use the same config. 

105 trial_1_df = results_df.loc[(results_df["trial_id"] == 1)] 

106 assert len(trial_1_df) == 1 

107 assert trial_1_df["tunable_config_id"].iloc[0] == 1 

108 assert trial_1_df["tunable_config_trial_group_id"].iloc[0] == 1 

109 

110 trial_2_df = results_df.loc[(results_df["trial_id"] == 2)] 

111 assert len(trial_2_df) == 1 

112 assert trial_2_df["tunable_config_id"].iloc[0] == 1 

113 assert trial_2_df["tunable_config_trial_group_id"].iloc[0] == 1 

114 

115 # The fourth, should be a new config. 

116 trial_4_df = results_df.loc[(results_df["trial_id"] == 4)] 

117 assert len(trial_4_df) == 1 

118 assert trial_4_df["tunable_config_id"].iloc[0] == 2 

119 assert trial_4_df["tunable_config_trial_group_id"].iloc[0] == 4 

120 

121 # And so on ... 

122 

123 

124def test_exp_data_tunable_config_trial_groups(exp_data: ExperimentData) -> None: 

125 """ 

126 Tests the tunable_config_trial_groups property of ExperimentData. 

127 

128 This tests bulk loading of the tunable_config_trial_groups. 

129 """ 

130 # Should be keyed by config_id. 

131 assert list(exp_data.tunable_config_trial_groups.keys()) == list(range(1, CONFIG_COUNT + 1)) 

132 # Which should match the objects. 

133 assert [ 

134 config_trial_group.tunable_config_id 

135 for config_trial_group in exp_data.tunable_config_trial_groups.values() 

136 ] == list(range(1, CONFIG_COUNT + 1)) 

137 # And the tunable_config_trial_group_id should also match the minimum trial_id. 

138 assert [ 

139 config_trial_group.tunable_config_trial_group_id 

140 for config_trial_group in exp_data.tunable_config_trial_groups.values() 

141 ] == list(range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT, CONFIG_TRIAL_REPEAT_COUNT)) 

142 

143 

144def test_exp_data_tunable_configs(exp_data: ExperimentData) -> None: 

145 """Tests the tunable_configs property of ExperimentData.""" 

146 # Should be keyed by config_id. 

147 assert list(exp_data.tunable_configs.keys()) == list(range(1, CONFIG_COUNT + 1)) 

148 # Which should match the objects. 

149 assert [config.tunable_config_id for config in exp_data.tunable_configs.values()] == list( 

150 range(1, CONFIG_COUNT + 1) 

151 ) 

152 

153 

154def test_exp_data_default_config_id(exp_data: ExperimentData) -> None: 

155 """Tests the default_tunable_config_id property of ExperimentData.""" 

156 assert exp_data.default_tunable_config_id == 1