Coverage for mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py: 100%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for scheduling trials for some future time.""" 

6from datetime import datetime, timedelta 

7from typing import Iterator, Set 

8 

9from pytz import UTC 

10 

11from mlos_bench.environments.status import Status 

12from mlos_bench.storage.base_storage import Storage 

13from mlos_bench.tunables.tunable_groups import TunableGroups 

14 

15 

16def _trial_ids(trials: Iterator[Storage.Trial]) -> Set[int]: 

17 """Extract trial IDs from a list of trials.""" 

18 return set(t.trial_id for t in trials) 

19 

20 

21def test_schedule_trial(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None: 

22 """Schedule several trials for future execution and retrieve them later at certain 

23 timestamps. 

24 """ 

25 timestamp = datetime.now(UTC) 

26 timedelta_1min = timedelta(minutes=1) 

27 timedelta_1hr = timedelta(hours=1) 

28 config = {"location": "westus2", "num_repeats": 10} 

29 

30 # Default, schedule now: 

31 trial_now1 = exp_storage.new_trial(tunable_groups, config=config) 

32 # Schedule with explicit current timestamp: 

33 trial_now2 = exp_storage.new_trial(tunable_groups, timestamp, config) 

34 # Schedule 1 hour in the future: 

35 trial_1h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr, config) 

36 # Schedule 2 hours in the future: 

37 trial_2h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr * 2, config) 

38 

39 # Scheduler side: get trials ready to run at certain timestamps: 

40 

41 # Pretend 1 minute has passed, get trials scheduled to run: 

42 pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1min, running=False)) 

43 assert pending_ids == { 

44 trial_now1.trial_id, 

45 trial_now2.trial_id, 

46 } 

47 

48 # Get trials scheduled to run within the next 1 hour: 

49 pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1hr, running=False)) 

50 assert pending_ids == { 

51 trial_now1.trial_id, 

52 trial_now2.trial_id, 

53 trial_1h.trial_id, 

54 } 

55 

56 # Get trials scheduled to run within the next 3 hours: 

57 pending_ids = _trial_ids( 

58 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False) 

59 ) 

60 assert pending_ids == { 

61 trial_now1.trial_id, 

62 trial_now2.trial_id, 

63 trial_1h.trial_id, 

64 trial_2h.trial_id, 

65 } 

66 

67 # Optimizer side: get trials completed after some known trial: 

68 

69 # No completed trials yet: 

70 assert exp_storage.load() == ([], [], [], []) 

71 

72 # Update the status of some trials: 

73 trial_now1.update(Status.RUNNING, timestamp + timedelta_1min) 

74 trial_now2.update(Status.RUNNING, timestamp + timedelta_1min) 

75 

76 # Still no completed trials: 

77 assert exp_storage.load() == ([], [], [], []) 

78 

79 # Get trials scheduled to run within the next 3 hours: 

80 pending_ids = _trial_ids( 

81 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False) 

82 ) 

83 assert pending_ids == { 

84 trial_1h.trial_id, 

85 trial_2h.trial_id, 

86 } 

87 

88 # Get trials scheduled to run OR running within the next 3 hours: 

89 pending_ids = _trial_ids( 

90 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True) 

91 ) 

92 assert pending_ids == { 

93 trial_now1.trial_id, 

94 trial_now2.trial_id, 

95 trial_1h.trial_id, 

96 trial_2h.trial_id, 

97 } 

98 

99 # Mark some trials completed after 2 minutes: 

100 trial_now1.update(Status.SUCCEEDED, timestamp + timedelta_1min * 2, metrics={"score": 1.0}) 

101 trial_now2.update(Status.FAILED, timestamp + timedelta_1min * 2) 

102 

103 # Another one completes after 2 hours: 

104 trial_1h.update(Status.SUCCEEDED, timestamp + timedelta_1hr * 2, metrics={"score": 1.0}) 

105 

106 # Check that three trials have completed so far: 

107 (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load() 

108 assert trial_ids == [trial_now1.trial_id, trial_now2.trial_id, trial_1h.trial_id] 

109 assert len(trial_configs) == len(trial_scores) == 3 

110 assert trial_status == [Status.SUCCEEDED, Status.FAILED, Status.SUCCEEDED] 

111 

112 # Get only trials completed after trial_now2: 

113 (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load( 

114 last_trial_id=trial_now2.trial_id 

115 ) 

116 assert trial_ids == [trial_1h.trial_id] 

117 assert len(trial_configs) == len(trial_scores) == 1 

118 assert trial_status == [Status.SUCCEEDED]