Coverage for mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py: 89%

57 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 00:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""A collection Service functions for managing hosts via SSH.""" 

6 

7import logging 

8from collections.abc import Callable, Iterable 

9from concurrent.futures import Future 

10from typing import Any 

11 

12from asyncssh import ConnectionLost, DisconnectError, ProcessError, SSHCompletedProcess 

13 

14from mlos_bench.environments.status import Status 

15from mlos_bench.services.base_service import Service 

16from mlos_bench.services.remote.ssh.ssh_service import SshService 

17from mlos_bench.services.types.os_ops_type import SupportsOSOps 

18from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec 

19from mlos_bench.util import merge_parameters 

20 

21_LOG = logging.getLogger(__name__) 

22 

23 

24class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec): 

25 """Helper methods to manage machines via SSH.""" 

26 

27 # pylint: disable=too-many-ancestors 

28 # pylint: disable=too-many-instance-attributes 

29 

30 def __init__( 

31 self, 

32 config: dict[str, Any] | None = None, 

33 global_config: dict[str, Any] | None = None, 

34 parent: Service | None = None, 

35 methods: dict[str, Callable] | list[Callable] | None = None, 

36 ): 

37 """ 

38 Create a new instance of an SSH Service. 

39 

40 Parameters 

41 ---------- 

42 config : dict 

43 Free-format dictionary that contains the benchmark environment 

44 configuration. 

45 global_config : dict 

46 Free-format dictionary of global parameters. 

47 parent : Service 

48 Parent service that can provide mixin functions. 

49 methods : Union[dict[str, Callable], list[Callable], None] 

50 New methods to register with the service. 

51 """ 

52 # Same methods are also provided by the AzureVMService class 

53 # pylint: disable=duplicate-code 

54 super().__init__( 

55 config, 

56 global_config, 

57 parent, 

58 self.merge_methods( 

59 methods, 

60 [ 

61 self.shutdown, 

62 self.reboot, 

63 self.wait_os_operation, 

64 self.remote_exec, 

65 self.get_remote_exec_results, 

66 ], 

67 ), 

68 ) 

69 self._shell = self.config.get("ssh_shell", "/bin/bash") 

70 

71 async def _run_cmd( 

72 self, 

73 params: dict, 

74 script: Iterable[str], 

75 env_params: dict, 

76 ) -> SSHCompletedProcess: 

77 """ 

78 Runs a command asynchronously on a host via SSH. 

79 

80 Parameters 

81 ---------- 

82 params : dict 

83 Flat dictionary of (key, value) pairs of parameters (used for 

84 establishing the connection). 

85 cmd : str 

86 Command(s) to run via shell. 

87 

88 Returns 

89 ------- 

90 SSHCompletedProcess 

91 Returns the result of the command. 

92 """ 

93 if isinstance(script, str): 

94 # Script should be an iterable of lines, not an iterable string. 

95 script = [script] 

96 connection, _ = await self._get_client_connection(params) 

97 # Note: passing environment variables to SSH servers is typically restricted 

98 # to just some LC_* values. 

99 # Handle transferring environment variables by making a script to set them. 

100 env_script_lines = [f"export {name}='{value}'" for (name, value) in env_params.items()] 

101 script_lines = env_script_lines + [ 

102 line_split for line in script for line_split in line.splitlines() 

103 ] 

104 # Note: connection.run() uses "exec" with a shell by default. 

105 script_str = "\n".join(script_lines) 

106 _LOG.debug("Running script on %s:\n%s", connection, script_str) 

107 return await connection.run( 

108 script_str, 

109 check=False, 

110 timeout=self._request_timeout, 

111 env=env_params, 

112 ) 

113 

114 def remote_exec( 

115 self, 

116 script: Iterable[str], 

117 config: dict, 

118 env_params: dict, 

119 ) -> tuple["Status", dict]: 

120 """ 

121 Start running a command on remote host OS. 

122 

123 Parameters 

124 ---------- 

125 script : Iterable[str] 

126 A list of lines to execute as a script on a remote VM. 

127 config : dict 

128 Flat dictionary of (key, value) pairs of parameters. 

129 They usually come from `const_args` and `tunable_params` 

130 properties of the Environment. 

131 env_params : dict 

132 Parameters to pass as *shell* environment variables into the script. 

133 This is usually a subset of `config` with some possible conversions. 

134 

135 Returns 

136 ------- 

137 result : (Status, dict) 

138 A pair of Status and result. 

139 Status is one of {PENDING, SUCCEEDED, FAILED} 

140 """ 

141 config = merge_parameters( 

142 dest=self.config.copy(), 

143 source=config, 

144 required_keys=[ 

145 "ssh_hostname", 

146 ], 

147 ) 

148 config["asyncRemoteExecResultsFuture"] = self._run_coroutine( 

149 self._run_cmd( 

150 config, 

151 script, 

152 env_params, 

153 ) 

154 ) 

155 return (Status.PENDING, config) 

156 

157 def get_remote_exec_results(self, config: dict) -> tuple["Status", dict]: 

158 """ 

159 Get the results of the asynchronously running command. 

160 

161 Parameters 

162 ---------- 

163 config : dict 

164 Flat dictionary of (key, value) pairs of tunable parameters. 

165 Must have the "asyncRemoteExecResultsFuture" key to get the results. 

166 If the key is not present, return Status.PENDING. 

167 

168 Returns 

169 ------- 

170 result : (Status, dict) 

171 A pair of Status and result. 

172 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} 

173 """ 

174 future = config.get("asyncRemoteExecResultsFuture") 

175 if not future: 

176 raise ValueError("Missing 'asyncRemoteExecResultsFuture'.") 

177 assert isinstance(future, Future) 

178 result = None 

179 try: 

180 result = future.result(timeout=self._request_timeout) 

181 assert isinstance(result, SSHCompletedProcess) 

182 stdout = result.stdout.decode() if isinstance(result.stdout, bytes) else result.stdout 

183 stderr = result.stderr.decode() if isinstance(result.stderr, bytes) else result.stderr 

184 return ( 

185 ( 

186 Status.SUCCEEDED 

187 if result.exit_status == 0 and result.returncode == 0 

188 else Status.FAILED 

189 ), 

190 { 

191 "stdout": stdout, 

192 "stderr": stderr, 

193 "ssh_completed_process_result": result, 

194 }, 

195 ) 

196 except (ConnectionLost, DisconnectError, ProcessError, TimeoutError) as ex: 

197 _LOG.error("Failed to get remote exec results: %s", ex) 

198 return (Status.FAILED, {"result": result}) 

199 

200 def _exec_os_op(self, cmd_opts_list: list[str], params: dict) -> tuple[Status, dict]: 

201 """ 

202 _summary_ 

203 

204 Parameters 

205 ---------- 

206 cmd_opts_list : list[str] 

207 List of commands to try to execute. 

208 params : dict 

209 The params used to connect to the host. 

210 

211 Returns 

212 ------- 

213 result : (Status, dict) 

214 A pair of Status and result. 

215 Status is one of {PENDING, SUCCEEDED, FAILED} 

216 """ 

217 config = merge_parameters( 

218 dest=self.config.copy(), 

219 source=params, 

220 required_keys=[ 

221 "ssh_hostname", 

222 ], 

223 ) 

224 cmd_opts = " ".join([f"'{cmd}'" for cmd in cmd_opts_list]) 

225 script = rf""" 

226 if [[ $EUID -ne 0 ]]; then 

227 sudo=$(command -v sudo) 

228 sudo=${ sudo:+$sudo -n} 

229 fi 

230 

231 set -x 

232 for cmd in {cmd_opts}; do 

233 $sudo /bin/bash -c "$cmd" && exit 0 

234 done 

235 

236 echo 'ERROR: Failed to shutdown/reboot the system.' 

237 exit 1 

238 """ 

239 return self.remote_exec(script, config, env_params={}) 

240 

241 def shutdown(self, params: dict, force: bool = False) -> tuple[Status, dict]: 

242 """ 

243 Initiates a (graceful) shutdown of the Host/VM OS. 

244 

245 Parameters 

246 ---------- 

247 params: dict 

248 Flat dictionary of (key, value) pairs of tunable parameters. 

249 force : bool 

250 If True, force stop the Host/VM. 

251 

252 Returns 

253 ------- 

254 result : (Status, dict) 

255 A pair of Status and result. 

256 Status is one of {PENDING, SUCCEEDED, FAILED} 

257 """ 

258 cmd_opts_list = [ 

259 "shutdown -h now", 

260 "poweroff", 

261 "halt -p", 

262 "systemctl poweroff", 

263 ] 

264 return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params) 

265 

266 def reboot(self, params: dict, force: bool = False) -> tuple[Status, dict]: 

267 """ 

268 Initiates a (graceful) shutdown of the Host/VM OS. 

269 

270 Parameters 

271 ---------- 

272 params: dict 

273 Flat dictionary of (key, value) pairs of tunable parameters. 

274 force : bool 

275 If True, force restart the Host/VM. 

276 

277 Returns 

278 ------- 

279 result : (Status, dict) 

280 A pair of Status and result. 

281 Status is one of {PENDING, SUCCEEDED, FAILED} 

282 """ 

283 cmd_opts_list = [ 

284 "shutdown -r now", 

285 "reboot", 

286 "halt --reboot", 

287 "systemctl reboot", 

288 "kill -KILL 1; kill -KILL -1" if force else "kill -TERM 1; kill -TERM -1", 

289 ] 

290 return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params) 

291 

292 def wait_os_operation(self, params: dict) -> tuple[Status, dict]: 

293 """ 

294 Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED. Return 

295 TIMED_OUT when timing out. 

296 

297 Parameters 

298 ---------- 

299 params: dict 

300 Flat dictionary of (key, value) pairs of tunable parameters. 

301 Must have the "asyncRemoteExecResultsFuture" key to get the results. 

302 If the key is not present, return Status.PENDING. 

303 

304 Returns 

305 ------- 

306 result : (Status, dict) 

307 A pair of Status and result. 

308 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} 

309 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

310 """ 

311 return self.get_remote_exec_results(params)