Coverage for mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py: 100%

84 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-05 00:36 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Tests for mlos_bench.services.remote.ssh.ssh_services 

7""" 

8 

9from contextlib import contextmanager 

10from os.path import basename 

11from pathlib import Path 

12from tempfile import _TemporaryFileWrapper # pylint: disable=import-private-name 

13from typing import Any, Dict, Generator, List 

14 

15import os 

16import tempfile 

17 

18import pytest 

19 

20from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService 

21from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService 

22from mlos_bench.util import path_join 

23 

24from mlos_bench.tests import are_dir_trees_equal, requires_docker 

25from mlos_bench.tests.services.remote.ssh import SshTestServerInfo 

26 

27 

28@contextmanager 

29def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper, None, None]: 

30 """ 

31 Provides a context manager for a temporary file that can be closed and 

32 still unlinked. 

33 

34 Since Windows doesn't allow us to reopen the file while it's still open we 

35 need to handle deletion ourselves separately. 

36 

37 Parameters 

38 ---------- 

39 kwargs: dict 

40 Args to pass to NamedTemporaryFile constructor. 

41 

42 Returns 

43 ------- 

44 context manager for a temporary file 

45 """ 

46 fname = None 

47 try: 

48 with tempfile.NamedTemporaryFile(delete=False, **kwargs) as temp_file: 

49 fname = temp_file.name 

50 yield temp_file 

51 finally: 

52 if fname: 

53 os.unlink(fname) 

54 

55 

56@requires_docker 

57def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo, 

58 ssh_fileshare_service: SshFileShareService) -> None: 

59 """Test the SshFileShareService single file download/upload.""" 

60 with ssh_fileshare_service: 

61 config = ssh_test_server.to_ssh_service_config() 

62 

63 remote_file_path = "/tmp/test_ssh_fileshare_single_file" 

64 lines = [ 

65 "foo", 

66 "bar", 

67 ] 

68 lines = [line + "\n" for line in lines] 

69 

70 # 1. Write a local file and upload it. 

71 with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: 

72 temp_file.writelines(lines) 

73 temp_file.flush() 

74 temp_file.close() 

75 

76 ssh_fileshare_service.upload( 

77 params=config, 

78 local_path=temp_file.name, 

79 remote_path=remote_file_path, 

80 ) 

81 

82 # 2. Download the remote file and compare the contents. 

83 with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: 

84 temp_file.close() 

85 ssh_fileshare_service.download( 

86 params=config, 

87 remote_path=remote_file_path, 

88 local_path=temp_file.name, 

89 ) 

90 # Download will replace the inode at that name, so we need to reopen the file. 

91 with open(temp_file.name, mode='r', encoding='utf-8') as temp_file_h: 

92 read_lines = temp_file_h.readlines() 

93 assert read_lines == lines 

94 

95 

96@requires_docker 

97def test_ssh_fileshare_recursive(ssh_test_server: SshTestServerInfo, 

98 ssh_fileshare_service: SshFileShareService) -> None: 

99 """Test the SshFileShareService recursive download/upload.""" 

100 with ssh_fileshare_service: 

101 config = ssh_test_server.to_ssh_service_config() 

102 

103 remote_file_path = "/tmp/test_ssh_fileshare_recursive_dir" 

104 files_lines: Dict[str, List[str]] = { 

105 "file-a.txt": [ 

106 "a", 

107 "1", 

108 ], 

109 "file-b.txt": [ 

110 "b", 

111 "2", 

112 ], 

113 "subdir/foo.txt": [ 

114 "foo", 

115 "bar", 

116 ], 

117 } 

118 files_lines = {path: [line + "\n" for line in lines] for (path, lines) in files_lines.items()} 

119 

120 with tempfile.TemporaryDirectory() as tempdir1, tempfile.TemporaryDirectory() as tempdir2: 

121 # Setup the directory structure. 

122 for (file_path, lines) in files_lines.items(): 

123 path = Path(tempdir1, file_path) 

124 path.parent.mkdir(parents=True, exist_ok=True) 

125 with open(path, mode='w+t', encoding='utf-8') as temp_file: 

126 temp_file.writelines(lines) 

127 temp_file.flush() 

128 assert os.path.getsize(path) > 0 

129 

130 # Copy that structure over to the remote server. 

131 ssh_fileshare_service.upload( 

132 params=config, 

133 local_path=f"{tempdir1}", 

134 remote_path=f"{remote_file_path}", 

135 recursive=True, 

136 ) 

137 

138 # Copy the remote structure back to the local machine. 

139 ssh_fileshare_service.download( 

140 params=config, 

141 remote_path=f"{remote_file_path}", 

142 local_path=f"{tempdir2}", 

143 recursive=True, 

144 ) 

145 

146 # Compare both. 

147 # Note: remote dir name is appended to target. 

148 assert are_dir_trees_equal(tempdir1, path_join(tempdir2, basename(remote_file_path))) 

149 

150 

151@requires_docker 

152def test_ssh_fileshare_download_file_dne(ssh_test_server: SshTestServerInfo, 

153 ssh_fileshare_service: SshFileShareService) -> None: 

154 """Test the SshFileShareService single file download that doesn't exist.""" 

155 with ssh_fileshare_service: 

156 config = ssh_test_server.to_ssh_service_config() 

157 

158 canary_str = "canary" 

159 

160 with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: 

161 temp_file.writelines([canary_str]) 

162 temp_file.flush() 

163 temp_file.close() 

164 

165 with pytest.raises(FileNotFoundError): 

166 ssh_fileshare_service.download( 

167 params=config, 

168 remote_path="/tmp/file-dne.txt", 

169 local_path=temp_file.name, 

170 ) 

171 with open(temp_file.name, mode='r', encoding='utf-8') as temp_file_h: 

172 read_lines = temp_file_h.readlines() 

173 assert read_lines == [canary_str] 

174 

175 

176@requires_docker 

177def test_ssh_fileshare_upload_file_dne(ssh_test_server: SshTestServerInfo, 

178 ssh_host_service: SshHostService, 

179 ssh_fileshare_service: SshFileShareService) -> None: 

180 """Test the SshFileShareService single file upload that doesn't exist.""" 

181 with ssh_host_service, ssh_fileshare_service: 

182 config = ssh_test_server.to_ssh_service_config() 

183 

184 path = '/tmp/upload-file-src-dne.txt' 

185 with pytest.raises(OSError): 

186 ssh_fileshare_service.upload( 

187 params=config, 

188 remote_path=path, 

189 local_path=path, 

190 ) 

191 (status, results) = ssh_host_service.remote_exec( 

192 script=[f"[[ ! -e {path} ]]; echo $?"], 

193 config=config, 

194 env_params={}, 

195 ) 

196 (status, results) = ssh_host_service.get_remote_exec_results(results) 

197 assert status.is_succeeded() 

198 assert str(results["stdout"]).strip() == "0"