Coverage for mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py: 95%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""A collection functions for interacting with SSH servers as file shares.""" 

6 

7import logging 

8from enum import Enum 

9from typing import Tuple, Union 

10 

11from asyncssh import SFTPError, SFTPFailure, SFTPNoSuchFile, SSHClientConnection, scp 

12 

13from mlos_bench.services.base_fileshare import FileShareService 

14from mlos_bench.services.remote.ssh.ssh_service import SshService 

15from mlos_bench.util import merge_parameters 

16 

17_LOG = logging.getLogger(__name__) 

18 

19 

20class CopyMode(Enum): 

21 """Copy mode enum.""" 

22 

23 DOWNLOAD = 1 

24 UPLOAD = 2 

25 

26 

27class SshFileShareService(FileShareService, SshService): 

28 """A collection of functions for interacting with SSH servers as file shares.""" 

29 

30 async def _start_file_copy( 

31 self, 

32 params: dict, 

33 mode: CopyMode, 

34 local_path: str, 

35 remote_path: str, 

36 recursive: bool = True, 

37 ) -> None: 

38 # pylint: disable=too-many-arguments,too-many-positional-arguments 

39 """ 

40 Starts a file copy operation. 

41 

42 Parameters 

43 ---------- 

44 params : dict 

45 Flat dictionary of (key, value) pairs of parameters (used for 

46 establishing the connection). 

47 mode : CopyMode 

48 Whether to download or upload the file. 

49 local_path : str 

50 Local path to the file/dir. 

51 remote_path : str 

52 Remote path to the file/dir. 

53 recursive : bool 

54 Whether to copy recursively. By default True. 

55 

56 Raises 

57 ------ 

58 OSError 

59 If the local OS returns an error. 

60 SFTPError 

61 If the remote OS returns an error. 

62 FileNotFoundError 

63 If the remote file does not exist, the SFTPError is converted to a FileNotFoundError. 

64 """ 

65 connection, _ = await self._get_client_connection(params) 

66 srcpaths: Union[str, Tuple[SSHClientConnection, str]] 

67 dstpath: Union[str, Tuple[SSHClientConnection, str]] 

68 if mode == CopyMode.DOWNLOAD: 

69 srcpaths = (connection, remote_path) 

70 dstpath = local_path 

71 elif mode == CopyMode.UPLOAD: 

72 srcpaths = local_path 

73 dstpath = (connection, remote_path) 

74 else: 

75 raise ValueError(f"Unknown copy mode: {mode}") 

76 return await scp(srcpaths=srcpaths, dstpath=dstpath, recurse=recursive, preserve=True) 

77 

78 def download( 

79 self, 

80 params: dict, 

81 remote_path: str, 

82 local_path: str, 

83 recursive: bool = True, 

84 ) -> None: 

85 params = merge_parameters( 

86 dest=self.config.copy(), 

87 source=params, 

88 required_keys=[ 

89 "ssh_hostname", 

90 ], 

91 ) 

92 super().download(params, remote_path, local_path, recursive) 

93 file_copy_future = self._run_coroutine( 

94 self._start_file_copy( 

95 params, 

96 CopyMode.DOWNLOAD, 

97 local_path, 

98 remote_path, 

99 recursive, 

100 ) 

101 ) 

102 try: 

103 file_copy_future.result() 

104 except (OSError, SFTPError) as ex: 

105 _LOG.error( 

106 "Failed to download %s to %s from %s: %s", 

107 remote_path, 

108 local_path, 

109 params, 

110 ex, 

111 ) 

112 if isinstance(ex, SFTPNoSuchFile) or ( 

113 isinstance(ex, SFTPFailure) 

114 and ex.code == 4 

115 and any( 

116 msg.lower() in ex.reason.lower() 

117 for msg in ("File not found", "No such file or directory") 

118 ) 

119 ): 

120 _LOG.warning("File %s does not exist on %s", remote_path, params) 

121 raise FileNotFoundError(f"File {remote_path} does not exist on {params}") from ex 

122 raise ex 

123 

124 def upload( 

125 self, 

126 params: dict, 

127 local_path: str, 

128 remote_path: str, 

129 recursive: bool = True, 

130 ) -> None: 

131 params = merge_parameters( 

132 dest=self.config.copy(), 

133 source=params, 

134 required_keys=[ 

135 "ssh_hostname", 

136 ], 

137 ) 

138 super().upload(params, local_path, remote_path, recursive) 

139 file_copy_future = self._run_coroutine( 

140 self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive) 

141 ) 

142 try: 

143 file_copy_future.result() 

144 except (OSError, SFTPError) as ex: 

145 _LOG.error("Failed to upload %s to %s on %s: %s", local_path, remote_path, params, ex) 

146 raise ex