Coverage for mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py: 100%
83 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Tests for mlos_bench.services.remote.ssh.ssh_services."""
7import os
8import tempfile
9from contextlib import contextmanager
10from os.path import basename
11from pathlib import Path
12from tempfile import _TemporaryFileWrapper # pylint: disable=import-private-name
13from typing import Any, Dict, Generator, List
15import pytest
17from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService
18from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService
19from mlos_bench.tests import are_dir_trees_equal, requires_docker
20from mlos_bench.tests.services.remote.ssh import SshTestServerInfo
21from mlos_bench.util import path_join
24@contextmanager
25def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper, None, None]:
26 """
27 Provides a context manager for a temporary file that can be closed and still
28 unlinked.
30 Since Windows doesn't allow us to reopen the file while it's still open we
31 need to handle deletion ourselves separately.
33 Parameters
34 ----------
35 kwargs: dict
36 Args to pass to NamedTemporaryFile constructor.
38 Returns
39 -------
40 context manager for a temporary file
41 """
42 fname = None
43 try:
44 with tempfile.NamedTemporaryFile(delete=False, **kwargs) as temp_file:
45 fname = temp_file.name
46 yield temp_file
47 finally:
48 if fname:
49 os.unlink(fname)
52@requires_docker
53def test_ssh_fileshare_single_file(
54 ssh_test_server: SshTestServerInfo,
55 ssh_fileshare_service: SshFileShareService,
56) -> None:
57 """Test the SshFileShareService single file download/upload."""
58 with ssh_fileshare_service:
59 config = ssh_test_server.to_ssh_service_config()
61 remote_file_path = "/tmp/test_ssh_fileshare_single_file"
62 lines = [
63 "foo",
64 "bar",
65 ]
66 lines = [line + "\n" for line in lines]
68 # 1. Write a local file and upload it.
69 with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file:
70 temp_file.writelines(lines)
71 temp_file.flush()
72 temp_file.close()
74 ssh_fileshare_service.upload(
75 params=config,
76 local_path=temp_file.name,
77 remote_path=remote_file_path,
78 )
80 # 2. Download the remote file and compare the contents.
81 with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file:
82 temp_file.close()
83 ssh_fileshare_service.download(
84 params=config,
85 remote_path=remote_file_path,
86 local_path=temp_file.name,
87 )
88 # Download will replace the inode at that name, so we need to reopen the file.
89 with open(temp_file.name, mode="r", encoding="utf-8") as temp_file_h:
90 read_lines = temp_file_h.readlines()
91 assert read_lines == lines
94@requires_docker
95def test_ssh_fileshare_recursive(
96 ssh_test_server: SshTestServerInfo,
97 ssh_fileshare_service: SshFileShareService,
98) -> None:
99 """Test the SshFileShareService recursive download/upload."""
100 with ssh_fileshare_service:
101 config = ssh_test_server.to_ssh_service_config()
103 remote_file_path = "/tmp/test_ssh_fileshare_recursive_dir"
104 files_lines: Dict[str, List[str]] = {
105 "file-a.txt": [
106 "a",
107 "1",
108 ],
109 "file-b.txt": [
110 "b",
111 "2",
112 ],
113 "subdir/foo.txt": [
114 "foo",
115 "bar",
116 ],
117 }
118 files_lines = {
119 path: [line + "\n" for line in lines] for (path, lines) in files_lines.items()
120 }
122 with tempfile.TemporaryDirectory() as tempdir1, tempfile.TemporaryDirectory() as tempdir2:
123 # Setup the directory structure.
124 for file_path, lines in files_lines.items():
125 path = Path(tempdir1, file_path)
126 path.parent.mkdir(parents=True, exist_ok=True)
127 with open(path, mode="w+t", encoding="utf-8") as temp_file:
128 temp_file.writelines(lines)
129 temp_file.flush()
130 assert os.path.getsize(path) > 0
132 # Copy that structure over to the remote server.
133 ssh_fileshare_service.upload(
134 params=config,
135 local_path=f"{tempdir1}",
136 remote_path=f"{remote_file_path}",
137 recursive=True,
138 )
140 # Copy the remote structure back to the local machine.
141 ssh_fileshare_service.download(
142 params=config,
143 remote_path=f"{remote_file_path}",
144 local_path=f"{tempdir2}",
145 recursive=True,
146 )
148 # Compare both.
149 # Note: remote dir name is appended to target.
150 assert are_dir_trees_equal(tempdir1, path_join(tempdir2, basename(remote_file_path)))
153@requires_docker
154def test_ssh_fileshare_download_file_dne(
155 ssh_test_server: SshTestServerInfo,
156 ssh_fileshare_service: SshFileShareService,
157) -> None:
158 """Test the SshFileShareService single file download that doesn't exist."""
159 with ssh_fileshare_service:
160 config = ssh_test_server.to_ssh_service_config()
162 canary_str = "canary"
164 with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file:
165 temp_file.writelines([canary_str])
166 temp_file.flush()
167 temp_file.close()
169 with pytest.raises(FileNotFoundError):
170 ssh_fileshare_service.download(
171 params=config,
172 remote_path="/tmp/file-dne.txt",
173 local_path=temp_file.name,
174 )
175 with open(temp_file.name, mode="r", encoding="utf-8") as temp_file_h:
176 read_lines = temp_file_h.readlines()
177 assert read_lines == [canary_str]
180@requires_docker
181def test_ssh_fileshare_upload_file_dne(
182 ssh_test_server: SshTestServerInfo,
183 ssh_host_service: SshHostService,
184 ssh_fileshare_service: SshFileShareService,
185) -> None:
186 """Test the SshFileShareService single file upload that doesn't exist."""
187 with ssh_host_service, ssh_fileshare_service:
188 config = ssh_test_server.to_ssh_service_config()
190 path = "/tmp/upload-file-src-dne.txt"
191 with pytest.raises(OSError):
192 ssh_fileshare_service.upload(
193 params=config,
194 remote_path=path,
195 local_path=path,
196 )
197 (status, results) = ssh_host_service.remote_exec(
198 script=[f"[[ ! -e {path} ]]; echo $?"],
199 config=config,
200 env_params={},
201 )
202 (status, results) = ssh_host_service.get_remote_exec_results(results)
203 assert status.is_succeeded()
204 assert str(results["stdout"]).strip() == "0"