Source code for aurora.model.aurora

"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

import contextlib
import dataclasses
import warnings
from datetime import timedelta
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    apply_activation_checkpointing,
)

from aurora.batch import Batch
from aurora.model.compat import (
    _adapt_checkpoint_air_pollution,
    _adapt_checkpoint_pretrained,
    _adapt_checkpoint_wave,
)
from aurora.model.decoder import Perceiver3DDecoder
from aurora.model.encoder import Perceiver3DEncoder
from aurora.model.lora import LoRAMode
from aurora.model.swin3d import Swin3DTransformerBackbone

__all__ = [
    "Aurora",
    "AuroraPretrained",
    "AuroraSmallPretrained",
    "AuroraSmall",
    "Aurora12hPretrained",
    "AuroraHighRes",
    "AuroraAirPollution",
    "AuroraWave",
]


[docs] class Aurora(torch.nn.Module): """The Aurora model. Defaults to the 1.3 B parameter configuration. """ default_checkpoint_repo = "microsoft/aurora" """str: Name of the HuggingFace repository to load the default checkpoint from.""" default_checkpoint_name = "aurora-0.25-finetuned.ckpt" """str: Name of the default checkpoint.""" default_checkpoint_revision = "0be7e57c685dac86b78c4a19a3ab149d13c6a3dd" """str: Commit hash of the default checkpoint."""
[docs] def __init__( self, *, surf_vars: tuple[str, ...] = ("2t", "10u", "10v", "msl"), static_vars: tuple[str, ...] = ("lsm", "z", "slt"), atmos_vars: tuple[str, ...] = ("z", "u", "v", "t", "q"), window_size: tuple[int, int, int] = (2, 6, 12), encoder_depths: tuple[int, ...] = (6, 10, 8), encoder_num_heads: tuple[int, ...] = (8, 16, 32), decoder_depths: tuple[int, ...] = (8, 10, 6), decoder_num_heads: tuple[int, ...] = (32, 16, 8), latent_levels: int = 4, patch_size: int = 4, embed_dim: int = 512, num_heads: int = 16, mlp_ratio: float = 4.0, drop_path: float = 0.0, drop_rate: float = 0.0, enc_depth: int = 1, dec_depth: int = 1, dec_mlp_ratio: float = 2.0, perceiver_ln_eps: float = 1e-5, max_history_size: int = 2, timestep: timedelta = timedelta(hours=6), stabilise_level_agg: bool = False, use_lora: bool = True, lora_steps: int = 40, lora_mode: LoRAMode = "single", surf_stats: Optional[dict[str, tuple[float, float]]] = None, autocast: bool = False, bf16_mode: bool = False, level_condition: Optional[tuple[int | float, ...]] = None, dynamic_vars: bool = False, atmos_static_vars: bool = False, separate_perceiver: tuple[str, ...] = (), modulation_heads: tuple[str, ...] = (), positive_surf_vars: tuple[str, ...] = (), positive_atmos_vars: tuple[str, ...] = (), clamp_at_first_step: bool = False, simulate_indexing_bug: bool = False, ) -> None: """Construct an instance of the model. Args: surf_vars (tuple[str, ...], optional): All surface-level variables supported by the model. static_vars (tuple[str, ...], optional): All static variables supported by the model. atmos_vars (tuple[str, ...], optional): All atmospheric variables supported by the model. window_size (tuple[int, int, int], optional): Vertical height, height, and width of the window of the underlying Swin transformer. encoder_depths (tuple[int, ...], optional): Number of blocks in each encoder layer. encoder_num_heads (tuple[int, ...], optional): Number of attention heads in each encoder layer. The dimensionality doubles after every layer. To keep the dimensionality of every head constant, you want to double the number of heads after every layer. The dimensionality of attention head of the first layer is determined by `embed_dim` divided by the value here. For all cases except one, this is equal to `64`. decoder_depths (tuple[int, ...], optional): Number of blocks in each decoder layer. Generally, you want this to be the reversal of `encoder_depths`. decoder_num_heads (tuple[int, ...], optional): Number of attention heads in each decoder layer. Generally, you want this to be the reversal of `encoder_num_heads`. latent_levels (int, optional): Number of latent pressure levels. patch_size (int, optional): Patch size. embed_dim (int, optional): Patch embedding dimension. num_heads (int, optional): Number of attention heads in the aggregation and deaggregation blocks. The dimensionality of these attention heads will be equal to `embed_dim` divided by this value. mlp_ratio (float, optional): Hidden dim. to embedding dim. ratio for MLPs. drop_rate (float, optional): Drop-out rate. drop_path (float, optional): Drop-path rate. enc_depth (int, optional): Number of Perceiver blocks in the encoder. dec_depth (int, optioanl): Number of Perceiver blocks in the decoder. dec_mlp_ratio (float, optional): Hidden dim. to embedding dim. ratio for MLPs in the decoder. The embedding dimensionality here is different, which is why this is a separate parameter. perceiver_ln_eps (float, optional): Epsilon in the perceiver layer norm. layers. Used to stabilise the model. max_history_size (int, optional): Maximum number of history steps. You can load checkpoints with a smaller `max_history_size`, but you cannot load checkpoints with a larger `max_history_size`. timestep (timedelta, optional): Timestep of the model. Defaults to 6 hours. stabilise_level_agg (bool, optional): Stabilise the level aggregation by inserting an additional layer normalisation. Defaults to `False`. use_lora (bool, optional): Use LoRA adaptation. lora_steps (int, optional): Use different LoRA adaptation for the first so-many roll-out steps. lora_mode (str, optional): LoRA mode. `"single"` uses the same LoRA for all roll-out steps, `"from_second"` uses the same LoRA from the second roll-out step on, and `"all"` uses a different LoRA for every roll-out step. Defaults to `"single"`. surf_stats (dict[str, tuple[float, float]], optional): For these surface-level variables, adjust the normalisation to the given tuple consisting of a new location and scale. bf16_mode (bool, optional): To reduce memory usage, convert the tokens to BF16, run the backbone in pure BF16, and run the decoder in FP16 AMP. This should enable a gradient computation. USE AT YOUR OWN RISK. THIS WAS NOT USED DURING THE DEVELOPMENT OF AURORA AND IS PURELY PROVIDED AS A STARTING POINT FOR FINE-TUNING. level_condition (tuple[int | float, ...], optional): Make the patch embeddings dependent on pressure level. If you want to enable this feature, provide a tuple of all possible pressure levels. dynamic_vars (bool, optional): Use dynamically generated static variables, like time of day. Defaults to `False`. atmos_static_vars (bool, optional): Also concatenate the static variables to the atmospheric variables. Defaults to `False`. separate_perceiver (tuple[str, ...], optional): In the decoder, use a separate Perceiver for specific atmospheric variables. This can be helpful at fine-tuning time to deal with variables that have a significantly different behaviour. If you want to enable this features, set this to the collection of variables that should be run on a separate Perceiver. modulation_heads (tuple[str, ...], optional): Names of every variable for which to enable an additional head, the so-called modulation head, that can be used to predict the difference. positive_surf_vars (tuple[str, ...], optional): Mark these surface-level variables as positive. Clamp them before running them through the encoder, and also clamp them when autoregressively rolling out the model. The variables are not clamped for the first roll-out step. positive_atmos_vars (tuple[str, ...], optional): Mark these atmospheric variables as positive. Clamp them before running them through the encoder, and also clamp them when autoregressively rolling out the model. The variables are not clamped for the first roll-out step. clamp_at_first_step (bool, optional): Clamp the positive variables for the first roll-out step. Should only be used for inference. Defaults to `False`. simulate_indexing_bug (bool, optional): Simulate an indexing bug that's present for the air pollution version of Aurora. This is necessary to obtain numerical equivalence to the original implementation. Defaults to `False`. """ super().__init__() self.surf_vars = surf_vars self.atmos_vars = atmos_vars self.patch_size = patch_size self.surf_stats = surf_stats or dict() self.max_history_size = max_history_size self.timestep = timestep self.use_lora = use_lora self.positive_surf_vars = positive_surf_vars self.positive_atmos_vars = positive_atmos_vars self.clamp_at_first_step = clamp_at_first_step if self.surf_stats: warnings.warn( f"The normalisation statics for the following surface-level variables are manually " f"adjusted: {', '.join(sorted(self.surf_stats.keys()))}. " f"Please ensure that this is right!", stacklevel=2, ) self.encoder = Perceiver3DEncoder( surf_vars=surf_vars, static_vars=static_vars, atmos_vars=atmos_vars, patch_size=patch_size, embed_dim=embed_dim, num_heads=num_heads, drop_rate=drop_rate, mlp_ratio=mlp_ratio, head_dim=embed_dim // num_heads, depth=enc_depth, latent_levels=latent_levels, max_history_size=max_history_size, perceiver_ln_eps=perceiver_ln_eps, stabilise_level_agg=stabilise_level_agg, level_condition=level_condition, dynamic_vars=dynamic_vars, atmos_static_vars=atmos_static_vars, simulate_indexing_bug=simulate_indexing_bug, ) self.backbone = Swin3DTransformerBackbone( window_size=window_size, encoder_depths=encoder_depths, encoder_num_heads=encoder_num_heads, decoder_depths=decoder_depths, decoder_num_heads=decoder_num_heads, embed_dim=embed_dim, mlp_ratio=mlp_ratio, drop_path_rate=drop_path, drop_rate=drop_rate, use_lora=use_lora, lora_steps=lora_steps, lora_mode=lora_mode, ) self.decoder = Perceiver3DDecoder( surf_vars=surf_vars, atmos_vars=atmos_vars, patch_size=patch_size, # Concatenation at the backbone end doubles the dim. embed_dim=embed_dim * 2, head_dim=embed_dim * 2 // num_heads, num_heads=num_heads, depth=dec_depth, # Because of the concatenation, high ratios are expensive. # We use a lower ratio here to keep the memory in check. mlp_ratio=dec_mlp_ratio, perceiver_ln_eps=perceiver_ln_eps, level_condition=level_condition, separate_perceiver=separate_perceiver, modulation_heads=modulation_heads, ) if autocast and not bf16_mode: warnings.warn( "The argument `autocast` no longer does anything due to limited utility. " "Consider instead using `bf16_mode`.", stacklevel=2, ) self.bf16_mode = bf16_mode if self.bf16_mode: # We run the backbone in pure BF16. self.backbone.to(torch.bfloat16)
[docs] def forward(self, batch: Batch) -> Batch: """Forward pass. Args: batch (:class:`Batch`): Batch to run the model on. Returns: :class:`Batch`: Prediction for the batch. """ batch = self.batch_transform_hook(batch) # Get the first parameter. We'll derive the data type and device from this parameter. p = next(self.parameters()) batch = batch.type(p.dtype) batch = batch.normalise(surf_stats=self.surf_stats) batch = batch.crop(patch_size=self.patch_size) batch = batch.to(p.device) H, W = batch.spatial_shape patch_res = ( self.encoder.latent_levels, H // self.encoder.patch_size, W // self.encoder.patch_size, ) # Insert batch and history dimension for static variables. B, T = next(iter(batch.surf_vars.values())).shape[:2] batch = dataclasses.replace( batch, static_vars={k: v[None, None].repeat(B, T, 1, 1) for k, v in batch.static_vars.items()}, ) # Apply some transformations before feeding `batch` to the encoder. We'll later want to # refer to the original batch too, so rename the variable. transformed_batch = batch # Clamp positive variables. if self.positive_surf_vars: transformed_batch = dataclasses.replace( transformed_batch, surf_vars={ k: v.clamp(min=0) if k in self.positive_surf_vars else v for k, v in batch.surf_vars.items() }, ) if self.positive_atmos_vars: transformed_batch = dataclasses.replace( transformed_batch, atmos_vars={ k: v.clamp(min=0) if k in self.positive_atmos_vars else v for k, v in batch.atmos_vars.items() }, ) transformed_batch = self._pre_encoder_hook(transformed_batch) # The encoder is always just run. x = self.encoder( transformed_batch, lead_time=self.timestep, ) # In BF16 mode, the backbone is run in pure BF16. if self.bf16_mode: x = x.to(torch.bfloat16) x = self.backbone( x, lead_time=self.timestep, patch_res=patch_res, rollout_step=batch.metadata.rollout_step, ) # In BF16 mode, the decoder is run in AMP PF16, and the output is converted back to FP32. # We run in PF16 as opposed to BF16 for improved relative precision. if self.bf16_mode: device_type = ( "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu" ) context = torch.autocast(device_type=device_type, dtype=torch.float16) x = x.to(torch.float16) else: context = contextlib.nullcontext() with context: pred = self.decoder( x, batch, lead_time=self.timestep, patch_res=patch_res, ) if self.bf16_mode: pred = dataclasses.replace( pred, surf_vars={k: v.float() for k, v in pred.surf_vars.items()}, static_vars={k: v.float() for k, v in pred.static_vars.items()}, atmos_vars={k: v.float() for k, v in pred.atmos_vars.items()}, ) # Remove batch and history dimension from static variables. pred = dataclasses.replace( pred, static_vars={k: v[0, 0] for k, v in batch.static_vars.items()}, ) # Insert history dimension in prediction. The time should already be right. pred = dataclasses.replace( pred, surf_vars={k: v[:, None] for k, v in pred.surf_vars.items()}, atmos_vars={k: v[:, None] for k, v in pred.atmos_vars.items()}, ) pred = self._post_decoder_hook(batch, pred) # Clamp positive variables. clamp_at_rollout_step = ( pred.metadata.rollout_step >= 1 if self.clamp_at_first_step else pred.metadata.rollout_step > 1 ) if self.positive_surf_vars and clamp_at_rollout_step: pred = dataclasses.replace( pred, surf_vars={ k: v.clamp(min=0) if k in self.positive_surf_vars else v for k, v in pred.surf_vars.items() }, ) if self.positive_atmos_vars and clamp_at_rollout_step: pred = dataclasses.replace( pred, atmos_vars={ k: v.clamp(min=0) if k in self.positive_atmos_vars else v for k, v in pred.atmos_vars.items() }, ) pred = pred.unnormalise(surf_stats=self.surf_stats) return pred
[docs] def batch_transform_hook(self, batch: Batch) -> Batch: """Transform the batch right after receiving it and before normalisation. This function should be idempotent. """ return batch
def _pre_encoder_hook(self, batch: Batch) -> Batch: """Transform the batch before it goes through the encoder.""" return batch def _post_decoder_hook(self, batch: Batch, pred: Batch) -> Batch: """Transform the prediction right after the decoder.""" return pred
[docs] def load_checkpoint( self, repo: Optional[str] = None, name: Optional[str] = None, revision: Optional[str] = None, strict: bool = True, ) -> None: """Load a checkpoint from HuggingFace. Args: repo (str, optional): Name of the repository of the form `user/repo`. name (str, optional): Path to the checkpoint relative to the root of the repository, e.g. `checkpoint.cpkt`. revision (str, optional): Version hash of the Huggingface git repository commit. strict (bool, optional): Error if the model parameters are not exactly equal to the parameters in the checkpoint. Defaults to `True`. """ repo = repo or self.default_checkpoint_repo name = name or self.default_checkpoint_name revision = revision or self.default_checkpoint_revision path = hf_hub_download(repo_id=repo, filename=name, revision=revision) self.load_checkpoint_local(path, strict=strict)
[docs] def load_checkpoint_local(self, path: str, strict: bool = True) -> None: """Load a checkpoint directly from a file. Args: path (str): Path to the checkpoint. strict (bool, optional): Error if the model parameters are not exactly equal to the parameters in the checkpoint. Defaults to `True`. """ # Assume that all parameters are either on the CPU or on the GPU. device = next(self.parameters()).device d = torch.load(path, map_location=device, weights_only=True) d = self._adapt_checkpoint(d) # Check if the history size is compatible and adjust weights if necessary. current_history_size = d["encoder.surf_token_embeds.weights.2t"].shape[2] if self.max_history_size > current_history_size: self.adapt_checkpoint_max_history_size(d) elif self.max_history_size < current_history_size: raise AssertionError( f"Cannot load checkpoint with `max_history_size` {current_history_size} " f"into model with `max_history_size` {self.max_history_size}." ) self.load_state_dict(d, strict=strict)
def _adapt_checkpoint(self, d: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Adapt an existing checkpoint to make it compatible with the current version of the model. Args: d (dict[str, torch.Tensor]): Checkpoint. Return: dict[str, torch.Tensor]: Adapted checkpoint. """ return _adapt_checkpoint_pretrained(self.patch_size, d)
[docs] def adapt_checkpoint_max_history_size(self, checkpoint: dict[str, torch.Tensor]) -> None: """Adapt a checkpoint with smaller `max_history_size` to a model with a larger `max_history_size` than the current model. If a checkpoint was trained with a larger `max_history_size` than the current model, this function will assert fail to prevent loading the checkpoint. This is to prevent loading a checkpoint which will likely cause the checkpoint to degrade is performance. This implementation copies weights from the checkpoint to the model and fills zeros for the new history width dimension. It mutates `checkpoint`. """ for name, weight in list(checkpoint.items()): # We only need to adapt the patch embedding in the encoder. enc_surf_embedding = name.startswith("encoder.surf_token_embeds.weights.") enc_atmos_embedding = name.startswith("encoder.atmos_token_embeds.weights.") if enc_surf_embedding or enc_atmos_embedding: # This shouldn't get called with current logic but leaving here for future proofing # and in cases where its called outside current context. if not (weight.shape[2] <= self.max_history_size): raise AssertionError( f"Cannot load checkpoint with `max_history_size` {weight.shape[2]} " f"into model with `max_history_size` {self.max_history_size}." ) # Initialize the new weight tensor. new_weight = torch.zeros( (weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4]), device=weight.device, dtype=weight.dtype, ) # Copy the existing weights to the new tensor by duplicating the histories provided # into any new history dimensions. The rest remains at zero. new_weight[:, :, : weight.shape[2]] = weight checkpoint[name] = new_weight
[docs] def configure_activation_checkpointing(self): """Configure activation checkpointing. This is required in order to compute gradients without running out of memory. """ # Checkpoint these modules: module_names = ( "Perceiver3DEncoder", "Swin3DTransformerBackbone", "Basic3DEncoderLayer", "Basic3DDecoderLayer", "Perceiver3DDecoder", "LinearPatchReconstruction", ) def check(x: torch.nn.Module) -> bool: name = x.__class__.__name__ return name in module_names apply_activation_checkpointing(self, check_fn=check)
[docs] class AuroraPretrained(Aurora): """Pretrained version of Aurora.""" default_checkpoint_name = "aurora-0.25-pretrained.ckpt" default_checkpoint_revision = "0be7e57c685dac86b78c4a19a3ab149d13c6a3dd" def __init__( self, *, use_lora: bool = False, **kw_args, ) -> None: super().__init__( use_lora=use_lora, **kw_args, )
[docs] class AuroraSmallPretrained(Aurora): """Small pretrained version of Aurora. Should only be used for debugging. """ default_checkpoint_name = "aurora-0.25-small-pretrained.ckpt" default_checkpoint_revision = "0be7e57c685dac86b78c4a19a3ab149d13c6a3dd" def __init__( self, *, encoder_depths: tuple[int, ...] = (2, 6, 2), encoder_num_heads: tuple[int, ...] = (4, 8, 16), decoder_depths: tuple[int, ...] = (2, 6, 2), decoder_num_heads: tuple[int, ...] = (16, 8, 4), embed_dim: int = 256, num_heads: int = 8, use_lora: bool = False, **kw_args, ) -> None: super().__init__( encoder_depths=encoder_depths, encoder_num_heads=encoder_num_heads, decoder_depths=decoder_depths, decoder_num_heads=decoder_num_heads, embed_dim=embed_dim, num_heads=num_heads, use_lora=use_lora, **kw_args, )
AuroraSmall = AuroraSmallPretrained #: Alias for backwards compatibility
[docs] class Aurora12hPretrained(Aurora): """Pretrained version of Aurora with time step 12 hours.""" default_checkpoint_name = "aurora-0.25-12h-pretrained.ckpt" default_checkpoint_revision = "15e76e47b65bf4b28fd2246b7b5b951d6e2443b9" def __init__( self, *, timestep: timedelta = timedelta(hours=12), use_lora: bool = False, **kw_args, ) -> None: super().__init__( timestep=timestep, use_lora=use_lora, **kw_args, )
[docs] class AuroraHighRes(Aurora): """High-resolution version of Aurora.""" default_checkpoint_name = "aurora-0.1-finetuned.ckpt" default_checkpoint_revision = "0be7e57c685dac86b78c4a19a3ab149d13c6a3dd" def __init__( self, *, patch_size: int = 10, encoder_depths: tuple[int, ...] = (6, 8, 8), decoder_depths: tuple[int, ...] = (8, 8, 6), **kw_args, ) -> None: super().__init__( patch_size=patch_size, encoder_depths=encoder_depths, decoder_depths=decoder_depths, **kw_args, )
[docs] class AuroraAirPollution(Aurora): """Fine-tuned version of Aurora for air pollution.""" default_checkpoint_name = "aurora-0.4-air-pollution.ckpt" default_checkpoint_revision = "1764d5630a53d3d7a7d169ca335236fc343e4bfc" _predict_difference_history_dim_lookup = { "pm1": 0, "pm2p5": 0, "pm10": 0, "co": 1, "tcco": 1, "no": 0, "tc_no": 0, "no2": 0, "tcno2": 0, "so2": 1, "tcso2": 1, "go3": 1, "gtco3": 1, } """dict[str, int]: For every variable that we want to predict the difference for, the index into the history dimension that should be used when predicting the difference.""" def __init__( self, *, surf_vars: tuple[str, ...] = ( ("2t", "10u", "10v", "msl") + ("pm1", "pm2p5", "pm10", "tcco", "tc_no", "tcno2", "gtco3", "tcso2") ), static_vars: tuple[str, ...] = ( ("lsm", "z", "slt") + ("static_ammonia", "static_ammonia_log", "static_co", "static_co_log") + ("static_nox", "static_nox_log", "static_so2", "static_so2_log") ), atmos_vars: tuple[str, ...] = ("z", "u", "v", "t", "q", "co", "no", "no2", "go3", "so2"), patch_size: int = 3, timestep: timedelta = timedelta(hours=12), level_condition: Optional[tuple[int | float, ...]] = ( (50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000) ), dynamic_vars: bool = True, atmos_static_vars: bool = True, separate_perceiver: tuple[str, ...] = ("co", "no", "no2", "go3", "so2"), modulation_heads: tuple[str, ...] = tuple(_predict_difference_history_dim_lookup.keys()), positive_surf_vars: tuple[str, ...] = ( ("pm1", "pm2p5", "pm10", "tcco", "tc_no", "tcno2", "gtco3", "tcso2") ), positive_atmos_vars: tuple[str, ...] = ("co", "no", "no2", "go3", "so2"), simulate_indexing_bug: bool = True, **kw_args, ) -> None: super().__init__( surf_vars=surf_vars, static_vars=static_vars, atmos_vars=atmos_vars, patch_size=patch_size, timestep=timestep, level_condition=level_condition, dynamic_vars=dynamic_vars, atmos_static_vars=atmos_static_vars, separate_perceiver=separate_perceiver, modulation_heads=modulation_heads, positive_surf_vars=positive_surf_vars, positive_atmos_vars=positive_atmos_vars, simulate_indexing_bug=simulate_indexing_bug, **kw_args, ) self.surf_feature_combiner = torch.nn.ParameterDict( {v: nn.Linear(2, 1, bias=True) for v in self.positive_surf_vars} ) self.atmos_feature_combiner = torch.nn.ParameterDict( {v: nn.Linear(2, 1, bias=True) for v in self.positive_atmos_vars} ) for p in (*self.surf_feature_combiner.values(), *self.atmos_feature_combiner.values()): nn.init.constant_(p.weight, 0.5) nn.init.zeros_(p.bias) def _pre_encoder_hook(self, batch: Batch) -> Batch: # Transform the spikey variables with a specific log-transform before feeding them # to the encoder. See the paper for a motivation for the precise form of the transform. eps = 1e-4 divisor = -np.log(eps) def _transform(z: torch.Tensor, feature_combiner: nn.Module) -> torch.Tensor: return feature_combiner( torch.stack( [ z.clamp(min=0, max=2.5), (torch.log(z.clamp(min=eps)) - np.log(eps)) / divisor, ], dim=-1, ) )[..., 0] return dataclasses.replace( batch, surf_vars={ k: _transform(v, self.surf_feature_combiner[k]) if k in self.surf_feature_combiner else v for k, v in batch.surf_vars.items() }, atmos_vars={ k: _transform(v, self.atmos_feature_combiner[k]) if k in self.atmos_feature_combiner else v for k, v in batch.atmos_vars.items() }, ) def _post_decoder_hook(self, batch: Batch, pred: Batch) -> Batch: # For this version of the model, we predict the difference. Specifically w.r.t. which # previous timestep (12 hours ago or 24 hours ago) is given by # `Aurora._predict_difference_history_dim_lookup`. dim_lookup = AuroraAirPollution._predict_difference_history_dim_lookup def _transform( prev: dict[str, torch.Tensor], model: dict[str, torch.Tensor], name: str, ) -> torch.Tensor: if name in dim_lookup: return model[name] + (1 + model[f"{name}_mod"]) * prev[name][:, dim_lookup[name]] else: return model[name] pred = dataclasses.replace( pred, surf_vars={k: _transform(batch.surf_vars, pred.surf_vars, k) for k in batch.surf_vars}, atmos_vars={ k: _transform(batch.atmos_vars, pred.atmos_vars, k) for k in batch.atmos_vars }, ) # When using LoRA, the lower-atmospheric levels of SO2 can be problematic and blow up. # We attempt to fix that by some very aggressive output clipping. if self.use_lora: parts: list[torch.Tensor] = [] for i, level in enumerate(pred.metadata.atmos_levels): section = pred.atmos_vars["so2"][..., i, :, :] if level >= 850: section = section.clamp(max=1) parts.append(section) pred.atmos_vars["so2"] = torch.stack(parts, dim=-3) return pred def _adapt_checkpoint(self, d: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: d = Aurora._adapt_checkpoint(self, d) d = _adapt_checkpoint_air_pollution(self.patch_size, d) return d
[docs] class AuroraWave(Aurora): """Version of Aurora fined-tuned to HRES-WAM ocean wave data.""" default_checkpoint_name = "aurora-0.25-wave.ckpt" default_checkpoint_revision = "74598e8c65d53a96077c08bb91acdfa5525340c9" def __init__( self, *, surf_vars: tuple[str, ...] = ( ("2t", "10u", "10v", "msl") + ("swh", "mwd", "mwp", "pp1d", "shww", "mdww", "mpww", "shts", "mdts", "mpts") + ("swh1", "mwd1", "mwp1", "swh2", "mwd2", "mwp2", "wind", "10u_wave", "10v_wave") ), static_vars: tuple[str, ...] = ("lsm", "z", "slt", "wmb", "lat_mask"), lora_mode: LoRAMode = "from_second", stabilise_level_agg: bool = True, density_channel_surf_vars: tuple[str, ...] = ( ("swh", "mwd", "mwp", "pp1d", "shww", "mdww", "mpww", "shts", "mdts", "mpts") + ("swh1", "mwd1", "mwp1", "swh2", "mwd2", "mwp2", "wind", "10u_wave", "10v_wave") ), angle_surf_vars: tuple[str, ...] = ("mwd", "mdww", "mdts", "mwd1", "mwd2"), **kw_args, ) -> None: # Model the density, sine, and cosine versions of the variables. supplemented_surf_vars: tuple[str, ...] = () for name in surf_vars: if name in angle_surf_vars: supplemented_surf_vars += (f"{name}_sin", f"{name}_cos") else: supplemented_surf_vars += (name,) if name in density_channel_surf_vars: supplemented_surf_vars += (f"{name}_density",) super().__init__( surf_vars=supplemented_surf_vars, static_vars=static_vars, lora_mode=lora_mode, stabilise_level_agg=stabilise_level_agg, **kw_args, ) self.density_channel_surf_vars = density_channel_surf_vars self.angle_surf_vars = angle_surf_vars def _adapt_checkpoint(self, d: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: d = Aurora._adapt_checkpoint(self, d) d = _adapt_checkpoint_wave(self.patch_size, d) return d
[docs] def batch_transform_hook(self, batch: Batch) -> Batch: # Below we mutate `batch`, so make a copy here. batch = dataclasses.replace(batch, surf_vars=dict(batch.surf_vars)) # It is important that these components are split off _before_ normalisation, as they # have specific normalisation statistics. if "dwi" in batch.surf_vars and "wind" in batch.surf_vars: # Split into u-component and v-component. u_wave = -batch.surf_vars["wind"] * torch.sin(torch.deg2rad(batch.surf_vars["dwi"])) v_wave = -batch.surf_vars["wind"] * torch.cos(torch.deg2rad(batch.surf_vars["dwi"])) # Update batch and remove `dwi`. batch.surf_vars["10u_wave"] = u_wave batch.surf_vars["10v_wave"] = v_wave del batch.surf_vars["dwi"] # If the magnitude of a wave is zero (or practically zero), it is absent, so indicate that # with NaNs. Only do this when data is given to the model and not when it is rolled out. if batch.metadata.rollout_step == 0: for name_sh, other_wave_components in [ ("swh", ("mwd", "mwp", "pp1d")), ("shww", ("mdww", "mpww")), ("shts", ("mdts", "mdts")), ("swh1", ("mwd1", "mwp1")), ("swh2", ("mwd2", "mwp2")), ]: mask = batch.surf_vars[name_sh] < 1e-4 if mask.sum() > 0: for name in (name_sh,) + other_wave_components: x = batch.surf_vars[name].clone() # Clone to safely mutate. x[mask] = np.nan batch.surf_vars[name] = x # There should be no small values left, except for in wave directions. if name not in {"mwd", "mdww", "mdts", "mwd1", "mwd2"}: assert (batch.surf_vars[name] < 1e-4).sum() == 0 return batch
def _pre_encoder_hook(self, batch: Batch) -> Batch: for name in list(batch.surf_vars): x = batch.surf_vars[name] # Create a density channel. if name in self.density_channel_surf_vars and f"{name}_density" not in batch.surf_vars: batch.surf_vars[f"{name}_density"] = (~torch.isnan(x)).float() batch.surf_vars[name] = x.nan_to_num(0) # Add sine and cosine values of the angle and remove the original angle variable sin_cos_present = f"{name}_sin" in batch.surf_vars and f"{name}_cos" in batch.surf_vars if name in self.angle_surf_vars and not sin_cos_present: batch.surf_vars[f"{name}_sin"] = torch.sin(torch.deg2rad(x)).nan_to_num(0) batch.surf_vars[f"{name}_cos"] = torch.cos(torch.deg2rad(x)).nan_to_num(0) del batch.surf_vars[name] return batch def _post_decoder_hook(self, batch: Batch, pred: Batch) -> Batch: wmb_mask = pred.static_vars["wmb"] > 0 # Undo the sine and cosine components. for name in self.angle_surf_vars: if f"{name}_sin" in pred.surf_vars and f"{name}_cos" in pred.surf_vars: sin = pred.surf_vars[f"{name}_sin"] cos = pred.surf_vars[f"{name}_cos"] pred.surf_vars[name] = torch.rad2deg(torch.atan2(sin, cos)) % 360 del pred.surf_vars[f"{name}_sin"] del pred.surf_vars[f"{name}_cos"] # Undo the density channels. First transform by a sigmoid to get the actual value of the # density channel. for name in self.density_channel_surf_vars: if name in pred.surf_vars: density = torch.sigmoid(pred.surf_vars[f"{name}_density"]) * wmb_mask data = pred.surf_vars[name] * wmb_mask data[density < 0.5] = np.nan pred.surf_vars[name] = data del pred.surf_vars[f"{name}_density"] return pred