Source code for genalog.ocr.blob_client

# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# ---------------------------------------------------------

"""Uses the python sdk to make operation on Azure Blob storage.
see: https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python
"""
import asyncio
import base64
import hashlib
import json
import os
import random

import aiofiles
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
from azure.storage.blob import BlobServiceClient
from azure.storage.blob.aio import BlobServiceClient as asyncBlobServiceClient
from tqdm import tqdm

from .common import DEFAULT_PROJECTIONS_CONTAINER_NAME

# maximum number of simultaneous requests
REQUEST_SEMAPHORE = asyncio.Semaphore(50)

# maximum number of simultaneous open files
FILE_SEMAPHORE = asyncio.Semaphore(500)

MAX_RETRIES = 5


[docs]class GrokBlobClient: """This class is a client that is used to upload and delete files from Azure Blob storage https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python """ def __init__( self, datasource_container_name, blob_account_name, blob_key, projections_container_name=DEFAULT_PROJECTIONS_CONTAINER_NAME, ): """Creates the blob storage client given the key and storage account name Args: datasource_container_name (str): container name. This container does not need to be existing projections_container_name (str): projections container to store ocr projections. This container does not need to be existing blob_account_name (str): storage account name blob_key (str): storage account key """ self.DATASOURCE_CONTAINER_NAME = datasource_container_name self.PROJECTIONS_CONTAINER_NAME = projections_container_name self.BLOB_NAME = blob_account_name self.BLOB_KEY = blob_key self.BLOB_CONNECTION_STRING = ( f"DefaultEndpointsProtocol=https;AccountName={self.BLOB_NAME};" f"AccountKey={self.BLOB_KEY};EndpointSuffix=core.windows.net" )
[docs] @staticmethod def create_from_env_var(): """Created the blob client using values in the environment variables Returns: GrokBlobClient: the new blob client """ DATASOURCE_CONTAINER_NAME = os.environ["DATASOURCE_CONTAINER_NAME"] BLOB_NAME = os.environ["BLOB_NAME"] BLOB_KEY = os.environ["BLOB_KEY"] PROJECTIONS_CONTAINER_NAME = os.environ.get( "PROJECTIONS_CONTAINER_NAME", DEFAULT_PROJECTIONS_CONTAINER_NAME ) client = GrokBlobClient( DATASOURCE_CONTAINER_NAME, BLOB_NAME, BLOB_KEY, projections_container_name=PROJECTIONS_CONTAINER_NAME, ) return client
[docs] def upload_images_to_blob( self, src_folder_path, dest_folder_name=None, check_existing_cache=True, use_async=True, ): """Uploads images from the src_folder_path to blob storage at the destination folder. The destination folder is created if it doesn't exist. If a destination folder is not given a folder is created named by the md5 hash of the files. Args: src_folder_path (src): path to local folder that has images dest_folder_name (str, optional): destination folder name. Defaults to None. Returns: str: the destination folder name """ self._create_container() blob_service_client = BlobServiceClient.from_connection_string( self.BLOB_CONNECTION_STRING ) if dest_folder_name is None: dest_folder_name = self.get_folder_hash(src_folder_path) files_to_upload = [] blob_names = [] for folder, _, files in os.walk(src_folder_path): for f in files: upload_file_path = os.path.join(folder, f) subfolder = folder.replace(src_folder_path, "") # Replace any "double //" to avoid creating an empty folder in the blob blob_name = f"{dest_folder_name}/{subfolder}/{f}".replace("//", "/") files_to_upload.append(upload_file_path) blob_names.append(blob_name) def get_job_args(upload_file_path, blob_name): return (upload_file_path, blob_name) if check_existing_cache: existing_blobs, _ = self.list_blobs(dest_folder_name or "") existing_blobs = list(map(lambda blob: blob["name"], existing_blobs)) file_blob_names = filter( lambda file_blob_names: not file_blob_names[1] in existing_blobs, zip(files_to_upload, blob_names), ) job_args = [ get_job_args(file_path, blob_name) for file_path, blob_name in file_blob_names ] else: job_args = [ get_job_args(file_path, blob_name) for file_path, blob_name in zip(files_to_upload, blob_names) ] print("uploading ", len(job_args), "files") if not use_async: blob_service_client = BlobServiceClient.from_connection_string( self.BLOB_CONNECTION_STRING ) blob_container_client = blob_service_client.get_container_client( self.DATASOURCE_CONTAINER_NAME ) jobs = [(blob_container_client,) + x for x in job_args] for _ in tqdm(map(_upload_worker_sync, jobs), total=len(jobs)): pass else: async_blob_service_client = asyncBlobServiceClient.from_connection_string( self.BLOB_CONNECTION_STRING ) async def async_upload(): async with async_blob_service_client: async_blob_container_client = ( async_blob_service_client.get_container_client( self.DATASOURCE_CONTAINER_NAME ) ) jobs = [(async_blob_container_client,) + x for x in job_args] for f in tqdm( asyncio.as_completed(map(_upload_worker_async, jobs)), total=len(jobs), ): await f loop = asyncio.get_event_loop() if loop.is_running(): result = loop.create_task(async_upload()) else: result = loop.run_until_complete(async_upload()) return dest_folder_name, result return dest_folder_name, None
[docs] def get_folder_hash(self, folder_name): """Create an Md5 hash for all files in a folder. This hash can be used for blob folders. Args: folder_name (str): path to folder Returns: str: md5 hash of all filenames in the folder """ hasher = hashlib.md5() for root, _, files in os.walk(folder_name): for f in files: fname = os.path.join(root, f) hasher.update(fname.encode()) folder_hash = hasher.hexdigest() return folder_hash
[docs] def delete_blobs_folder(self, folder_name): """Deletes all blobs in a folder Args: folder_name (str): folder to delete """ blobs_list, blob_service_client = self.list_blobs(folder_name) for blob in blobs_list: blob_client = blob_service_client.get_blob_client( container=self.DATASOURCE_CONTAINER_NAME, blob=blob ) blob_client.delete_blob()
def list_blobs(self, folder_name): blob_service_client = BlobServiceClient.from_connection_string( self.BLOB_CONNECTION_STRING ) container_client = blob_service_client.get_container_client( self.DATASOURCE_CONTAINER_NAME ) return ( container_client.list_blobs(name_starts_with=folder_name), blob_service_client, ) def _create_container(self): """Creates the container named {self.DATASOURCE_CONTAINER_NAME} if it doesn't exist.""" # Create the BlobServiceClient object which will be used to create a container client blob_service_client = BlobServiceClient.from_connection_string( self.BLOB_CONNECTION_STRING ) try: blob_service_client.create_container(self.DATASOURCE_CONTAINER_NAME) except ResourceExistsError: print("container already exists:", self.DATASOURCE_CONTAINER_NAME) # create the container for storing ocr projections try: print( "creating projections storage container:", self.PROJECTIONS_CONTAINER_NAME, ) blob_service_client.create_container(self.PROJECTIONS_CONTAINER_NAME) except ResourceExistsError: print("container already exists:", self.PROJECTIONS_CONTAINER_NAME) def get_ocr_json(self, remote_path, output_folder, use_async=True): blob_service_client = BlobServiceClient.from_connection_string( self.BLOB_CONNECTION_STRING ) container_client = blob_service_client.get_container_client( self.DATASOURCE_CONTAINER_NAME ) blobs_list = list(container_client.list_blobs(name_starts_with=remote_path)) container_uri = f"https://{self.BLOB_NAME}.blob.core.windows.net/{self.DATASOURCE_CONTAINER_NAME}" if use_async: async_blob_service_client = asyncBlobServiceClient.from_connection_string( self.BLOB_CONNECTION_STRING ) async def async_download(): async with async_blob_service_client: async_projection_container_client = ( async_blob_service_client.get_container_client( self.PROJECTIONS_CONTAINER_NAME ) ) jobs = list( map( lambda blob: ( blob, async_projection_container_client, container_uri, output_folder, ), blobs_list, ) ) for f in tqdm( asyncio.as_completed(map(_download_worker_async, jobs)), total=len(jobs), ): await f loop = asyncio.get_event_loop() if loop.is_running(): result = loop.create_task(async_download()) else: result = loop.run_until_complete(async_download()) return result else: projection_container_client = blob_service_client.get_container_client( self.PROJECTIONS_CONTAINER_NAME ) jobs = list( map( lambda blob: ( blob, projection_container_client, container_uri, output_folder, ), blobs_list, ) ) print("downloading", len(jobs), "files") for _ in tqdm(map(_download_worker_sync, jobs), total=len(jobs)): pass
def _get_projection_path(container_uri, blob): blob_uri = f"{container_uri}/{blob.name}" # projections use a base64 doc id as a key to store results in folders # see File Projection in https://docs.microsoft.com/en-us/azure/search/knowledge-store-projection-overview # hopefully this doesn't change soon otherwise we will have to do linear search over all docs to find # the projections we want projection_path = base64.b64encode(blob_uri.encode()).decode() projection_path = projection_path.replace("=", "") + str(projection_path.count("=")) return projection_path def _download_worker_sync(args): blob, projection_container_client, container_uri, output_folder = args projection_path = _get_projection_path(container_uri, blob) blob_client = projection_container_client.get_blob_client( blob=f"{projection_path}/document.json" ) doc = json.loads(blob_client.download_blob().readall()) file_name = os.path.basename(blob.name) base_name, ext = os.path.splitext(file_name) output_file = f"{output_folder}/{base_name}.json" os.makedirs(os.path.dirname(output_file), exist_ok=True) text = doc["ocrLayoutText"] json.dump(text, open(output_file, "w", encoding="utf-8"), ensure_ascii=False) return output_file async def _download_worker_async(args): blob, async_projection_container_client, container_uri, output_folder = args projection_path = _get_projection_path(container_uri, blob) async_blob_client = async_projection_container_client.get_blob_client( blob=f"{projection_path}/document.json" ) file_name = os.path.basename(blob.name) base_name, ext = os.path.splitext(file_name) for retry in range(MAX_RETRIES): async with REQUEST_SEMAPHORE: try: blob_task = await async_blob_client.download_blob() doc = json.loads(await blob_task.readall()) output_file = f"{output_folder}/{base_name}.json".replace("//", "/") async with FILE_SEMAPHORE: os.makedirs(os.path.dirname(output_file), exist_ok=True) text = doc["ocrLayoutText"] json.dump(text, open(output_file, "w")) return output_file except ResourceNotFoundError: print(f"Blob '{blob.name}'' doesn't exist in OCR projection. try rerunning OCR") return except Exception as e: print("error getting blob OCR projection", blob.name, e) # sleep for a bit then retry asyncio.sleep(2 * random.random()) async def _upload_worker_async(args): async_blob_container_client, upload_file_path, blob_name = args async with FILE_SEMAPHORE: async with aiofiles.open(upload_file_path, "rb") as f: data = await f.read() for retry in range(MAX_RETRIES): async with REQUEST_SEMAPHORE: try: await async_blob_container_client.upload_blob( name=blob_name, max_concurrency=8, data=data ) return blob_name except ResourceExistsError: print("blob already exists:", blob_name) return except Exception as e: print( f"blob upload error. retry count: {retry}/{MAX_RETRIES} :", blob_name, e, ) # sleep for a bit then retry asyncio.sleep(2 * random.random()) return blob_name def _upload_worker_sync(args): blob_container_client, upload_file_path, blob_name = args with open(upload_file_path, "rb") as data: try: blob_container_client.upload_blob( name=blob_name, max_concurrency=8, data=data ) except ResourceExistsError: print("blob already exists:", blob_name) except Exception as e: print("blob upload error:", blob_name, e) return blob_name