Coverage for mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py: 100%
71 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Unit tests for scheduling trials for some future time."""
6from collections.abc import Iterator
7from datetime import datetime, timedelta
9from pytz import UTC
11from mlos_bench.environments.status import Status
12from mlos_bench.storage.base_experiment_data import ExperimentData
13from mlos_bench.storage.base_storage import Storage
14from mlos_bench.tests.storage import (
15 CONFIG_COUNT,
16 CONFIG_TRIAL_REPEAT_COUNT,
17 TRIAL_RUNNER_COUNT,
18)
19from mlos_bench.tunables.tunable_groups import TunableGroups
22def _trial_ids(trials: Iterator[Storage.Trial]) -> set[int]:
23 """Extract trial IDs from a list of trials."""
24 return {t.trial_id for t in trials}
27def test_schedule_trial(
28 storage: Storage,
29 exp_storage: Storage.Experiment,
30 tunable_groups: TunableGroups,
31) -> None:
32 # pylint: disable=too-many-locals,too-many-statements
33 """Schedule several trials for future execution and retrieve them later at certain
34 timestamps.
35 """
36 timestamp = datetime.now(UTC)
37 timedelta_1min = timedelta(minutes=1)
38 timedelta_1hr = timedelta(hours=1)
39 config = {"location": "westus2", "num_repeats": 10}
41 # Default, schedule now:
42 trial_now1 = exp_storage.new_trial(tunable_groups, config=config)
43 # Schedule with explicit current timestamp:
44 trial_now2 = exp_storage.new_trial(tunable_groups, timestamp, config)
45 # Schedule 1 hour in the future:
46 trial_1h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr, config)
47 # Schedule 2 hours in the future:
48 trial_2h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr * 2, config)
50 # Check that if we assign a TrialRunner that that value is still available on restore.
51 trial_now2.set_trial_runner(1)
52 assert trial_now2.trial_runner_id
54 exp_data = storage.experiments[exp_storage.experiment_id]
55 trial_now1_data = exp_data.trials[trial_now1.trial_id]
56 assert trial_now1_data.trial_runner_id is None
57 assert trial_now1_data.status == Status.PENDING
58 # Check that Status matches in object vs. backend storage.
59 assert trial_now1.status == trial_now1_data.status
61 trial_now2_data = exp_data.trials[trial_now2.trial_id]
62 assert trial_now2_data.trial_runner_id == trial_now2.trial_runner_id
64 # --- Test the trial_runner_assigned parameter ---
65 # At this point:
66 # - trial_now1: no trial_runner assigned
67 # - trial_now2: trial_runner assigned
68 # - trial_1h, trial_2h: no trial_runner assigned
70 # All pending trials (should include all 4)
71 all_pending = _trial_ids(
72 exp_storage.pending_trials(
73 timestamp + timedelta_1hr * 3,
74 running=False,
75 trial_runner_assigned=None,
76 )
77 )
78 assert all_pending == {
79 trial_now1.trial_id,
80 trial_now2.trial_id,
81 trial_1h.trial_id,
82 trial_2h.trial_id,
83 }, f"Expected all pending trials, got {all_pending}"
85 # Only those with a trial_runner assigned
86 assigned_pending = _trial_ids(
87 exp_storage.pending_trials(
88 timestamp + timedelta_1hr * 3,
89 running=False,
90 trial_runner_assigned=True,
91 )
92 )
93 assert assigned_pending == {
94 trial_now2.trial_id
95 }, f"Expected only trials with a runner assigned, got {assigned_pending}"
97 # Only those without a trial_runner assigned
98 unassigned_pending = _trial_ids(
99 exp_storage.pending_trials(
100 timestamp + timedelta_1hr * 3,
101 running=False,
102 trial_runner_assigned=False,
103 )
104 )
105 assert unassigned_pending == {
106 trial_now1.trial_id,
107 trial_1h.trial_id,
108 trial_2h.trial_id,
109 }, f"Expected only trials without a runner assigned, got {unassigned_pending}"
111 # Scheduler side: get trials ready to run at certain timestamps:
113 # Pretend 1 minute has passed, get trials scheduled to run:
114 pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1min, running=False))
115 assert pending_ids == {
116 trial_now1.trial_id,
117 trial_now2.trial_id,
118 }
120 # Make sure that the pending trials and trial_runner_ids match.
121 pending_trial_runner_ids = {
122 pending_trial.trial_id: pending_trial.trial_runner_id
123 for pending_trial in exp_storage.pending_trials(timestamp + timedelta_1min, running=False)
124 }
125 assert pending_trial_runner_ids == {
126 trial_now1.trial_id: trial_now1.trial_runner_id,
127 trial_now2.trial_id: trial_now2.trial_runner_id,
128 }
130 # Get trials scheduled to run within the next 1 hour:
131 pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1hr, running=False))
132 assert pending_ids == {
133 trial_now1.trial_id,
134 trial_now2.trial_id,
135 trial_1h.trial_id,
136 }
138 # Get trials scheduled to run within the next 3 hours:
139 pending_ids = _trial_ids(
140 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)
141 )
142 assert pending_ids == {
143 trial_now1.trial_id,
144 trial_now2.trial_id,
145 trial_1h.trial_id,
146 trial_2h.trial_id,
147 }
149 # Optimizer side: get trials completed after some known trial:
151 # No completed trials yet:
152 assert exp_storage.load() == ([], [], [], [])
154 # Update the status of some trials:
155 trial_now1.update(Status.RUNNING, timestamp + timedelta_1min)
156 trial_now2.update(Status.RUNNING, timestamp + timedelta_1min)
158 # Still no completed trials:
159 assert exp_storage.load() == ([], [], [], [])
161 # Get trials scheduled to run within the next 3 hours:
162 pending_ids = _trial_ids(
163 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)
164 )
165 assert pending_ids == {
166 trial_1h.trial_id,
167 trial_2h.trial_id,
168 }
170 # Get trials scheduled to run OR running within the next 3 hours:
171 pending_ids = _trial_ids(
172 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True)
173 )
174 assert pending_ids == {
175 trial_now1.trial_id,
176 trial_now2.trial_id,
177 trial_1h.trial_id,
178 trial_2h.trial_id,
179 }
181 # Mark some trials completed after 2 minutes:
182 trial_now1.update(Status.SUCCEEDED, timestamp + timedelta_1min * 2, metrics={"score": 1.0})
183 trial_now2.update(Status.FAILED, timestamp + timedelta_1min * 2)
185 # Another one completes after 2 hours:
186 trial_1h.update(Status.SUCCEEDED, timestamp + timedelta_1hr * 2, metrics={"score": 1.0})
188 # Check that three trials have completed so far:
189 (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load()
190 assert trial_ids == [trial_now1.trial_id, trial_now2.trial_id, trial_1h.trial_id]
191 assert len(trial_configs) == len(trial_scores) == 3
192 assert trial_status == [Status.SUCCEEDED, Status.FAILED, Status.SUCCEEDED]
194 # Get only trials completed after trial_now2:
195 (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load(
196 last_trial_id=trial_now2.trial_id
197 )
198 assert trial_ids == [trial_1h.trial_id]
199 assert len(trial_configs) == len(trial_scores) == 1
200 assert trial_status == [Status.SUCCEEDED]
203def test_rr_scheduling(exp_data: ExperimentData) -> None:
204 """Checks that the scheduler produced basic round-robin scheduling of Trials across
205 TrialRunners.
206 """
207 for trial_id in range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT + 1):
208 # User visible IDs start from 1.
209 expected_config_id = (trial_id - 1) // CONFIG_TRIAL_REPEAT_COUNT + 1
210 expected_repeat_num = (trial_id - 1) % CONFIG_TRIAL_REPEAT_COUNT + 1
211 expected_runner_id = (trial_id - 1) % TRIAL_RUNNER_COUNT + 1
212 trial = exp_data.trials[trial_id]
213 assert trial.trial_id == trial_id, f"Expected trial_id {trial_id} for {trial}"
214 assert (
215 trial.tunable_config_id == expected_config_id
216 ), f"Expected tunable_config_id {expected_config_id} for {trial}"
217 assert (
218 trial.metadata_dict["repeat_i"] == expected_repeat_num
219 ), f"Expected repeat_i {expected_repeat_num} for {trial}"
220 assert (
221 trial.trial_runner_id == expected_runner_id
222 ), f"Expected trial_runner_id {expected_runner_id} for {trial}"