Coverage for mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py: 89%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""A collection Service functions for managing hosts via SSH.""" 

6 

7import logging 

8from concurrent.futures import Future 

9from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union 

10 

11from asyncssh import ConnectionLost, DisconnectError, ProcessError, SSHCompletedProcess 

12 

13from mlos_bench.environments.status import Status 

14from mlos_bench.services.base_service import Service 

15from mlos_bench.services.remote.ssh.ssh_service import SshService 

16from mlos_bench.services.types.os_ops_type import SupportsOSOps 

17from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec 

18from mlos_bench.util import merge_parameters 

19 

20_LOG = logging.getLogger(__name__) 

21 

22 

23class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec): 

24 """Helper methods to manage machines via SSH.""" 

25 

26 # pylint: disable=too-many-instance-attributes 

27 

28 def __init__( 

29 self, 

30 config: Optional[Dict[str, Any]] = None, 

31 global_config: Optional[Dict[str, Any]] = None, 

32 parent: Optional[Service] = None, 

33 methods: Union[Dict[str, Callable], List[Callable], None] = None, 

34 ): 

35 """ 

36 Create a new instance of an SSH Service. 

37 

38 Parameters 

39 ---------- 

40 config : dict 

41 Free-format dictionary that contains the benchmark environment 

42 configuration. 

43 global_config : dict 

44 Free-format dictionary of global parameters. 

45 parent : Service 

46 Parent service that can provide mixin functions. 

47 methods : Union[Dict[str, Callable], List[Callable], None] 

48 New methods to register with the service. 

49 """ 

50 # Same methods are also provided by the AzureVMService class 

51 # pylint: disable=duplicate-code 

52 super().__init__( 

53 config, 

54 global_config, 

55 parent, 

56 self.merge_methods( 

57 methods, 

58 [ 

59 self.shutdown, 

60 self.reboot, 

61 self.wait_os_operation, 

62 self.remote_exec, 

63 self.get_remote_exec_results, 

64 ], 

65 ), 

66 ) 

67 self._shell = self.config.get("ssh_shell", "/bin/bash") 

68 

69 async def _run_cmd( 

70 self, 

71 params: dict, 

72 script: Iterable[str], 

73 env_params: dict, 

74 ) -> SSHCompletedProcess: 

75 """ 

76 Runs a command asynchronously on a host via SSH. 

77 

78 Parameters 

79 ---------- 

80 params : dict 

81 Flat dictionary of (key, value) pairs of parameters (used for 

82 establishing the connection). 

83 cmd : str 

84 Command(s) to run via shell. 

85 

86 Returns 

87 ------- 

88 SSHCompletedProcess 

89 Returns the result of the command. 

90 """ 

91 if isinstance(script, str): 

92 # Script should be an iterable of lines, not an iterable string. 

93 script = [script] 

94 connection, _ = await self._get_client_connection(params) 

95 # Note: passing environment variables to SSH servers is typically restricted 

96 # to just some LC_* values. 

97 # Handle transferring environment variables by making a script to set them. 

98 env_script_lines = [f"export {name}='{value}'" for (name, value) in env_params.items()] 

99 script_lines = env_script_lines + [ 

100 line_split for line in script for line_split in line.splitlines() 

101 ] 

102 # Note: connection.run() uses "exec" with a shell by default. 

103 script_str = "\n".join(script_lines) 

104 _LOG.debug("Running script on %s:\n%s", connection, script_str) 

105 return await connection.run( 

106 script_str, 

107 check=False, 

108 timeout=self._request_timeout, 

109 env=env_params, 

110 ) 

111 

112 def remote_exec( 

113 self, 

114 script: Iterable[str], 

115 config: dict, 

116 env_params: dict, 

117 ) -> Tuple["Status", dict]: 

118 """ 

119 Start running a command on remote host OS. 

120 

121 Parameters 

122 ---------- 

123 script : Iterable[str] 

124 A list of lines to execute as a script on a remote VM. 

125 config : dict 

126 Flat dictionary of (key, value) pairs of parameters. 

127 They usually come from `const_args` and `tunable_params` 

128 properties of the Environment. 

129 env_params : dict 

130 Parameters to pass as *shell* environment variables into the script. 

131 This is usually a subset of `config` with some possible conversions. 

132 

133 Returns 

134 ------- 

135 result : (Status, dict) 

136 A pair of Status and result. 

137 Status is one of {PENDING, SUCCEEDED, FAILED} 

138 """ 

139 config = merge_parameters( 

140 dest=self.config.copy(), 

141 source=config, 

142 required_keys=[ 

143 "ssh_hostname", 

144 ], 

145 ) 

146 config["asyncRemoteExecResultsFuture"] = self._run_coroutine( 

147 self._run_cmd( 

148 config, 

149 script, 

150 env_params, 

151 ) 

152 ) 

153 return (Status.PENDING, config) 

154 

155 def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]: 

156 """ 

157 Get the results of the asynchronously running command. 

158 

159 Parameters 

160 ---------- 

161 config : dict 

162 Flat dictionary of (key, value) pairs of tunable parameters. 

163 Must have the "asyncRemoteExecResultsFuture" key to get the results. 

164 If the key is not present, return Status.PENDING. 

165 

166 Returns 

167 ------- 

168 result : (Status, dict) 

169 A pair of Status and result. 

170 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} 

171 """ 

172 future = config.get("asyncRemoteExecResultsFuture") 

173 if not future: 

174 raise ValueError("Missing 'asyncRemoteExecResultsFuture'.") 

175 assert isinstance(future, Future) 

176 result = None 

177 try: 

178 result = future.result(timeout=self._request_timeout) 

179 assert isinstance(result, SSHCompletedProcess) 

180 stdout = result.stdout.decode() if isinstance(result.stdout, bytes) else result.stdout 

181 stderr = result.stderr.decode() if isinstance(result.stderr, bytes) else result.stderr 

182 return ( 

183 ( 

184 Status.SUCCEEDED 

185 if result.exit_status == 0 and result.returncode == 0 

186 else Status.FAILED 

187 ), 

188 { 

189 "stdout": stdout, 

190 "stderr": stderr, 

191 "ssh_completed_process_result": result, 

192 }, 

193 ) 

194 except (ConnectionLost, DisconnectError, ProcessError, TimeoutError) as ex: 

195 _LOG.error("Failed to get remote exec results: %s", ex) 

196 return (Status.FAILED, {"result": result}) 

197 

198 def _exec_os_op(self, cmd_opts_list: List[str], params: dict) -> Tuple[Status, dict]: 

199 """ 

200 _summary_ 

201 

202 Parameters 

203 ---------- 

204 cmd_opts_list : List[str] 

205 List of commands to try to execute. 

206 params : dict 

207 The params used to connect to the host. 

208 

209 Returns 

210 ------- 

211 result : (Status, dict) 

212 A pair of Status and result. 

213 Status is one of {PENDING, SUCCEEDED, FAILED} 

214 """ 

215 config = merge_parameters( 

216 dest=self.config.copy(), 

217 source=params, 

218 required_keys=[ 

219 "ssh_hostname", 

220 ], 

221 ) 

222 cmd_opts = " ".join([f"'{cmd}'" for cmd in cmd_opts_list]) 

223 script = rf""" 

224 if [[ $EUID -ne 0 ]]; then 

225 sudo=$(command -v sudo) 

226 sudo=${ sudo:+$sudo -n} 

227 fi 

228 

229 set -x 

230 for cmd in {cmd_opts}; do 

231 $sudo /bin/bash -c "$cmd" && exit 0 

232 done 

233 

234 echo 'ERROR: Failed to shutdown/reboot the system.' 

235 exit 1 

236 """ 

237 return self.remote_exec(script, config, env_params={}) 

238 

239 def shutdown(self, params: dict, force: bool = False) -> Tuple[Status, dict]: 

240 """ 

241 Initiates a (graceful) shutdown of the Host/VM OS. 

242 

243 Parameters 

244 ---------- 

245 params: dict 

246 Flat dictionary of (key, value) pairs of tunable parameters. 

247 force : bool 

248 If True, force stop the Host/VM. 

249 

250 Returns 

251 ------- 

252 result : (Status, dict) 

253 A pair of Status and result. 

254 Status is one of {PENDING, SUCCEEDED, FAILED} 

255 """ 

256 cmd_opts_list = [ 

257 "shutdown -h now", 

258 "poweroff", 

259 "halt -p", 

260 "systemctl poweroff", 

261 ] 

262 return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params) 

263 

264 def reboot(self, params: dict, force: bool = False) -> Tuple[Status, dict]: 

265 """ 

266 Initiates a (graceful) shutdown of the Host/VM OS. 

267 

268 Parameters 

269 ---------- 

270 params: dict 

271 Flat dictionary of (key, value) pairs of tunable parameters. 

272 force : bool 

273 If True, force restart the Host/VM. 

274 

275 Returns 

276 ------- 

277 result : (Status, dict) 

278 A pair of Status and result. 

279 Status is one of {PENDING, SUCCEEDED, FAILED} 

280 """ 

281 cmd_opts_list = [ 

282 "shutdown -r now", 

283 "reboot", 

284 "halt --reboot", 

285 "systemctl reboot", 

286 "kill -KILL 1; kill -KILL -1" if force else "kill -TERM 1; kill -TERM -1", 

287 ] 

288 return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params) 

289 

290 def wait_os_operation(self, params: dict) -> Tuple[Status, dict]: 

291 """ 

292 Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED. Return 

293 TIMED_OUT when timing out. 

294 

295 Parameters 

296 ---------- 

297 params: dict 

298 Flat dictionary of (key, value) pairs of tunable parameters. 

299 Must have the "asyncRemoteExecResultsFuture" key to get the results. 

300 If the key is not present, return Status.PENDING. 

301 

302 Returns 

303 ------- 

304 result : (Status, dict) 

305 A pair of Status and result. 

306 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} 

307 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

308 """ 

309 return self.get_remote_exec_results(params)