Coverage for mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py: 95%
44 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6A collection functions for interacting with SSH servers as file shares.
7"""
9from enum import Enum
10from typing import Tuple, Union
12import logging
14from asyncssh import scp, SFTPError, SFTPNoSuchFile, SFTPFailure, SSHClientConnection
16from mlos_bench.services.base_fileshare import FileShareService
17from mlos_bench.services.remote.ssh.ssh_service import SshService
18from mlos_bench.util import merge_parameters
20_LOG = logging.getLogger(__name__)
23class CopyMode(Enum):
24 """
25 Copy mode enum.
26 """
28 DOWNLOAD = 1
29 UPLOAD = 2
32class SshFileShareService(FileShareService, SshService):
33 """A collection of functions for interacting with SSH servers as file shares."""
35 async def _start_file_copy(self, params: dict, mode: CopyMode,
36 local_path: str, remote_path: str,
37 recursive: bool = True) -> None:
38 # pylint: disable=too-many-arguments
39 """
40 Starts a file copy operation
42 Parameters
43 ----------
44 params : dict
45 Flat dictionary of (key, value) pairs of parameters (used for establishing the connection).
46 mode : CopyMode
47 Whether to download or upload the file.
48 local_path : str
49 Local path to the file/dir.
50 remote_path : str
51 Remote path to the file/dir.
52 recursive : bool, optional
53 _description_, by default True
55 Raises
56 ------
57 OSError
58 If the local OS returns an error.
59 SFTPError
60 If the remote OS returns an error.
61 FileNotFoundError
62 If the remote file does not exist, the SFTPError is converted to a FileNotFoundError.
63 """
64 connection, _ = await self._get_client_connection(params)
65 srcpaths: Union[str, Tuple[SSHClientConnection, str]]
66 dstpath: Union[str, Tuple[SSHClientConnection, str]]
67 if mode == CopyMode.DOWNLOAD:
68 srcpaths = (connection, remote_path)
69 dstpath = local_path
70 elif mode == CopyMode.UPLOAD:
71 srcpaths = local_path
72 dstpath = (connection, remote_path)
73 else:
74 raise ValueError(f"Unknown copy mode: {mode}")
75 return await scp(srcpaths=srcpaths, dstpath=dstpath, recurse=recursive, preserve=True)
77 def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None:
78 params = merge_parameters(
79 dest=self.config.copy(),
80 source=params,
81 required_keys=[
82 "ssh_hostname",
83 ]
84 )
85 super().download(params, remote_path, local_path, recursive)
86 file_copy_future = self._run_coroutine(
87 self._start_file_copy(params, CopyMode.DOWNLOAD, local_path, remote_path, recursive))
88 try:
89 file_copy_future.result()
90 except (OSError, SFTPError) as ex:
91 _LOG.error("Failed to download %s to %s from %s: %s", remote_path, local_path, params, ex)
92 if isinstance(ex, SFTPNoSuchFile) or (
93 isinstance(ex, SFTPFailure) and ex.code == 4
94 and any(msg.lower() in ex.reason.lower() for msg in ("File not found", "No such file or directory"))
95 ):
96 _LOG.warning("File %s does not exist on %s", remote_path, params)
97 raise FileNotFoundError(f"File {remote_path} does not exist on {params}") from ex
98 raise ex
100 def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None:
101 params = merge_parameters(
102 dest=self.config.copy(),
103 source=params,
104 required_keys=[
105 "ssh_hostname",
106 ]
107 )
108 super().upload(params, local_path, remote_path, recursive)
109 file_copy_future = self._run_coroutine(
110 self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive))
111 try:
112 file_copy_future.result()
113 except (OSError, SFTPError) as ex:
114 _LOG.error("Failed to upload %s to %s on %s: %s", local_path, remote_path, params, ex)
115 raise ex