Coverage for mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py: 89%

157 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""A collection functions for interacting with SSH servers as file shares.""" 

6 

7import logging 

8import os 

9from abc import ABCMeta 

10from asyncio import Event as CoroEvent 

11from asyncio import Lock as CoroLock 

12from threading import current_thread 

13from types import TracebackType 

14from typing import ( 

15 Any, 

16 Callable, 

17 Coroutine, 

18 Dict, 

19 List, 

20 Literal, 

21 Optional, 

22 Tuple, 

23 Type, 

24 Union, 

25) 

26from warnings import warn 

27 

28import asyncssh 

29from asyncssh.connection import SSHClientConnection 

30 

31from mlos_bench.event_loop_context import ( 

32 CoroReturnType, 

33 EventLoopContext, 

34 FutureReturnType, 

35) 

36from mlos_bench.services.base_service import Service 

37from mlos_bench.util import nullable 

38 

39_LOG = logging.getLogger(__name__) 

40 

41 

42class SshClient(asyncssh.SSHClient): 

43 """ 

44 Wrapper around SSHClient to help provide connection caching and reconnect logic. 

45 

46 Used by the SshService to try and maintain a single connection to hosts, handle 

47 reconnects if possible, and use that to run commands rather than reconnect for each 

48 command. 

49 """ 

50 

51 _CONNECTION_PENDING = "INIT" 

52 _CONNECTION_LOST = "LOST" 

53 

54 def __init__(self, *args: tuple, **kwargs: dict): 

55 self._connection_id: str = SshClient._CONNECTION_PENDING 

56 self._connection: Optional[SSHClientConnection] = None 

57 self._conn_event: CoroEvent = CoroEvent() 

58 super().__init__(*args, **kwargs) 

59 

60 def __repr__(self) -> str: 

61 return self._connection_id 

62 

63 @staticmethod 

64 def id_from_connection(connection: SSHClientConnection) -> str: 

65 """Gets a unique id repr for the connection.""" 

66 # pylint: disable=protected-access 

67 return f"{connection._username}@{connection._host}:{connection._port}" 

68 

69 @staticmethod 

70 def id_from_params(connect_params: dict) -> str: 

71 """Gets a unique id repr for the connection.""" 

72 return ( 

73 f"""{connect_params.get("username")}@{connect_params["host"]}""" 

74 f""":{connect_params.get("port")}""" 

75 ) 

76 

77 def connection_made(self, conn: SSHClientConnection) -> None: 

78 """ 

79 Override hook provided by asyncssh.SSHClient. 

80 

81 Changes the connection_id from _CONNECTION_PENDING to a unique id repr. 

82 """ 

83 self._conn_event.clear() 

84 _LOG.debug( 

85 "%s: Connection made by %s: %s", 

86 current_thread().name, 

87 conn._options.env, # pylint: disable=protected-access 

88 conn, 

89 ) 

90 self._connection_id = SshClient.id_from_connection(conn) 

91 self._connection = conn 

92 self._conn_event.set() 

93 return super().connection_made(conn) 

94 

95 def connection_lost(self, exc: Optional[Exception]) -> None: 

96 self._conn_event.clear() 

97 _LOG.debug("%s: %s", current_thread().name, "connection_lost") 

98 if exc is None: 

99 _LOG.debug( 

100 "%s: gracefully disconnected ssh from %s: %s", 

101 current_thread().name, 

102 self._connection_id, 

103 exc, 

104 ) 

105 else: 

106 _LOG.debug( 

107 "%s: ssh connection lost on %s: %s", 

108 current_thread().name, 

109 self._connection_id, 

110 exc, 

111 ) 

112 self._connection_id = SshClient._CONNECTION_LOST 

113 self._connection = None 

114 self._conn_event.set() 

115 return super().connection_lost(exc) 

116 

117 async def connection(self) -> Optional[SSHClientConnection]: 

118 """Waits for and returns the asyncssh.connection.SSHClientConnection to be 

119 established or lost. 

120 """ 

121 _LOG.debug("%s: Waiting for connection to be available.", current_thread().name) 

122 await self._conn_event.wait() 

123 _LOG.debug("%s: Connection available for %s", current_thread().name, self._connection_id) 

124 return self._connection 

125 

126 

127class SshClientCache: 

128 """ 

129 Manages a cache of SshClient connections. 

130 

131 Note: Only one per event loop thread supported. 

132 See additional details in SshService comments. 

133 """ 

134 

135 def __init__(self) -> None: 

136 self._cache: Dict[str, Tuple[SSHClientConnection, SshClient]] = {} 

137 self._cache_lock = CoroLock() 

138 self._refcnt: int = 0 

139 

140 def __str__(self) -> str: 

141 return str(self._cache) 

142 

143 def __len__(self) -> int: 

144 return len(self._cache) 

145 

146 def enter(self) -> None: 

147 """ 

148 Manages the cache lifecycle with reference counting. 

149 

150 To be used in the __enter__ method of a caller's context manager. 

151 """ 

152 self._refcnt += 1 

153 

154 def exit(self) -> None: 

155 """ 

156 Manages the cache lifecycle with reference counting. 

157 

158 To be used in the __exit__ method of a caller's context manager. 

159 """ 

160 self._refcnt -= 1 

161 if self._refcnt <= 0: 

162 self.cleanup() 

163 if self._cache_lock.locked(): 

164 warn(RuntimeWarning("SshClientCache lock was still held on exit.")) 

165 self._cache_lock.release() 

166 

167 async def get_client_connection( 

168 self, 

169 connect_params: dict, 

170 ) -> Tuple[SSHClientConnection, SshClient]: 

171 """ 

172 Gets a (possibly cached) client connection. 

173 

174 Parameters 

175 ---------- 

176 connect_params: dict 

177 Parameters to pass to asyncssh.create_connection. 

178 

179 Returns 

180 ------- 

181 Tuple[asyncssh.connection.SSHClientConnection, SshClient] 

182 A tuple of (SSHClientConnection, SshClient). 

183 """ 

184 _LOG.debug("%s: get_client_connection: %s", current_thread().name, connect_params) 

185 async with self._cache_lock: 

186 connection_id = SshClient.id_from_params(connect_params) 

187 client: Union[None, SshClient, asyncssh.SSHClient] 

188 _, client = self._cache.get(connection_id, (None, None)) 

189 if client: 

190 _LOG.debug("%s: Checking cached client %s", current_thread().name, connection_id) 

191 connection = await client.connection() 

192 if not connection: 

193 _LOG.debug( 

194 "%s: Removing stale client connection %s from cache.", 

195 current_thread().name, 

196 connection_id, 

197 ) 

198 self._cache.pop(connection_id) 

199 # Try to reconnect next. 

200 else: 

201 _LOG.debug("%s: Using cached client %s", current_thread().name, connection_id) 

202 if connection_id not in self._cache: 

203 _LOG.debug( 

204 "%s: Establishing client connection to %s", 

205 current_thread().name, 

206 connection_id, 

207 ) 

208 connection, client = await asyncssh.create_connection(SshClient, **connect_params) 

209 assert isinstance(client, SshClient) 

210 self._cache[connection_id] = (connection, client) 

211 _LOG.debug("%s: Created connection to %s.", current_thread().name, connection_id) 

212 return self._cache[connection_id] 

213 

214 def cleanup(self) -> None: 

215 """Closes all cached connections.""" 

216 for connection, _ in self._cache.values(): 

217 connection.close() 

218 self._cache = {} 

219 

220 

221class SshService(Service, metaclass=ABCMeta): 

222 """Base class for SSH services.""" 

223 

224 # AsyncSSH requires an asyncio event loop to be running to work. 

225 # However, running that event loop blocks the main thread. 

226 # To avoid having to change our entire API to use async/await, all the way 

227 # up the stack, we run the event loop that runs any async code in a 

228 # background thread and submit async code to it using 

229 # asyncio.run_coroutine_threadsafe, interacting with Futures after that. 

230 # This is a bit of a hack, but it works for now. 

231 # 

232 # The event loop is created on demand and shared across all SshService 

233 # instances, hence we need to lock it when doing the creation/cleanup, 

234 # or later, during context enter and exit. 

235 # 

236 # We ran tests to ensure that multiple requests can still be executing 

237 # concurrently inside that event loop so there should be no practical 

238 # performance loss for our initial cases even with just single background 

239 # thread running the event loop. 

240 # 

241 # Note: the tests were run to confirm that this works with two threads. 

242 # Using a larger thread pool requires a bit more work since asyncssh 

243 # requires that run() requests are submitted to the same event loop handler 

244 # that the connection was made on. 

245 # In that case, each background thread should get its own SshClientCache. 

246 

247 # Maintain one just one event loop thread for all SshService instances. 

248 # But only keep it running while they are within a context. 

249 _EVENT_LOOP_CONTEXT = EventLoopContext() 

250 _EVENT_LOOP_THREAD_SSH_CLIENT_CACHE = SshClientCache() 

251 

252 _REQUEST_TIMEOUT: Optional[float] = None # seconds 

253 

254 def __init__( 

255 self, 

256 config: Optional[Dict[str, Any]] = None, 

257 global_config: Optional[Dict[str, Any]] = None, 

258 parent: Optional[Service] = None, 

259 methods: Union[Dict[str, Callable], List[Callable], None] = None, 

260 ): 

261 super().__init__(config, global_config, parent, methods) 

262 

263 # Make sure that the value we allow overriding on a per-connection 

264 # basis are present in the config so merge_parameters can do its thing. 

265 self.config.setdefault("ssh_port", None) 

266 assert isinstance(self.config["ssh_port"], (int, type(None))) 

267 self.config.setdefault("ssh_username", None) 

268 assert isinstance(self.config["ssh_username"], (str, type(None))) 

269 self.config.setdefault("ssh_priv_key_path", None) 

270 assert isinstance(self.config["ssh_priv_key_path"], (str, type(None))) 

271 

272 # None can be used to disable the request timeout. 

273 self._request_timeout = self.config.get("ssh_request_timeout", self._REQUEST_TIMEOUT) 

274 self._request_timeout = nullable(float, self._request_timeout) 

275 

276 # Prep an initial connect_params. 

277 self._connect_params: dict = { 

278 # In general scripted commands shouldn't need a pty and having one 

279 # available can confuse some commands, though we may need to make 

280 # this configurable in the future. 

281 "request_pty": False, 

282 # By default disable known_hosts checking (since most VMs expected to be 

283 # dynamically created). 

284 "known_hosts": None, 

285 } 

286 

287 if "ssh_known_hosts_file" in self.config: 

288 self._connect_params["known_hosts"] = self.config.get("ssh_known_hosts_file", None) 

289 if isinstance(self._connect_params["known_hosts"], str): 

290 known_hosts_file = os.path.expanduser(self._connect_params["known_hosts"]) 

291 if not os.path.exists(known_hosts_file): 

292 raise ValueError(f"ssh_known_hosts_file {known_hosts_file} does not exist") 

293 self._connect_params["known_hosts"] = known_hosts_file 

294 if self._connect_params["known_hosts"] is None: 

295 _LOG.info("%s known_hosts checking is disabled per config.", self) 

296 

297 if "ssh_keepalive_interval" in self.config: 

298 keepalive_internal = self.config.get("ssh_keepalive_interval") 

299 self._connect_params["keepalive_interval"] = nullable(int, keepalive_internal) 

300 

301 def _enter_context(self) -> "SshService": 

302 # Start the background thread if it's not already running. 

303 assert not self._in_context 

304 SshService._EVENT_LOOP_CONTEXT.enter() 

305 SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.enter() 

306 super()._enter_context() 

307 return self 

308 

309 def _exit_context( 

310 self, 

311 ex_type: Optional[Type[BaseException]], 

312 ex_val: Optional[BaseException], 

313 ex_tb: Optional[TracebackType], 

314 ) -> Literal[False]: 

315 # Stop the background thread if it's not needed anymore and potentially 

316 # cleanup the cache as well. 

317 assert self._in_context 

318 SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.exit() 

319 SshService._EVENT_LOOP_CONTEXT.exit() 

320 return super()._exit_context(ex_type, ex_val, ex_tb) 

321 

322 @classmethod 

323 def clear_client_cache(cls) -> None: 

324 """ 

325 Clears the cache of client connections. 

326 

327 Note: This may cause in flight operations to fail. 

328 """ 

329 cls._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.cleanup() 

330 

331 def _run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType: 

332 """ 

333 Runs the given coroutine in the background event loop thread. 

334 

335 Parameters 

336 ---------- 

337 coro : Coroutine[Any, Any, CoroReturnType] 

338 The coroutine to run. 

339 

340 Returns 

341 ------- 

342 Future[CoroReturnType] 

343 A future that will be completed when the coroutine completes. 

344 """ 

345 assert self._in_context 

346 return self._EVENT_LOOP_CONTEXT.run_coroutine(coro) 

347 

348 def _get_connect_params(self, params: dict) -> dict: 

349 """ 

350 Produces a dict of connection parameters for asyncssh.create_connection. 

351 

352 Parameters 

353 ---------- 

354 params : dict 

355 Additional connection parameters specific to this host. 

356 

357 Returns 

358 ------- 

359 dict 

360 A dict of connection parameters for asyncssh.create_connection. 

361 """ 

362 # Setup default connect_params dict for all SshClients we might need to create. 

363 

364 # Note: None is an acceptable value for several of these, in which case 

365 # reasonable defaults or values from ~/.ssh/config will take effect. 

366 

367 # Start with the base config params. 

368 connect_params = self._connect_params.copy() 

369 

370 connect_params["host"] = params["ssh_hostname"] # required 

371 

372 if params.get("ssh_port"): 

373 connect_params["port"] = int(params.pop("ssh_port")) 

374 elif self.config["ssh_port"]: 

375 connect_params["port"] = int(self.config["ssh_port"]) 

376 

377 if "ssh_username" in params: 

378 connect_params["username"] = str(params.pop("ssh_username")) 

379 elif self.config["ssh_username"]: 

380 connect_params["username"] = str(self.config["ssh_username"]) 

381 

382 priv_key_file: Optional[str] = params.get( 

383 "ssh_priv_key_path", 

384 self.config["ssh_priv_key_path"], 

385 ) 

386 if priv_key_file: 

387 priv_key_file = os.path.expanduser(priv_key_file) 

388 if not os.path.exists(priv_key_file): 

389 raise ValueError(f"ssh_priv_key_path {priv_key_file} does not exist") 

390 connect_params["client_keys"] = [priv_key_file] 

391 

392 return connect_params 

393 

394 async def _get_client_connection(self, params: dict) -> Tuple[SSHClientConnection, SshClient]: 

395 """ 

396 Gets a (possibly cached) SshClient (connection) for the given connection params. 

397 

398 Parameters 

399 ---------- 

400 params : dict 

401 Optional override connection parameters. 

402 

403 Returns 

404 ------- 

405 Tuple[SSHClientConnection, SshClient] 

406 The connection and client objects. 

407 """ 

408 assert self._in_context 

409 return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection( 

410 self._get_connect_params(params) 

411 )