Coverage for mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py: 100%
35 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Unit tests for saving and restoring the telemetry data."""
6from datetime import datetime, timedelta, tzinfo
7from typing import Any, List, Optional, Tuple
9import pytest
10from pytz import UTC
12from mlos_bench.environments.status import Status
13from mlos_bench.storage.base_storage import Storage
14from mlos_bench.tests import ZONE_INFO
15from mlos_bench.tunables.tunable_groups import TunableGroups
16from mlos_bench.util import nullable
18# pylint: disable=redefined-outer-name
21def zoned_telemetry_data(zone_info: Optional[tzinfo]) -> List[Tuple[datetime, str, Any]]:
22 """
23 Mock telemetry data for the trial.
25 Returns
26 -------
27 List[Tuple[datetime, str, str]]
28 A list of (timestamp, metric_id, metric_value)
29 """
30 timestamp1 = datetime.now(zone_info)
31 timestamp2 = timestamp1 + timedelta(seconds=1)
32 return sorted(
33 [
34 (timestamp1, "cpu_load", 10.1),
35 (timestamp1, "memory", 20),
36 (timestamp1, "setup", "prod"),
37 (timestamp2, "cpu_load", 30.1),
38 (timestamp2, "memory", 40),
39 (timestamp2, "setup", "prod"),
40 ]
41 )
44def _telemetry_str(
45 data: List[Tuple[datetime, str, Any]],
46) -> List[Tuple[datetime, str, Optional[str]]]:
47 """Convert telemetry values to strings."""
48 # All retrieved timestamps should have been converted to UTC.
49 return [(ts.astimezone(UTC), key, nullable(str, val)) for (ts, key, val) in data]
52@pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO)
53def test_update_telemetry(
54 storage: Storage,
55 exp_storage: Storage.Experiment,
56 tunable_groups: TunableGroups,
57 origin_zone_info: Optional[tzinfo],
58) -> None:
59 """Make sure update_telemetry() and load_telemetry() methods work."""
60 telemetry_data = zoned_telemetry_data(origin_zone_info)
61 trial = exp_storage.new_trial(tunable_groups)
62 assert exp_storage.load_telemetry(trial.trial_id) == []
64 trial.update_telemetry(Status.RUNNING, datetime.now(origin_zone_info), telemetry_data)
65 assert exp_storage.load_telemetry(trial.trial_id) == _telemetry_str(telemetry_data)
67 # Also check that the TrialData telemetry looks right.
68 trial_data = storage.experiments[exp_storage.experiment_id].trials[trial.trial_id]
69 trial_telemetry_df = trial_data.telemetry_df
70 trial_telemetry_data = [tuple(r) for r in trial_telemetry_df.to_numpy()]
71 assert _telemetry_str(trial_telemetry_data) == _telemetry_str(telemetry_data)
74@pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO)
75def test_update_telemetry_twice(
76 exp_storage: Storage.Experiment,
77 tunable_groups: TunableGroups,
78 origin_zone_info: Optional[tzinfo],
79) -> None:
80 """Make sure update_telemetry() call is idempotent."""
81 telemetry_data = zoned_telemetry_data(origin_zone_info)
82 trial = exp_storage.new_trial(tunable_groups)
83 timestamp = datetime.now(origin_zone_info)
84 trial.update_telemetry(Status.RUNNING, timestamp, telemetry_data)
85 trial.update_telemetry(Status.RUNNING, timestamp, telemetry_data)
86 trial.update_telemetry(Status.RUNNING, timestamp, telemetry_data)
87 assert exp_storage.load_telemetry(trial.trial_id) == _telemetry_str(telemetry_data)