Coverage for mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py: 100%

36 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Unit tests for saving and restoring the telemetry data. 

7""" 

8from datetime import datetime, timedelta, tzinfo 

9from typing import Any, List, Optional, Tuple 

10 

11from pytz import UTC 

12 

13import pytest 

14 

15from mlos_bench.environments.status import Status 

16from mlos_bench.tunables.tunable_groups import TunableGroups 

17from mlos_bench.storage.base_storage import Storage 

18from mlos_bench.util import nullable 

19 

20from mlos_bench.tests import ZONE_INFO 

21 

22# pylint: disable=redefined-outer-name 

23 

24 

25def zoned_telemetry_data(zone_info: Optional[tzinfo]) -> List[Tuple[datetime, str, Any]]: 

26 """ 

27 Mock telemetry data for the trial. 

28 

29 Returns 

30 ------- 

31 List[Tuple[datetime, str, str]] 

32 A list of (timestamp, metric_id, metric_value) 

33 """ 

34 timestamp1 = datetime.now(zone_info) 

35 timestamp2 = timestamp1 + timedelta(seconds=1) 

36 return sorted([ 

37 (timestamp1, "cpu_load", 10.1), 

38 (timestamp1, "memory", 20), 

39 (timestamp1, "setup", "prod"), 

40 (timestamp2, "cpu_load", 30.1), 

41 (timestamp2, "memory", 40), 

42 (timestamp2, "setup", "prod"), 

43 ]) 

44 

45 

46def _telemetry_str(data: List[Tuple[datetime, str, Any]] 

47 ) -> List[Tuple[datetime, str, Optional[str]]]: 

48 """ 

49 Convert telemetry values to strings. 

50 """ 

51 # All retrieved timestamps should have been converted to UTC. 

52 return [(ts.astimezone(UTC), key, nullable(str, val)) for (ts, key, val) in data] 

53 

54 

55@pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO) 

56def test_update_telemetry(storage: Storage, 

57 exp_storage: Storage.Experiment, 

58 tunable_groups: TunableGroups, 

59 origin_zone_info: Optional[tzinfo]) -> None: 

60 """ 

61 Make sure update_telemetry() and load_telemetry() methods work. 

62 """ 

63 telemetry_data = zoned_telemetry_data(origin_zone_info) 

64 trial = exp_storage.new_trial(tunable_groups) 

65 assert exp_storage.load_telemetry(trial.trial_id) == [] 

66 

67 trial.update_telemetry(Status.RUNNING, datetime.now(origin_zone_info), telemetry_data) 

68 assert exp_storage.load_telemetry(trial.trial_id) == _telemetry_str(telemetry_data) 

69 

70 # Also check that the TrialData telemetry looks right. 

71 trial_data = storage.experiments[exp_storage.experiment_id].trials[trial.trial_id] 

72 trial_telemetry_df = trial_data.telemetry_df 

73 trial_telemetry_data = [tuple(r) for r in trial_telemetry_df.to_numpy()] 

74 assert _telemetry_str(trial_telemetry_data) == _telemetry_str(telemetry_data) 

75 

76 

77@pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO) 

78def test_update_telemetry_twice(exp_storage: Storage.Experiment, 

79 tunable_groups: TunableGroups, 

80 origin_zone_info: Optional[tzinfo]) -> None: 

81 """ 

82 Make sure update_telemetry() call is idempotent. 

83 """ 

84 telemetry_data = zoned_telemetry_data(origin_zone_info) 

85 trial = exp_storage.new_trial(tunable_groups) 

86 timestamp = datetime.now(origin_zone_info) 

87 trial.update_telemetry(Status.RUNNING, timestamp, telemetry_data) 

88 trial.update_telemetry(Status.RUNNING, timestamp, telemetry_data) 

89 trial.update_telemetry(Status.RUNNING, timestamp, telemetry_data) 

90 assert exp_storage.load_telemetry(trial.trial_id) == _telemetry_str(telemetry_data)