Fine-Tuning#
Generally, if you wish to fine-tune Aurora for a specific application, you should build on the pretrained version:
from aurora import Aurora
model = Aurora(use_lora=False) # Model is not fine-tuned.
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
Computing Gradients#
To compute gradients, you will need an A100 with 80 GB of memory. In addition, you will need to use PyTorch AMP and gradient checkpointing. You can do this as follows:
from aurora import Aurora
model = Aurora(
use_lora=False, # Model was not fine-tuned.
autocast=True, # Use AMP.
)
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
batch = ... # Load some data.
model = model.cuda()
model.train()
model.configure_activation_checkpointing()
pred = model.forward(batch)
loss = ...
loss.backward()
Extending Aurora with New Variables#
Aurora can be extended with new variables by adjusting the keyword arguments surf_vars
,
static_vars
, and atmos_vars
.
When you add a new variable, you also need to set the normalisation statistics.
from aurora import Aurora
from aurora.normalisation import locations, scales
model = Aurora(
use_lora=False,
surf_vars=("2t", "10u", "10v", "msl", "new_surf_var"),
static_vars=("lsm", "z", "slt", "new_static_var"),
atmos_vars=("z", "u", "v", "t", "q", "new_atmos_var"),
)
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt", strict=False)
# Normalisation means:
locations["new_surf_var"] = 0.0
locations["new_static_var"] = 0.0
locations["new_atmos_var"] = 0.0
# Normalisation standard deviations:
scales["new_surf_var"] = 1.0
scales["new_static_var"] = 1.0
scales["new_atmos_var"] = 1.0
Other Model Extensions#
It is possible to extend to model in any way you like.
If you do this, you will likely add or remove parameters.
Then Aurora.load_checkpoint
will error,
because the existing checkpoint now mismatches with the model’s parameters.
Simply set Aurora.load_checkpoint(..., strict=False)
to ignore the mismatches:
from aurora import Aurora
model = Aurora(...)
... # Modify `model`.
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt", strict=False)