Coverage for mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py: 100%

71 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-14 00:55 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for scheduling trials for some future time.""" 

6from collections.abc import Iterator 

7from datetime import datetime, timedelta 

8 

9from pytz import UTC 

10 

11from mlos_bench.environments.status import Status 

12from mlos_bench.storage.base_experiment_data import ExperimentData 

13from mlos_bench.storage.base_storage import Storage 

14from mlos_bench.tests.storage import ( 

15 CONFIG_COUNT, 

16 CONFIG_TRIAL_REPEAT_COUNT, 

17 TRIAL_RUNNER_COUNT, 

18) 

19from mlos_bench.tunables.tunable_groups import TunableGroups 

20 

21 

22def _trial_ids(trials: Iterator[Storage.Trial]) -> set[int]: 

23 """Extract trial IDs from a list of trials.""" 

24 return {t.trial_id for t in trials} 

25 

26 

27def test_schedule_trial( 

28 storage: Storage, 

29 exp_storage: Storage.Experiment, 

30 tunable_groups: TunableGroups, 

31) -> None: 

32 # pylint: disable=too-many-locals,too-many-statements 

33 """Schedule several trials for future execution and retrieve them later at certain 

34 timestamps. 

35 """ 

36 timestamp = datetime.now(UTC) 

37 timedelta_1min = timedelta(minutes=1) 

38 timedelta_1hr = timedelta(hours=1) 

39 config = {"location": "westus2", "num_repeats": 10} 

40 

41 # Default, schedule now: 

42 trial_now1 = exp_storage.new_trial(tunable_groups, config=config) 

43 # Schedule with explicit current timestamp: 

44 trial_now2 = exp_storage.new_trial(tunable_groups, timestamp, config) 

45 # Schedule 1 hour in the future: 

46 trial_1h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr, config) 

47 # Schedule 2 hours in the future: 

48 trial_2h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr * 2, config) 

49 

50 # Check that if we assign a TrialRunner that that value is still available on restore. 

51 trial_now2.set_trial_runner(1) 

52 assert trial_now2.trial_runner_id 

53 

54 exp_data = storage.experiments[exp_storage.experiment_id] 

55 trial_now1_data = exp_data.trials[trial_now1.trial_id] 

56 assert trial_now1_data.trial_runner_id is None 

57 assert trial_now1_data.status == Status.PENDING 

58 # Check that Status matches in object vs. backend storage. 

59 assert trial_now1.status == trial_now1_data.status 

60 

61 trial_now2_data = exp_data.trials[trial_now2.trial_id] 

62 assert trial_now2_data.trial_runner_id == trial_now2.trial_runner_id 

63 

64 # --- Test the trial_runner_assigned parameter --- 

65 # At this point: 

66 # - trial_now1: no trial_runner assigned 

67 # - trial_now2: trial_runner assigned 

68 # - trial_1h, trial_2h: no trial_runner assigned 

69 

70 # All pending trials (should include all 4) 

71 all_pending = _trial_ids( 

72 exp_storage.pending_trials( 

73 timestamp + timedelta_1hr * 3, 

74 running=False, 

75 trial_runner_assigned=None, 

76 ) 

77 ) 

78 assert all_pending == { 

79 trial_now1.trial_id, 

80 trial_now2.trial_id, 

81 trial_1h.trial_id, 

82 trial_2h.trial_id, 

83 }, f"Expected all pending trials, got {all_pending}" 

84 

85 # Only those with a trial_runner assigned 

86 assigned_pending = _trial_ids( 

87 exp_storage.pending_trials( 

88 timestamp + timedelta_1hr * 3, 

89 running=False, 

90 trial_runner_assigned=True, 

91 ) 

92 ) 

93 assert assigned_pending == { 

94 trial_now2.trial_id 

95 }, f"Expected only trials with a runner assigned, got {assigned_pending}" 

96 

97 # Only those without a trial_runner assigned 

98 unassigned_pending = _trial_ids( 

99 exp_storage.pending_trials( 

100 timestamp + timedelta_1hr * 3, 

101 running=False, 

102 trial_runner_assigned=False, 

103 ) 

104 ) 

105 assert unassigned_pending == { 

106 trial_now1.trial_id, 

107 trial_1h.trial_id, 

108 trial_2h.trial_id, 

109 }, f"Expected only trials without a runner assigned, got {unassigned_pending}" 

110 

111 # Scheduler side: get trials ready to run at certain timestamps: 

112 

113 # Pretend 1 minute has passed, get trials scheduled to run: 

114 pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1min, running=False)) 

115 assert pending_ids == { 

116 trial_now1.trial_id, 

117 trial_now2.trial_id, 

118 } 

119 

120 # Make sure that the pending trials and trial_runner_ids match. 

121 pending_trial_runner_ids = { 

122 pending_trial.trial_id: pending_trial.trial_runner_id 

123 for pending_trial in exp_storage.pending_trials(timestamp + timedelta_1min, running=False) 

124 } 

125 assert pending_trial_runner_ids == { 

126 trial_now1.trial_id: trial_now1.trial_runner_id, 

127 trial_now2.trial_id: trial_now2.trial_runner_id, 

128 } 

129 

130 # Get trials scheduled to run within the next 1 hour: 

131 pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1hr, running=False)) 

132 assert pending_ids == { 

133 trial_now1.trial_id, 

134 trial_now2.trial_id, 

135 trial_1h.trial_id, 

136 } 

137 

138 # Get trials scheduled to run within the next 3 hours: 

139 pending_ids = _trial_ids( 

140 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False) 

141 ) 

142 assert pending_ids == { 

143 trial_now1.trial_id, 

144 trial_now2.trial_id, 

145 trial_1h.trial_id, 

146 trial_2h.trial_id, 

147 } 

148 

149 # Optimizer side: get trials completed after some known trial: 

150 

151 # No completed trials yet: 

152 assert exp_storage.load() == ([], [], [], []) 

153 

154 # Update the status of some trials: 

155 trial_now1.update(Status.RUNNING, timestamp + timedelta_1min) 

156 trial_now2.update(Status.RUNNING, timestamp + timedelta_1min) 

157 

158 # Still no completed trials: 

159 assert exp_storage.load() == ([], [], [], []) 

160 

161 # Get trials scheduled to run within the next 3 hours: 

162 pending_ids = _trial_ids( 

163 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False) 

164 ) 

165 assert pending_ids == { 

166 trial_1h.trial_id, 

167 trial_2h.trial_id, 

168 } 

169 

170 # Get trials scheduled to run OR running within the next 3 hours: 

171 pending_ids = _trial_ids( 

172 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True) 

173 ) 

174 assert pending_ids == { 

175 trial_now1.trial_id, 

176 trial_now2.trial_id, 

177 trial_1h.trial_id, 

178 trial_2h.trial_id, 

179 } 

180 

181 # Mark some trials completed after 2 minutes: 

182 trial_now1.update(Status.SUCCEEDED, timestamp + timedelta_1min * 2, metrics={"score": 1.0}) 

183 trial_now2.update(Status.FAILED, timestamp + timedelta_1min * 2) 

184 

185 # Another one completes after 2 hours: 

186 trial_1h.update(Status.SUCCEEDED, timestamp + timedelta_1hr * 2, metrics={"score": 1.0}) 

187 

188 # Check that three trials have completed so far: 

189 (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load() 

190 assert trial_ids == [trial_now1.trial_id, trial_now2.trial_id, trial_1h.trial_id] 

191 assert len(trial_configs) == len(trial_scores) == 3 

192 assert trial_status == [Status.SUCCEEDED, Status.FAILED, Status.SUCCEEDED] 

193 

194 # Get only trials completed after trial_now2: 

195 (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load( 

196 last_trial_id=trial_now2.trial_id 

197 ) 

198 assert trial_ids == [trial_1h.trial_id] 

199 assert len(trial_configs) == len(trial_scores) == 1 

200 assert trial_status == [Status.SUCCEEDED] 

201 

202 

203def test_rr_scheduling(exp_data: ExperimentData) -> None: 

204 """Checks that the scheduler produced basic round-robin scheduling of Trials across 

205 TrialRunners. 

206 """ 

207 for trial_id in range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT + 1): 

208 # User visible IDs start from 1. 

209 expected_config_id = (trial_id - 1) // CONFIG_TRIAL_REPEAT_COUNT + 1 

210 expected_repeat_num = (trial_id - 1) % CONFIG_TRIAL_REPEAT_COUNT + 1 

211 expected_runner_id = (trial_id - 1) % TRIAL_RUNNER_COUNT + 1 

212 trial = exp_data.trials[trial_id] 

213 assert trial.trial_id == trial_id, f"Expected trial_id {trial_id} for {trial}" 

214 assert ( 

215 trial.tunable_config_id == expected_config_id 

216 ), f"Expected tunable_config_id {expected_config_id} for {trial}" 

217 assert ( 

218 trial.metadata_dict["repeat_i"] == expected_repeat_num 

219 ), f"Expected repeat_i {expected_repeat_num} for {trial}" 

220 assert ( 

221 trial.trial_runner_id == expected_runner_id 

222 ), f"Expected trial_runner_id {expected_runner_id} for {trial}"