Coverage for mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py: 89%

157 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6A collection functions for interacting with SSH servers as file shares. 

7""" 

8 

9from abc import ABCMeta 

10from asyncio import Event as CoroEvent, Lock as CoroLock 

11from warnings import warn 

12from types import TracebackType 

13from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Type, Union 

14from threading import current_thread 

15 

16import logging 

17import os 

18 

19import asyncssh 

20from asyncssh.connection import SSHClientConnection 

21 

22from mlos_bench.services.base_service import Service 

23from mlos_bench.event_loop_context import EventLoopContext, CoroReturnType, FutureReturnType 

24from mlos_bench.util import nullable 

25 

26_LOG = logging.getLogger(__name__) 

27 

28 

29class SshClient(asyncssh.SSHClient): 

30 """ 

31 Wrapper around SSHClient to help provide connection caching and reconnect logic. 

32 

33 Used by the SshService to try and maintain a single connection to hosts, 

34 handle reconnects if possible, and use that to run commands rather than 

35 reconnect for each command. 

36 """ 

37 

38 _CONNECTION_PENDING = 'INIT' 

39 _CONNECTION_LOST = 'LOST' 

40 

41 def __init__(self, *args: tuple, **kwargs: dict): 

42 self._connection_id: str = SshClient._CONNECTION_PENDING 

43 self._connection: Optional[SSHClientConnection] = None 

44 self._conn_event: CoroEvent = CoroEvent() 

45 super().__init__(*args, **kwargs) 

46 

47 def __repr__(self) -> str: 

48 return self._connection_id 

49 

50 @staticmethod 

51 def id_from_connection(connection: SSHClientConnection) -> str: 

52 """Gets a unique id repr for the connection.""" 

53 return f"{connection._username}@{connection._host}:{connection._port}" # pylint: disable=protected-access 

54 

55 @staticmethod 

56 def id_from_params(connect_params: dict) -> str: 

57 """Gets a unique id repr for the connection.""" 

58 return f"{connect_params.get('username')}@{connect_params['host']}:{connect_params.get('port')}" 

59 

60 def connection_made(self, conn: SSHClientConnection) -> None: 

61 """ 

62 Override hook provided by asyncssh.SSHClient. 

63 

64 Changes the connection_id from _CONNECTION_PENDING to a unique id repr. 

65 """ 

66 self._conn_event.clear() 

67 _LOG.debug("%s: Connection made by %s: %s", current_thread().name, conn._options.env, conn) \ 

68 # pylint: disable=protected-access 

69 self._connection_id = SshClient.id_from_connection(conn) 

70 self._connection = conn 

71 self._conn_event.set() 

72 return super().connection_made(conn) 

73 

74 def connection_lost(self, exc: Optional[Exception]) -> None: 

75 self._conn_event.clear() 

76 _LOG.debug("%s: %s", current_thread().name, "connection_lost") 

77 if exc is None: 

78 _LOG.debug("%s: gracefully disconnected ssh from %s: %s", current_thread().name, self._connection_id, exc) 

79 else: 

80 _LOG.debug("%s: ssh connection lost on %s: %s", current_thread().name, self._connection_id, exc) 

81 self._connection_id = SshClient._CONNECTION_LOST 

82 self._connection = None 

83 self._conn_event.set() 

84 return super().connection_lost(exc) 

85 

86 async def connection(self) -> Optional[SSHClientConnection]: 

87 """ 

88 Waits for and returns the SSHClientConnection to be established or lost. 

89 """ 

90 _LOG.debug("%s: Waiting for connection to be available.", current_thread().name) 

91 await self._conn_event.wait() 

92 _LOG.debug("%s: Connection available for %s", current_thread().name, self._connection_id) 

93 return self._connection 

94 

95 

96class SshClientCache: 

97 """ 

98 Manages a cache of SshClient connections. 

99 Note: Only one per event loop thread supported. 

100 See additional details in SshService comments. 

101 """ 

102 

103 def __init__(self) -> None: 

104 self._cache: Dict[str, Tuple[SSHClientConnection, SshClient]] = {} 

105 self._cache_lock = CoroLock() 

106 self._refcnt: int = 0 

107 

108 def __str__(self) -> str: 

109 return str(self._cache) 

110 

111 def __len__(self) -> int: 

112 return len(self._cache) 

113 

114 def enter(self) -> None: 

115 """ 

116 Manages the cache lifecycle with reference counting. 

117 To be used in the __enter__ method of a caller's context manager. 

118 """ 

119 self._refcnt += 1 

120 

121 def exit(self) -> None: 

122 """ 

123 Manages the cache lifecycle with reference counting. 

124 To be used in the __exit__ method of a caller's context manager. 

125 """ 

126 self._refcnt -= 1 

127 if self._refcnt <= 0: 

128 self.cleanup() 

129 if self._cache_lock.locked(): 

130 warn(RuntimeWarning("SshClientCache lock was still held on exit.")) 

131 self._cache_lock.release() 

132 

133 async def get_client_connection(self, connect_params: dict) -> Tuple[SSHClientConnection, SshClient]: 

134 """ 

135 Gets a (possibly cached) client connection. 

136 

137 Parameters 

138 ---------- 

139 connect_params: dict 

140 Parameters to pass to asyncssh.create_connection. 

141 

142 Returns 

143 ------- 

144 Tuple[SSHClientConnection, SshClient] 

145 A tuple of (SSHClientConnection, SshClient). 

146 """ 

147 _LOG.debug("%s: get_client_connection: %s", current_thread().name, connect_params) 

148 async with self._cache_lock: 

149 connection_id = SshClient.id_from_params(connect_params) 

150 client: Union[None, SshClient, asyncssh.SSHClient] 

151 _, client = self._cache.get(connection_id, (None, None)) 

152 if client: 

153 _LOG.debug("%s: Checking cached client %s", current_thread().name, connection_id) 

154 connection = await client.connection() 

155 if not connection: 

156 _LOG.debug("%s: Removing stale client connection %s from cache.", current_thread().name, connection_id) 

157 self._cache.pop(connection_id) 

158 # Try to reconnect next. 

159 else: 

160 _LOG.debug("%s: Using cached client %s", current_thread().name, connection_id) 

161 if connection_id not in self._cache: 

162 _LOG.debug("%s: Establishing client connection to %s", current_thread().name, connection_id) 

163 connection, client = await asyncssh.create_connection(SshClient, **connect_params) 

164 assert isinstance(client, SshClient) 

165 self._cache[connection_id] = (connection, client) 

166 _LOG.debug("%s: Created connection to %s.", current_thread().name, connection_id) 

167 return self._cache[connection_id] 

168 

169 def cleanup(self) -> None: 

170 """ 

171 Closes all cached connections. 

172 """ 

173 for (connection, _) in self._cache.values(): 

174 connection.close() 

175 self._cache = {} 

176 

177 

178class SshService(Service, metaclass=ABCMeta): 

179 """ 

180 Base class for SSH services. 

181 """ 

182 

183 # AsyncSSH requires an asyncio event loop to be running to work. 

184 # However, running that event loop blocks the main thread. 

185 # To avoid having to change our entire API to use async/await, all the way 

186 # up the stack, we run the event loop that runs any async code in a 

187 # background thread and submit async code to it using 

188 # asyncio.run_coroutine_threadsafe, interacting with Futures after that. 

189 # This is a bit of a hack, but it works for now. 

190 # 

191 # The event loop is created on demand and shared across all SshService 

192 # instances, hence we need to lock it when doing the creation/cleanup, 

193 # or later, during context enter and exit. 

194 # 

195 # We ran tests to ensure that multiple requests can still be executing 

196 # concurrently inside that event loop so there should be no practical 

197 # performance loss for our initial cases even with just single background 

198 # thread running the event loop. 

199 # 

200 # Note: the tests were run to confirm that this works with two threads. 

201 # Using a larger thread pool requires a bit more work since asyncssh 

202 # requires that run() requests are submitted to the same event loop handler 

203 # that the connection was made on. 

204 # In that case, each background thread should get its own SshClientCache. 

205 

206 # Maintain one just one event loop thread for all SshService instances. 

207 # But only keep it running while they are within a context. 

208 _EVENT_LOOP_CONTEXT = EventLoopContext() 

209 _EVENT_LOOP_THREAD_SSH_CLIENT_CACHE = SshClientCache() 

210 

211 _REQUEST_TIMEOUT: Optional[float] = None # seconds 

212 

213 def __init__(self, 

214 config: Optional[Dict[str, Any]] = None, 

215 global_config: Optional[Dict[str, Any]] = None, 

216 parent: Optional[Service] = None, 

217 methods: Union[Dict[str, Callable], List[Callable], None] = None): 

218 super().__init__(config, global_config, parent, methods) 

219 

220 # Make sure that the value we allow overriding on a per-connection 

221 # basis are present in the config so merge_parameters can do its thing. 

222 self.config.setdefault('ssh_port', None) 

223 assert isinstance(self.config['ssh_port'], (int, type(None))) 

224 self.config.setdefault('ssh_username', None) 

225 assert isinstance(self.config['ssh_username'], (str, type(None))) 

226 self.config.setdefault('ssh_priv_key_path', None) 

227 assert isinstance(self.config['ssh_priv_key_path'], (str, type(None))) 

228 

229 # None can be used to disable the request timeout. 

230 self._request_timeout = self.config.get("ssh_request_timeout", self._REQUEST_TIMEOUT) 

231 self._request_timeout = nullable(float, self._request_timeout) 

232 

233 # Prep an initial connect_params. 

234 self._connect_params: dict = { 

235 # In general scripted commands shouldn't need a pty and having one 

236 # available can confuse some commands, though we may need to make 

237 # this configurable in the future. 

238 'request_pty': False, 

239 # By default disable known_hosts checking (since most VMs expected to be dynamically created). 

240 'known_hosts': None, 

241 } 

242 

243 if 'ssh_known_hosts_file' in self.config: 

244 self._connect_params['known_hosts'] = self.config.get("ssh_known_hosts_file", None) 

245 if isinstance(self._connect_params['known_hosts'], str): 

246 known_hosts_file = os.path.expanduser(self._connect_params['known_hosts']) 

247 if not os.path.exists(known_hosts_file): 

248 raise ValueError(f"ssh_known_hosts_file {known_hosts_file} does not exist") 

249 self._connect_params['known_hosts'] = known_hosts_file 

250 if self._connect_params['known_hosts'] is None: 

251 _LOG.info("%s known_hosts checking is disabled per config.", self) 

252 

253 if 'ssh_keepalive_interval' in self.config: 

254 keepalive_internal = self.config.get('ssh_keepalive_interval') 

255 self._connect_params['keepalive_interval'] = nullable(int, keepalive_internal) 

256 

257 def _enter_context(self) -> "SshService": 

258 # Start the background thread if it's not already running. 

259 assert not self._in_context 

260 SshService._EVENT_LOOP_CONTEXT.enter() 

261 SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.enter() 

262 super()._enter_context() 

263 return self 

264 

265 def _exit_context(self, ex_type: Optional[Type[BaseException]], 

266 ex_val: Optional[BaseException], 

267 ex_tb: Optional[TracebackType]) -> Literal[False]: 

268 # Stop the background thread if it's not needed anymore and potentially 

269 # cleanup the cache as well. 

270 assert self._in_context 

271 SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.exit() 

272 SshService._EVENT_LOOP_CONTEXT.exit() 

273 return super()._exit_context(ex_type, ex_val, ex_tb) 

274 

275 @classmethod 

276 def clear_client_cache(cls) -> None: 

277 """ 

278 Clears the cache of client connections. 

279 Note: This may cause in flight operations to fail. 

280 """ 

281 cls._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.cleanup() 

282 

283 def _run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType: 

284 """ 

285 Runs the given coroutine in the background event loop thread. 

286 

287 Parameters 

288 ---------- 

289 coro : Coroutine[Any, Any, CoroReturnType] 

290 The coroutine to run. 

291 

292 Returns 

293 ------- 

294 Future[CoroReturnType] 

295 A future that will be completed when the coroutine completes. 

296 """ 

297 assert self._in_context 

298 return self._EVENT_LOOP_CONTEXT.run_coroutine(coro) 

299 

300 def _get_connect_params(self, params: dict) -> dict: 

301 """ 

302 Produces a dict of connection parameters for asyncssh.create_connection. 

303 

304 Parameters 

305 ---------- 

306 params : dict 

307 Additional connection parameters specific to this host. 

308 

309 Returns 

310 ------- 

311 dict 

312 A dict of connection parameters for asyncssh.create_connection. 

313 """ 

314 # Setup default connect_params dict for all SshClients we might need to create. 

315 

316 # Note: None is an acceptable value for several of these, in which case 

317 # reasonable defaults or values from ~/.ssh/config will take effect. 

318 

319 # Start with the base config params. 

320 connect_params = self._connect_params.copy() 

321 

322 connect_params['host'] = params['ssh_hostname'] # required 

323 

324 if params.get('ssh_port'): 

325 connect_params['port'] = int(params.pop('ssh_port')) 

326 elif self.config['ssh_port']: 

327 connect_params['port'] = int(self.config['ssh_port']) 

328 

329 if 'ssh_username' in params: 

330 connect_params['username'] = str(params.pop('ssh_username')) 

331 elif self.config['ssh_username']: 

332 connect_params['username'] = str(self.config['ssh_username']) 

333 

334 priv_key_file: Optional[str] = params.get('ssh_priv_key_path', self.config['ssh_priv_key_path']) 

335 if priv_key_file: 

336 priv_key_file = os.path.expanduser(priv_key_file) 

337 if not os.path.exists(priv_key_file): 

338 raise ValueError(f"ssh_priv_key_path {priv_key_file} does not exist") 

339 connect_params['client_keys'] = [priv_key_file] 

340 

341 return connect_params 

342 

343 async def _get_client_connection(self, params: dict) -> Tuple[SSHClientConnection, SshClient]: 

344 """ 

345 Gets a (possibly cached) SshClient (connection) for the given connection params. 

346 

347 Parameters 

348 ---------- 

349 params : dict 

350 Optional override connection parameters. 

351 

352 Returns 

353 ------- 

354 Tuple[SSHClientConnection, SshClient] 

355 The connection and client objects. 

356 """ 

357 assert self._in_context 

358 return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection(self._get_connect_params(params))