Coverage for mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py: 98%

92 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tests for mlos_bench.services.remote.ssh.ssh_host_service.""" 

6 

7import logging 

8import time 

9from subprocess import CalledProcessError, run 

10 

11from pytest_docker.plugin import Services as DockerServices 

12 

13from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService 

14from mlos_bench.services.remote.ssh.ssh_service import SshClient 

15from mlos_bench.tests import requires_docker 

16from mlos_bench.tests.services.remote.ssh import ( 

17 ALT_TEST_SERVER_NAME, 

18 REBOOT_TEST_SERVER_NAME, 

19 SSH_TEST_SERVER_NAME, 

20 SshTestServerInfo, 

21 wait_docker_service_socket, 

22) 

23 

24_LOG = logging.getLogger(__name__) 

25 

26 

27@requires_docker 

28def test_ssh_service_remote_exec( 

29 ssh_test_server: SshTestServerInfo, 

30 alt_test_server: SshTestServerInfo, 

31 ssh_host_service: SshHostService, 

32) -> None: 

33 """ 

34 Test the SshHostService remote_exec. 

35 

36 This checks state of the service across multiple invocations and states to check for 

37 internal cache handling logic as well. 

38 """ 

39 # pylint: disable=protected-access 

40 with ssh_host_service: 

41 config = ssh_test_server.to_ssh_service_config() 

42 

43 connection_id = SshClient.id_from_params(ssh_test_server.to_connect_params()) 

44 assert ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None 

45 connection_client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache.get( 

46 connection_id 

47 ) 

48 assert connection_client is None 

49 

50 (status, results_info) = ssh_host_service.remote_exec( 

51 script=["hostname"], 

52 config=config, 

53 env_params={}, 

54 ) 

55 assert status.is_pending() 

56 assert "asyncRemoteExecResultsFuture" in results_info 

57 status, results = ssh_host_service.get_remote_exec_results(results_info) 

58 assert status.is_succeeded() 

59 assert results["stdout"].strip() == SSH_TEST_SERVER_NAME 

60 

61 # Check that the client caching is behaving as expected. 

62 connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[ 

63 connection_id 

64 ] 

65 assert connection is not None 

66 assert connection._username == ssh_test_server.username 

67 assert connection._host == ssh_test_server.hostname 

68 assert connection._port == ssh_test_server.get_port() 

69 local_port = connection._local_port 

70 assert local_port 

71 assert client is not None 

72 assert client._conn_event.is_set() 

73 

74 # Connect to a different server. 

75 (status, results_info) = ssh_host_service.remote_exec( 

76 script=["hostname"], 

77 config=alt_test_server.to_ssh_service_config(), 

78 env_params={ 

79 # unused, making sure it doesn't carry over with cached connections 

80 "UNUSED": "unused", 

81 }, 

82 ) 

83 assert status.is_pending() 

84 assert "asyncRemoteExecResultsFuture" in results_info 

85 status, results = ssh_host_service.get_remote_exec_results(results_info) 

86 assert status.is_succeeded() 

87 assert results["stdout"].strip() == ALT_TEST_SERVER_NAME 

88 

89 # Test reusing the existing connection. 

90 (status, results_info) = ssh_host_service.remote_exec( 

91 script=["echo BAR=$BAR && echo UNUSED=$UNUSED && false"], 

92 config=config, 

93 # Also test interacting with environment_variables. 

94 env_params={ 

95 "BAR": "bar", 

96 }, 

97 ) 

98 status, results = ssh_host_service.get_remote_exec_results(results_info) 

99 assert status.is_failed() # should retain exit code from "false" 

100 stdout = str(results["stdout"]) 

101 assert stdout.splitlines() == [ 

102 "BAR=bar", 

103 "UNUSED=", 

104 ] 

105 connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[ 

106 connection_id 

107 ] 

108 assert connection._local_port == local_port 

109 

110 # Close the connection (gracefully) 

111 connection.close() 

112 

113 # Try and reconnect and see if it detects the closed connection and starts over. 

114 (status, results_info) = ssh_host_service.remote_exec( 

115 script=[ 

116 # Test multi-string scripts. 

117 "echo FOO=$FOO\n", 

118 # Test multi-line strings. 

119 "echo BAR=$BAR\necho BAZ=$BAZ", 

120 ], 

121 config=config, 

122 # Also test interacting with environment_variables. 

123 env_params={ 

124 "FOO": "foo", 

125 }, 

126 ) 

127 status, results = ssh_host_service.get_remote_exec_results(results_info) 

128 assert status.is_succeeded() 

129 stdout = str(results["stdout"]) 

130 lines = stdout.splitlines() 

131 assert lines == [ 

132 "FOO=foo", 

133 "BAR=", 

134 "BAZ=", 

135 ] 

136 # Make sure it looks like we reconnected. 

137 connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[ 

138 connection_id 

139 ] 

140 assert connection._local_port != local_port 

141 

142 # Make sure the cache is cleaned up on context exit. 

143 assert len(SshHostService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE) == 0 

144 

145 

146def check_ssh_service_reboot( 

147 docker_services: DockerServices, 

148 reboot_test_server: SshTestServerInfo, 

149 ssh_host_service: SshHostService, 

150 graceful: bool, 

151) -> None: 

152 """Check the SshHostService reboot operation.""" 

153 # Note: rebooting changes the port number unfortunately, but makes it 

154 # easier to check for success. 

155 # Also, it may cause issues with other parallel unit tests, so we run it as 

156 # a part of the same unit test for now. 

157 with ssh_host_service: 

158 reboot_test_srv_ssh_svc_conf = reboot_test_server.to_ssh_service_config(uncached=True) 

159 (status, results_info) = ssh_host_service.remote_exec( 

160 script=['echo "sleeping..."', "sleep 30", 'echo "should not reach this point"'], 

161 config=reboot_test_srv_ssh_svc_conf, 

162 env_params={}, 

163 ) 

164 assert status.is_pending() 

165 # Wait a moment for that to start in the background thread. 

166 time.sleep(1) 

167 

168 # Now try to restart the server. 

169 (status, reboot_results_info) = ssh_host_service.reboot( 

170 params=reboot_test_srv_ssh_svc_conf, 

171 force=not graceful, 

172 ) 

173 assert status.is_pending() 

174 

175 (status, reboot_results_info) = ssh_host_service.wait_os_operation(reboot_results_info) 

176 # NOTE: reboot/shutdown ops mostly return FAILED, even though the reboot succeeds. 

177 _LOG.debug("reboot status: %s: %s", status, reboot_results_info) 

178 

179 # Check for decent error handling on disconnects. 

180 status, results = ssh_host_service.get_remote_exec_results(results_info) 

181 assert status.is_failed() 

182 stdout = str(results["stdout"]) 

183 assert "sleeping" in stdout 

184 assert "should not reach this point" not in stdout 

185 

186 reboot_test_srv_ssh_svc_conf_new: dict = {} 

187 for _ in range(0, 3): 

188 # Give docker some time to restart the service after the "reboot". 

189 # Note: this relies on having a `restart_policy` in the docker-compose.yml file. 

190 time.sleep(1) 

191 # try to reconnect and see if the port changed 

192 try: 

193 run_res = run( 

194 "docker ps | grep mlos_bench-test- | grep reboot", 

195 shell=True, 

196 capture_output=True, 

197 check=False, 

198 ) 

199 print(run_res.stdout.decode()) 

200 print(run_res.stderr.decode()) 

201 reboot_test_srv_ssh_svc_conf_new = reboot_test_server.to_ssh_service_config( 

202 uncached=True 

203 ) 

204 if ( 

205 reboot_test_srv_ssh_svc_conf_new["ssh_port"] 

206 != reboot_test_srv_ssh_svc_conf["ssh_port"] 

207 ): 

208 break 

209 except CalledProcessError as ex: 

210 _LOG.info("Failed to check port for reboot test server: %s", ex) 

211 assert ( 

212 reboot_test_srv_ssh_svc_conf_new["ssh_port"] 

213 != reboot_test_srv_ssh_svc_conf["ssh_port"] 

214 ) 

215 

216 wait_docker_service_socket( 

217 docker_services, 

218 reboot_test_server.hostname, 

219 reboot_test_srv_ssh_svc_conf_new["ssh_port"], 

220 ) 

221 

222 (status, results_info) = ssh_host_service.remote_exec( 

223 script=["hostname"], 

224 config=reboot_test_srv_ssh_svc_conf_new, 

225 env_params={}, 

226 ) 

227 status, results = ssh_host_service.get_remote_exec_results(results_info) 

228 assert status.is_succeeded() 

229 assert results["stdout"].strip() == REBOOT_TEST_SERVER_NAME 

230 

231 

232@requires_docker 

233def test_ssh_service_reboot( 

234 locked_docker_services: DockerServices, 

235 reboot_test_server: SshTestServerInfo, 

236 ssh_host_service: SshHostService, 

237) -> None: 

238 """Test the SshHostService reboot operation.""" 

239 # Grouped together to avoid parallel runner interactions. 

240 check_ssh_service_reboot( 

241 locked_docker_services, 

242 reboot_test_server, 

243 ssh_host_service, 

244 graceful=True, 

245 ) 

246 check_ssh_service_reboot( 

247 locked_docker_services, 

248 reboot_test_server, 

249 ssh_host_service, 

250 graceful=False, 

251 )