Coverage for mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py: 99%

113 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-05 00:36 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Tests for mlos_bench.services.remote.azure.azure_fileshare 

7""" 

8 

9import os 

10from unittest.mock import MagicMock, Mock, patch, call 

11 

12from mlos_bench.services.remote.azure.azure_fileshare import AzureFileShareService 

13 

14# pylint: disable=missing-function-docstring 

15# pylint: disable=too-many-arguments 

16# pylint: disable=unused-argument 

17 

18 

19@patch("mlos_bench.services.remote.azure.azure_fileshare.open") 

20@patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") 

21def test_download_file(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: 

22 filename = "test.csv" 

23 remote_folder = "a/remote/folder" 

24 local_folder = "some/local/folder" 

25 remote_path = f"{remote_folder}/{filename}" 

26 local_path = f"{local_folder}/{filename}" 

27 mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access 

28 config: dict = {} 

29 with patch.object(mock_share_client, "get_file_client") as mock_get_file_client, \ 

30 patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client: 

31 mock_get_directory_client.return_value = Mock(exists=Mock(return_value=False)) 

32 

33 azure_fileshare.download(config, remote_path, local_path) 

34 

35 mock_get_file_client.assert_called_with(remote_path) 

36 

37 mock_makedirs.assert_called_with( 

38 local_folder, 

39 exist_ok=True, 

40 ) 

41 open_path, open_mode = mock_open.call_args.args 

42 assert os.path.abspath(local_path) == os.path.abspath(open_path) 

43 assert open_mode == "wb" 

44 

45 

46def make_dir_client_returns(remote_folder: str) -> dict: 

47 return { 

48 remote_folder: Mock( 

49 exists=Mock(return_value=True), 

50 list_directories_and_files=Mock(return_value=[ 

51 {"name": "a_folder", "is_directory": True}, 

52 {"name": "a_file_1.csv", "is_directory": False}, 

53 ]) 

54 ), 

55 f"{remote_folder}/a_folder": Mock( 

56 exists=Mock(return_value=True), 

57 list_directories_and_files=Mock(return_value=[ 

58 {"name": "a_file_2.csv", "is_directory": False}, 

59 ]) 

60 ), 

61 f"{remote_folder}/a_file_1.csv": Mock( 

62 exists=Mock(return_value=False) 

63 ), 

64 f"{remote_folder}/a_folder/a_file_2.csv": Mock( 

65 exists=Mock(return_value=False) 

66 ), 

67 } 

68 

69 

70@patch("mlos_bench.services.remote.azure.azure_fileshare.open") 

71@patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") 

72def test_download_folder_non_recursive(mock_makedirs: MagicMock, 

73 mock_open: MagicMock, 

74 azure_fileshare: AzureFileShareService) -> None: 

75 remote_folder = "a/remote/folder" 

76 local_folder = "some/local/folder" 

77 dir_client_returns = make_dir_client_returns(remote_folder) 

78 mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access 

79 config: dict = {} 

80 with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \ 

81 patch.object(mock_share_client, "get_file_client") as mock_get_file_client: 

82 

83 mock_get_directory_client.side_effect = lambda x: dir_client_returns[x] 

84 

85 azure_fileshare.download(config, remote_folder, local_folder, recursive=False) 

86 

87 mock_get_file_client.assert_called_with( 

88 f"{remote_folder}/a_file_1.csv", 

89 ) 

90 mock_get_directory_client.assert_has_calls([ 

91 call(remote_folder), 

92 call(f"{remote_folder}/a_file_1.csv"), 

93 ], any_order=True) 

94 

95 

96@patch("mlos_bench.services.remote.azure.azure_fileshare.open") 

97@patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") 

98def test_download_folder_recursive(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: 

99 remote_folder = "a/remote/folder" 

100 local_folder = "some/local/folder" 

101 dir_client_returns = make_dir_client_returns(remote_folder) 

102 mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access 

103 config: dict = {} 

104 with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \ 

105 patch.object(mock_share_client, "get_file_client") as mock_get_file_client: 

106 mock_get_directory_client.side_effect = lambda x: dir_client_returns[x] 

107 

108 azure_fileshare.download(config, remote_folder, local_folder, recursive=True) 

109 

110 mock_get_file_client.assert_has_calls([ 

111 call(f"{remote_folder}/a_file_1.csv"), 

112 call(f"{remote_folder}/a_folder/a_file_2.csv"), 

113 ], any_order=True) 

114 mock_get_directory_client.assert_has_calls([ 

115 call(remote_folder), 

116 call(f"{remote_folder}/a_file_1.csv"), 

117 call(f"{remote_folder}/a_folder"), 

118 call(f"{remote_folder}/a_folder/a_file_2.csv"), 

119 ], any_order=True) 

120 

121 

122@patch("mlos_bench.services.remote.azure.azure_fileshare.open") 

123@patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") 

124def test_upload_file(mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: 

125 filename = "test.csv" 

126 remote_folder = "a/remote/folder" 

127 local_folder = "some/local/folder" 

128 remote_path = f"{remote_folder}/{filename}" 

129 local_path = f"{local_folder}/{filename}" 

130 mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access 

131 mock_isdir.return_value = False 

132 config: dict = {} 

133 

134 with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: 

135 azure_fileshare.upload(config, local_path, remote_path) 

136 

137 mock_get_file_client.assert_called_with(remote_path) 

138 open_path, open_mode = mock_open.call_args.args 

139 assert os.path.abspath(local_path) == os.path.abspath(open_path) 

140 assert open_mode == "rb" 

141 

142 

143class MyDirEntry: 

144 # pylint: disable=too-few-public-methods 

145 """Dummy class for os.DirEntry""" 

146 def __init__(self, name: str, is_a_dir: bool): 

147 self.name = name 

148 self.is_a_dir = is_a_dir 

149 

150 def is_dir(self) -> bool: 

151 return self.is_a_dir 

152 

153 

154def make_scandir_returns(local_folder: str) -> dict: 

155 return { 

156 local_folder: [ 

157 MyDirEntry("a_folder", True), 

158 MyDirEntry("a_file_1.csv", False), 

159 ], 

160 f"{local_folder}/a_folder": [ 

161 MyDirEntry("a_file_2.csv", False), 

162 ], 

163 } 

164 

165 

166def make_isdir_returns(local_folder: str) -> dict: 

167 return { 

168 local_folder: True, 

169 f"{local_folder}/a_file_1.csv": False, 

170 f"{local_folder}/a_folder": True, 

171 f"{local_folder}/a_folder/a_file_2.csv": False, 

172 } 

173 

174 

175def process_paths(input_path: str) -> str: 

176 skip_prefix = os.getcwd() 

177 # Remove prefix from os.path.abspath if there 

178 if input_path == os.path.abspath(input_path): 

179 result = input_path[len(skip_prefix) + 1:] 

180 else: 

181 result = input_path 

182 # Change file seps to unix-style 

183 return result.replace("\\", "/") 

184 

185 

186@patch("mlos_bench.services.remote.azure.azure_fileshare.open") 

187@patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") 

188@patch("mlos_bench.services.remote.azure.azure_fileshare.os.scandir") 

189def test_upload_directory_non_recursive(mock_scandir: MagicMock, 

190 mock_isdir: MagicMock, 

191 mock_open: MagicMock, 

192 azure_fileshare: AzureFileShareService) -> None: 

193 remote_folder = "a/remote/folder" 

194 local_folder = "some/local/folder" 

195 scandir_returns = make_scandir_returns(local_folder) 

196 isdir_returns = make_isdir_returns(local_folder) 

197 mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)] 

198 mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)] 

199 mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access 

200 config: dict = {} 

201 

202 with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: 

203 azure_fileshare.upload(config, local_folder, remote_folder, recursive=False) 

204 

205 mock_get_file_client.assert_called_with(f"{remote_folder}/a_file_1.csv") 

206 

207 

208@patch("mlos_bench.services.remote.azure.azure_fileshare.open") 

209@patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") 

210@patch("mlos_bench.services.remote.azure.azure_fileshare.os.scandir") 

211def test_upload_directory_recursive(mock_scandir: MagicMock, 

212 mock_isdir: MagicMock, 

213 mock_open: MagicMock, 

214 azure_fileshare: AzureFileShareService) -> None: 

215 remote_folder = "a/remote/folder" 

216 local_folder = "some/local/folder" 

217 scandir_returns = make_scandir_returns(local_folder) 

218 isdir_returns = make_isdir_returns(local_folder) 

219 mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)] 

220 mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)] 

221 mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access 

222 config: dict = {} 

223 

224 with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: 

225 azure_fileshare.upload(config, local_folder, remote_folder, recursive=True) 

226 

227 mock_get_file_client.assert_has_calls([ 

228 call(f"{remote_folder}/a_file_1.csv"), 

229 call(f"{remote_folder}/a_folder/a_file_2.csv"), 

230 ], any_order=True)