Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py: 91%
69 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6A collection FileShare functions for interacting with Azure File Shares.
7"""
9import os
10import logging
12from typing import Any, Callable, Dict, List, Optional, Set, Union
14from azure.storage.fileshare import ShareClient
15from azure.core.exceptions import ResourceNotFoundError
17from mlos_bench.services.base_service import Service
18from mlos_bench.services.base_fileshare import FileShareService
19from mlos_bench.util import check_required_params
21_LOG = logging.getLogger(__name__)
24class AzureFileShareService(FileShareService):
25 """
26 Helper methods for interacting with Azure File Share
27 """
29 _SHARE_URL = "https://{account_name}.file.core.windows.net/{fs_name}"
31 def __init__(self,
32 config: Optional[Dict[str, Any]] = None,
33 global_config: Optional[Dict[str, Any]] = None,
34 parent: Optional[Service] = None,
35 methods: Union[Dict[str, Callable], List[Callable], None] = None):
36 """
37 Create a new file share Service for Azure environments with a given config.
39 Parameters
40 ----------
41 config : dict
42 Free-format dictionary that contains the file share configuration.
43 It will be passed as a constructor parameter of the class
44 specified by `class_name`.
45 global_config : dict
46 Free-format dictionary of global parameters.
47 parent : Service
48 Parent service that can provide mixin functions.
49 methods : Union[Dict[str, Callable], List[Callable], None]
50 New methods to register with the service.
51 """
52 super().__init__(
53 config, global_config, parent,
54 self.merge_methods(methods, [self.upload, self.download])
55 )
57 check_required_params(
58 self.config, {
59 "storageAccountName",
60 "storageFileShareName",
61 "storageAccountKey",
62 }
63 )
65 self._share_client = ShareClient.from_share_url(
66 AzureFileShareService._SHARE_URL.format(
67 account_name=self.config["storageAccountName"],
68 fs_name=self.config["storageFileShareName"],
69 ),
70 credential=self.config["storageAccountKey"],
71 )
73 def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None:
74 super().download(params, remote_path, local_path, recursive)
75 dir_client = self._share_client.get_directory_client(remote_path)
76 if dir_client.exists():
77 os.makedirs(local_path, exist_ok=True)
78 for content in dir_client.list_directories_and_files():
79 name = content["name"]
80 local_target = f"{local_path}/{name}"
81 remote_target = f"{remote_path}/{name}"
82 if recursive or not content["is_directory"]:
83 self.download(params, remote_target, local_target, recursive)
84 else: # Must be a file
85 # Ensure parent folders exist
86 folder, _ = os.path.split(local_path)
87 os.makedirs(folder, exist_ok=True)
88 file_client = self._share_client.get_file_client(remote_path)
89 try:
90 data = file_client.download_file()
91 with open(local_path, "wb") as output_file:
92 _LOG.debug("Download file: %s -> %s", remote_path, local_path)
93 data.readinto(output_file) # type: ignore[no-untyped-call]
94 except ResourceNotFoundError as ex:
95 # Translate into non-Azure exception:
96 raise FileNotFoundError(f"Cannot download: {remote_path}") from ex
98 def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None:
99 super().upload(params, local_path, remote_path, recursive)
100 self._upload(local_path, remote_path, recursive, set())
102 def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[str]) -> None:
103 """
104 Upload contents from a local path to an Azure file share.
105 This method is called from `.upload()` above. We need it to avoid exposing
106 the `seen` parameter and to make `.upload()` match the base class' virtual
107 method.
109 Parameters
110 ----------
111 local_path : str
112 Path to the local directory to upload contents from, either a file or directory.
113 remote_path : str
114 Path in the remote file share to store the uploaded content to.
115 recursive : bool
116 If False, ignore the subdirectories;
117 if True (the default), upload the entire directory tree.
118 seen: Set[str]
119 Helper set for keeping track of visited directories to break circular paths.
120 """
121 local_path = os.path.abspath(local_path)
122 if local_path in seen:
123 _LOG.warning("Loop in directories, skipping '%s'", local_path)
124 return
125 seen.add(local_path)
127 if os.path.isdir(local_path):
128 self._remote_makedirs(remote_path)
129 for entry in os.scandir(local_path):
130 name = entry.name
131 local_target = f"{local_path}/{name}"
132 remote_target = f"{remote_path}/{name}"
133 if recursive or not entry.is_dir():
134 self._upload(local_target, remote_target, recursive, seen)
135 else:
136 # Ensure parent folders exist
137 folder, _ = os.path.split(remote_path)
138 self._remote_makedirs(folder)
139 file_client = self._share_client.get_file_client(remote_path)
140 with open(local_path, "rb") as file_data:
141 _LOG.debug("Upload file: %s -> %s", local_path, remote_path)
142 file_client.upload_file(file_data)
144 def _remote_makedirs(self, remote_path: str) -> None:
145 """
146 Create remote directories for the entire path.
147 Succeeds even some or all directories along the path already exist.
149 Parameters
150 ----------
151 remote_path : str
152 Path in the remote file share to create.
153 """
154 path = ""
155 for folder in remote_path.replace("\\", "/").split("/"):
156 if not folder:
157 continue
158 path += folder + "/"
159 dir_client = self._share_client.get_directory_client(path)
160 if not dir_client.exists():
161 dir_client.create_directory()