"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""importabcimportloggingfromtypingimportGeneratorimporttorchimportaurorafromauroraimportBatch,rollout__all__=["Model","models"]logger=logging.getLogger(__name__)# A dictionary containing ``<name, artifact_path>`` entries, where ``artifact_path`` is an# absolute filesystem path to the artifact.MLFLOW_ARTIFACTS:dict[str,str]=dict()classModel(metaclass=abc.ABCMeta):"""A model that can run predictions."""def__init__(self):"""Initialise. This creates and loads the model and determines the device the run the model on. """self.model=self.create_model()self.model.eval()iftorch.cuda.is_available():logger.info("GPU detected. Running on GPU.")self.target_device=torch.device("cuda")else:logger.warning("No GPU available. Running on CPU.")self.target_device=torch.device("cpu")@abc.abstractmethoddefcreate_model(self)->aurora.Aurora:"""Create the model. Returns: :class:`aurora.Aurora`: Model. """@torch.inference_modedefrun(self,batch:Batch,num_steps:int)->Generator[Batch,None,None]:"""Perform predictions on the target device. Args: batch (:class:`aurora.Batch`): Initial condition. num_steps (int): Number of prediction steps. Returns: :class:`aurora.Aurora`: Model. """# Move batch and model to target device.self.model.to(self.target_device)# Modifies in-place!batch=batch.to(self.target_device)# Perform predictions, immediately moving the output to the CPU.forpredinrollout(self.model,batch,steps=num_steps):yieldpred.to("cpu")# Move batch and model back to the CPU.batch=batch.to("cpu")self.model.cpu()# Modifies in-place!
[docs]classAuroraSmall(Model):name="aurora-0.25-small-pretrained""""str: Name of the model."""defcreate_model(self)->aurora.Aurora:model=aurora.AuroraSmall()model.load_checkpoint_local(MLFLOW_ARTIFACTS[self.name])returnmodel
[docs]classAuroraFineTuned(Model):name="aurora-0.25-finetuned""""str: Name of the model."""defcreate_model(self)->aurora.Aurora:model=aurora.Aurora()model.load_checkpoint_local(MLFLOW_ARTIFACTS[self.name])returnmodel
models:dict[str,type[Model]]={}"""dict[str, type[Model]]: A dictionary that lists all available models by their name."""formodel_classinModel.__subclasses__():asserthasattr(model_class,"name"),f"`{model_class.__name__}` is missing `name`."# `mypy` will complain, because `Model` is abstract, so it cannot be passed to `type`.models[model_class.name]=model_class# type: ignore[type-abstract]