Coverage for mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py: 100%
20 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Unit tests for RemoveEnv benchmark environment via local SSH test services."""
7import os
8import sys
9from typing import Dict
11import numpy as np
12import pytest
14from mlos_bench.services.config_persistence import ConfigPersistenceService
15from mlos_bench.tests import requires_docker
16from mlos_bench.tests.environments import check_env_success
17from mlos_bench.tests.services.remote.ssh import SshTestServerInfo
18from mlos_bench.tunables.tunable import TunableValue
19from mlos_bench.tunables.tunable_groups import TunableGroups
21if sys.version_info < (3, 10):
22 from importlib_resources import files
23else:
24 from importlib.resources import files
27@requires_docker
28def test_remote_ssh_env(ssh_test_server: SshTestServerInfo) -> None:
29 """Produce benchmark and telemetry data in a local script and read it."""
30 global_config: Dict[str, TunableValue] = {
31 "ssh_hostname": ssh_test_server.hostname,
32 "ssh_port": ssh_test_server.get_port(),
33 "ssh_username": ssh_test_server.username,
34 "ssh_priv_key_path": ssh_test_server.id_rsa_path,
35 }
37 service = ConfigPersistenceService(
38 config={"config_path": [str(files("mlos_bench.tests.config"))]}
39 )
40 config_path = service.resolve_path("environments/remote/test_ssh_env.jsonc")
41 env = service.load_environment(
42 config_path,
43 TunableGroups(),
44 global_config=global_config,
45 service=service,
46 )
48 check_env_success(
49 env,
50 env.tunable_params,
51 expected_results={
52 "hostname": ssh_test_server.service_name,
53 "username": ssh_test_server.username,
54 "score": 0.9,
55 "ssh_priv_key_path": np.nan, # empty strings are returned as "not a number"
56 "test_param": "unset",
57 "FOO": "unset",
58 "ssh_username": "unset",
59 },
60 expected_telemetry=[],
61 )
62 assert not os.path.exists(
63 os.path.join(os.getcwd(), "output-downloaded.csv")
64 ), "output-downloaded.csv should have been cleaned up by temp_dir context"
67if __name__ == "__main__":
68 pytest.main(["-n1", __file__])