Coverage for mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py: 100%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 01:18 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for saving and restoring the telemetry data.""" 

6from datetime import datetime, timedelta, tzinfo 

7from typing import Any, List, Optional, Tuple 

8 

9import pytest 

10from pytz import UTC 

11 

12from mlos_bench.environments.status import Status 

13from mlos_bench.storage.base_storage import Storage 

14from mlos_bench.tests import ZONE_INFO 

15from mlos_bench.tunables.tunable_groups import TunableGroups 

16from mlos_bench.util import nullable 

17 

18# pylint: disable=redefined-outer-name 

19 

20 

21def zoned_telemetry_data(zone_info: Optional[tzinfo]) -> List[Tuple[datetime, str, Any]]: 

22 """ 

23 Mock telemetry data for the trial. 

24 

25 Returns 

26 ------- 

27 List[Tuple[datetime, str, str]] 

28 A list of (timestamp, metric_id, metric_value) 

29 """ 

30 timestamp1 = datetime.now(zone_info) 

31 timestamp2 = timestamp1 + timedelta(seconds=1) 

32 return sorted( 

33 [ 

34 (timestamp1, "cpu_load", 10.1), 

35 (timestamp1, "memory", 20), 

36 (timestamp1, "setup", "prod"), 

37 (timestamp2, "cpu_load", 30.1), 

38 (timestamp2, "memory", 40), 

39 (timestamp2, "setup", "prod"), 

40 ] 

41 ) 

42 

43 

44def _telemetry_str( 

45 data: List[Tuple[datetime, str, Any]], 

46) -> List[Tuple[datetime, str, Optional[str]]]: 

47 """Convert telemetry values to strings.""" 

48 # All retrieved timestamps should have been converted to UTC. 

49 return [(ts.astimezone(UTC), key, nullable(str, val)) for (ts, key, val) in data] 

50 

51 

52@pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO) 

53def test_update_telemetry( 

54 storage: Storage, 

55 exp_storage: Storage.Experiment, 

56 tunable_groups: TunableGroups, 

57 origin_zone_info: Optional[tzinfo], 

58) -> None: 

59 """Make sure update_telemetry() and load_telemetry() methods work.""" 

60 telemetry_data = zoned_telemetry_data(origin_zone_info) 

61 trial = exp_storage.new_trial(tunable_groups) 

62 assert exp_storage.load_telemetry(trial.trial_id) == [] 

63 

64 trial.update_telemetry(Status.RUNNING, datetime.now(origin_zone_info), telemetry_data) 

65 assert exp_storage.load_telemetry(trial.trial_id) == _telemetry_str(telemetry_data) 

66 

67 # Also check that the TrialData telemetry looks right. 

68 trial_data = storage.experiments[exp_storage.experiment_id].trials[trial.trial_id] 

69 trial_telemetry_df = trial_data.telemetry_df 

70 trial_telemetry_data = [tuple(r) for r in trial_telemetry_df.to_numpy()] 

71 assert _telemetry_str(trial_telemetry_data) == _telemetry_str(telemetry_data) 

72 

73 

74@pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO) 

75def test_update_telemetry_twice( 

76 exp_storage: Storage.Experiment, 

77 tunable_groups: TunableGroups, 

78 origin_zone_info: Optional[tzinfo], 

79) -> None: 

80 """Make sure update_telemetry() call is idempotent.""" 

81 telemetry_data = zoned_telemetry_data(origin_zone_info) 

82 trial = exp_storage.new_trial(tunable_groups) 

83 timestamp = datetime.now(origin_zone_info) 

84 trial.update_telemetry(Status.RUNNING, timestamp, telemetry_data) 

85 trial.update_telemetry(Status.RUNNING, timestamp, telemetry_data) 

86 trial.update_telemetry(Status.RUNNING, timestamp, telemetry_data) 

87 assert exp_storage.load_telemetry(trial.trial_id) == _telemetry_str(telemetry_data)