Source code for aurora.foundry.server.mlflow_wrapper

"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

import json
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from uuid import uuid4

import mlflow.pyfunc
from pydantic import BaseModel, HttpUrl

from aurora.foundry.common.channel import (
    BlobStorageChannel,
    iterate_prediction_files,
)
from aurora.foundry.common.model import MLFLOW_ARTIFACTS, models

__all__ = ["AuroraModelWrapper"]

# Need to give the name explicitly here, because the script may be run stand-alone.
logger = logging.getLogger("aurora.foundry.server.score")


class Submission(BaseModel):
    data_folder_uri: HttpUrl
    model_name: str
    num_steps: int

    class Config:
        json_schema_extra = dict(
            example=dict(
                data_folder_uri="https://my.blob.core.windows.net/container/some/path?WRITABLE_SAS",
                model_name="aurora-0.25-small-pretrained",
                num_steps=5,
            )
        )


class CreationResponse(BaseModel):
    task_id: str

    class Config:
        json_schema_extra = dict(example=dict(task_id="abc-123-def"))


class TaskInfo(BaseModel):
    task_id: str
    completed: bool
    progress_percentage: int
    success: bool | None
    submitted: bool
    status: str

    class Config:
        json_schema_extra = dict(
            example=dict(
                task_id="abc-123-def",
                completed=True,
                progress_percentage=100,
                success=None,
                submitted=True,
                error_info="Queued",
            )
        )


class Task:
    def __init__(self, submission: Submission):
        self.submission: Submission = submission

        self.task_info = TaskInfo(
            # TODO: Make sure that this `uuid` really is unique!
            task_id=str(uuid4()),
            completed=False,
            progress_percentage=0,
            success=None,
            submitted=False,
            status="Unsubmitted",
        )

    def __call__(self) -> None:
        self.task_info.status = "Running"

        try:
            submission = self.submission
            channel = BlobStorageChannel(str(submission.data_folder_uri))

            model_class = models[submission.model_name]
            model = model_class()

            batch = channel.receive(self.task_info.task_id, "input.nc")

            logger.info("Running predictions.")
            for i, (pred, path) in enumerate(
                zip(
                    model.run(batch, submission.num_steps),
                    iterate_prediction_files("prediction.nc", submission.num_steps),
                )
            ):
                channel.send(pred, self.task_info.task_id, path)

                self.task_info.progress_percentage = int((100 * (i + 1)) / submission.num_steps)

            self.task_info.success = True
            self.task_info.status = "Successfully completed"

        except Exception as exc:
            self.task_info.success = False
            self.task_info.status = f"Exception: {str(exc)}"

        finally:
            self.task_info.completed = True


[docs] class AuroraModelWrapper(mlflow.pyfunc.PythonModel): """A wrapper around an async workflow for making predictions with Aurora.""" def load_context(self, context) -> None: logging.getLogger("aurora").setLevel(logging.INFO) logger.info("Starting `ThreadPoolExecutor`.") self.POOL = ThreadPoolExecutor(max_workers=1) self.TASKS: dict[str, Task] = {} self.POOL.__enter__() MLFLOW_ARTIFACTS.update(context.artifacts) def predict(self, context, model_input: dict, params=None) -> dict: data = json.loads(model_input["data"].item()) if data["type"] == "submission": logger.info("Creating a new task.") task = Task(Submission(**data["msg"])) self.TASKS[task.task_info.task_id] = task return CreationResponse(task_id=task.task_info.task_id).model_dump() elif data["type"] == "task_info": logger.info("Processing an existing task.") task_id = data["msg"]["task_id"] if not task_id: raise Exception("Missing `task_id` parameter.") if task_id not in self.TASKS: raise Exception("Task ID cannot be found.") task = self.TASKS[task_id] if not task.task_info.submitted: # Attempt to submit the task if the initial condition is available. channel = BlobStorageChannel(str(task.submission.data_folder_uri)) if channel.exists(task_id, "input.nc"): logger.info("Initial condition was found. Submitting task.") # Send an acknowledgement back to test that the host can write. The client will # check for this acknowledgement. channel.write(b"Acknowledgement of initial condition", task_id, "input.nc.ack") # Queue the task. task.task_info.submitted = True task.task_info.status = "Queued" self.POOL.submit(task) else: logger.info("Initial condition not available. Waiting.") # Wait a little to prevent the client for querying too frequently. time.sleep(3) else: logger.info("Task still running. Waiting.") # Wait a little to prevent the client for querying too frequently. While waiting, # do check for the task to be completed. for _ in range(3): if task.task_info.completed: break time.sleep(1) return task.task_info.model_dump() else: raise ValueError(f"Unknown data type: `{data['type']}`.")