Coverage for mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py: 89%

57 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6A collection Service functions for managing hosts via SSH. 

7""" 

8 

9from concurrent.futures import Future 

10from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union 

11 

12import logging 

13 

14from asyncssh import SSHCompletedProcess, ConnectionLost, DisconnectError, ProcessError 

15 

16from mlos_bench.environments.status import Status 

17from mlos_bench.services.base_service import Service 

18from mlos_bench.services.remote.ssh.ssh_service import SshService 

19from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec 

20from mlos_bench.services.types.os_ops_type import SupportsOSOps 

21from mlos_bench.util import merge_parameters 

22 

23_LOG = logging.getLogger(__name__) 

24 

25 

26class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec): 

27 """ 

28 Helper methods to manage machines via SSH. 

29 """ 

30 

31 # pylint: disable=too-many-instance-attributes 

32 

33 def __init__(self, 

34 config: Optional[Dict[str, Any]] = None, 

35 global_config: Optional[Dict[str, Any]] = None, 

36 parent: Optional[Service] = None, 

37 methods: Union[Dict[str, Callable], List[Callable], None] = None): 

38 """ 

39 Create a new instance of an SSH Service. 

40 

41 Parameters 

42 ---------- 

43 config : dict 

44 Free-format dictionary that contains the benchmark environment 

45 configuration. 

46 global_config : dict 

47 Free-format dictionary of global parameters. 

48 parent : Service 

49 Parent service that can provide mixin functions. 

50 methods : Union[Dict[str, Callable], List[Callable], None] 

51 New methods to register with the service. 

52 """ 

53 # Same methods are also provided by the AzureVMService class 

54 # pylint: disable=duplicate-code 

55 super().__init__( 

56 config, global_config, parent, 

57 self.merge_methods(methods, [ 

58 self.shutdown, 

59 self.reboot, 

60 self.wait_os_operation, 

61 self.remote_exec, 

62 self.get_remote_exec_results, 

63 ])) 

64 self._shell = self.config.get("ssh_shell", "/bin/bash") 

65 

66 async def _run_cmd(self, params: dict, script: Iterable[str], env_params: dict) -> SSHCompletedProcess: 

67 """ 

68 Runs a command asynchronously on a host via SSH. 

69 

70 Parameters 

71 ---------- 

72 params : dict 

73 Flat dictionary of (key, value) pairs of parameters (used for establishing the connection). 

74 cmd : str 

75 Command(s) to run via shell. 

76 

77 Returns 

78 ------- 

79 SSHCompletedProcess 

80 Returns the result of the command. 

81 """ 

82 if isinstance(script, str): 

83 # Script should be an iterable of lines, not an iterable string. 

84 script = [script] 

85 connection, _ = await self._get_client_connection(params) 

86 # Note: passing environment variables to SSH servers is typically restricted to just some LC_* values. 

87 # Handle transferring environment variables by making a script to set them. 

88 env_script_lines = [f"export {name}='{value}'" for (name, value) in env_params.items()] 

89 script_lines = env_script_lines + [line_split for line in script for line_split in line.splitlines()] 

90 # Note: connection.run() uses "exec" with a shell by default. 

91 script_str = '\n'.join(script_lines) 

92 _LOG.debug("Running script on %s:\n%s", connection, script_str) 

93 return await connection.run(script_str, 

94 check=False, 

95 timeout=self._request_timeout, 

96 env=env_params) 

97 

98 def remote_exec(self, script: Iterable[str], config: dict, env_params: dict) -> Tuple["Status", dict]: 

99 """ 

100 Start running a command on remote host OS. 

101 

102 Parameters 

103 ---------- 

104 script : Iterable[str] 

105 A list of lines to execute as a script on a remote VM. 

106 config : dict 

107 Flat dictionary of (key, value) pairs of parameters. 

108 They usually come from `const_args` and `tunable_params` 

109 properties of the Environment. 

110 env_params : dict 

111 Parameters to pass as *shell* environment variables into the script. 

112 This is usually a subset of `config` with some possible conversions. 

113 

114 Returns 

115 ------- 

116 result : (Status, dict) 

117 A pair of Status and result. 

118 Status is one of {PENDING, SUCCEEDED, FAILED} 

119 """ 

120 config = merge_parameters( 

121 dest=self.config.copy(), 

122 source=config, 

123 required_keys=[ 

124 "ssh_hostname", 

125 ] 

126 ) 

127 config["asyncRemoteExecResultsFuture"] = self._run_coroutine(self._run_cmd(config, script, env_params)) 

128 return (Status.PENDING, config) 

129 

130 def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]: 

131 """ 

132 Get the results of the asynchronously running command. 

133 

134 Parameters 

135 ---------- 

136 config : dict 

137 Flat dictionary of (key, value) pairs of tunable parameters. 

138 Must have the "asyncRemoteExecResultsFuture" key to get the results. 

139 If the key is not present, return Status.PENDING. 

140 

141 Returns 

142 ------- 

143 result : (Status, dict) 

144 A pair of Status and result. 

145 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} 

146 """ 

147 future = config.get("asyncRemoteExecResultsFuture") 

148 if not future: 

149 raise ValueError("Missing 'asyncRemoteExecResultsFuture'.") 

150 assert isinstance(future, Future) 

151 result = None 

152 try: 

153 result = future.result(timeout=self._request_timeout) 

154 assert isinstance(result, SSHCompletedProcess) 

155 stdout = result.stdout.decode() if isinstance(result.stdout, bytes) else result.stdout 

156 stderr = result.stderr.decode() if isinstance(result.stderr, bytes) else result.stderr 

157 return ( 

158 Status.SUCCEEDED if result.exit_status == 0 and result.returncode == 0 else Status.FAILED, 

159 { 

160 "stdout": stdout, 

161 "stderr": stderr, 

162 "ssh_completed_process_result": result, 

163 }, 

164 ) 

165 except (ConnectionLost, DisconnectError, ProcessError, TimeoutError) as ex: 

166 _LOG.error("Failed to get remote exec results: %s", ex) 

167 return (Status.FAILED, {"result": result}) 

168 

169 def _exec_os_op(self, cmd_opts_list: List[str], params: dict) -> Tuple[Status, dict]: 

170 """_summary_ 

171 

172 Parameters 

173 ---------- 

174 cmd_opts_list : List[str] 

175 List of commands to try to execute. 

176 params : dict 

177 The params used to connect to the host. 

178 

179 Returns 

180 ------- 

181 result : (Status, dict={}) 

182 A pair of Status and result. 

183 Status is one of {PENDING, SUCCEEDED, FAILED} 

184 """ 

185 config = merge_parameters( 

186 dest=self.config.copy(), 

187 source=params, 

188 required_keys=[ 

189 "ssh_hostname", 

190 ] 

191 ) 

192 cmd_opts = ' '.join([f"'{cmd}'" for cmd in cmd_opts_list]) 

193 script = rf""" 

194 if [[ $EUID -ne 0 ]]; then 

195 sudo=$(command -v sudo) 

196 sudo=${{sudo:+$sudo -n}} 

197 fi 

198 

199 set -x 

200 for cmd in {cmd_opts}; do 

201 $sudo /bin/bash -c "$cmd" && exit 0 

202 done 

203 

204 echo 'ERROR: Failed to shutdown/reboot the system.' 

205 exit 1 

206 """ 

207 return self.remote_exec(script, config, env_params={}) 

208 

209 def shutdown(self, params: dict, force: bool = False) -> Tuple[Status, dict]: 

210 """ 

211 Initiates a (graceful) shutdown of the Host/VM OS. 

212 

213 Parameters 

214 ---------- 

215 params: dict 

216 Flat dictionary of (key, value) pairs of tunable parameters. 

217 force : bool 

218 If True, force stop the Host/VM. 

219 

220 Returns 

221 ------- 

222 result : (Status, dict={}) 

223 A pair of Status and result. 

224 Status is one of {PENDING, SUCCEEDED, FAILED} 

225 """ 

226 cmd_opts_list = [ 

227 'shutdown -h now', 

228 'poweroff', 

229 'halt -p', 

230 'systemctl poweroff', 

231 ] 

232 return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params) 

233 

234 def reboot(self, params: dict, force: bool = False) -> Tuple[Status, dict]: 

235 """ 

236 Initiates a (graceful) shutdown of the Host/VM OS. 

237 

238 Parameters 

239 ---------- 

240 params: dict 

241 Flat dictionary of (key, value) pairs of tunable parameters. 

242 force : bool 

243 If True, force restart the Host/VM. 

244 

245 Returns 

246 ------- 

247 result : (Status, dict={}) 

248 A pair of Status and result. 

249 Status is one of {PENDING, SUCCEEDED, FAILED} 

250 """ 

251 cmd_opts_list = [ 

252 'shutdown -r now', 

253 'reboot', 

254 'halt --reboot', 

255 'systemctl reboot', 

256 'kill -KILL 1; kill -KILL -1' if force else 'kill -TERM 1; kill -TERM -1', 

257 ] 

258 return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params) 

259 

260 def wait_os_operation(self, params: dict) -> Tuple[Status, dict]: 

261 """ 

262 Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED. 

263 Return TIMED_OUT when timing out. 

264 

265 Parameters 

266 ---------- 

267 params: dict 

268 Flat dictionary of (key, value) pairs of tunable parameters. 

269 Must have the "asyncRemoteExecResultsFuture" key to get the results. 

270 If the key is not present, return Status.PENDING. 

271 

272 Returns 

273 ------- 

274 result : (Status, dict) 

275 A pair of Status and result. 

276 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} 

277 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

278 """ 

279 return self.get_remote_exec_results(params)