Coverage for mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py: 100%

83 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tests for mlos_bench.services.remote.ssh.ssh_services.""" 

6 

7import os 

8import tempfile 

9from contextlib import contextmanager 

10from os.path import basename 

11from pathlib import Path 

12from tempfile import _TemporaryFileWrapper # pylint: disable=import-private-name 

13from typing import Any, Dict, Generator, List 

14 

15import pytest 

16 

17from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService 

18from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService 

19from mlos_bench.tests import are_dir_trees_equal, requires_docker 

20from mlos_bench.tests.services.remote.ssh import SshTestServerInfo 

21from mlos_bench.util import path_join 

22 

23 

24@contextmanager 

25def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper, None, None]: 

26 """ 

27 Provides a context manager for a temporary file that can be closed and still 

28 unlinked. 

29 

30 Since Windows doesn't allow us to reopen the file while it's still open we 

31 need to handle deletion ourselves separately. 

32 

33 Parameters 

34 ---------- 

35 kwargs: dict 

36 Args to pass to NamedTemporaryFile constructor. 

37 

38 Returns 

39 ------- 

40 context manager for a temporary file 

41 """ 

42 fname = None 

43 try: 

44 with tempfile.NamedTemporaryFile(delete=False, **kwargs) as temp_file: 

45 fname = temp_file.name 

46 yield temp_file 

47 finally: 

48 if fname: 

49 os.unlink(fname) 

50 

51 

52@requires_docker 

53def test_ssh_fileshare_single_file( 

54 ssh_test_server: SshTestServerInfo, 

55 ssh_fileshare_service: SshFileShareService, 

56) -> None: 

57 """Test the SshFileShareService single file download/upload.""" 

58 with ssh_fileshare_service: 

59 config = ssh_test_server.to_ssh_service_config() 

60 

61 remote_file_path = "/tmp/test_ssh_fileshare_single_file" 

62 lines = [ 

63 "foo", 

64 "bar", 

65 ] 

66 lines = [line + "\n" for line in lines] 

67 

68 # 1. Write a local file and upload it. 

69 with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: 

70 temp_file.writelines(lines) 

71 temp_file.flush() 

72 temp_file.close() 

73 

74 ssh_fileshare_service.upload( 

75 params=config, 

76 local_path=temp_file.name, 

77 remote_path=remote_file_path, 

78 ) 

79 

80 # 2. Download the remote file and compare the contents. 

81 with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: 

82 temp_file.close() 

83 ssh_fileshare_service.download( 

84 params=config, 

85 remote_path=remote_file_path, 

86 local_path=temp_file.name, 

87 ) 

88 # Download will replace the inode at that name, so we need to reopen the file. 

89 with open(temp_file.name, mode="r", encoding="utf-8") as temp_file_h: 

90 read_lines = temp_file_h.readlines() 

91 assert read_lines == lines 

92 

93 

94@requires_docker 

95def test_ssh_fileshare_recursive( 

96 ssh_test_server: SshTestServerInfo, 

97 ssh_fileshare_service: SshFileShareService, 

98) -> None: 

99 """Test the SshFileShareService recursive download/upload.""" 

100 with ssh_fileshare_service: 

101 config = ssh_test_server.to_ssh_service_config() 

102 

103 remote_file_path = "/tmp/test_ssh_fileshare_recursive_dir" 

104 files_lines: Dict[str, List[str]] = { 

105 "file-a.txt": [ 

106 "a", 

107 "1", 

108 ], 

109 "file-b.txt": [ 

110 "b", 

111 "2", 

112 ], 

113 "subdir/foo.txt": [ 

114 "foo", 

115 "bar", 

116 ], 

117 } 

118 files_lines = { 

119 path: [line + "\n" for line in lines] for (path, lines) in files_lines.items() 

120 } 

121 

122 with tempfile.TemporaryDirectory() as tempdir1, tempfile.TemporaryDirectory() as tempdir2: 

123 # Setup the directory structure. 

124 for file_path, lines in files_lines.items(): 

125 path = Path(tempdir1, file_path) 

126 path.parent.mkdir(parents=True, exist_ok=True) 

127 with open(path, mode="w+t", encoding="utf-8") as temp_file: 

128 temp_file.writelines(lines) 

129 temp_file.flush() 

130 assert os.path.getsize(path) > 0 

131 

132 # Copy that structure over to the remote server. 

133 ssh_fileshare_service.upload( 

134 params=config, 

135 local_path=f"{tempdir1}", 

136 remote_path=f"{remote_file_path}", 

137 recursive=True, 

138 ) 

139 

140 # Copy the remote structure back to the local machine. 

141 ssh_fileshare_service.download( 

142 params=config, 

143 remote_path=f"{remote_file_path}", 

144 local_path=f"{tempdir2}", 

145 recursive=True, 

146 ) 

147 

148 # Compare both. 

149 # Note: remote dir name is appended to target. 

150 assert are_dir_trees_equal(tempdir1, path_join(tempdir2, basename(remote_file_path))) 

151 

152 

153@requires_docker 

154def test_ssh_fileshare_download_file_dne( 

155 ssh_test_server: SshTestServerInfo, 

156 ssh_fileshare_service: SshFileShareService, 

157) -> None: 

158 """Test the SshFileShareService single file download that doesn't exist.""" 

159 with ssh_fileshare_service: 

160 config = ssh_test_server.to_ssh_service_config() 

161 

162 canary_str = "canary" 

163 

164 with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: 

165 temp_file.writelines([canary_str]) 

166 temp_file.flush() 

167 temp_file.close() 

168 

169 with pytest.raises(FileNotFoundError): 

170 ssh_fileshare_service.download( 

171 params=config, 

172 remote_path="/tmp/file-dne.txt", 

173 local_path=temp_file.name, 

174 ) 

175 with open(temp_file.name, mode="r", encoding="utf-8") as temp_file_h: 

176 read_lines = temp_file_h.readlines() 

177 assert read_lines == [canary_str] 

178 

179 

180@requires_docker 

181def test_ssh_fileshare_upload_file_dne( 

182 ssh_test_server: SshTestServerInfo, 

183 ssh_host_service: SshHostService, 

184 ssh_fileshare_service: SshFileShareService, 

185) -> None: 

186 """Test the SshFileShareService single file upload that doesn't exist.""" 

187 with ssh_host_service, ssh_fileshare_service: 

188 config = ssh_test_server.to_ssh_service_config() 

189 

190 path = "/tmp/upload-file-src-dne.txt" 

191 with pytest.raises(OSError): 

192 ssh_fileshare_service.upload( 

193 params=config, 

194 remote_path=path, 

195 local_path=path, 

196 ) 

197 (status, results) = ssh_host_service.remote_exec( 

198 script=[f"[[ ! -e {path} ]]; echo $?"], 

199 config=config, 

200 env_params={}, 

201 ) 

202 (status, results) = ssh_host_service.get_remote_exec_results(results) 

203 assert status.is_succeeded() 

204 assert str(results["stdout"]).strip() == "0"