Coverage for mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py: 100%
21 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Unit tests for RemoveEnv benchmark environment via local SSH test services.
7"""
9from typing import Dict
11import os
12import sys
14import numpy as np
16import pytest
18from mlos_bench.services.config_persistence import ConfigPersistenceService
19from mlos_bench.tunables.tunable import TunableValue
20from mlos_bench.tunables.tunable_groups import TunableGroups
22from mlos_bench.tests import requires_docker
23from mlos_bench.tests.environments import check_env_success
24from mlos_bench.tests.services.remote.ssh import SshTestServerInfo
26if sys.version_info < (3, 10):
27 from importlib_resources import files
28else:
29 from importlib.resources import files
32@requires_docker
33def test_remote_ssh_env(ssh_test_server: SshTestServerInfo) -> None:
34 """
35 Produce benchmark and telemetry data in a local script and read it.
36 """
37 global_config: Dict[str, TunableValue] = {
38 "ssh_hostname": ssh_test_server.hostname,
39 "ssh_port": ssh_test_server.get_port(),
40 "ssh_username": ssh_test_server.username,
41 "ssh_priv_key_path": ssh_test_server.id_rsa_path,
42 }
44 service = ConfigPersistenceService(config={"config_path": [str(files("mlos_bench.tests.config"))]})
45 config_path = service.resolve_path("environments/remote/test_ssh_env.jsonc")
46 env = service.load_environment(config_path, TunableGroups(), global_config=global_config, service=service)
48 check_env_success(
49 env, env.tunable_params,
50 expected_results={
51 "hostname": ssh_test_server.service_name,
52 "username": ssh_test_server.username,
53 "score": 0.9,
54 "ssh_priv_key_path": np.nan, # empty strings are returned as "not a number"
55 "test_param": "unset",
56 "FOO": "unset",
57 "ssh_username": "unset",
58 },
59 expected_telemetry=[],
60 )
61 assert not os.path.exists(os.path.join(os.getcwd(), "output-downloaded.csv")), \
62 "output-downloaded.csv should have been cleaned up by temp_dir context"
65if __name__ == "__main__":
66 pytest.main(["-n1", __file__])