Coverage for mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py: 98%

92 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-14 00:55 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tests for mlos_bench.services.remote.ssh.ssh_host_service.""" 

6 

7import logging 

8import time 

9from subprocess import CalledProcessError, run 

10 

11from pytest_docker.plugin import Services as DockerServices 

12 

13from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService 

14from mlos_bench.services.remote.ssh.ssh_service import SshClient 

15from mlos_bench.tests import requires_docker, wait_docker_service_socket 

16from mlos_bench.tests.services.remote.ssh import ( 

17 ALT_TEST_SERVER_NAME, 

18 REBOOT_TEST_SERVER_NAME, 

19 SSH_TEST_SERVER_NAME, 

20 SshTestServerInfo, 

21) 

22 

23_LOG = logging.getLogger(__name__) 

24 

25 

26@requires_docker 

27def test_ssh_service_remote_exec( 

28 ssh_test_server: SshTestServerInfo, 

29 alt_test_server: SshTestServerInfo, 

30 ssh_host_service: SshHostService, 

31) -> None: 

32 """ 

33 Test the SshHostService remote_exec. 

34 

35 This checks state of the service across multiple invocations and states to check for 

36 internal cache handling logic as well. 

37 """ 

38 # pylint: disable=protected-access 

39 with ssh_host_service: 

40 config = ssh_test_server.to_ssh_service_config() 

41 

42 connection_id = SshClient.id_from_params(ssh_test_server.to_connect_params()) 

43 assert ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None 

44 connection_client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache.get( 

45 connection_id 

46 ) 

47 assert connection_client is None 

48 

49 (status, results_info) = ssh_host_service.remote_exec( 

50 script=["hostname"], 

51 config=config, 

52 env_params={}, 

53 ) 

54 assert status.is_pending() 

55 assert "asyncRemoteExecResultsFuture" in results_info 

56 status, results = ssh_host_service.get_remote_exec_results(results_info) 

57 assert status.is_succeeded() 

58 assert results["stdout"].strip() == SSH_TEST_SERVER_NAME 

59 

60 # Check that the client caching is behaving as expected. 

61 connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[ 

62 connection_id 

63 ] 

64 assert connection is not None 

65 assert connection._username == ssh_test_server.username 

66 assert connection._host == ssh_test_server.hostname 

67 assert connection._port == ssh_test_server.get_port() 

68 local_port = connection._local_port 

69 assert local_port 

70 assert client is not None 

71 assert client._conn_event.is_set() 

72 

73 # Connect to a different server. 

74 (status, results_info) = ssh_host_service.remote_exec( 

75 script=["hostname"], 

76 config=alt_test_server.to_ssh_service_config(), 

77 env_params={ 

78 # unused, making sure it doesn't carry over with cached connections 

79 "UNUSED": "unused", 

80 }, 

81 ) 

82 assert status.is_pending() 

83 assert "asyncRemoteExecResultsFuture" in results_info 

84 status, results = ssh_host_service.get_remote_exec_results(results_info) 

85 assert status.is_succeeded() 

86 assert results["stdout"].strip() == ALT_TEST_SERVER_NAME 

87 

88 # Test reusing the existing connection. 

89 (status, results_info) = ssh_host_service.remote_exec( 

90 script=["echo BAR=$BAR && echo UNUSED=$UNUSED && false"], 

91 config=config, 

92 # Also test interacting with environment_variables. 

93 env_params={ 

94 "BAR": "bar", 

95 }, 

96 ) 

97 status, results = ssh_host_service.get_remote_exec_results(results_info) 

98 assert status.is_failed() # should retain exit code from "false" 

99 stdout = str(results["stdout"]) 

100 assert stdout.splitlines() == [ 

101 "BAR=bar", 

102 "UNUSED=", 

103 ] 

104 connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[ 

105 connection_id 

106 ] 

107 assert connection._local_port == local_port 

108 

109 # Close the connection (gracefully) 

110 connection.close() 

111 

112 # Try and reconnect and see if it detects the closed connection and starts over. 

113 (status, results_info) = ssh_host_service.remote_exec( 

114 script=[ 

115 # Test multi-string scripts. 

116 "echo FOO=$FOO\n", 

117 # Test multi-line strings. 

118 "echo BAR=$BAR\necho BAZ=$BAZ", 

119 ], 

120 config=config, 

121 # Also test interacting with environment_variables. 

122 env_params={ 

123 "FOO": "foo", 

124 }, 

125 ) 

126 status, results = ssh_host_service.get_remote_exec_results(results_info) 

127 assert status.is_succeeded() 

128 stdout = str(results["stdout"]) 

129 lines = stdout.splitlines() 

130 assert lines == [ 

131 "FOO=foo", 

132 "BAR=", 

133 "BAZ=", 

134 ] 

135 # Make sure it looks like we reconnected. 

136 connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[ 

137 connection_id 

138 ] 

139 assert connection._local_port != local_port 

140 

141 # Make sure the cache is cleaned up on context exit. 

142 assert len(SshHostService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE) == 0 

143 

144 

145def check_ssh_service_reboot( 

146 docker_services: DockerServices, 

147 reboot_test_server: SshTestServerInfo, 

148 ssh_host_service: SshHostService, 

149 graceful: bool, 

150) -> None: 

151 """Check the SshHostService reboot operation.""" 

152 # Note: rebooting changes the port number unfortunately, but makes it 

153 # easier to check for success. 

154 # Also, it may cause issues with other parallel unit tests, so we run it as 

155 # a part of the same unit test for now. 

156 with ssh_host_service: 

157 reboot_test_srv_ssh_svc_conf = reboot_test_server.to_ssh_service_config(uncached=True) 

158 (status, results_info) = ssh_host_service.remote_exec( 

159 script=['echo "sleeping..."', "sleep 30", 'echo "should not reach this point"'], 

160 config=reboot_test_srv_ssh_svc_conf, 

161 env_params={}, 

162 ) 

163 assert status.is_pending() 

164 # Wait a moment for that to start in the background thread. 

165 time.sleep(1) 

166 

167 # Now try to restart the server. 

168 (status, reboot_results_info) = ssh_host_service.reboot( 

169 params=reboot_test_srv_ssh_svc_conf, 

170 force=not graceful, 

171 ) 

172 assert status.is_pending() 

173 

174 (status, reboot_results_info) = ssh_host_service.wait_os_operation(reboot_results_info) 

175 # NOTE: reboot/shutdown ops mostly return FAILED, even though the reboot succeeds. 

176 _LOG.debug("reboot status: %s: %s", status, reboot_results_info) 

177 

178 # Check for decent error handling on disconnects. 

179 status, results = ssh_host_service.get_remote_exec_results(results_info) 

180 assert status.is_failed() 

181 stdout = str(results["stdout"]) 

182 assert "sleeping" in stdout 

183 assert "should not reach this point" not in stdout 

184 

185 reboot_test_srv_ssh_svc_conf_new: dict = {} 

186 for _ in range(0, 3): 

187 # Give docker some time to restart the service after the "reboot". 

188 # Note: this relies on having a `restart_policy` in the docker-compose.yml file. 

189 time.sleep(1) 

190 # try to reconnect and see if the port changed 

191 try: 

192 run_res = run( 

193 "docker ps | grep mlos_bench-test- | grep reboot", 

194 shell=True, 

195 capture_output=True, 

196 check=False, 

197 ) 

198 print(run_res.stdout.decode()) 

199 print(run_res.stderr.decode()) 

200 reboot_test_srv_ssh_svc_conf_new = reboot_test_server.to_ssh_service_config( 

201 uncached=True 

202 ) 

203 if ( 

204 reboot_test_srv_ssh_svc_conf_new["ssh_port"] 

205 != reboot_test_srv_ssh_svc_conf["ssh_port"] 

206 ): 

207 break 

208 except CalledProcessError as ex: 

209 _LOG.info("Failed to check port for reboot test server: %s", ex) 

210 assert ( 

211 reboot_test_srv_ssh_svc_conf_new["ssh_port"] 

212 != reboot_test_srv_ssh_svc_conf["ssh_port"] 

213 ) 

214 

215 wait_docker_service_socket( 

216 docker_services, 

217 reboot_test_server.hostname, 

218 reboot_test_srv_ssh_svc_conf_new["ssh_port"], 

219 ) 

220 

221 (status, results_info) = ssh_host_service.remote_exec( 

222 script=["hostname"], 

223 config=reboot_test_srv_ssh_svc_conf_new, 

224 env_params={}, 

225 ) 

226 status, results = ssh_host_service.get_remote_exec_results(results_info) 

227 assert status.is_succeeded() 

228 assert results["stdout"].strip() == REBOOT_TEST_SERVER_NAME 

229 

230 

231@requires_docker 

232def test_ssh_service_reboot( 

233 locked_docker_services: DockerServices, 

234 reboot_test_server: SshTestServerInfo, 

235 ssh_host_service: SshHostService, 

236) -> None: 

237 """Test the SshHostService reboot operation.""" 

238 # Grouped together to avoid parallel runner interactions. 

239 check_ssh_service_reboot( 

240 locked_docker_services, 

241 reboot_test_server, 

242 ssh_host_service, 

243 graceful=True, 

244 ) 

245 check_ssh_service_reboot( 

246 locked_docker_services, 

247 reboot_test_server, 

248 ssh_host_service, 

249 graceful=False, 

250 )