Source code for aurora.foundry.client.foundry

"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

import json
import logging

import requests

__all__ = ["FoundryClient"]


logger = logging.getLogger(__name__)


[docs] class FoundryClient:
[docs] def __init__(self, endpoint: str, token: str) -> None: """Initialise. Args: endpoint (str): URL to the endpoint. token (str): Authorisation token. """ self.endpoint = endpoint self.token = token
def _req( self, data: dict | None = None, ) -> requests.Response: wrapped = {"data": json.dumps(data)} return requests.request( "POST", self.endpoint, headers={ "Authorization": f"Bearer {self.token}", "Content-Type": "application/json", }, json={"input_data": wrapped}, ) def _unwrap(self, response: requests.Response) -> dict: if not response.ok: logger.error(response.text) response.raise_for_status() response_json = response.json() return response_json def submit_task(self, data: dict) -> dict: """Send `data` to the scoring path. Args: data (dict): Data to send. Returns: dict: Submission information. """ answer = self._req({"type": "submission", "msg": data}) return self._unwrap(answer) def get_progress(self, task_id: str) -> dict: """Get the progress of the task. Args: task_id (str): Task ID to get progress info for. Returns: dict: Progress information. """ answer = self._req({"type": "task_info", "msg": {"task_id": task_id}}) return self._unwrap(answer)