Coverage for mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py: 100%
36 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Unit tests for saving and restoring the telemetry data.
7"""
8from datetime import datetime, timedelta, tzinfo
9from typing import Any, List, Optional, Tuple
11from pytz import UTC
13import pytest
15from mlos_bench.environments.status import Status
16from mlos_bench.tunables.tunable_groups import TunableGroups
17from mlos_bench.storage.base_storage import Storage
18from mlos_bench.util import nullable
20from mlos_bench.tests import ZONE_INFO
22# pylint: disable=redefined-outer-name
25def zoned_telemetry_data(zone_info: Optional[tzinfo]) -> List[Tuple[datetime, str, Any]]:
26 """
27 Mock telemetry data for the trial.
29 Returns
30 -------
31 List[Tuple[datetime, str, str]]
32 A list of (timestamp, metric_id, metric_value)
33 """
34 timestamp1 = datetime.now(zone_info)
35 timestamp2 = timestamp1 + timedelta(seconds=1)
36 return sorted([
37 (timestamp1, "cpu_load", 10.1),
38 (timestamp1, "memory", 20),
39 (timestamp1, "setup", "prod"),
40 (timestamp2, "cpu_load", 30.1),
41 (timestamp2, "memory", 40),
42 (timestamp2, "setup", "prod"),
43 ])
46def _telemetry_str(data: List[Tuple[datetime, str, Any]]
47 ) -> List[Tuple[datetime, str, Optional[str]]]:
48 """
49 Convert telemetry values to strings.
50 """
51 # All retrieved timestamps should have been converted to UTC.
52 return [(ts.astimezone(UTC), key, nullable(str, val)) for (ts, key, val) in data]
55@pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO)
56def test_update_telemetry(storage: Storage,
57 exp_storage: Storage.Experiment,
58 tunable_groups: TunableGroups,
59 origin_zone_info: Optional[tzinfo]) -> None:
60 """
61 Make sure update_telemetry() and load_telemetry() methods work.
62 """
63 telemetry_data = zoned_telemetry_data(origin_zone_info)
64 trial = exp_storage.new_trial(tunable_groups)
65 assert exp_storage.load_telemetry(trial.trial_id) == []
67 trial.update_telemetry(Status.RUNNING, datetime.now(origin_zone_info), telemetry_data)
68 assert exp_storage.load_telemetry(trial.trial_id) == _telemetry_str(telemetry_data)
70 # Also check that the TrialData telemetry looks right.
71 trial_data = storage.experiments[exp_storage.experiment_id].trials[trial.trial_id]
72 trial_telemetry_df = trial_data.telemetry_df
73 trial_telemetry_data = [tuple(r) for r in trial_telemetry_df.to_numpy()]
74 assert _telemetry_str(trial_telemetry_data) == _telemetry_str(telemetry_data)
77@pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO)
78def test_update_telemetry_twice(exp_storage: Storage.Experiment,
79 tunable_groups: TunableGroups,
80 origin_zone_info: Optional[tzinfo]) -> None:
81 """
82 Make sure update_telemetry() call is idempotent.
83 """
84 telemetry_data = zoned_telemetry_data(origin_zone_info)
85 trial = exp_storage.new_trial(tunable_groups)
86 timestamp = datetime.now(origin_zone_info)
87 trial.update_telemetry(Status.RUNNING, timestamp, telemetry_data)
88 trial.update_telemetry(Status.RUNNING, timestamp, telemetry_data)
89 trial.update_telemetry(Status.RUNNING, timestamp, telemetry_data)
90 assert exp_storage.load_telemetry(trial.trial_id) == _telemetry_str(telemetry_data)