Coverage for mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py: 100%
84 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-05 00:36 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-05 00:36 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Tests for mlos_bench.services.remote.ssh.ssh_services
7"""
9from contextlib import contextmanager
10from os.path import basename
11from pathlib import Path
12from tempfile import _TemporaryFileWrapper # pylint: disable=import-private-name
13from typing import Any, Dict, Generator, List
15import os
16import tempfile
18import pytest
20from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService
21from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService
22from mlos_bench.util import path_join
24from mlos_bench.tests import are_dir_trees_equal, requires_docker
25from mlos_bench.tests.services.remote.ssh import SshTestServerInfo
28@contextmanager
29def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper, None, None]:
30 """
31 Provides a context manager for a temporary file that can be closed and
32 still unlinked.
34 Since Windows doesn't allow us to reopen the file while it's still open we
35 need to handle deletion ourselves separately.
37 Parameters
38 ----------
39 kwargs: dict
40 Args to pass to NamedTemporaryFile constructor.
42 Returns
43 -------
44 context manager for a temporary file
45 """
46 fname = None
47 try:
48 with tempfile.NamedTemporaryFile(delete=False, **kwargs) as temp_file:
49 fname = temp_file.name
50 yield temp_file
51 finally:
52 if fname:
53 os.unlink(fname)
56@requires_docker
57def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo,
58 ssh_fileshare_service: SshFileShareService) -> None:
59 """Test the SshFileShareService single file download/upload."""
60 with ssh_fileshare_service:
61 config = ssh_test_server.to_ssh_service_config()
63 remote_file_path = "/tmp/test_ssh_fileshare_single_file"
64 lines = [
65 "foo",
66 "bar",
67 ]
68 lines = [line + "\n" for line in lines]
70 # 1. Write a local file and upload it.
71 with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file:
72 temp_file.writelines(lines)
73 temp_file.flush()
74 temp_file.close()
76 ssh_fileshare_service.upload(
77 params=config,
78 local_path=temp_file.name,
79 remote_path=remote_file_path,
80 )
82 # 2. Download the remote file and compare the contents.
83 with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file:
84 temp_file.close()
85 ssh_fileshare_service.download(
86 params=config,
87 remote_path=remote_file_path,
88 local_path=temp_file.name,
89 )
90 # Download will replace the inode at that name, so we need to reopen the file.
91 with open(temp_file.name, mode='r', encoding='utf-8') as temp_file_h:
92 read_lines = temp_file_h.readlines()
93 assert read_lines == lines
96@requires_docker
97def test_ssh_fileshare_recursive(ssh_test_server: SshTestServerInfo,
98 ssh_fileshare_service: SshFileShareService) -> None:
99 """Test the SshFileShareService recursive download/upload."""
100 with ssh_fileshare_service:
101 config = ssh_test_server.to_ssh_service_config()
103 remote_file_path = "/tmp/test_ssh_fileshare_recursive_dir"
104 files_lines: Dict[str, List[str]] = {
105 "file-a.txt": [
106 "a",
107 "1",
108 ],
109 "file-b.txt": [
110 "b",
111 "2",
112 ],
113 "subdir/foo.txt": [
114 "foo",
115 "bar",
116 ],
117 }
118 files_lines = {path: [line + "\n" for line in lines] for (path, lines) in files_lines.items()}
120 with tempfile.TemporaryDirectory() as tempdir1, tempfile.TemporaryDirectory() as tempdir2:
121 # Setup the directory structure.
122 for (file_path, lines) in files_lines.items():
123 path = Path(tempdir1, file_path)
124 path.parent.mkdir(parents=True, exist_ok=True)
125 with open(path, mode='w+t', encoding='utf-8') as temp_file:
126 temp_file.writelines(lines)
127 temp_file.flush()
128 assert os.path.getsize(path) > 0
130 # Copy that structure over to the remote server.
131 ssh_fileshare_service.upload(
132 params=config,
133 local_path=f"{tempdir1}",
134 remote_path=f"{remote_file_path}",
135 recursive=True,
136 )
138 # Copy the remote structure back to the local machine.
139 ssh_fileshare_service.download(
140 params=config,
141 remote_path=f"{remote_file_path}",
142 local_path=f"{tempdir2}",
143 recursive=True,
144 )
146 # Compare both.
147 # Note: remote dir name is appended to target.
148 assert are_dir_trees_equal(tempdir1, path_join(tempdir2, basename(remote_file_path)))
151@requires_docker
152def test_ssh_fileshare_download_file_dne(ssh_test_server: SshTestServerInfo,
153 ssh_fileshare_service: SshFileShareService) -> None:
154 """Test the SshFileShareService single file download that doesn't exist."""
155 with ssh_fileshare_service:
156 config = ssh_test_server.to_ssh_service_config()
158 canary_str = "canary"
160 with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file:
161 temp_file.writelines([canary_str])
162 temp_file.flush()
163 temp_file.close()
165 with pytest.raises(FileNotFoundError):
166 ssh_fileshare_service.download(
167 params=config,
168 remote_path="/tmp/file-dne.txt",
169 local_path=temp_file.name,
170 )
171 with open(temp_file.name, mode='r', encoding='utf-8') as temp_file_h:
172 read_lines = temp_file_h.readlines()
173 assert read_lines == [canary_str]
176@requires_docker
177def test_ssh_fileshare_upload_file_dne(ssh_test_server: SshTestServerInfo,
178 ssh_host_service: SshHostService,
179 ssh_fileshare_service: SshFileShareService) -> None:
180 """Test the SshFileShareService single file upload that doesn't exist."""
181 with ssh_host_service, ssh_fileshare_service:
182 config = ssh_test_server.to_ssh_service_config()
184 path = '/tmp/upload-file-src-dne.txt'
185 with pytest.raises(OSError):
186 ssh_fileshare_service.upload(
187 params=config,
188 remote_path=path,
189 local_path=path,
190 )
191 (status, results) = ssh_host_service.remote_exec(
192 script=[f"[[ ! -e {path} ]]; echo $?"],
193 config=config,
194 env_params={},
195 )
196 (status, results) = ssh_host_service.get_remote_exec_results(results)
197 assert status.is_succeeded()
198 assert str(results["stdout"]).strip() == "0"