Application Programming Interface#

Batch#

class aurora.Batch(surf_vars: dict[str, Tensor], static_vars: dict[str, Tensor], atmos_vars: dict[str, Tensor], metadata: Metadata)[source]#

A batch of data.

Parameters:
  • surf_vars (dict[str, torch.Tensor]) – Surface-level variables with shape (b, t, h, w).

  • static_vars (dict[str, torch.Tensor]) – Static variables with shape (h, w).

  • atmos_vars (dict[str, torch.Tensor]) – Atmospheric variables with shape (b, t, c, h, w).

  • metadata (Metadata) – Metadata associated to this batch.

crop(patch_size: int) Batch[source]#

Crop the variables in the batch to patch size patch_size.

classmethod from_netcdf(path: str | Path) Batch[source]#

Load a batch from a file.

normalise(surf_stats: dict[str, tuple[float, float]]) Batch[source]#

Normalise all variables in the batch.

Parameters:

surf_stats (dict[str, tuple[float, float]]) – For these surface-level variables, adjust the normalisation to the given tuple consisting of a new location and scale.

Returns:

Normalised batch.

Return type:

Batch

regrid(res: float) Batch[source]#

Regrid the batch to a res degrees resolution.

This results in float32 data on the CPU.

This function is not optimised for either speed or accuracy. Use at your own risk.

property spatial_shape: tuple[int, int]#

Get the spatial shape from an arbitrary surface-level variable.

to(device: str | device) Batch[source]#

Move the batch to another device.

to_netcdf(path: str | Path) None[source]#

Write the batch to a file.

This requires xarray and netcdf4 to be installed.

type(t: type) Batch[source]#

Convert everything to type t.

unnormalise(surf_stats: dict[str, tuple[float, float]]) Batch[source]#

Unnormalise all variables in the batch.

Parameters:

surf_stats (dict[str, tuple[float, float]]) – For these surface-level variables, adjust the normalisation to the given tuple consisting of a new location and scale.

Returns:

Unnormalised batch.

Return type:

Batch

class aurora.Metadata(lat: Tensor, lon: Tensor, time: tuple[datetime, ...], atmos_levels: tuple[int | float, ...], rollout_step: int = 0)[source]#

Metadata in a batch.

Parameters:
  • lat (torch.Tensor) – Latitudes.

  • lon (torch.Tensor) – Longitudes.

  • time (tuple[datetime, ...]) – For every batch element, the time.

  • atmos_levels (tuple[int | float, ...]) – Pressure levels for the atmospheric variables in hPa.

  • rollout_step (int, optional) – How many roll-out steps were used to produce this prediction. If equal to 0, which is the default, then this means that this is not a prediction, but actual data. This field is automatically populated by the model and used to use a separate LoRA for every roll-out step. Generally, you are safe to ignore this field.

Roll-Outs#

class aurora.rollout(model: Aurora, batch: Batch, steps: int)[source]#

Perform a roll-out to make long-term predictions.

Parameters:
Yields:

aurora.batch.Batch – The prediction after every step.

Models#

class aurora.Aurora(surf_vars: tuple[str, ...] = ('2t', '10u', '10v', 'msl'), static_vars: tuple[str, ...] = ('lsm', 'z', 'slt'), atmos_vars: tuple[str, ...] = ('z', 'u', 'v', 't', 'q'), window_size: tuple[int, int, int] = (2, 6, 12), encoder_depths: tuple[int, ...] = (6, 10, 8), encoder_num_heads: tuple[int, ...] = (8, 16, 32), decoder_depths: tuple[int, ...] = (8, 10, 6), decoder_num_heads: tuple[int, ...] = (32, 16, 8), latent_levels: int = 4, patch_size: int = 4, embed_dim: int = 512, num_heads: int = 16, mlp_ratio: float = 4.0, drop_path: float = 0.0, drop_rate: float = 0.0, enc_depth: int = 1, dec_depth: int = 1, dec_mlp_ratio: float = 2.0, perceiver_ln_eps: float = 1e-05, max_history_size: int = 2, timestep: timedelta = datetime.timedelta(seconds=21600), stabilise_level_agg: bool = False, use_lora: bool = True, lora_steps: int = 40, lora_mode: Literal['single', 'all'] = 'single', surf_stats: dict[str, tuple[float, float]] | None = None, autocast: bool = False)[source]#

The Aurora model.

Defaults to to the 1.3 B parameter configuration.

__init__(surf_vars: tuple[str, ...] = ('2t', '10u', '10v', 'msl'), static_vars: tuple[str, ...] = ('lsm', 'z', 'slt'), atmos_vars: tuple[str, ...] = ('z', 'u', 'v', 't', 'q'), window_size: tuple[int, int, int] = (2, 6, 12), encoder_depths: tuple[int, ...] = (6, 10, 8), encoder_num_heads: tuple[int, ...] = (8, 16, 32), decoder_depths: tuple[int, ...] = (8, 10, 6), decoder_num_heads: tuple[int, ...] = (32, 16, 8), latent_levels: int = 4, patch_size: int = 4, embed_dim: int = 512, num_heads: int = 16, mlp_ratio: float = 4.0, drop_path: float = 0.0, drop_rate: float = 0.0, enc_depth: int = 1, dec_depth: int = 1, dec_mlp_ratio: float = 2.0, perceiver_ln_eps: float = 1e-05, max_history_size: int = 2, timestep: timedelta = datetime.timedelta(seconds=21600), stabilise_level_agg: bool = False, use_lora: bool = True, lora_steps: int = 40, lora_mode: Literal['single', 'all'] = 'single', surf_stats: dict[str, tuple[float, float]] | None = None, autocast: bool = False) None[source]#

Construct an instance of the model.

Parameters:
  • surf_vars (tuple[str, ...], optional) – All surface-level variables supported by the model.

  • static_vars (tuple[str, ...], optional) – All static variables supported by the model.

  • atmos_vars (tuple[str, ...], optional) – All atmospheric variables supported by the model.

  • window_size (tuple[int, int, int], optional) – Vertical height, height, and width of the window of the underlying Swin transformer.

  • encoder_depths (tuple[int, ...], optional) – Number of blocks in each encoder layer.

  • encoder_num_heads (tuple[int, ...], optional) – Number of attention heads in each encoder layer. The dimensionality doubles after every layer. To keep the dimensionality of every head constant, you want to double the number of heads after every layer. The dimensionality of attention head of the first layer is determined by embed_dim divided by the value here. For all cases except one, this is equal to 64.

  • decoder_depths (tuple[int, ...], optional) – Number of blocks in each decoder layer. Generally, you want this to be the reversal of encoder_depths.

  • decoder_num_heads (tuple[int, ...], optional) – Number of attention heads in each decoder layer. Generally, you want this to be the reversal of encoder_num_heads.

  • latent_levels (int, optional) – Number of latent pressure levels.

  • patch_size (int, optional) – Patch size.

  • embed_dim (int, optional) – Patch embedding dimension.

  • num_heads (int, optional) – Number of attention heads in the aggregation and deaggregation blocks. The dimensionality of these attention heads will be equal to embed_dim divided by this value.

  • mlp_ratio (float, optional) – Hidden dim. to embedding dim. ratio for MLPs.

  • drop_rate (float, optional) – Drop-out rate.

  • drop_path (float, optional) – Drop-path rate.

  • enc_depth (int, optional) – Number of Perceiver blocks in the encoder.

  • dec_depth (int, optioanl) – Number of Perceiver blocks in the decoder.

  • dec_mlp_ratio (float, optional) – Hidden dim. to embedding dim. ratio for MLPs in the decoder. The embedding dimensionality here is different, which is why this is a separate parameter.

  • perceiver_ln_eps (float, optional) – Epsilon in the perceiver layer norm. layers. Used to stabilise the model.

  • max_history_size (int, optional) – Maximum number of history steps. You can load checkpoints with a smaller max_history_size, but you cannot load checkpoints with a larger max_history_size.

  • timestep (timedelta, optional) – Timestep of the model. Defaults to 6 hours.

  • stabilise_level_agg (bool, optional) – Stabilise the level aggregation by inserting an additional layer normalisation. Defaults to False.

  • use_lora (bool, optional) – Use LoRA adaptation.

  • lora_steps (int, optional) – Use different LoRA adaptation for the first so-many roll-out steps.

  • lora_mode (str, optional) – LoRA mode. “single” uses the same LoRA for all roll-out steps, and “all” uses a different LoRA for every roll-out step. Defaults to “single”.

  • surf_stats (dict[str, tuple[float, float]], optional) – For these surface-level variables, adjust the normalisation to the given tuple consisting of a new location and scale.

  • autocast (bool, optional) – Use torch.autocast to reduce memory usage. Defaults to False.

adapt_checkpoint_max_history_size(checkpoint: dict[str, Tensor]) None[source]#

Adapt a checkpoint with smaller max_history_size to a model with a larger max_history_size than the current model.

If a checkpoint was trained with a larger max_history_size than the current model, this function will assert fail to prevent loading the checkpoint. This is to prevent loading a checkpoint which will likely cause the checkpoint to degrade is performance.

This implementation copies weights from the checkpoint to the model and fills zeros for the new history width dimension. It mutates checkpoint.

configure_activation_checkpointing()[source]#

Configure activation checkpointing.

This is required in order to compute gradients without running out of memory.

forward(batch: Batch) Batch[source]#

Forward pass.

Parameters:

batch (Batch) – Batch to run the model on.

Returns:

Prediction for the batch.

Return type:

Batch

load_checkpoint(repo: str, name: str, strict: bool = True) None[source]#

Load a checkpoint from HuggingFace.

Parameters:
  • repo (str) – Name of the repository of the form user/repo.

  • name (str) – Path to the checkpoint relative to the root of the repository, e.g. checkpoint.cpkt.

  • strict (bool, optional) – Error if the model parameters are not exactly equal to the parameters in the checkpoint. Defaults to True.

load_checkpoint_local(path: str, strict: bool = True) None[source]#

Load a checkpoint directly from a file.

Parameters:
  • path (str) – Path to the checkpoint.

  • strict (bool, optional) – Error if the model parameters are not exactly equal to the parameters in the checkpoint. Defaults to True.

aurora.AuroraSmall#

alias of functools.partial(<class ‘aurora.model.aurora.Aurora’>, encoder_depths=(2, 6, 2), encoder_num_heads=(4, 8, 16), decoder_depths=(2, 6, 2), decoder_num_heads=(16, 8, 4), embed_dim=256, num_heads=8, use_lora=False)

aurora.AuroraHighRes#

alias of functools.partial(<class ‘aurora.model.aurora.Aurora’>, patch_size=10, encoder_depths=(6, 8, 8), decoder_depths=(8, 8, 6))