Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py: 91%

69 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6A collection FileShare functions for interacting with Azure File Shares. 

7""" 

8 

9import os 

10import logging 

11 

12from typing import Any, Callable, Dict, List, Optional, Set, Union 

13 

14from azure.storage.fileshare import ShareClient 

15from azure.core.exceptions import ResourceNotFoundError 

16 

17from mlos_bench.services.base_service import Service 

18from mlos_bench.services.base_fileshare import FileShareService 

19from mlos_bench.util import check_required_params 

20 

21_LOG = logging.getLogger(__name__) 

22 

23 

24class AzureFileShareService(FileShareService): 

25 """ 

26 Helper methods for interacting with Azure File Share 

27 """ 

28 

29 _SHARE_URL = "https://{account_name}.file.core.windows.net/{fs_name}" 

30 

31 def __init__(self, 

32 config: Optional[Dict[str, Any]] = None, 

33 global_config: Optional[Dict[str, Any]] = None, 

34 parent: Optional[Service] = None, 

35 methods: Union[Dict[str, Callable], List[Callable], None] = None): 

36 """ 

37 Create a new file share Service for Azure environments with a given config. 

38 

39 Parameters 

40 ---------- 

41 config : dict 

42 Free-format dictionary that contains the file share configuration. 

43 It will be passed as a constructor parameter of the class 

44 specified by `class_name`. 

45 global_config : dict 

46 Free-format dictionary of global parameters. 

47 parent : Service 

48 Parent service that can provide mixin functions. 

49 methods : Union[Dict[str, Callable], List[Callable], None] 

50 New methods to register with the service. 

51 """ 

52 super().__init__( 

53 config, global_config, parent, 

54 self.merge_methods(methods, [self.upload, self.download]) 

55 ) 

56 

57 check_required_params( 

58 self.config, { 

59 "storageAccountName", 

60 "storageFileShareName", 

61 "storageAccountKey", 

62 } 

63 ) 

64 

65 self._share_client = ShareClient.from_share_url( 

66 AzureFileShareService._SHARE_URL.format( 

67 account_name=self.config["storageAccountName"], 

68 fs_name=self.config["storageFileShareName"], 

69 ), 

70 credential=self.config["storageAccountKey"], 

71 ) 

72 

73 def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: 

74 super().download(params, remote_path, local_path, recursive) 

75 dir_client = self._share_client.get_directory_client(remote_path) 

76 if dir_client.exists(): 

77 os.makedirs(local_path, exist_ok=True) 

78 for content in dir_client.list_directories_and_files(): 

79 name = content["name"] 

80 local_target = f"{local_path}/{name}" 

81 remote_target = f"{remote_path}/{name}" 

82 if recursive or not content["is_directory"]: 

83 self.download(params, remote_target, local_target, recursive) 

84 else: # Must be a file 

85 # Ensure parent folders exist 

86 folder, _ = os.path.split(local_path) 

87 os.makedirs(folder, exist_ok=True) 

88 file_client = self._share_client.get_file_client(remote_path) 

89 try: 

90 data = file_client.download_file() 

91 with open(local_path, "wb") as output_file: 

92 _LOG.debug("Download file: %s -> %s", remote_path, local_path) 

93 data.readinto(output_file) # type: ignore[no-untyped-call] 

94 except ResourceNotFoundError as ex: 

95 # Translate into non-Azure exception: 

96 raise FileNotFoundError(f"Cannot download: {remote_path}") from ex 

97 

98 def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: 

99 super().upload(params, local_path, remote_path, recursive) 

100 self._upload(local_path, remote_path, recursive, set()) 

101 

102 def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[str]) -> None: 

103 """ 

104 Upload contents from a local path to an Azure file share. 

105 This method is called from `.upload()` above. We need it to avoid exposing 

106 the `seen` parameter and to make `.upload()` match the base class' virtual 

107 method. 

108 

109 Parameters 

110 ---------- 

111 local_path : str 

112 Path to the local directory to upload contents from, either a file or directory. 

113 remote_path : str 

114 Path in the remote file share to store the uploaded content to. 

115 recursive : bool 

116 If False, ignore the subdirectories; 

117 if True (the default), upload the entire directory tree. 

118 seen: Set[str] 

119 Helper set for keeping track of visited directories to break circular paths. 

120 """ 

121 local_path = os.path.abspath(local_path) 

122 if local_path in seen: 

123 _LOG.warning("Loop in directories, skipping '%s'", local_path) 

124 return 

125 seen.add(local_path) 

126 

127 if os.path.isdir(local_path): 

128 self._remote_makedirs(remote_path) 

129 for entry in os.scandir(local_path): 

130 name = entry.name 

131 local_target = f"{local_path}/{name}" 

132 remote_target = f"{remote_path}/{name}" 

133 if recursive or not entry.is_dir(): 

134 self._upload(local_target, remote_target, recursive, seen) 

135 else: 

136 # Ensure parent folders exist 

137 folder, _ = os.path.split(remote_path) 

138 self._remote_makedirs(folder) 

139 file_client = self._share_client.get_file_client(remote_path) 

140 with open(local_path, "rb") as file_data: 

141 _LOG.debug("Upload file: %s -> %s", local_path, remote_path) 

142 file_client.upload_file(file_data) 

143 

144 def _remote_makedirs(self, remote_path: str) -> None: 

145 """ 

146 Create remote directories for the entire path. 

147 Succeeds even some or all directories along the path already exist. 

148 

149 Parameters 

150 ---------- 

151 remote_path : str 

152 Path in the remote file share to create. 

153 """ 

154 path = "" 

155 for folder in remote_path.replace("\\", "/").split("/"): 

156 if not folder: 

157 continue 

158 path += folder + "/" 

159 dir_client = self._share_client.get_directory_client(path) 

160 if not dir_client.exists(): 

161 dir_client.create_directory()