Coverage for mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py: 100%
47 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-05 00:36 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-05 00:36 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Unit tests for loading the TunableConfigTrialGroupData.
7"""
9from mlos_bench.tunables.tunable_groups import TunableGroups
10from mlos_bench.storage.base_experiment_data import ExperimentData
12from mlos_bench.tests.storage import CONFIG_TRIAL_REPEAT_COUNT
15def test_tunable_config_trial_group_data(exp_data: ExperimentData) -> None:
16 """Test basic TunableConfigTrialGroupData properties."""
17 trial_id = 1
18 trial = exp_data.trials[trial_id]
19 tunable_config_trial_group = trial.tunable_config_trial_group
20 assert tunable_config_trial_group.experiment_id == exp_data.experiment_id == trial.experiment_id
21 assert tunable_config_trial_group.tunable_config_id == trial.tunable_config_id
22 assert tunable_config_trial_group.tunable_config == trial.tunable_config
23 assert tunable_config_trial_group == next(iter(tunable_config_trial_group.trials.values())).tunable_config_trial_group
26def test_exp_trial_data_tunable_config_trial_group_id(exp_data: ExperimentData) -> None:
27 """
28 Test the TunableConfigTrialGroupData property of TrialData.
30 See Also:
31 - test_exp_data_tunable_config_trial_group_id_in_results_df()
32 - test_exp_data_tunable_config_trial_groups()
34 This tests individual fetching.
35 """
36 # First three trials should use the same config.
37 trial_1 = exp_data.trials[1]
38 assert trial_1.tunable_config_id == 1
39 assert trial_1.tunable_config_trial_group.tunable_config_trial_group_id == 1
41 trial_2 = exp_data.trials[2]
42 assert trial_2.tunable_config_id == 1
43 assert trial_2.tunable_config_trial_group.tunable_config_trial_group_id == 1
45 # The fourth, should be a new config.
46 trial_4 = exp_data.trials[4]
47 assert trial_4.tunable_config_id == 2
48 assert trial_4.tunable_config_trial_group.tunable_config_trial_group_id == 4
50 # And so on ...
53def test_tunable_config_trial_group_results_df(exp_data: ExperimentData, tunable_groups: TunableGroups) -> None:
54 """Tests the results_df property of the TunableConfigTrialGroup."""
55 tunable_config_id = 2
56 expected_group_id = 4
57 tunable_config_trial_group = exp_data.tunable_config_trial_groups[tunable_config_id]
58 results_df = tunable_config_trial_group.results_df
59 # We shouldn't have the results for the other configs, just this one.
60 expected_count = CONFIG_TRIAL_REPEAT_COUNT
61 assert len(results_df) == expected_count
62 assert len(results_df[(results_df["tunable_config_id"] == tunable_config_id)]) == expected_count
63 assert len(results_df[(results_df["tunable_config_id"] != tunable_config_id)]) == 0
64 assert len(results_df[(results_df["tunable_config_trial_group_id"] == expected_group_id)]) == expected_count
65 assert len(results_df[(results_df["tunable_config_trial_group_id"] != expected_group_id)]) == 0
66 assert len(results_df["trial_id"].unique()) == expected_count
67 obj_target = next(iter(exp_data.objectives))
68 assert len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_count
69 (tunable, _covariant_group) = next(iter(tunable_groups))
70 assert len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) == expected_count
73def test_tunable_config_trial_group_trials(exp_data: ExperimentData) -> None:
74 """Tests the trials property of the TunableConfigTrialGroup."""
75 tunable_config_id = 2
76 expected_group_id = 4
77 tunable_config_trial_group = exp_data.tunable_config_trial_groups[tunable_config_id]
78 trials = tunable_config_trial_group.trials
79 assert len(trials) == CONFIG_TRIAL_REPEAT_COUNT
80 assert all(trial.tunable_config_trial_group.tunable_config_trial_group_id == expected_group_id
81 for trial in trials.values())
82 assert all(trial.tunable_config_id == tunable_config_id
83 for trial in tunable_config_trial_group.trials.values())
84 assert exp_data.trials[expected_group_id] == tunable_config_trial_group.trials[expected_group_id]