Coverage for mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py: 89%
157 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6A collection functions for interacting with SSH servers as file shares.
7"""
9from abc import ABCMeta
10from asyncio import Event as CoroEvent, Lock as CoroLock
11from warnings import warn
12from types import TracebackType
13from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Type, Union
14from threading import current_thread
16import logging
17import os
19import asyncssh
20from asyncssh.connection import SSHClientConnection
22from mlos_bench.services.base_service import Service
23from mlos_bench.event_loop_context import EventLoopContext, CoroReturnType, FutureReturnType
24from mlos_bench.util import nullable
26_LOG = logging.getLogger(__name__)
29class SshClient(asyncssh.SSHClient):
30 """
31 Wrapper around SSHClient to help provide connection caching and reconnect logic.
33 Used by the SshService to try and maintain a single connection to hosts,
34 handle reconnects if possible, and use that to run commands rather than
35 reconnect for each command.
36 """
38 _CONNECTION_PENDING = 'INIT'
39 _CONNECTION_LOST = 'LOST'
41 def __init__(self, *args: tuple, **kwargs: dict):
42 self._connection_id: str = SshClient._CONNECTION_PENDING
43 self._connection: Optional[SSHClientConnection] = None
44 self._conn_event: CoroEvent = CoroEvent()
45 super().__init__(*args, **kwargs)
47 def __repr__(self) -> str:
48 return self._connection_id
50 @staticmethod
51 def id_from_connection(connection: SSHClientConnection) -> str:
52 """Gets a unique id repr for the connection."""
53 return f"{connection._username}@{connection._host}:{connection._port}" # pylint: disable=protected-access
55 @staticmethod
56 def id_from_params(connect_params: dict) -> str:
57 """Gets a unique id repr for the connection."""
58 return f"{connect_params.get('username')}@{connect_params['host']}:{connect_params.get('port')}"
60 def connection_made(self, conn: SSHClientConnection) -> None:
61 """
62 Override hook provided by asyncssh.SSHClient.
64 Changes the connection_id from _CONNECTION_PENDING to a unique id repr.
65 """
66 self._conn_event.clear()
67 _LOG.debug("%s: Connection made by %s: %s", current_thread().name, conn._options.env, conn) \
68 # pylint: disable=protected-access
69 self._connection_id = SshClient.id_from_connection(conn)
70 self._connection = conn
71 self._conn_event.set()
72 return super().connection_made(conn)
74 def connection_lost(self, exc: Optional[Exception]) -> None:
75 self._conn_event.clear()
76 _LOG.debug("%s: %s", current_thread().name, "connection_lost")
77 if exc is None:
78 _LOG.debug("%s: gracefully disconnected ssh from %s: %s", current_thread().name, self._connection_id, exc)
79 else:
80 _LOG.debug("%s: ssh connection lost on %s: %s", current_thread().name, self._connection_id, exc)
81 self._connection_id = SshClient._CONNECTION_LOST
82 self._connection = None
83 self._conn_event.set()
84 return super().connection_lost(exc)
86 async def connection(self) -> Optional[SSHClientConnection]:
87 """
88 Waits for and returns the SSHClientConnection to be established or lost.
89 """
90 _LOG.debug("%s: Waiting for connection to be available.", current_thread().name)
91 await self._conn_event.wait()
92 _LOG.debug("%s: Connection available for %s", current_thread().name, self._connection_id)
93 return self._connection
96class SshClientCache:
97 """
98 Manages a cache of SshClient connections.
99 Note: Only one per event loop thread supported.
100 See additional details in SshService comments.
101 """
103 def __init__(self) -> None:
104 self._cache: Dict[str, Tuple[SSHClientConnection, SshClient]] = {}
105 self._cache_lock = CoroLock()
106 self._refcnt: int = 0
108 def __str__(self) -> str:
109 return str(self._cache)
111 def __len__(self) -> int:
112 return len(self._cache)
114 def enter(self) -> None:
115 """
116 Manages the cache lifecycle with reference counting.
117 To be used in the __enter__ method of a caller's context manager.
118 """
119 self._refcnt += 1
121 def exit(self) -> None:
122 """
123 Manages the cache lifecycle with reference counting.
124 To be used in the __exit__ method of a caller's context manager.
125 """
126 self._refcnt -= 1
127 if self._refcnt <= 0:
128 self.cleanup()
129 if self._cache_lock.locked():
130 warn(RuntimeWarning("SshClientCache lock was still held on exit."))
131 self._cache_lock.release()
133 async def get_client_connection(self, connect_params: dict) -> Tuple[SSHClientConnection, SshClient]:
134 """
135 Gets a (possibly cached) client connection.
137 Parameters
138 ----------
139 connect_params: dict
140 Parameters to pass to asyncssh.create_connection.
142 Returns
143 -------
144 Tuple[SSHClientConnection, SshClient]
145 A tuple of (SSHClientConnection, SshClient).
146 """
147 _LOG.debug("%s: get_client_connection: %s", current_thread().name, connect_params)
148 async with self._cache_lock:
149 connection_id = SshClient.id_from_params(connect_params)
150 client: Union[None, SshClient, asyncssh.SSHClient]
151 _, client = self._cache.get(connection_id, (None, None))
152 if client:
153 _LOG.debug("%s: Checking cached client %s", current_thread().name, connection_id)
154 connection = await client.connection()
155 if not connection:
156 _LOG.debug("%s: Removing stale client connection %s from cache.", current_thread().name, connection_id)
157 self._cache.pop(connection_id)
158 # Try to reconnect next.
159 else:
160 _LOG.debug("%s: Using cached client %s", current_thread().name, connection_id)
161 if connection_id not in self._cache:
162 _LOG.debug("%s: Establishing client connection to %s", current_thread().name, connection_id)
163 connection, client = await asyncssh.create_connection(SshClient, **connect_params)
164 assert isinstance(client, SshClient)
165 self._cache[connection_id] = (connection, client)
166 _LOG.debug("%s: Created connection to %s.", current_thread().name, connection_id)
167 return self._cache[connection_id]
169 def cleanup(self) -> None:
170 """
171 Closes all cached connections.
172 """
173 for (connection, _) in self._cache.values():
174 connection.close()
175 self._cache = {}
178class SshService(Service, metaclass=ABCMeta):
179 """
180 Base class for SSH services.
181 """
183 # AsyncSSH requires an asyncio event loop to be running to work.
184 # However, running that event loop blocks the main thread.
185 # To avoid having to change our entire API to use async/await, all the way
186 # up the stack, we run the event loop that runs any async code in a
187 # background thread and submit async code to it using
188 # asyncio.run_coroutine_threadsafe, interacting with Futures after that.
189 # This is a bit of a hack, but it works for now.
190 #
191 # The event loop is created on demand and shared across all SshService
192 # instances, hence we need to lock it when doing the creation/cleanup,
193 # or later, during context enter and exit.
194 #
195 # We ran tests to ensure that multiple requests can still be executing
196 # concurrently inside that event loop so there should be no practical
197 # performance loss for our initial cases even with just single background
198 # thread running the event loop.
199 #
200 # Note: the tests were run to confirm that this works with two threads.
201 # Using a larger thread pool requires a bit more work since asyncssh
202 # requires that run() requests are submitted to the same event loop handler
203 # that the connection was made on.
204 # In that case, each background thread should get its own SshClientCache.
206 # Maintain one just one event loop thread for all SshService instances.
207 # But only keep it running while they are within a context.
208 _EVENT_LOOP_CONTEXT = EventLoopContext()
209 _EVENT_LOOP_THREAD_SSH_CLIENT_CACHE = SshClientCache()
211 _REQUEST_TIMEOUT: Optional[float] = None # seconds
213 def __init__(self,
214 config: Optional[Dict[str, Any]] = None,
215 global_config: Optional[Dict[str, Any]] = None,
216 parent: Optional[Service] = None,
217 methods: Union[Dict[str, Callable], List[Callable], None] = None):
218 super().__init__(config, global_config, parent, methods)
220 # Make sure that the value we allow overriding on a per-connection
221 # basis are present in the config so merge_parameters can do its thing.
222 self.config.setdefault('ssh_port', None)
223 assert isinstance(self.config['ssh_port'], (int, type(None)))
224 self.config.setdefault('ssh_username', None)
225 assert isinstance(self.config['ssh_username'], (str, type(None)))
226 self.config.setdefault('ssh_priv_key_path', None)
227 assert isinstance(self.config['ssh_priv_key_path'], (str, type(None)))
229 # None can be used to disable the request timeout.
230 self._request_timeout = self.config.get("ssh_request_timeout", self._REQUEST_TIMEOUT)
231 self._request_timeout = nullable(float, self._request_timeout)
233 # Prep an initial connect_params.
234 self._connect_params: dict = {
235 # In general scripted commands shouldn't need a pty and having one
236 # available can confuse some commands, though we may need to make
237 # this configurable in the future.
238 'request_pty': False,
239 # By default disable known_hosts checking (since most VMs expected to be dynamically created).
240 'known_hosts': None,
241 }
243 if 'ssh_known_hosts_file' in self.config:
244 self._connect_params['known_hosts'] = self.config.get("ssh_known_hosts_file", None)
245 if isinstance(self._connect_params['known_hosts'], str):
246 known_hosts_file = os.path.expanduser(self._connect_params['known_hosts'])
247 if not os.path.exists(known_hosts_file):
248 raise ValueError(f"ssh_known_hosts_file {known_hosts_file} does not exist")
249 self._connect_params['known_hosts'] = known_hosts_file
250 if self._connect_params['known_hosts'] is None:
251 _LOG.info("%s known_hosts checking is disabled per config.", self)
253 if 'ssh_keepalive_interval' in self.config:
254 keepalive_internal = self.config.get('ssh_keepalive_interval')
255 self._connect_params['keepalive_interval'] = nullable(int, keepalive_internal)
257 def _enter_context(self) -> "SshService":
258 # Start the background thread if it's not already running.
259 assert not self._in_context
260 SshService._EVENT_LOOP_CONTEXT.enter()
261 SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.enter()
262 super()._enter_context()
263 return self
265 def _exit_context(self, ex_type: Optional[Type[BaseException]],
266 ex_val: Optional[BaseException],
267 ex_tb: Optional[TracebackType]) -> Literal[False]:
268 # Stop the background thread if it's not needed anymore and potentially
269 # cleanup the cache as well.
270 assert self._in_context
271 SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.exit()
272 SshService._EVENT_LOOP_CONTEXT.exit()
273 return super()._exit_context(ex_type, ex_val, ex_tb)
275 @classmethod
276 def clear_client_cache(cls) -> None:
277 """
278 Clears the cache of client connections.
279 Note: This may cause in flight operations to fail.
280 """
281 cls._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.cleanup()
283 def _run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType:
284 """
285 Runs the given coroutine in the background event loop thread.
287 Parameters
288 ----------
289 coro : Coroutine[Any, Any, CoroReturnType]
290 The coroutine to run.
292 Returns
293 -------
294 Future[CoroReturnType]
295 A future that will be completed when the coroutine completes.
296 """
297 assert self._in_context
298 return self._EVENT_LOOP_CONTEXT.run_coroutine(coro)
300 def _get_connect_params(self, params: dict) -> dict:
301 """
302 Produces a dict of connection parameters for asyncssh.create_connection.
304 Parameters
305 ----------
306 params : dict
307 Additional connection parameters specific to this host.
309 Returns
310 -------
311 dict
312 A dict of connection parameters for asyncssh.create_connection.
313 """
314 # Setup default connect_params dict for all SshClients we might need to create.
316 # Note: None is an acceptable value for several of these, in which case
317 # reasonable defaults or values from ~/.ssh/config will take effect.
319 # Start with the base config params.
320 connect_params = self._connect_params.copy()
322 connect_params['host'] = params['ssh_hostname'] # required
324 if params.get('ssh_port'):
325 connect_params['port'] = int(params.pop('ssh_port'))
326 elif self.config['ssh_port']:
327 connect_params['port'] = int(self.config['ssh_port'])
329 if 'ssh_username' in params:
330 connect_params['username'] = str(params.pop('ssh_username'))
331 elif self.config['ssh_username']:
332 connect_params['username'] = str(self.config['ssh_username'])
334 priv_key_file: Optional[str] = params.get('ssh_priv_key_path', self.config['ssh_priv_key_path'])
335 if priv_key_file:
336 priv_key_file = os.path.expanduser(priv_key_file)
337 if not os.path.exists(priv_key_file):
338 raise ValueError(f"ssh_priv_key_path {priv_key_file} does not exist")
339 connect_params['client_keys'] = [priv_key_file]
341 return connect_params
343 async def _get_client_connection(self, params: dict) -> Tuple[SSHClientConnection, SshClient]:
344 """
345 Gets a (possibly cached) SshClient (connection) for the given connection params.
347 Parameters
348 ----------
349 params : dict
350 Optional override connection parameters.
352 Returns
353 -------
354 Tuple[SSHClientConnection, SshClient]
355 The connection and client objects.
356 """
357 assert self._in_context
358 return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection(self._get_connect_params(params))