Coverage for mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py: 100%
23 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Common data classes for the SSH service tests."""
7from dataclasses import dataclass
8from subprocess import run
10# The SSH test server port and name.
11# See Also: docker-compose.yml
12SSH_TEST_SERVER_PORT = 2254
13SSH_TEST_SERVER_NAME = "ssh-server"
14ALT_TEST_SERVER_NAME = "alt-server"
15REBOOT_TEST_SERVER_NAME = "reboot-server"
18@dataclass
19class SshTestServerInfo:
20 """
21 A data class for SshTestServerInfo.
23 See Also
24 --------
25 mlos_bench.tests.storage.sql.SqlTestServerInfo
26 """
28 compose_project_name: str
29 service_name: str
30 hostname: str
31 username: str
32 id_rsa_path: str
33 _port: int | None = None
35 def get_port(self, uncached: bool = False) -> int:
36 """
37 Gets the port that the SSH test server is listening on.
39 Note: this value can change when the service restarts so we can't rely on
40 the DockerServices.
41 """
42 if self._port is None or uncached:
43 port_cmd = run(
44 (
45 f"docker compose -p {self.compose_project_name} "
46 f"port {self.service_name} {SSH_TEST_SERVER_PORT}"
47 ),
48 shell=True,
49 check=True,
50 capture_output=True,
51 )
52 self._port = int(port_cmd.stdout.decode().strip().split(":")[1])
53 return self._port
55 def to_ssh_service_config(self, uncached: bool = False) -> dict:
56 """Convert to a config dict for SshService."""
57 return {
58 "ssh_hostname": self.hostname,
59 "ssh_port": self.get_port(uncached),
60 "ssh_username": self.username,
61 "ssh_priv_key_path": self.id_rsa_path,
62 }
64 def to_connect_params(self, uncached: bool = False) -> dict:
65 """
66 Convert to a connect_params dict for SshClient.
68 See Also: mlos_bench.services.remote.ssh.ssh_service.SshService._get_connect_params()
69 """
70 return {
71 "host": self.hostname,
72 "port": self.get_port(uncached),
73 "username": self.username,
74 }