Coverage for mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py: 100%
42 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Unit tests for scheduling trials for some future time."""
6from datetime import datetime, timedelta
7from typing import Iterator, Set
9from pytz import UTC
11from mlos_bench.environments.status import Status
12from mlos_bench.storage.base_storage import Storage
13from mlos_bench.tunables.tunable_groups import TunableGroups
16def _trial_ids(trials: Iterator[Storage.Trial]) -> Set[int]:
17 """Extract trial IDs from a list of trials."""
18 return set(t.trial_id for t in trials)
21def test_schedule_trial(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None:
22 """Schedule several trials for future execution and retrieve them later at certain
23 timestamps.
24 """
25 timestamp = datetime.now(UTC)
26 timedelta_1min = timedelta(minutes=1)
27 timedelta_1hr = timedelta(hours=1)
28 config = {"location": "westus2", "num_repeats": 10}
30 # Default, schedule now:
31 trial_now1 = exp_storage.new_trial(tunable_groups, config=config)
32 # Schedule with explicit current timestamp:
33 trial_now2 = exp_storage.new_trial(tunable_groups, timestamp, config)
34 # Schedule 1 hour in the future:
35 trial_1h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr, config)
36 # Schedule 2 hours in the future:
37 trial_2h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr * 2, config)
39 # Scheduler side: get trials ready to run at certain timestamps:
41 # Pretend 1 minute has passed, get trials scheduled to run:
42 pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1min, running=False))
43 assert pending_ids == {
44 trial_now1.trial_id,
45 trial_now2.trial_id,
46 }
48 # Get trials scheduled to run within the next 1 hour:
49 pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1hr, running=False))
50 assert pending_ids == {
51 trial_now1.trial_id,
52 trial_now2.trial_id,
53 trial_1h.trial_id,
54 }
56 # Get trials scheduled to run within the next 3 hours:
57 pending_ids = _trial_ids(
58 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)
59 )
60 assert pending_ids == {
61 trial_now1.trial_id,
62 trial_now2.trial_id,
63 trial_1h.trial_id,
64 trial_2h.trial_id,
65 }
67 # Optimizer side: get trials completed after some known trial:
69 # No completed trials yet:
70 assert exp_storage.load() == ([], [], [], [])
72 # Update the status of some trials:
73 trial_now1.update(Status.RUNNING, timestamp + timedelta_1min)
74 trial_now2.update(Status.RUNNING, timestamp + timedelta_1min)
76 # Still no completed trials:
77 assert exp_storage.load() == ([], [], [], [])
79 # Get trials scheduled to run within the next 3 hours:
80 pending_ids = _trial_ids(
81 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)
82 )
83 assert pending_ids == {
84 trial_1h.trial_id,
85 trial_2h.trial_id,
86 }
88 # Get trials scheduled to run OR running within the next 3 hours:
89 pending_ids = _trial_ids(
90 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True)
91 )
92 assert pending_ids == {
93 trial_now1.trial_id,
94 trial_now2.trial_id,
95 trial_1h.trial_id,
96 trial_2h.trial_id,
97 }
99 # Mark some trials completed after 2 minutes:
100 trial_now1.update(Status.SUCCEEDED, timestamp + timedelta_1min * 2, metrics={"score": 1.0})
101 trial_now2.update(Status.FAILED, timestamp + timedelta_1min * 2)
103 # Another one completes after 2 hours:
104 trial_1h.update(Status.SUCCEEDED, timestamp + timedelta_1hr * 2, metrics={"score": 1.0})
106 # Check that three trials have completed so far:
107 (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load()
108 assert trial_ids == [trial_now1.trial_id, trial_now2.trial_id, trial_1h.trial_id]
109 assert len(trial_configs) == len(trial_scores) == 3
110 assert trial_status == [Status.SUCCEEDED, Status.FAILED, Status.SUCCEEDED]
112 # Get only trials completed after trial_now2:
113 (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load(
114 last_trial_id=trial_now2.trial_id
115 )
116 assert trial_ids == [trial_1h.trial_id]
117 assert len(trial_configs) == len(trial_scores) == 1
118 assert trial_status == [Status.SUCCEEDED]