Coverage for mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py: 100%

47 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-05 00:36 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Unit tests for loading the TunableConfigTrialGroupData. 

7""" 

8 

9from mlos_bench.tunables.tunable_groups import TunableGroups 

10from mlos_bench.storage.base_experiment_data import ExperimentData 

11 

12from mlos_bench.tests.storage import CONFIG_TRIAL_REPEAT_COUNT 

13 

14 

15def test_tunable_config_trial_group_data(exp_data: ExperimentData) -> None: 

16 """Test basic TunableConfigTrialGroupData properties.""" 

17 trial_id = 1 

18 trial = exp_data.trials[trial_id] 

19 tunable_config_trial_group = trial.tunable_config_trial_group 

20 assert tunable_config_trial_group.experiment_id == exp_data.experiment_id == trial.experiment_id 

21 assert tunable_config_trial_group.tunable_config_id == trial.tunable_config_id 

22 assert tunable_config_trial_group.tunable_config == trial.tunable_config 

23 assert tunable_config_trial_group == next(iter(tunable_config_trial_group.trials.values())).tunable_config_trial_group 

24 

25 

26def test_exp_trial_data_tunable_config_trial_group_id(exp_data: ExperimentData) -> None: 

27 """ 

28 Test the TunableConfigTrialGroupData property of TrialData. 

29 

30 See Also: 

31 - test_exp_data_tunable_config_trial_group_id_in_results_df() 

32 - test_exp_data_tunable_config_trial_groups() 

33 

34 This tests individual fetching. 

35 """ 

36 # First three trials should use the same config. 

37 trial_1 = exp_data.trials[1] 

38 assert trial_1.tunable_config_id == 1 

39 assert trial_1.tunable_config_trial_group.tunable_config_trial_group_id == 1 

40 

41 trial_2 = exp_data.trials[2] 

42 assert trial_2.tunable_config_id == 1 

43 assert trial_2.tunable_config_trial_group.tunable_config_trial_group_id == 1 

44 

45 # The fourth, should be a new config. 

46 trial_4 = exp_data.trials[4] 

47 assert trial_4.tunable_config_id == 2 

48 assert trial_4.tunable_config_trial_group.tunable_config_trial_group_id == 4 

49 

50 # And so on ... 

51 

52 

53def test_tunable_config_trial_group_results_df(exp_data: ExperimentData, tunable_groups: TunableGroups) -> None: 

54 """Tests the results_df property of the TunableConfigTrialGroup.""" 

55 tunable_config_id = 2 

56 expected_group_id = 4 

57 tunable_config_trial_group = exp_data.tunable_config_trial_groups[tunable_config_id] 

58 results_df = tunable_config_trial_group.results_df 

59 # We shouldn't have the results for the other configs, just this one. 

60 expected_count = CONFIG_TRIAL_REPEAT_COUNT 

61 assert len(results_df) == expected_count 

62 assert len(results_df[(results_df["tunable_config_id"] == tunable_config_id)]) == expected_count 

63 assert len(results_df[(results_df["tunable_config_id"] != tunable_config_id)]) == 0 

64 assert len(results_df[(results_df["tunable_config_trial_group_id"] == expected_group_id)]) == expected_count 

65 assert len(results_df[(results_df["tunable_config_trial_group_id"] != expected_group_id)]) == 0 

66 assert len(results_df["trial_id"].unique()) == expected_count 

67 obj_target = next(iter(exp_data.objectives)) 

68 assert len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_count 

69 (tunable, _covariant_group) = next(iter(tunable_groups)) 

70 assert len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) == expected_count 

71 

72 

73def test_tunable_config_trial_group_trials(exp_data: ExperimentData) -> None: 

74 """Tests the trials property of the TunableConfigTrialGroup.""" 

75 tunable_config_id = 2 

76 expected_group_id = 4 

77 tunable_config_trial_group = exp_data.tunable_config_trial_groups[tunable_config_id] 

78 trials = tunable_config_trial_group.trials 

79 assert len(trials) == CONFIG_TRIAL_REPEAT_COUNT 

80 assert all(trial.tunable_config_trial_group.tunable_config_trial_group_id == expected_group_id 

81 for trial in trials.values()) 

82 assert all(trial.tunable_config_id == tunable_config_id 

83 for trial in tunable_config_trial_group.trials.values()) 

84 assert exp_data.trials[expected_group_id] == tunable_config_trial_group.trials[expected_group_id]