Coverage for mlos_bench/mlos_bench/tests/storage/exp_data_test.py: 100%
61 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Unit tests for loading the experiment metadata."""
7from mlos_bench.storage.base_experiment_data import ExperimentData
8from mlos_bench.storage.base_storage import Storage
9from mlos_bench.tests.storage import CONFIG_COUNT, CONFIG_TRIAL_REPEAT_COUNT
10from mlos_bench.tunables.tunable_groups import TunableGroups
13def test_load_empty_exp_data(storage: Storage, exp_storage: Storage.Experiment) -> None:
14 """Try to retrieve old experimental data from the empty storage."""
15 exp = storage.experiments[exp_storage.experiment_id]
16 assert exp.experiment_id == exp_storage.experiment_id
17 assert exp.description == exp_storage.description
18 assert exp.objectives == exp_storage.opt_targets
21def test_exp_data_root_env_config(
22 exp_storage: Storage.Experiment,
23 exp_data: ExperimentData,
24) -> None:
25 """Tests the root_env_config property of ExperimentData."""
26 # pylint: disable=protected-access
27 assert exp_data.root_env_config == (
28 exp_storage._root_env_config,
29 exp_storage._git_repo,
30 exp_storage._git_commit,
31 )
34def test_exp_trial_data_objectives(
35 storage: Storage,
36 exp_storage: Storage.Experiment,
37 tunable_groups: TunableGroups,
38) -> None:
39 """Start a new trial and check the storage for the trial data."""
41 trial_opt_new = exp_storage.new_trial(
42 tunable_groups,
43 config={
44 "opt_target": "some-other-target",
45 "opt_direction": "max",
46 },
47 )
48 assert trial_opt_new.config() == {
49 "experiment_id": exp_storage.experiment_id,
50 "trial_id": trial_opt_new.trial_id,
51 "opt_target": "some-other-target",
52 "opt_direction": "max",
53 }
55 trial_opt_old = exp_storage.new_trial(
56 tunable_groups,
57 config={
58 "opt_target": "back-compat",
59 # "opt_direction": "max", # missing
60 },
61 )
62 assert trial_opt_old.config() == {
63 "experiment_id": exp_storage.experiment_id,
64 "trial_id": trial_opt_old.trial_id,
65 "opt_target": "back-compat",
66 }
68 exp = storage.experiments[exp_storage.experiment_id]
69 assert exp.objectives == exp_storage.opt_targets
71 trial_data_opt_new = exp.trials[trial_opt_new.trial_id]
72 assert trial_data_opt_new.metadata_dict == {
73 "opt_target": "some-other-target",
74 "opt_direction": "max",
75 }
78def test_exp_data_results_df(exp_data: ExperimentData, tunable_groups: TunableGroups) -> None:
79 """Tests the results_df property of ExperimentData."""
80 results_df = exp_data.results_df
81 expected_trials_count = CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT
82 assert len(results_df) == expected_trials_count
83 assert len(results_df["tunable_config_id"].unique()) == CONFIG_COUNT
84 assert len(results_df["trial_id"].unique()) == expected_trials_count
85 obj_target = next(iter(exp_data.objectives))
86 assert (
87 len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_trials_count
88 )
89 (tunable, _covariant_group) = next(iter(tunable_groups))
90 assert (
91 len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name])
92 == expected_trials_count
93 )
96def test_exp_no_tunables_data_results_df(exp_no_tunables_data: ExperimentData) -> None:
97 """Tests the results_df property of ExperimentData when there are no tunables."""
98 results_df = exp_no_tunables_data.results_df
99 expected_trials_count = CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT
100 assert len(results_df) == expected_trials_count
101 assert len(results_df["trial_id"].unique()) == expected_trials_count
102 obj_target = next(iter(exp_no_tunables_data.objectives))
103 assert (
104 len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_trials_count
105 )
106 assert not results_df.columns.str.startswith(ExperimentData.CONFIG_COLUMN_PREFIX).any()
109def test_exp_data_tunable_config_trial_group_id_in_results_df(exp_data: ExperimentData) -> None:
110 """
111 Tests the tunable_config_trial_group_id property of ExperimentData.results_df.
113 See Also: test_exp_trial_data_tunable_config_trial_group_id()
114 """
115 results_df = exp_data.results_df
117 # First three trials should use the same config.
118 trial_1_df = results_df.loc[(results_df["trial_id"] == 1)]
119 assert len(trial_1_df) == 1
120 assert trial_1_df["tunable_config_id"].iloc[0] == 1
121 assert trial_1_df["tunable_config_trial_group_id"].iloc[0] == 1
123 trial_2_df = results_df.loc[(results_df["trial_id"] == 2)]
124 assert len(trial_2_df) == 1
125 assert trial_2_df["tunable_config_id"].iloc[0] == 1
126 assert trial_2_df["tunable_config_trial_group_id"].iloc[0] == 1
128 # The fourth, should be a new config.
129 trial_4_df = results_df.loc[(results_df["trial_id"] == 4)]
130 assert len(trial_4_df) == 1
131 assert trial_4_df["tunable_config_id"].iloc[0] == 2
132 assert trial_4_df["tunable_config_trial_group_id"].iloc[0] == 4
134 # And so on ...
137def test_exp_data_tunable_config_trial_groups(exp_data: ExperimentData) -> None:
138 """
139 Tests the tunable_config_trial_groups property of ExperimentData.
141 This tests bulk loading of the tunable_config_trial_groups.
142 """
143 # Should be keyed by config_id.
144 assert list(exp_data.tunable_config_trial_groups.keys()) == list(range(1, CONFIG_COUNT + 1))
145 # Which should match the objects.
146 assert [
147 config_trial_group.tunable_config_id
148 for config_trial_group in exp_data.tunable_config_trial_groups.values()
149 ] == list(range(1, CONFIG_COUNT + 1))
150 # And the tunable_config_trial_group_id should also match the minimum trial_id.
151 assert [
152 config_trial_group.tunable_config_trial_group_id
153 for config_trial_group in exp_data.tunable_config_trial_groups.values()
154 ] == list(range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT, CONFIG_TRIAL_REPEAT_COUNT))
157def test_exp_data_tunable_configs(exp_data: ExperimentData) -> None:
158 """Tests the tunable_configs property of ExperimentData."""
159 # Should be keyed by config_id.
160 assert list(exp_data.tunable_configs.keys()) == list(range(1, CONFIG_COUNT + 1))
161 # Which should match the objects.
162 assert [config.tunable_config_id for config in exp_data.tunable_configs.values()] == list(
163 range(1, CONFIG_COUNT + 1)
164 )
167def test_exp_data_default_config_id(exp_data: ExperimentData) -> None:
168 """Tests the default_tunable_config_id property of ExperimentData."""
169 assert exp_data.default_tunable_config_id == 1