Coverage for mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py: 100%

43 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Unit tests for scheduling trials for some future time. 

7""" 

8from datetime import datetime, timedelta 

9 

10from typing import Iterator, Set 

11 

12from pytz import UTC 

13 

14from mlos_bench.environments.status import Status 

15from mlos_bench.storage.base_storage import Storage 

16from mlos_bench.tunables.tunable_groups import TunableGroups 

17 

18 

19def _trial_ids(trials: Iterator[Storage.Trial]) -> Set[int]: 

20 """ 

21 Extract trial IDs from a list of trials. 

22 """ 

23 return set(t.trial_id for t in trials) 

24 

25 

26def test_schedule_trial(exp_storage: Storage.Experiment, 

27 tunable_groups: TunableGroups) -> None: 

28 """ 

29 Schedule several trials for future execution and retrieve them later at certain timestamps. 

30 """ 

31 timestamp = datetime.now(UTC) 

32 timedelta_1min = timedelta(minutes=1) 

33 timedelta_1hr = timedelta(hours=1) 

34 config = {"location": "westus2", "num_repeats": 10} 

35 

36 # Default, schedule now: 

37 trial_now1 = exp_storage.new_trial(tunable_groups, config=config) 

38 # Schedule with explicit current timestamp: 

39 trial_now2 = exp_storage.new_trial(tunable_groups, timestamp, config) 

40 # Schedule 1 hour in the future: 

41 trial_1h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr, config) 

42 # Schedule 2 hours in the future: 

43 trial_2h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr * 2, config) 

44 

45 # Scheduler side: get trials ready to run at certain timestamps: 

46 

47 # Pretend 1 minute has passed, get trials scheduled to run: 

48 pending_ids = _trial_ids( 

49 exp_storage.pending_trials(timestamp + timedelta_1min, running=False)) 

50 assert pending_ids == { 

51 trial_now1.trial_id, 

52 trial_now2.trial_id, 

53 } 

54 

55 # Get trials scheduled to run within the next 1 hour: 

56 pending_ids = _trial_ids( 

57 exp_storage.pending_trials(timestamp + timedelta_1hr, running=False)) 

58 assert pending_ids == { 

59 trial_now1.trial_id, 

60 trial_now2.trial_id, 

61 trial_1h.trial_id, 

62 } 

63 

64 # Get trials scheduled to run within the next 3 hours: 

65 pending_ids = _trial_ids( 

66 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)) 

67 assert pending_ids == { 

68 trial_now1.trial_id, 

69 trial_now2.trial_id, 

70 trial_1h.trial_id, 

71 trial_2h.trial_id, 

72 } 

73 

74 # Optimizer side: get trials completed after some known trial: 

75 

76 # No completed trials yet: 

77 assert exp_storage.load() == ([], [], [], []) 

78 

79 # Update the status of some trials: 

80 trial_now1.update(Status.RUNNING, timestamp + timedelta_1min) 

81 trial_now2.update(Status.RUNNING, timestamp + timedelta_1min) 

82 

83 # Still no completed trials: 

84 assert exp_storage.load() == ([], [], [], []) 

85 

86 # Get trials scheduled to run within the next 3 hours: 

87 pending_ids = _trial_ids( 

88 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)) 

89 assert pending_ids == { 

90 trial_1h.trial_id, 

91 trial_2h.trial_id, 

92 } 

93 

94 # Get trials scheduled to run OR running within the next 3 hours: 

95 pending_ids = _trial_ids( 

96 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True)) 

97 assert pending_ids == { 

98 trial_now1.trial_id, 

99 trial_now2.trial_id, 

100 trial_1h.trial_id, 

101 trial_2h.trial_id, 

102 } 

103 

104 # Mark some trials completed after 2 minutes: 

105 trial_now1.update(Status.SUCCEEDED, timestamp + timedelta_1min * 2, metrics={"score": 1.0}) 

106 trial_now2.update(Status.FAILED, timestamp + timedelta_1min * 2) 

107 

108 # Another one completes after 2 hours: 

109 trial_1h.update(Status.SUCCEEDED, timestamp + timedelta_1hr * 2, metrics={"score": 1.0}) 

110 

111 # Check that three trials have completed so far: 

112 (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load() 

113 assert trial_ids == [trial_now1.trial_id, trial_now2.trial_id, trial_1h.trial_id] 

114 assert len(trial_configs) == len(trial_scores) == 3 

115 assert trial_status == [Status.SUCCEEDED, Status.FAILED, Status.SUCCEEDED] 

116 

117 # Get only trials completed after trial_now2: 

118 (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load(last_trial_id=trial_now2.trial_id) 

119 assert trial_ids == [trial_1h.trial_id] 

120 assert len(trial_configs) == len(trial_scores) == 1 

121 assert trial_status == [Status.SUCCEEDED]