Coverage for mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py: 89%
57 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 00:52 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 00:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""A collection Service functions for managing hosts via SSH."""
7import logging
8from collections.abc import Callable, Iterable
9from concurrent.futures import Future
10from typing import Any
12from asyncssh import ConnectionLost, DisconnectError, ProcessError, SSHCompletedProcess
14from mlos_bench.environments.status import Status
15from mlos_bench.services.base_service import Service
16from mlos_bench.services.remote.ssh.ssh_service import SshService
17from mlos_bench.services.types.os_ops_type import SupportsOSOps
18from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
19from mlos_bench.util import merge_parameters
21_LOG = logging.getLogger(__name__)
24class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
25 """Helper methods to manage machines via SSH."""
27 # pylint: disable=too-many-ancestors
28 # pylint: disable=too-many-instance-attributes
30 def __init__(
31 self,
32 config: dict[str, Any] | None = None,
33 global_config: dict[str, Any] | None = None,
34 parent: Service | None = None,
35 methods: dict[str, Callable] | list[Callable] | None = None,
36 ):
37 """
38 Create a new instance of an SSH Service.
40 Parameters
41 ----------
42 config : dict
43 Free-format dictionary that contains the benchmark environment
44 configuration.
45 global_config : dict
46 Free-format dictionary of global parameters.
47 parent : Service
48 Parent service that can provide mixin functions.
49 methods : Union[dict[str, Callable], list[Callable], None]
50 New methods to register with the service.
51 """
52 # Same methods are also provided by the AzureVMService class
53 # pylint: disable=duplicate-code
54 super().__init__(
55 config,
56 global_config,
57 parent,
58 self.merge_methods(
59 methods,
60 [
61 self.shutdown,
62 self.reboot,
63 self.wait_os_operation,
64 self.remote_exec,
65 self.get_remote_exec_results,
66 ],
67 ),
68 )
69 self._shell = self.config.get("ssh_shell", "/bin/bash")
71 async def _run_cmd(
72 self,
73 params: dict,
74 script: Iterable[str],
75 env_params: dict,
76 ) -> SSHCompletedProcess:
77 """
78 Runs a command asynchronously on a host via SSH.
80 Parameters
81 ----------
82 params : dict
83 Flat dictionary of (key, value) pairs of parameters (used for
84 establishing the connection).
85 cmd : str
86 Command(s) to run via shell.
88 Returns
89 -------
90 SSHCompletedProcess
91 Returns the result of the command.
92 """
93 if isinstance(script, str):
94 # Script should be an iterable of lines, not an iterable string.
95 script = [script]
96 connection, _ = await self._get_client_connection(params)
97 # Note: passing environment variables to SSH servers is typically restricted
98 # to just some LC_* values.
99 # Handle transferring environment variables by making a script to set them.
100 env_script_lines = [f"export {name}='{value}'" for (name, value) in env_params.items()]
101 script_lines = env_script_lines + [
102 line_split for line in script for line_split in line.splitlines()
103 ]
104 # Note: connection.run() uses "exec" with a shell by default.
105 script_str = "\n".join(script_lines)
106 _LOG.debug("Running script on %s:\n%s", connection, script_str)
107 return await connection.run(
108 script_str,
109 check=False,
110 timeout=self._request_timeout,
111 env=env_params,
112 )
114 def remote_exec(
115 self,
116 script: Iterable[str],
117 config: dict,
118 env_params: dict,
119 ) -> tuple["Status", dict]:
120 """
121 Start running a command on remote host OS.
123 Parameters
124 ----------
125 script : Iterable[str]
126 A list of lines to execute as a script on a remote VM.
127 config : dict
128 Flat dictionary of (key, value) pairs of parameters.
129 They usually come from `const_args` and `tunable_params`
130 properties of the Environment.
131 env_params : dict
132 Parameters to pass as *shell* environment variables into the script.
133 This is usually a subset of `config` with some possible conversions.
135 Returns
136 -------
137 result : (Status, dict)
138 A pair of Status and result.
139 Status is one of {PENDING, SUCCEEDED, FAILED}
140 """
141 config = merge_parameters(
142 dest=self.config.copy(),
143 source=config,
144 required_keys=[
145 "ssh_hostname",
146 ],
147 )
148 config["asyncRemoteExecResultsFuture"] = self._run_coroutine(
149 self._run_cmd(
150 config,
151 script,
152 env_params,
153 )
154 )
155 return (Status.PENDING, config)
157 def get_remote_exec_results(self, config: dict) -> tuple["Status", dict]:
158 """
159 Get the results of the asynchronously running command.
161 Parameters
162 ----------
163 config : dict
164 Flat dictionary of (key, value) pairs of tunable parameters.
165 Must have the "asyncRemoteExecResultsFuture" key to get the results.
166 If the key is not present, return Status.PENDING.
168 Returns
169 -------
170 result : (Status, dict)
171 A pair of Status and result.
172 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
173 """
174 future = config.get("asyncRemoteExecResultsFuture")
175 if not future:
176 raise ValueError("Missing 'asyncRemoteExecResultsFuture'.")
177 assert isinstance(future, Future)
178 result = None
179 try:
180 result = future.result(timeout=self._request_timeout)
181 assert isinstance(result, SSHCompletedProcess)
182 stdout = result.stdout.decode() if isinstance(result.stdout, bytes) else result.stdout
183 stderr = result.stderr.decode() if isinstance(result.stderr, bytes) else result.stderr
184 return (
185 (
186 Status.SUCCEEDED
187 if result.exit_status == 0 and result.returncode == 0
188 else Status.FAILED
189 ),
190 {
191 "stdout": stdout,
192 "stderr": stderr,
193 "ssh_completed_process_result": result,
194 },
195 )
196 except (ConnectionLost, DisconnectError, ProcessError, TimeoutError) as ex:
197 _LOG.error("Failed to get remote exec results: %s", ex)
198 return (Status.FAILED, {"result": result})
200 def _exec_os_op(self, cmd_opts_list: list[str], params: dict) -> tuple[Status, dict]:
201 """
202 _summary_
204 Parameters
205 ----------
206 cmd_opts_list : list[str]
207 List of commands to try to execute.
208 params : dict
209 The params used to connect to the host.
211 Returns
212 -------
213 result : (Status, dict)
214 A pair of Status and result.
215 Status is one of {PENDING, SUCCEEDED, FAILED}
216 """
217 config = merge_parameters(
218 dest=self.config.copy(),
219 source=params,
220 required_keys=[
221 "ssh_hostname",
222 ],
223 )
224 cmd_opts = " ".join([f"'{cmd}'" for cmd in cmd_opts_list])
225 script = rf"""
226 if [[ $EUID -ne 0 ]]; then
227 sudo=$(command -v sudo)
228 sudo=${ sudo:+$sudo -n}
229 fi
231 set -x
232 for cmd in {cmd_opts}; do
233 $sudo /bin/bash -c "$cmd" && exit 0
234 done
236 echo 'ERROR: Failed to shutdown/reboot the system.'
237 exit 1
238 """
239 return self.remote_exec(script, config, env_params={})
241 def shutdown(self, params: dict, force: bool = False) -> tuple[Status, dict]:
242 """
243 Initiates a (graceful) shutdown of the Host/VM OS.
245 Parameters
246 ----------
247 params: dict
248 Flat dictionary of (key, value) pairs of tunable parameters.
249 force : bool
250 If True, force stop the Host/VM.
252 Returns
253 -------
254 result : (Status, dict)
255 A pair of Status and result.
256 Status is one of {PENDING, SUCCEEDED, FAILED}
257 """
258 cmd_opts_list = [
259 "shutdown -h now",
260 "poweroff",
261 "halt -p",
262 "systemctl poweroff",
263 ]
264 return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params)
266 def reboot(self, params: dict, force: bool = False) -> tuple[Status, dict]:
267 """
268 Initiates a (graceful) shutdown of the Host/VM OS.
270 Parameters
271 ----------
272 params: dict
273 Flat dictionary of (key, value) pairs of tunable parameters.
274 force : bool
275 If True, force restart the Host/VM.
277 Returns
278 -------
279 result : (Status, dict)
280 A pair of Status and result.
281 Status is one of {PENDING, SUCCEEDED, FAILED}
282 """
283 cmd_opts_list = [
284 "shutdown -r now",
285 "reboot",
286 "halt --reboot",
287 "systemctl reboot",
288 "kill -KILL 1; kill -KILL -1" if force else "kill -TERM 1; kill -TERM -1",
289 ]
290 return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params)
292 def wait_os_operation(self, params: dict) -> tuple[Status, dict]:
293 """
294 Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED. Return
295 TIMED_OUT when timing out.
297 Parameters
298 ----------
299 params: dict
300 Flat dictionary of (key, value) pairs of tunable parameters.
301 Must have the "asyncRemoteExecResultsFuture" key to get the results.
302 If the key is not present, return Status.PENDING.
304 Returns
305 -------
306 result : (Status, dict)
307 A pair of Status and result.
308 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
309 Result is info on the operation runtime if SUCCEEDED, otherwise {}.
310 """
311 return self.get_remote_exec_results(params)