Source code for aurora.rollout

"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

import dataclasses
from typing import Generator

import torch

from aurora.batch import Batch
from aurora.model.aurora import Aurora

__all__ = ["rollout"]


[docs] def rollout(model: Aurora, batch: Batch, steps: int) -> Generator[Batch, None, None]: """Perform a roll-out to make long-term predictions. Args: model (:class:`aurora.Aurora`): The model to roll out. batch (:class:`aurora.Batch`): The batch to start the roll-out from. steps (int): The number of roll-out steps. Yields: :class:`aurora.Batch`: The prediction after every step. """ # We will need to concatenate data, so ensure that everything is already of the right form. batch = model.batch_transform_hook(batch) # This might modify the available variables. # Use an arbitary parameter of the model to derive the data type and device. p = next(model.parameters()) batch = batch.type(p.dtype) batch = batch.crop(model.patch_size) batch = batch.to(p.device) for _ in range(steps): pred = model.forward(batch) yield pred # Add the appropriate history so the model can be run on the prediction. batch = dataclasses.replace( pred, surf_vars={ k: torch.cat([batch.surf_vars[k][:, 1:], v], dim=1) for k, v in pred.surf_vars.items() }, atmos_vars={ k: torch.cat([batch.atmos_vars[k][:, 1:], v], dim=1) for k, v in pred.atmos_vars.items() }, )