Coverage for mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py: 95%

44 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6A collection functions for interacting with SSH servers as file shares. 

7""" 

8 

9from enum import Enum 

10from typing import Tuple, Union 

11 

12import logging 

13 

14from asyncssh import scp, SFTPError, SFTPNoSuchFile, SFTPFailure, SSHClientConnection 

15 

16from mlos_bench.services.base_fileshare import FileShareService 

17from mlos_bench.services.remote.ssh.ssh_service import SshService 

18from mlos_bench.util import merge_parameters 

19 

20_LOG = logging.getLogger(__name__) 

21 

22 

23class CopyMode(Enum): 

24 """ 

25 Copy mode enum. 

26 """ 

27 

28 DOWNLOAD = 1 

29 UPLOAD = 2 

30 

31 

32class SshFileShareService(FileShareService, SshService): 

33 """A collection of functions for interacting with SSH servers as file shares.""" 

34 

35 async def _start_file_copy(self, params: dict, mode: CopyMode, 

36 local_path: str, remote_path: str, 

37 recursive: bool = True) -> None: 

38 # pylint: disable=too-many-arguments 

39 """ 

40 Starts a file copy operation 

41 

42 Parameters 

43 ---------- 

44 params : dict 

45 Flat dictionary of (key, value) pairs of parameters (used for establishing the connection). 

46 mode : CopyMode 

47 Whether to download or upload the file. 

48 local_path : str 

49 Local path to the file/dir. 

50 remote_path : str 

51 Remote path to the file/dir. 

52 recursive : bool, optional 

53 _description_, by default True 

54 

55 Raises 

56 ------ 

57 OSError 

58 If the local OS returns an error. 

59 SFTPError 

60 If the remote OS returns an error. 

61 FileNotFoundError 

62 If the remote file does not exist, the SFTPError is converted to a FileNotFoundError. 

63 """ 

64 connection, _ = await self._get_client_connection(params) 

65 srcpaths: Union[str, Tuple[SSHClientConnection, str]] 

66 dstpath: Union[str, Tuple[SSHClientConnection, str]] 

67 if mode == CopyMode.DOWNLOAD: 

68 srcpaths = (connection, remote_path) 

69 dstpath = local_path 

70 elif mode == CopyMode.UPLOAD: 

71 srcpaths = local_path 

72 dstpath = (connection, remote_path) 

73 else: 

74 raise ValueError(f"Unknown copy mode: {mode}") 

75 return await scp(srcpaths=srcpaths, dstpath=dstpath, recurse=recursive, preserve=True) 

76 

77 def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: 

78 params = merge_parameters( 

79 dest=self.config.copy(), 

80 source=params, 

81 required_keys=[ 

82 "ssh_hostname", 

83 ] 

84 ) 

85 super().download(params, remote_path, local_path, recursive) 

86 file_copy_future = self._run_coroutine( 

87 self._start_file_copy(params, CopyMode.DOWNLOAD, local_path, remote_path, recursive)) 

88 try: 

89 file_copy_future.result() 

90 except (OSError, SFTPError) as ex: 

91 _LOG.error("Failed to download %s to %s from %s: %s", remote_path, local_path, params, ex) 

92 if isinstance(ex, SFTPNoSuchFile) or ( 

93 isinstance(ex, SFTPFailure) and ex.code == 4 

94 and any(msg.lower() in ex.reason.lower() for msg in ("File not found", "No such file or directory")) 

95 ): 

96 _LOG.warning("File %s does not exist on %s", remote_path, params) 

97 raise FileNotFoundError(f"File {remote_path} does not exist on {params}") from ex 

98 raise ex 

99 

100 def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: 

101 params = merge_parameters( 

102 dest=self.config.copy(), 

103 source=params, 

104 required_keys=[ 

105 "ssh_hostname", 

106 ] 

107 ) 

108 super().upload(params, local_path, remote_path, recursive) 

109 file_copy_future = self._run_coroutine( 

110 self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive)) 

111 try: 

112 file_copy_future.result() 

113 except (OSError, SFTPError) as ex: 

114 _LOG.error("Failed to upload %s to %s on %s: %s", local_path, remote_path, params, ex) 

115 raise ex