Source code for archai.discrete_search.evaluators.remote_azure_benchmark

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import time
import datetime
import uuid
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from overrides import overrides

from archai.discrete_search.api.archai_model import ArchaiModel
from archai.discrete_search.api.model_evaluator import AsyncModelEvaluator
from archai.common.store import ArchaiStore


[docs]class RemoteAzureBenchmarkEvaluator(AsyncModelEvaluator): """Simple adapter for benchmarking architectures asynchronously on Azure. This adapter uploads an ONNX model to a Azure Blob storage container and records the model entry on the respective Azure Table. """ def __init__( self, input_shape: Union[Tuple, List[Tuple]], store: ArchaiStore, experiment_name: str, metric_key: str, overwrite: Optional[bool] = True, max_retries: Optional[int] = 5, retry_interval: Optional[int] = 120, onnx_export_kwargs: Optional[Dict[str, Any]] = None, verbose: bool = False, benchmark_only: bool = True ) -> None: """Initialize the evaluator. Args: input_shape: Model input shape or list of model input shapes for ONNX export. connection_string: Storage account connection string. blob_container_name: Name of the blob container. table_name: Name of the table. metric_key: Column that should be used as result. partition_key: Partition key for the table used to record all entries. overwrite: Whether to overwrite existing models. max_retries: Maximum number of retries in `fetch_all`. retry_interval: Interval between each retry attempt. onnx_export_kwargs: Dictionary containing key-value arguments for `torch.onnx.export`. verbose: Whether to print debug messages. """ # TODO: Make this class more general / less pipeline-specific self.store = store input_shapes = [input_shape] if isinstance(input_shape, tuple) else input_shape self.sample_input = tuple([torch.rand(*input_shape) for input_shape in input_shapes]) self.experiment_name = experiment_name self.metric_key = metric_key self.overwrite = overwrite self.max_retries = max_retries self.retry_interval = retry_interval self.onnx_export_kwargs = onnx_export_kwargs or dict() self.verbose = verbose self.results = {} self.benchmark_only = benchmark_only # Architecture list self.archids = [] # Test connection string works unknown_id = str(uuid.uuid4()) _ = self.store.get_existing_status(unknown_id) _ = self.store.list_blobs(unknown_id) def _reset(self, entity): changed = False for k in ['mean', 'macs', 'params', 'stdev', 'total_inference_avg', 'error']: if k in entity: del entity[k] changed = True if changed: self.store.update_status_entity(entity)
[docs] @overrides def send(self, arch: ArchaiModel, budget: Optional[float] = None) -> None: # bug in azure ml sdk requires blob store folder names not begin with digits, so we prefix with 'id_' archid = f'id_{arch.archid}' # Checks if architecture was already benchmarked entity = self.store.get_existing_status(archid) if entity is not None: if entity["status"] == "complete": if self.metric_key in entity: if self.verbose: value = entity[self.metric_key] self.archids.append(archid) print(f"Entry for {archid} already exists with {self.metric_key} = {value}") return else: # complete but missing the mean, so reset the benchmark metrics so we can try again. self._reset(entity) else: # job is still running, let it continue if self.verbose: print(f"Job for {archid} is running...") self.archids.append(archid) return entity = self.store.get_status(archid) # this is a get or create operation. if self.benchmark_only: entity["benchmark_only"] = 1 entity["model_date"] = self.store.get_utc_date() entity["model_name"] = "model.onnx" self.store.update_status_entity(entity) # must be an update, not a merge. self.store.lock_entity(entity, "uploading") if arch.arch is not None: try: with TemporaryDirectory() as tmp_dir: tmp_dir = Path(tmp_dir) # Uploads ONNX file to blob storage and updates the table entry arch.arch.to("cpu") file_name = str(tmp_dir / "model.onnx") # Exports model to ONNX torch.onnx.export( arch.arch, self.sample_input, file_name, input_names=[f"input_{i}" for i in range(len(self.sample_input))], **self.onnx_export_kwargs, ) self.store.upload_blob(f'{self.experiment_name}/{archid}', file_name, "model.onnx") entity["status"] = "new" except Exception as e: entity["error"] = str(e) else: # then the blob store must already have a model.onnx file! blobs = self.store.list_blobs(f'{self.experiment_name}/{archid}') if 'model.onnx' not in blobs: entity["error"] = "model.onnx is missing" self.store.unlock_entity(entity) self.archids.append(archid) if self.verbose: print(f"Sent {archid} to Remote Benchmark")
[docs] @overrides def fetch_all(self) -> List[Union[float, None]]: results = [None] * len(self.archids) completed = [False] * len(self.archids) # retries defines how long we wait for progress, as soon as we see something complete we # reset this counter because we are making progress. retries = self.max_retries start = time.time() count = 0 while retries > 0: for i, archid in enumerate(self.archids): if not completed[i]: entity = self.store.get_existing_status(archid) if entity is not None: if self.metric_key in entity and entity[self.metric_key]: results[i] = entity[self.metric_key] if "error" in entity: error = entity["error"] print(f"Skipping architecture {archid} because of remote error: {error}") completed[i] = True retries = self.max_retries elif entity["status"] == "complete": print(f"Architecture {archid} is complete with {self.metric_key}={ results[i]}") completed[i] = True retries = self.max_retries if all(completed): break count = sum(1 for c in completed if c) if self.verbose: remaining = len(self.archids) - count estimate = (time.time() - start) / (count + 1) * remaining status_dict = { "complete": count, "total": len(results), "time_remaining": str(datetime.timedelta(seconds=estimate)) } print( f"Waiting for results. Current status: {status_dict}\n" f"Pending Archids: {[archid for archid, status in zip(self.archids, results) if status is None]}" ) time.sleep(self.retry_interval) retries += 1 count = sum(1 for c in completed if c) if count == 0: raise Exception("Something is wrong, the uploaded models are not being processed. Please check your SNPE remote runner setup.") # Resets state self.archids = [] return results