Application Programming Interface#

Batch#

class aurora.Batch(surf_vars: dict[str, Tensor], static_vars: dict[str, Tensor], atmos_vars: dict[str, Tensor], metadata: Metadata)[source]#

A batch of data.

Parameters:
  • surf_vars (dict[str, torch.Tensor]) – Surface-level variables with shape (b, t, h, w).

  • static_vars (dict[str, torch.Tensor]) – Static variables with shape (h, w).

  • atmos_vars (dict[str, torch.Tensor]) – Atmospheric variables with shape (b, t, c, h, w).

  • metadata (Metadata) – Metadata associated to this batch.

crop(patch_size: int) Batch[source]#

Crop the variables in the batch to patch size patch_size.

classmethod from_netcdf(path: str | Path) Batch[source]#

Load a batch from a file.

normalise(surf_stats: dict[str, tuple[float, float]]) Batch[source]#

Normalise all variables in the batch.

Parameters:

surf_stats (dict[str, tuple[float, float]]) – For these surface-level variables, adjust the normalisation to the given tuple consisting of a new location and scale.

Returns:

Normalised batch.

Return type:

Batch

regrid(res: float) Batch[source]#

Regrid the batch to a res degrees resolution.

This results in float32 data on the CPU.

This function is not optimised for either speed or accuracy. Use at your own risk.

property spatial_shape: tuple[int, int]#

Get the spatial shape from an arbitrary surface-level variable.

to(device: str | device) Batch[source]#

Move the batch to another device.

to_netcdf(path: str | Path) None[source]#

Write the batch to a file.

This requires xarray and netcdf4 to be installed.

type(t: type) Batch[source]#

Convert everything to type t.

unnormalise(surf_stats: dict[str, tuple[float, float]]) Batch[source]#

Unnormalise all variables in the batch.

Parameters:

surf_stats (dict[str, tuple[float, float]]) – For these surface-level variables, adjust the normalisation to the given tuple consisting of a new location and scale.

Returns:

Unnormalised batch.

Return type:

Batch

class aurora.Metadata(lat: Tensor, lon: Tensor, time: tuple[datetime, ...], atmos_levels: tuple[int | float, ...], rollout_step: int = 0)[source]#

Metadata in a batch.

Parameters:
  • lat (torch.Tensor) – Latitudes.

  • lon (torch.Tensor) – Longitudes.

  • time (tuple[datetime, ...]) – For every batch element, the time.

  • atmos_levels (tuple[int | float, ...]) – Pressure levels for the atmospheric variables in hPa.

  • rollout_step (int, optional) – How many roll-out steps were used to produce this prediction. If equal to 0, which is the default, then this means that this is not a prediction, but actual data. This field is automatically populated by the model and used to use a separate LoRA for every roll-out step. Generally, you are safe to ignore this field.

Roll-Outs#

class aurora.rollout(model: Aurora, batch: Batch, steps: int)[source]#

Perform a roll-out to make long-term predictions.

Parameters:
Yields:

aurora.batch.Batch – The prediction after every step.

Tropical Cyclone Tracking#

class aurora.Tracker(init_lat: float, init_lon: float, init_time: datetime)[source]#

Simple tropical cyclone tracker.

This algorithm was originally designed and implemented by Anna Allen. This particular implementation is by Wessel Bruinsma and features various improvements over the original design.

results() DataFrame[source]#

Assemble the track into a convenient DataFrame.

step(batch: Batch) None[source]#

Track the next step.

Parameters:

batch (aurora.batch.Batch) – Prediction.

Models#

class aurora.Aurora(*, surf_vars: tuple[str, ...] = ('2t', '10u', '10v', 'msl'), static_vars: tuple[str, ...] = ('lsm', 'z', 'slt'), atmos_vars: tuple[str, ...] = ('z', 'u', 'v', 't', 'q'), window_size: tuple[int, int, int] = (2, 6, 12), encoder_depths: tuple[int, ...] = (6, 10, 8), encoder_num_heads: tuple[int, ...] = (8, 16, 32), decoder_depths: tuple[int, ...] = (8, 10, 6), decoder_num_heads: tuple[int, ...] = (32, 16, 8), latent_levels: int = 4, patch_size: int = 4, embed_dim: int = 512, num_heads: int = 16, mlp_ratio: float = 4.0, drop_path: float = 0.0, drop_rate: float = 0.0, enc_depth: int = 1, dec_depth: int = 1, dec_mlp_ratio: float = 2.0, perceiver_ln_eps: float = 1e-05, max_history_size: int = 2, timestep: timedelta = datetime.timedelta(seconds=21600), stabilise_level_agg: bool = False, use_lora: bool = True, lora_steps: int = 40, lora_mode: Literal['single', 'from_second', 'all'] = 'single', surf_stats: dict[str, tuple[float, float]] | None = None, autocast: bool = False, bf16_mode: bool = False, level_condition: tuple[int | float, ...] | None = None, dynamic_vars: bool = False, atmos_static_vars: bool = False, separate_perceiver: tuple[str, ...] = (), modulation_heads: tuple[str, ...] = (), positive_surf_vars: tuple[str, ...] = (), positive_atmos_vars: tuple[str, ...] = (), clamp_at_first_step: bool = False, simulate_indexing_bug: bool = False)[source]#

The Aurora model.

Defaults to the 1.3 B parameter configuration.

__init__(*, surf_vars: tuple[str, ...] = ('2t', '10u', '10v', 'msl'), static_vars: tuple[str, ...] = ('lsm', 'z', 'slt'), atmos_vars: tuple[str, ...] = ('z', 'u', 'v', 't', 'q'), window_size: tuple[int, int, int] = (2, 6, 12), encoder_depths: tuple[int, ...] = (6, 10, 8), encoder_num_heads: tuple[int, ...] = (8, 16, 32), decoder_depths: tuple[int, ...] = (8, 10, 6), decoder_num_heads: tuple[int, ...] = (32, 16, 8), latent_levels: int = 4, patch_size: int = 4, embed_dim: int = 512, num_heads: int = 16, mlp_ratio: float = 4.0, drop_path: float = 0.0, drop_rate: float = 0.0, enc_depth: int = 1, dec_depth: int = 1, dec_mlp_ratio: float = 2.0, perceiver_ln_eps: float = 1e-05, max_history_size: int = 2, timestep: timedelta = datetime.timedelta(seconds=21600), stabilise_level_agg: bool = False, use_lora: bool = True, lora_steps: int = 40, lora_mode: Literal['single', 'from_second', 'all'] = 'single', surf_stats: dict[str, tuple[float, float]] | None = None, autocast: bool = False, bf16_mode: bool = False, level_condition: tuple[int | float, ...] | None = None, dynamic_vars: bool = False, atmos_static_vars: bool = False, separate_perceiver: tuple[str, ...] = (), modulation_heads: tuple[str, ...] = (), positive_surf_vars: tuple[str, ...] = (), positive_atmos_vars: tuple[str, ...] = (), clamp_at_first_step: bool = False, simulate_indexing_bug: bool = False) None[source]#

Construct an instance of the model.

Parameters:
  • surf_vars (tuple[str, ...], optional) – All surface-level variables supported by the model.

  • static_vars (tuple[str, ...], optional) – All static variables supported by the model.

  • atmos_vars (tuple[str, ...], optional) – All atmospheric variables supported by the model.

  • window_size (tuple[int, int, int], optional) – Vertical height, height, and width of the window of the underlying Swin transformer.

  • encoder_depths (tuple[int, ...], optional) – Number of blocks in each encoder layer.

  • encoder_num_heads (tuple[int, ...], optional) – Number of attention heads in each encoder layer. The dimensionality doubles after every layer. To keep the dimensionality of every head constant, you want to double the number of heads after every layer. The dimensionality of attention head of the first layer is determined by embed_dim divided by the value here. For all cases except one, this is equal to 64.

  • decoder_depths (tuple[int, ...], optional) – Number of blocks in each decoder layer. Generally, you want this to be the reversal of encoder_depths.

  • decoder_num_heads (tuple[int, ...], optional) – Number of attention heads in each decoder layer. Generally, you want this to be the reversal of encoder_num_heads.

  • latent_levels (int, optional) – Number of latent pressure levels.

  • patch_size (int, optional) – Patch size.

  • embed_dim (int, optional) – Patch embedding dimension.

  • num_heads (int, optional) – Number of attention heads in the aggregation and deaggregation blocks. The dimensionality of these attention heads will be equal to embed_dim divided by this value.

  • mlp_ratio (float, optional) – Hidden dim. to embedding dim. ratio for MLPs.

  • drop_rate (float, optional) – Drop-out rate.

  • drop_path (float, optional) – Drop-path rate.

  • enc_depth (int, optional) – Number of Perceiver blocks in the encoder.

  • dec_depth (int, optioanl) – Number of Perceiver blocks in the decoder.

  • dec_mlp_ratio (float, optional) – Hidden dim. to embedding dim. ratio for MLPs in the decoder. The embedding dimensionality here is different, which is why this is a separate parameter.

  • perceiver_ln_eps (float, optional) – Epsilon in the perceiver layer norm. layers. Used to stabilise the model.

  • max_history_size (int, optional) – Maximum number of history steps. You can load checkpoints with a smaller max_history_size, but you cannot load checkpoints with a larger max_history_size.

  • timestep (timedelta, optional) – Timestep of the model. Defaults to 6 hours.

  • stabilise_level_agg (bool, optional) – Stabilise the level aggregation by inserting an additional layer normalisation. Defaults to False.

  • use_lora (bool, optional) – Use LoRA adaptation.

  • lora_steps (int, optional) – Use different LoRA adaptation for the first so-many roll-out steps.

  • lora_mode (str, optional) – LoRA mode. “single” uses the same LoRA for all roll-out steps, “from_second” uses the same LoRA from the second roll-out step on, and “all” uses a different LoRA for every roll-out step. Defaults to “single”.

  • surf_stats (dict[str, tuple[float, float]], optional) – For these surface-level variables, adjust the normalisation to the given tuple consisting of a new location and scale.

  • bf16_mode (bool, optional) – To reduce memory usage, convert the tokens to BF16, run the backbone in pure BF16, and run the decoder in FP16 AMP. This should enable a gradient computation. USE AT YOUR OWN RISK. THIS WAS NOT USED DURING THE DEVELOPMENT OF AURORA AND IS PURELY PROVIDED AS A STARTING POINT FOR FINE-TUNING.

  • level_condition (tuple[int | float, ...], optional) – Make the patch embeddings dependent on pressure level. If you want to enable this feature, provide a tuple of all possible pressure levels.

  • dynamic_vars (bool, optional) – Use dynamically generated static variables, like time of day. Defaults to False.

  • atmos_static_vars (bool, optional) – Also concatenate the static variables to the atmospheric variables. Defaults to False.

  • separate_perceiver (tuple[str, ...], optional) – In the decoder, use a separate Perceiver for specific atmospheric variables. This can be helpful at fine-tuning time to deal with variables that have a significantly different behaviour. If you want to enable this features, set this to the collection of variables that should be run on a separate Perceiver.

  • modulation_heads (tuple[str, ...], optional) – Names of every variable for which to enable an additional head, the so-called modulation head, that can be used to predict the difference.

  • positive_surf_vars (tuple[str, ...], optional) – Mark these surface-level variables as positive. Clamp them before running them through the encoder, and also clamp them when autoregressively rolling out the model. The variables are not clamped for the first roll-out step.

  • positive_atmos_vars (tuple[str, ...], optional) – Mark these atmospheric variables as positive. Clamp them before running them through the encoder, and also clamp them when autoregressively rolling out the model. The variables are not clamped for the first roll-out step.

  • clamp_at_first_step (bool, optional) – Clamp the positive variables for the first roll-out step. Should only be used for inference. Defaults to False.

  • simulate_indexing_bug (bool, optional) – Simulate an indexing bug that’s present for the air pollution version of Aurora. This is necessary to obtain numerical equivalence to the original implementation. Defaults to False.

adapt_checkpoint_max_history_size(checkpoint: dict[str, Tensor]) None[source]#

Adapt a checkpoint with smaller max_history_size to a model with a larger max_history_size than the current model.

If a checkpoint was trained with a larger max_history_size than the current model, this function will assert fail to prevent loading the checkpoint. This is to prevent loading a checkpoint which will likely cause the checkpoint to degrade is performance.

This implementation copies weights from the checkpoint to the model and fills zeros for the new history width dimension. It mutates checkpoint.

batch_transform_hook(batch: Batch) Batch[source]#

Transform the batch right after receiving it and before normalisation.

This function should be idempotent.

configure_activation_checkpointing()[source]#

Configure activation checkpointing.

This is required in order to compute gradients without running out of memory.

default_checkpoint_name = 'aurora-0.25-finetuned.ckpt'#

Name of the default checkpoint.

Type:

str

default_checkpoint_repo = 'microsoft/aurora'#

Name of the HuggingFace repository to load the default checkpoint from.

Type:

str

default_checkpoint_revision = '0be7e57c685dac86b78c4a19a3ab149d13c6a3dd'#

Commit hash of the default checkpoint.

Type:

str

forward(batch: Batch) Batch[source]#

Forward pass.

Parameters:

batch (Batch) – Batch to run the model on.

Returns:

Prediction for the batch.

Return type:

Batch

load_checkpoint(repo: str | None = None, name: str | None = None, revision: str | None = None, strict: bool = True) None[source]#

Load a checkpoint from HuggingFace.

Parameters:
  • repo (str, optional) – Name of the repository of the form user/repo.

  • name (str, optional) – Path to the checkpoint relative to the root of the repository, e.g. checkpoint.cpkt.

  • revision (str, optional) – Version hash of the Huggingface git repository commit.

  • strict (bool, optional) – Error if the model parameters are not exactly equal to the parameters in the checkpoint. Defaults to True.

load_checkpoint_local(path: str, strict: bool = True) None[source]#

Load a checkpoint directly from a file.

Parameters:
  • path (str) – Path to the checkpoint.

  • strict (bool, optional) – Error if the model parameters are not exactly equal to the parameters in the checkpoint. Defaults to True.

class aurora.AuroraPretrained(*, use_lora: bool = False, **kw_args)[source]#

Pretrained version of Aurora.

default_checkpoint_name = 'aurora-0.25-pretrained.ckpt'#

Name of the default checkpoint.

Type:

str

default_checkpoint_revision = '0be7e57c685dac86b78c4a19a3ab149d13c6a3dd'#

Commit hash of the default checkpoint.

Type:

str

class aurora.AuroraSmallPretrained(*, encoder_depths: tuple[int, ...] = (2, 6, 2), encoder_num_heads: tuple[int, ...] = (4, 8, 16), decoder_depths: tuple[int, ...] = (2, 6, 2), decoder_num_heads: tuple[int, ...] = (16, 8, 4), embed_dim: int = 256, num_heads: int = 8, use_lora: bool = False, **kw_args)[source]#

Small pretrained version of Aurora.

Should only be used for debugging.

default_checkpoint_name = 'aurora-0.25-small-pretrained.ckpt'#

Name of the default checkpoint.

Type:

str

default_checkpoint_revision = '0be7e57c685dac86b78c4a19a3ab149d13c6a3dd'#

Commit hash of the default checkpoint.

Type:

str

class aurora.Aurora12hPretrained(*, timestep: timedelta = datetime.timedelta(seconds=43200), use_lora: bool = False, **kw_args)[source]#

Pretrained version of Aurora with time step 12 hours.

default_checkpoint_name = 'aurora-0.25-12h-pretrained.ckpt'#

Name of the default checkpoint.

Type:

str

default_checkpoint_revision = '15e76e47b65bf4b28fd2246b7b5b951d6e2443b9'#

Commit hash of the default checkpoint.

Type:

str

class aurora.AuroraHighRes(*, patch_size: int = 10, encoder_depths: tuple[int, ...] = (6, 8, 8), decoder_depths: tuple[int, ...] = (8, 8, 6), **kw_args)[source]#

High-resolution version of Aurora.

default_checkpoint_name = 'aurora-0.1-finetuned.ckpt'#

Name of the default checkpoint.

Type:

str

default_checkpoint_revision = '0be7e57c685dac86b78c4a19a3ab149d13c6a3dd'#

Commit hash of the default checkpoint.

Type:

str

class aurora.AuroraAirPollution(*, surf_vars: tuple[str, ...] = ('2t', '10u', '10v', 'msl', 'pm1', 'pm2p5', 'pm10', 'tcco', 'tc_no', 'tcno2', 'gtco3', 'tcso2'), static_vars: tuple[str, ...] = ('lsm', 'z', 'slt', 'static_ammonia', 'static_ammonia_log', 'static_co', 'static_co_log', 'static_nox', 'static_nox_log', 'static_so2', 'static_so2_log'), atmos_vars: tuple[str, ...] = ('z', 'u', 'v', 't', 'q', 'co', 'no', 'no2', 'go3', 'so2'), patch_size: int = 3, timestep: timedelta = datetime.timedelta(seconds=43200), level_condition: tuple[int | float, ...] | None = (50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000), dynamic_vars: bool = True, atmos_static_vars: bool = True, separate_perceiver: tuple[str, ...] = ('co', 'no', 'no2', 'go3', 'so2'), modulation_heads: tuple[str, ...] = ('pm1', 'pm2p5', 'pm10', 'co', 'tcco', 'no', 'tc_no', 'no2', 'tcno2', 'so2', 'tcso2', 'go3', 'gtco3'), positive_surf_vars: tuple[str, ...] = ('pm1', 'pm2p5', 'pm10', 'tcco', 'tc_no', 'tcno2', 'gtco3', 'tcso2'), positive_atmos_vars: tuple[str, ...] = ('co', 'no', 'no2', 'go3', 'so2'), simulate_indexing_bug: bool = True, **kw_args)[source]#

Fine-tuned version of Aurora for air pollution.

default_checkpoint_name = 'aurora-0.4-air-pollution.ckpt'#

Name of the default checkpoint.

Type:

str

default_checkpoint_revision = '1764d5630a53d3d7a7d169ca335236fc343e4bfc'#

Commit hash of the default checkpoint.

Type:

str

class aurora.AuroraWave(*, surf_vars: tuple[str, ...] = ('2t', '10u', '10v', 'msl', 'swh', 'mwd', 'mwp', 'pp1d', 'shww', 'mdww', 'mpww', 'shts', 'mdts', 'mpts', 'swh1', 'mwd1', 'mwp1', 'swh2', 'mwd2', 'mwp2', 'wind', '10u_wave', '10v_wave'), static_vars: tuple[str, ...] = ('lsm', 'z', 'slt', 'wmb', 'lat_mask'), lora_mode: Literal['single', 'from_second', 'all'] = 'from_second', stabilise_level_agg: bool = True, density_channel_surf_vars: tuple[str, ...] = ('swh', 'mwd', 'mwp', 'pp1d', 'shww', 'mdww', 'mpww', 'shts', 'mdts', 'mpts', 'swh1', 'mwd1', 'mwp1', 'swh2', 'mwd2', 'mwp2', 'wind', '10u_wave', '10v_wave'), angle_surf_vars: tuple[str, ...] = ('mwd', 'mdww', 'mdts', 'mwd1', 'mwd2'), **kw_args)[source]#

Version of Aurora fined-tuned to HRES-WAM ocean wave data.

batch_transform_hook(batch: Batch) Batch[source]#

Transform the batch right after receiving it and before normalisation.

This function should be idempotent.

default_checkpoint_name = 'aurora-0.25-wave.ckpt'#

Name of the default checkpoint.

Type:

str

default_checkpoint_revision = '74598e8c65d53a96077c08bb91acdfa5525340c9'#

Commit hash of the default checkpoint.

Type:

str