Coverage for mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py: 100%
43 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Unit tests for scheduling trials for some future time.
7"""
8from datetime import datetime, timedelta
10from typing import Iterator, Set
12from pytz import UTC
14from mlos_bench.environments.status import Status
15from mlos_bench.storage.base_storage import Storage
16from mlos_bench.tunables.tunable_groups import TunableGroups
19def _trial_ids(trials: Iterator[Storage.Trial]) -> Set[int]:
20 """
21 Extract trial IDs from a list of trials.
22 """
23 return set(t.trial_id for t in trials)
26def test_schedule_trial(exp_storage: Storage.Experiment,
27 tunable_groups: TunableGroups) -> None:
28 """
29 Schedule several trials for future execution and retrieve them later at certain timestamps.
30 """
31 timestamp = datetime.now(UTC)
32 timedelta_1min = timedelta(minutes=1)
33 timedelta_1hr = timedelta(hours=1)
34 config = {"location": "westus2", "num_repeats": 10}
36 # Default, schedule now:
37 trial_now1 = exp_storage.new_trial(tunable_groups, config=config)
38 # Schedule with explicit current timestamp:
39 trial_now2 = exp_storage.new_trial(tunable_groups, timestamp, config)
40 # Schedule 1 hour in the future:
41 trial_1h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr, config)
42 # Schedule 2 hours in the future:
43 trial_2h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr * 2, config)
45 # Scheduler side: get trials ready to run at certain timestamps:
47 # Pretend 1 minute has passed, get trials scheduled to run:
48 pending_ids = _trial_ids(
49 exp_storage.pending_trials(timestamp + timedelta_1min, running=False))
50 assert pending_ids == {
51 trial_now1.trial_id,
52 trial_now2.trial_id,
53 }
55 # Get trials scheduled to run within the next 1 hour:
56 pending_ids = _trial_ids(
57 exp_storage.pending_trials(timestamp + timedelta_1hr, running=False))
58 assert pending_ids == {
59 trial_now1.trial_id,
60 trial_now2.trial_id,
61 trial_1h.trial_id,
62 }
64 # Get trials scheduled to run within the next 3 hours:
65 pending_ids = _trial_ids(
66 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False))
67 assert pending_ids == {
68 trial_now1.trial_id,
69 trial_now2.trial_id,
70 trial_1h.trial_id,
71 trial_2h.trial_id,
72 }
74 # Optimizer side: get trials completed after some known trial:
76 # No completed trials yet:
77 assert exp_storage.load() == ([], [], [], [])
79 # Update the status of some trials:
80 trial_now1.update(Status.RUNNING, timestamp + timedelta_1min)
81 trial_now2.update(Status.RUNNING, timestamp + timedelta_1min)
83 # Still no completed trials:
84 assert exp_storage.load() == ([], [], [], [])
86 # Get trials scheduled to run within the next 3 hours:
87 pending_ids = _trial_ids(
88 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False))
89 assert pending_ids == {
90 trial_1h.trial_id,
91 trial_2h.trial_id,
92 }
94 # Get trials scheduled to run OR running within the next 3 hours:
95 pending_ids = _trial_ids(
96 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True))
97 assert pending_ids == {
98 trial_now1.trial_id,
99 trial_now2.trial_id,
100 trial_1h.trial_id,
101 trial_2h.trial_id,
102 }
104 # Mark some trials completed after 2 minutes:
105 trial_now1.update(Status.SUCCEEDED, timestamp + timedelta_1min * 2, metrics={"score": 1.0})
106 trial_now2.update(Status.FAILED, timestamp + timedelta_1min * 2)
108 # Another one completes after 2 hours:
109 trial_1h.update(Status.SUCCEEDED, timestamp + timedelta_1hr * 2, metrics={"score": 1.0})
111 # Check that three trials have completed so far:
112 (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load()
113 assert trial_ids == [trial_now1.trial_id, trial_now2.trial_id, trial_1h.trial_id]
114 assert len(trial_configs) == len(trial_scores) == 3
115 assert trial_status == [Status.SUCCEEDED, Status.FAILED, Status.SUCCEEDED]
117 # Get only trials completed after trial_now2:
118 (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load(last_trial_id=trial_now2.trial_id)
119 assert trial_ids == [trial_1h.trial_id]
120 assert len(trial_configs) == len(trial_scores) == 1
121 assert trial_status == [Status.SUCCEEDED]