Skip to content

Available PDE Surrogate Modules¤

SpectralConv2d ¤

Bases: nn.Module

2D Fourier layer. Does FFT, linear transform, and Inverse FFT.

Implemented in a way to allow multi-gpu training.

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
modes1 int

Number of Fourier modes to keep in the first spatial direction

required
modes2 int

Number of Fourier modes to keep in the second spatial direction

required

paper

Source code in pdearena/modules/fourier.py
class SpectralConv2d(nn.Module):
    """2D Fourier layer. Does FFT, linear transform, and Inverse FFT.

    Implemented in a way to allow multi-gpu training.

    Args:
        in_channels (int): Number of input channels
        out_channels (int): Number of output channels
        modes1 (int): Number of Fourier modes to keep in the first spatial direction
        modes2 (int): Number of Fourier modes to keep in the second spatial direction

    [paper](https://arxiv.org/abs/2010.08895)
    """

    def __init__(self, in_channels: int, out_channels: int, modes1: int, modes2: int):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1  # Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes2 = modes2

        self.scale = 1 / (in_channels * out_channels)
        self.weights1 = nn.Parameter(
            self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2, dtype=torch.float32)
        )
        self.weights2 = nn.Parameter(
            self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2, dtype=torch.float32)
        )

    def forward(self, x, x_dim=None, y_dim=None):
        batchsize = x.shape[0]
        # Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft2(x)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(
            batchsize,
            self.out_channels,
            x.size(-2),
            x.size(-1) // 2 + 1,
            dtype=torch.cfloat,
            device=x.device,
        )
        out_ft[:, :, : self.modes1, : self.modes2] = batchmul2d(
            x_ft[:, :, : self.modes1, : self.modes2], torch.view_as_complex(self.weights1)
        )
        out_ft[:, :, -self.modes1 :, : self.modes2] = batchmul2d(
            x_ft[:, :, -self.modes1 :, : self.modes2], torch.view_as_complex(self.weights2)
        )

        # Return to physical space
        x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
        return x

DilatedBasicBlock ¤

Bases: nn.Module

Basic block for Dilated ResNet

Parameters:

Name Type Description Default
in_planes int

number of input channels

required
planes int

number of output channels

required
stride int

stride of the convolution. Defaults to 1.

1
activation str

activation function. Defaults to "relu".

'relu'
norm bool

whether to use group normalization. Defaults to True.

True
num_groups int

number of groups for group normalization. Defaults to 1.

1
Source code in pdearena/modules/twod_resnet.py
class DilatedBasicBlock(nn.Module):
    """Basic block for Dilated ResNet

    Args:
        in_planes (int): number of input channels
        planes (int): number of output channels
        stride (int, optional): stride of the convolution. Defaults to 1.
        activation (str, optional): activation function. Defaults to "relu".
        norm (bool, optional): whether to use group normalization. Defaults to True.
        num_groups (int, optional): number of groups for group normalization. Defaults to 1.
    """

    expansion = 1

    def __init__(
        self,
        in_planes: int,
        planes: int,
        stride: int = 1,
        activation: str = "relu",
        norm: bool = True,
        num_groups: int = 1,
    ):
        super().__init__()

        self.dilation = [1, 2, 4, 8, 4, 2, 1]
        dilation_layers = []
        for dil in self.dilation:
            dilation_layers.append(
                nn.Conv2d(
                    in_planes,
                    planes,
                    kernel_size=3,
                    stride=stride,
                    dilation=dil,
                    padding=dil,
                    bias=True,
                )
            )
        self.dilation_layers = nn.ModuleList(dilation_layers)
        self.norm_layers = nn.ModuleList(
            nn.GroupNorm(num_groups, num_channels=planes) if norm else nn.Identity() for dil in self.dilation
        )
        self.activation: nn.Module = ACTIVATION_REGISTRY.get(activation, None)
        if self.activation is None:
            raise NotImplementedError(f"Activation {activation} not implemented")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = x
        for layer, norm in zip(self.dilation_layers, self.norm_layers):
            out = self.activation(layer(norm(out)))
        return out + x

FourierBasicBlock ¤

Bases: nn.Module

Basic block for Fourier Neural Operators

Parameters:

Name Type Description Default
in_planes int

number of input channels

required
planes int

number of output channels

required
stride int

stride of the convolution. Defaults to 1.

1
modes1 int

number of modes for the first spatial dimension. Defaults to 16.

16
modes2 int

number of modes for the second spatial dimension. Defaults to 16.

16
activation str

activation function. Defaults to "relu".

'gelu'
norm bool

whether to use group normalization. Defaults to False.

False
Source code in pdearena/modules/twod_resnet.py
class FourierBasicBlock(nn.Module):
    """Basic block for Fourier Neural Operators

    Args:
        in_planes (int): number of input channels
        planes (int): number of output channels
        stride (int, optional): stride of the convolution. Defaults to 1.
        modes1 (int, optional): number of modes for the first spatial dimension. Defaults to 16.
        modes2 (int, optional): number of modes for the second spatial dimension. Defaults to 16.
        activation (str, optional): activation function. Defaults to "relu".
        norm (bool, optional): whether to use group normalization. Defaults to False.

    """

    expansion: int = 1

    def __init__(
        self,
        in_planes: int,
        planes: int,
        stride: int = 1,
        modes1: int = 16,
        modes2: int = 16,
        activation: str = "gelu",
        norm: bool = False,
    ):
        super().__init__()
        self.modes1 = modes1
        self.modes2 = modes2
        assert not norm
        self.fourier1 = SpectralConv2d(in_planes, planes, modes1=self.modes1, modes2=self.modes2)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, padding=0, padding_mode="zeros", bias=True)
        self.fourier2 = SpectralConv2d(planes, planes, modes1=self.modes1, modes2=self.modes2)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=1, padding=0, padding_mode="zeros", bias=True)

        # So far shortcut connections are not helping
        # self.shortcut = nn.Sequential()
        # if stride != 1 or in_planes != self.expansion * planes:
        #     self.shortcut = nn.Sequential(
        #         nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1)
        #     )

        self.activation: nn.Module = ACTIVATION_REGISTRY.get(activation, None)
        if self.activation is None:
            raise NotImplementedError(f"Activation {activation} not implemented")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = self.fourier1(x)
        x2 = self.conv1(x)
        out = self.activation(x1 + x2)

        x1 = self.fourier2(out)
        x2 = self.conv2(out)
        out = x1 + x2
        # out += self.shortcut(x)
        out = self.activation(out)
        return out

ResNet ¤

Bases: nn.Module

Class to support ResNet like feedforward architectures

Parameters:

Name Type Description Default
n_input_scalar_components int

Number of input scalar components in the model

required
n_input_vector_components int

Number of input vector components in the model

required
n_output_scalar_components int

Number of output scalar components in the model

required
n_output_vector_components int

Number of output vector components in the model

required
block Callable

BasicBlock or DilatedBasicBlock or FourierBasicBlock

required
num_blocks List[int]

Number of blocks in each layer

required
time_history int

Number of time steps to use in the input

required
time_future int

Number of time steps to predict in the output

required
hidden_channels int

Number of channels in the hidden layers

64
activation str

Activation function to use

'gelu'
norm bool

Whether to use normalization

True
Source code in pdearena/modules/twod_resnet.py
class ResNet(nn.Module):
    """Class to support ResNet like feedforward architectures

    Args:
        n_input_scalar_components (int): Number of input scalar components in the model
        n_input_vector_components (int): Number of input vector components in the model
        n_output_scalar_components (int): Number of output scalar components in the model
        n_output_vector_components (int): Number of output vector components in the model
        block (Callable): BasicBlock or DilatedBasicBlock or FourierBasicBlock
        num_blocks (List[int]): Number of blocks in each layer
        time_history (int): Number of time steps to use in the input
        time_future (int): Number of time steps to predict in the output
        hidden_channels (int): Number of channels in the hidden layers
        activation (str): Activation function to use
        norm (bool): Whether to use normalization
    """

    padding = 9

    def __init__(
        self,
        n_input_scalar_components: int,
        n_input_vector_components: int,
        n_output_scalar_components: int,
        n_output_vector_components: int,
        block,
        num_blocks: list,
        time_history: int,
        time_future: int,
        hidden_channels: int = 64,
        activation: str = "gelu",
        norm: bool = True,
        diffmode: bool = False,
        usegrid: bool = False,
    ):
        super().__init__()
        self.n_input_scalar_components = n_input_scalar_components
        self.n_input_vector_components = n_input_vector_components
        self.n_output_scalar_components = n_output_scalar_components
        self.n_output_vector_components = n_output_vector_components
        self.diffmode = diffmode
        self.usegrid = usegrid
        self.in_planes = hidden_channels
        insize = time_history * (self.n_input_scalar_components + self.n_input_vector_components * 2)
        if self.usegrid:
            insize += 2
        self.conv_in1 = nn.Conv2d(
            insize,
            self.in_planes,
            kernel_size=1,
            bias=True,
        )
        self.conv_in2 = nn.Conv2d(
            self.in_planes,
            self.in_planes,
            kernel_size=1,
            bias=True,
        )
        self.conv_out1 = nn.Conv2d(
            self.in_planes,
            self.in_planes,
            kernel_size=1,
            bias=True,
        )
        self.conv_out2 = nn.Conv2d(
            self.in_planes,
            time_future * (self.n_output_scalar_components + self.n_output_vector_components * 2),
            kernel_size=1,
            bias=True,
        )

        self.layers = nn.ModuleList(
            [
                self._make_layer(
                    block,
                    self.in_planes,
                    num_blocks[i],
                    stride=1,
                    activation=activation,
                    norm=norm,
                )
                for i in range(len(num_blocks))
            ]
        )
        self.activation: nn.Module = ACTIVATION_REGISTRY.get(activation, None)
        if self.activation is None:
            raise NotImplementedError(f"Activation {activation} not implemented")

    def _make_layer(
        self,
        block: Callable,
        planes: int,
        num_blocks: int,
        stride: int,
        activation: str,
        norm: bool = True,
    ) -> nn.Sequential:
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(
                block(
                    self.in_planes,
                    planes,
                    stride,
                    activation=activation,
                    norm=norm,
                )
            )
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def __repr__(self):
        return "ResNet"

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.dim() == 5
        orig_shape = x.shape
        x = x.reshape(x.size(0), -1, *x.shape[3:])  # collapse T,C
        # prev = x.float()
        x = self.activation(self.conv_in1(x.float()))
        x = self.activation(self.conv_in2(x.float()))

        if self.padding > 0:
            x = F.pad(x, [0, self.padding, 0, self.padding])

        for layer in self.layers:
            x = layer(x)

        if self.padding > 0:
            x = x[..., : -self.padding, : -self.padding]

        x = self.activation(self.conv_out1(x))
        x = self.conv_out2(x)

        if self.diffmode:
            raise NotImplementedError("diffmode")
            # x = x + prev[:, -1:, ...].detach()
        return x.reshape(
            orig_shape[0], -1, (self.n_output_scalar_components + self.n_output_vector_components * 2), *orig_shape[3:]
        )

OperatorBlock_2D ¤

Bases: nn.Module

Parameters:

Name Type Description Default
in_codim int

Input co-domian dimension

required
out_codim int

output co-domain dimension

required
dim1 int

Default output grid size along x (or 1st dimension)

required
dim2 int

Default output grid size along y ( or 2nd dimension)

required
modes1 int

Number of fourier modes to consider along 1st dimension

required
modes2 int

Number of fourier modes to consider along 2nd dimension

required
norm bool

Whether to use normalization (torch.nn.InstanceNorm2d)

True
nonlin bool

Whether to use non-linearity (torch.nn.GELU)

True

All variables are consistent with the SpectralConv2d_Uno.

Source code in pdearena/modules/twod_uno.py
class OperatorBlock_2D(nn.Module):
    """

    Args:
        in_codim (int): Input co-domian dimension
        out_codim (int): output co-domain dimension
        dim1 (int):  Default output grid size along x (or 1st dimension)
        dim2 (int): Default output grid size along y ( or 2nd dimension)
        modes1 (int): Number of fourier modes to consider along 1st dimension
        modes2 (int): Number of fourier modes to consider along 2nd dimension
        norm (bool): Whether to use normalization ([torch.nn.InstanceNorm2d][])
        nonlin (bool): Whether to use non-linearity ([torch.nn.GELU][])


    All variables are consistent with the [`SpectralConv2d_Uno`][pdearena.modules.twod_uno.SpectralConv2d_Uno].
    """

    def __init__(self, in_codim, out_codim, dim1, dim2, modes1, modes2, norm=True, nonlin=True):
        super().__init__()
        self.conv = SpectralConv2d_Uno(in_codim, out_codim, dim1, dim2, modes1, modes2)
        self.w = Pointwise_op_2D(in_codim, out_codim, dim1, dim2)
        self.norm = norm
        self.non_lin = nonlin
        if norm:
            self.normalize_layer = nn.InstanceNorm2d(int(out_codim), affine=True)

    def forward(self, x, dim1=None, dim2=None):
        #
        # input shape = (batch, in_codim, input_dim1,input_dim2)
        # output shape = (batch, out_codim, dim1,dim2)
        x1_out = self.conv(x, dim1, dim2)
        x2_out = self.w(x, dim1, dim2)
        x_out = x1_out + x2_out
        if self.norm:
            x_out = self.normalize_layer(x_out)
        if self.non_lin:
            x_out = F.gelu(x_out)
        return x_out

Pointwise_op_2D ¤

Bases: nn.Module

Parameters:

Name Type Description Default
in_codim int

Input co-domian dimension

required
out_codim int

output co-domain dimension

required
dim1 int

Default output grid size along x (or 1st dimension)

required
dim2 int

Default output grid size along y ( or 2nd dimension)

required
Source code in pdearena/modules/twod_uno.py
class Pointwise_op_2D(nn.Module):
    """

    Args:
        in_codim (int): Input co-domian dimension
        out_codim (int): output co-domain dimension

        dim1 (int):  Default output grid size along x (or 1st dimension)
        dim2 (int): Default output grid size along y ( or 2nd dimension)
    """

    def __init__(self, in_codim: int, out_codim: int, dim1: int, dim2: int):
        super().__init__()
        self.conv = nn.Conv2d(int(in_codim), int(out_codim), 1)
        self.dim1 = int(dim1)
        self.dim2 = int(dim2)

    def forward(self, x, dim1=None, dim2=None):
        #
        # input shape = (batch, in_codim, input_dim1,input_dim2)
        # output shape = (batch, out_codim, dim1,dim2)

        if dim1 is None:
            dim1 = self.dim1
            dim2 = self.dim2
        x_out = self.conv(x)

        x_out = F.interpolate(x_out, size=(dim1, dim2), mode="bicubic", align_corners=True, antialias=True)
        return x_out

SpectralConv2d_Uno ¤

Bases: nn.Module

2D Fourier layer. It does FFT, linear transform, and Inverse FFT.

Modified to support multi-gpu training.

Parameters:

Name Type Description Default
in_codim int

Input co-domian dimension

required
out_codim int

output co-domain dimension

required
dim1 int

Default output grid size along x (or 1st dimension of output domain)

required
dim2 int

Default output grid size along y ( or 2nd dimension of output domain) Ratio of grid size of the input and the output implecitely set the expansion or contraction farctor along each dimension.

required
modes1 int), modes2 (int

Number of fourier modes to consider for the ontegral operator Number of modes must be compatibale with the input grid size and desired output grid size. i.e., modes1 <= min( dim1/2, input_dim1/2). Here "input_dim1" is the grid size along x axis (or first dimension) of the input domain. Other modes also the have same constrain.

None
Source code in pdearena/modules/twod_uno.py
class SpectralConv2d_Uno(nn.Module):
    """2D Fourier layer. It does FFT, linear transform, and Inverse FFT.

    Modified to support multi-gpu training.

    Args:
        in_codim (int): Input co-domian dimension
        out_codim (int): output co-domain dimension

        dim1 (int): Default output grid size along x (or 1st dimension of output domain)
        dim2 (int): Default output grid size along y ( or 2nd dimension of output domain)
                    Ratio of grid size of the input and the output implecitely
                    set the expansion or contraction farctor along each dimension.
        modes1 (int), modes2 (int):  Number of fourier modes to consider for the ontegral operator
                    Number of modes must be compatibale with the input grid size
                    and desired output grid size.
                    i.e., modes1 <= min( dim1/2, input_dim1/2).
                    Here "input_dim1" is the grid size along x axis (or first dimension) of the input domain.
                    Other modes also the have same constrain.
    """

    def __init__(self, in_codim, out_codim, dim1, dim2, modes1=None, modes2=None):
        super().__init__()

        in_codim = int(in_codim)
        out_codim = int(out_codim)
        self.in_channels = in_codim
        self.out_channels = out_codim
        self.dim1 = dim1
        self.dim2 = dim2
        if modes1 is not None:
            self.modes1 = modes1
            self.modes2 = modes2
        else:
            self.modes1 = dim1 // 2 - 1
            self.modes2 = dim2 // 2
        self.scale = (1 / (2 * in_codim)) ** (1.0 / 2.0)
        self.weights1 = nn.Parameter(
            self.scale * (torch.randn(in_codim, out_codim, self.modes1, self.modes2, 2, dtype=torch.float32))
        )
        self.weights2 = nn.Parameter(
            self.scale * (torch.randn(in_codim, out_codim, self.modes1, self.modes2, 2, dtype=torch.float32))
        )

    # Complex multiplication
    def compl_mul2d(self, input, weights):

        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x, dim1=None, dim2=None):
        if dim1 is not None:
            self.dim1 = dim1
            self.dim2 = dim2
        batchsize = x.shape[0]
        # Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft2(x, norm="forward")

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(
            batchsize,
            self.out_channels,
            self.dim1,
            self.dim2 // 2 + 1,
            dtype=torch.cfloat,
            device=x.device,
        )
        out_ft[:, :, : self.modes1, : self.modes2] = self.compl_mul2d(
            x_ft[:, :, : self.modes1, : self.modes2], torch.view_as_complex(self.weights1)
        )
        out_ft[:, :, -self.modes1 :, : self.modes2] = self.compl_mul2d(
            x_ft[:, :, -self.modes1 :, : self.modes2], torch.view_as_complex(self.weights2)
        )

        # Return to physical space
        x = torch.fft.irfft2(out_ft, s=(self.dim1, self.dim2), norm="forward")
        return x

UNO ¤

Bases: nn.Module

UNO model

Parameters:

Name Type Description Default
n_input_scalar_components int

Number of scalar components in the model

required
n_input_vector_components int

Number of vector components in the model

required
n_output_scalar_components int

Number of output scalar components in the model

required
n_output_vector_components int

Number of output vector components in the model

required
time_history int

Number of time steps to include in the model

required
time_future int

Number of time steps to predict in the model

required
hidden_channels int

Number of hidden channels in the model

required
pad int

Padding to use in the model

0
factor int

Scaling factor to use in the model

3 / 4
activation str

Activation function to use in the model

'gelu'
Source code in pdearena/modules/twod_uno.py
class UNO(nn.Module):
    """UNO model

    Args:
        n_input_scalar_components (int): Number of scalar components in the model
        n_input_vector_components (int): Number of vector components in the model
        n_output_scalar_components (int): Number of output scalar components in the model
        n_output_vector_components (int): Number of output vector components in the model
        time_history (int): Number of time steps to include in the model
        time_future (int): Number of time steps to predict in the model
        hidden_channels (int): Number of hidden channels in the model
        pad (int): Padding to use in the model
        factor (int): Scaling factor to use in the model
        activation (str): Activation function to use in the model
    """

    def __init__(
        self,
        n_input_scalar_components: int,
        n_input_vector_components: int,
        n_output_scalar_components: int,
        n_output_vector_components: int,
        time_history: int,
        time_future: int,
        hidden_channels: int,
        pad=0,
        factor=3 / 4,
        activation="gelu",
    ):
        super().__init__()

        self.n_input_scalar_components = n_input_scalar_components
        self.n_input_vector_components = n_input_vector_components
        self.n_output_scalar_components = n_output_scalar_components
        self.n_output_vector_components = n_output_vector_components

        self.width = hidden_channels
        self.factor = factor
        self.padding = pad
        self.activation: nn.Module = ACTIVATION_REGISTRY.get(activation, None)
        if self.activation is None:
            raise NotImplementedError(f"Activation {activation} not implemented")

        in_width = time_history * (self.n_input_scalar_components + self.n_input_vector_components * 2)
        out_width = time_future * (self.n_output_scalar_components + self.n_output_vector_components * 2)
        self.fc = nn.Linear(in_width, self.width // 2)

        self.fc0 = nn.Linear(self.width // 2, self.width)  # input channel is 3: (a(x, y), x, y)

        self.L0 = OperatorBlock_2D(self.width, 2 * factor * self.width, 48, 48, 18, 18)

        self.L1 = OperatorBlock_2D(2 * factor * self.width, 4 * factor * self.width, 32, 32, 14, 14)

        self.L2 = OperatorBlock_2D(4 * factor * self.width, 8 * factor * self.width, 16, 16, 6, 6)

        self.L3 = OperatorBlock_2D(8 * factor * self.width, 8 * factor * self.width, 16, 16, 6, 6)

        self.L4 = OperatorBlock_2D(8 * factor * self.width, 4 * factor * self.width, 32, 32, 6, 6)

        self.L5 = OperatorBlock_2D(8 * factor * self.width, 2 * factor * self.width, 48, 48, 14, 14)

        self.L6 = OperatorBlock_2D(4 * factor * self.width, self.width, 64, 64, 18, 18)  # will be reshaped

        self.fc1 = nn.Linear(2 * self.width, 4 * self.width)
        self.fc2 = nn.Linear(4 * self.width, out_width)

    def forward(self, x):
        assert x.dim() == 5
        orig_shape = x.shape
        x = x.reshape(x.size(0), -1, *x.shape[3:])  # collapse T,C

        x = x.permute(0, 2, 3, 1)
        x_fc = self.fc(x)
        x_fc = self.activation(x_fc)

        x_fc0 = self.fc0(x_fc)
        x_fc0 = self.activation(x_fc0)

        x_fc0 = x_fc0.permute(0, 3, 1, 2)

        x_fc0 = F.pad(x_fc0, [self.padding, self.padding, self.padding, self.padding])

        D1, D2 = x_fc0.shape[-2], x_fc0.shape[-1]

        x_c0 = self.L0(x_fc0, int(D1 * self.factor), int(D2 * self.factor))
        x_c1 = self.L1(x_c0, D1 // 2, D2 // 2)

        x_c2 = self.L2(x_c1, D1 // 4, D2 // 4)
        x_c3 = self.L3(x_c2, D1 // 4, D2 // 4)
        x_c4 = self.L4(x_c3, D1 // 2, D2 // 2)
        x_c4 = torch.cat([x_c4, x_c1], dim=1)
        x_c5 = self.L5(x_c4, int(D1 * self.factor), int(D2 * self.factor))
        x_c5 = torch.cat([x_c5, x_c0], dim=1)
        x_c6 = self.L6(x_c5, D1, D2)
        x_c6 = torch.cat([x_c6, x_fc0], dim=1)

        if self.padding != 0:
            x_c6 = x_c6[..., : -self.padding, : -self.padding]

        x_c6 = x_c6.permute(0, 2, 3, 1)

        x_fc1 = self.fc1(x_c6)
        x_fc1 = self.activation(x_fc1)

        x_out = self.fc2(x_fc1)
        x_out = x_out.permute(0, 3, 1, 2)

        return x_out.reshape(
            orig_shape[0], -1, (self.n_output_scalar_components + self.n_output_vector_components * 2), *orig_shape[3:]
        )

Unet2015 ¤

Bases: nn.Module

Two-dimensional UNet based on original architecture.

Parameters:

Name Type Description Default
n_input_scalar_components int

Number of scalar components in the model

required
n_input_vector_components int

Number of vector components in the model

required
n_output_scalar_components int

Number of output scalar components in the model

required
n_output_vector_components int

Number of output vector components in the model

required
time_history int

Number of time steps in the input.

required
time_future int

Number of time steps in the output.

required
hidden_channels int

Number of hidden channels.

required
activation str

Activation function.

required
Source code in pdearena/modules/twod_unet2015.py
class Unet2015(nn.Module):
    """Two-dimensional UNet based on original architecture.

    Args:
        n_input_scalar_components (int): Number of scalar components in the model
        n_input_vector_components (int): Number of vector components in the model
        n_output_scalar_components (int): Number of output scalar components in the model
        n_output_vector_components (int): Number of output vector components in the model
        time_history (int): Number of time steps in the input.
        time_future (int): Number of time steps in the output.
        hidden_channels (int): Number of hidden channels.
        activation (str): Activation function.
    """

    def __init__(
        self,
        n_input_scalar_components: int,
        n_input_vector_components: int,
        n_output_scalar_components: int,
        n_output_vector_components: int,
        time_history: int,
        time_future: int,
        hidden_channels: int,
        activation: str,
    ) -> None:
        super().__init__()
        self.n_input_scalar_components = n_input_scalar_components
        self.n_input_vector_components = n_input_vector_components
        self.n_output_scalar_components = n_output_scalar_components
        self.n_output_vector_components = n_output_vector_components
        self.time_history = time_history
        self.time_future = time_future
        self.hidden_channels = hidden_channels
        self.activation = ACTIVATION_REGISTRY.get(activation, None)
        if self.activation is None:
            raise NotImplementedError(f"Activation {activation} not implemented")

        in_channels = time_history * (self.n_input_scalar_components + self.n_input_vector_components * 2)
        out_channels = time_future * (self.n_output_scalar_components + self.n_output_vector_components * 2)

        features = hidden_channels
        self.encoder1 = Unet2015._block(in_channels, features, name="enc1", activation=self.activation)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = Unet2015._block(features, features * 2, name="enc2", activation=self.activation)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = Unet2015._block(features * 2, features * 4, name="enc3", activation=self.activation)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = Unet2015._block(features * 4, features * 8, name="enc4", activation=self.activation)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = Unet2015._block(features * 8, features * 16, name="bottleneck", activation=self.activation)

        self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, kernel_size=2, stride=2)
        self.decoder4 = Unet2015._block((features * 8) * 2, features * 8, name="dec4", activation=self.activation)
        self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, kernel_size=2, stride=2)
        self.decoder3 = Unet2015._block((features * 4) * 2, features * 4, name="dec3", activation=self.activation)
        self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2)
        self.decoder2 = Unet2015._block((features * 2) * 2, features * 2, name="dec2", activation=self.activation)
        self.upconv1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)
        self.decoder1 = Unet2015._block(features * 2, features, name="dec1", activation=self.activation)

        self.conv = nn.Conv2d(in_channels=features, out_channels=out_channels, kernel_size=1)

    def forward(self, x):
        assert x.dim() == 5
        orig_shape = x.shape
        x = x.reshape(x.size(0), -1, *x.shape[3:])
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        out = self.conv(dec1)
        return out.reshape(orig_shape[0], -1, *orig_shape[2:])

    @staticmethod
    def _block(in_channels, features, name, activation=nn.Tanh()):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "act1", activation),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "act2", activation),
                ]
            )
        )

Unetbase ¤

Bases: nn.Module

Our interpretation of the original U-Net architecture.

Uses torch.nn.GroupNorm instead of torch.nn.BatchNorm2d. Also there is no BottleNeck block.

Parameters:

Name Type Description Default
n_input_scalar_components int

Number of scalar components in the model

required
n_input_vector_components int

Number of vector components in the model

required
n_output_scalar_components int

Number of output scalar components in the model

required
n_output_vector_components int

Number of output vector components in the model

required
time_history int

Number of time steps in the input.

required
time_future int

Number of time steps in the output.

required
hidden_channels int

Number of channels in the hidden layers.

required
activation str

Activation function to use. One of ["gelu", "relu", "silu"].

'gelu'
Source code in pdearena/modules/twod_unetbase.py
class Unetbase(nn.Module):
    """Our interpretation of the original U-Net architecture.

    Uses [torch.nn.GroupNorm][] instead of [torch.nn.BatchNorm2d][]. Also there is no `BottleNeck` block.

    Args:
        n_input_scalar_components (int): Number of scalar components in the model
        n_input_vector_components (int): Number of vector components in the model
        n_output_scalar_components (int): Number of output scalar components in the model
        n_output_vector_components (int): Number of output vector components in the model
        time_history (int): Number of time steps in the input.
        time_future (int): Number of time steps in the output.
        hidden_channels (int): Number of channels in the hidden layers.
        activation (str): Activation function to use. One of ["gelu", "relu", "silu"].
    """

    def __init__(
        self,
        n_input_scalar_components: int,
        n_input_vector_components: int,
        n_output_scalar_components: int,
        n_output_vector_components: int,
        time_history: int,
        time_future: int,
        hidden_channels: int,
        activation="gelu",
    ) -> None:
        super().__init__()
        self.n_input_scalar_components = n_input_scalar_components
        self.n_input_vector_components = n_input_vector_components
        self.n_output_scalar_components = n_output_scalar_components
        self.n_output_vector_components = n_output_vector_components
        self.time_history = time_history
        self.time_future = time_future
        self.hidden_channels = hidden_channels
        self.activation = ACTIVATION_REGISTRY.get(activation, None)
        if self.activation is None:
            raise NotImplementedError(f"Activation {activation} not implemented")

        insize = time_history * (self.n_input_scalar_components + self.n_input_vector_components * 2)
        n_channels = hidden_channels
        self.image_proj = ConvBlock(insize, n_channels, activation=activation)

        self.down = nn.ModuleList(
            [
                Down(n_channels, n_channels * 2, activation=activation),
                Down(n_channels * 2, n_channels * 4, activation=activation),
                Down(n_channels * 4, n_channels * 8, activation=activation),
                Down(n_channels * 8, n_channels * 16, activation=activation),
            ]
        )
        self.up = nn.ModuleList(
            [
                Up(n_channels * 16, n_channels * 8, activation=activation),
                Up(n_channels * 8, n_channels * 4, activation=activation),
                Up(n_channels * 4, n_channels * 2, activation=activation),
                Up(n_channels * 2, n_channels, activation=activation),
            ]
        )
        out_channels = time_future * (self.n_output_scalar_components + self.n_output_vector_components * 2)
        # should there be a final norm too? but we aren't doing "prenorm" in the original
        self.final = nn.Conv2d(n_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))

    def forward(self, x):
        assert x.dim() == 5
        orig_shape = x.shape
        x = x.reshape(x.size(0), -1, *x.shape[3:])
        h = self.image_proj(x)

        x1 = self.down[0](h)
        x2 = self.down[1](x1)
        x3 = self.down[2](x2)
        x4 = self.down[3](x3)
        x = self.up[0](x4, x3)
        x = self.up[1](x, x2)
        x = self.up[2](x, x1)
        x = self.up[3](x, h)

        x = self.final(x)
        return x.reshape(
            orig_shape[0], -1, (self.n_output_scalar_components + self.n_output_vector_components * 2), *orig_shape[3:]
        )

AttentionBlock ¤

Bases: nn.Module

Attention block This is similar to [transformer multi-head attention]

Parameters:

Name Type Description Default
n_channels int

the number of channels in the input

required
n_heads int

the number of heads in multi-head attention

1
d_k Optional[int]

the number of dimensions in each head

None
n_groups int

the number of groups for group normalization.

1
Source code in pdearena/modules/twod_unet.py
class AttentionBlock(nn.Module):
    """Attention block This is similar to [transformer multi-head
    attention]

    Args:
        n_channels (int): the number of channels in the input
        n_heads (int): the number of heads in multi-head attention
        d_k: the number of dimensions in each head
        n_groups (int): the number of groups for [group normalization][torch.nn.GroupNorm].

    """

    def __init__(self, n_channels: int, n_heads: int = 1, d_k: Optional[int] = None, n_groups: int = 1):
        super().__init__()

        # Default `d_k`
        if d_k is None:
            d_k = n_channels
        # Normalization layer
        self.norm = nn.GroupNorm(n_groups, n_channels)
        # Projections for query, key and values
        self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
        # Linear layer for final transformation
        self.output = nn.Linear(n_heads * d_k, n_channels)
        # Scale for dot-product attention
        self.scale = d_k**-0.5
        #
        self.n_heads = n_heads
        self.d_k = d_k

    def forward(self, x: torch.Tensor):
        # Get shape
        batch_size, n_channels, height, width = x.shape
        # Change `x` to shape `[batch_size, seq, n_channels]`
        x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
        # Get query, key, and values (concatenated) and shape it to `[batch_size, seq, n_heads, 3 * d_k]`
        qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
        # Split query, key, and values. Each of them will have shape `[batch_size, seq, n_heads, d_k]`
        q, k, v = torch.chunk(qkv, 3, dim=-1)
        # Calculate scaled dot-product $\frac{Q K^\top}{\sqrt{d_k}}$
        attn = torch.einsum("bihd,bjhd->bijh", q, k) * self.scale
        # Softmax along the sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
        attn = attn.softmax(dim=1)
        # Multiply by values
        res = torch.einsum("bijh,bjhd->bihd", attn, v)
        # Reshape to `[batch_size, seq, n_heads * d_k]`
        res = res.view(batch_size, -1, self.n_heads * self.d_k)
        # Transform to `[batch_size, seq, n_channels]`
        res = self.output(res)

        # Add skip connection
        res += x

        # Change to shape `[batch_size, in_channels, height, width]`
        res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)
        return res

DownBlock ¤

Bases: nn.Module

Down block This combines ResidualBlock and AttentionBlock.

These are used in the first half of U-Net at each resolution.

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
has_attn bool

Whether to use attention block

False
activation nn.Module

Activation function

'gelu'
norm bool

Whether to use normalization

False
Source code in pdearena/modules/twod_unet.py
class DownBlock(nn.Module):
    """Down block This combines [`ResidualBlock`][pdearena.modules.twod_unet.ResidualBlock] and [`AttentionBlock`][pdearena.modules.twod_unet.AttentionBlock].

    These are used in the first half of U-Net at each resolution.

    Args:
        in_channels (int): Number of input channels
        out_channels (int): Number of output channels
        has_attn (bool): Whether to use attention block
        activation (nn.Module): Activation function
        norm (bool): Whether to use normalization
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        has_attn: bool = False,
        activation: str = "gelu",
        norm: bool = False,
    ):
        super().__init__()
        self.res = ResidualBlock(in_channels, out_channels, activation=activation, norm=norm)
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()

    def forward(self, x: torch.Tensor):
        x = self.res(x)
        x = self.attn(x)
        return x

Downsample ¤

Bases: nn.Module

Scale down the feature map by \(\frac{1}{2} \times\)

Parameters:

Name Type Description Default
n_channels int

Number of channels in the input and output.

required
Source code in pdearena/modules/twod_unet.py
class Downsample(nn.Module):
    r"""Scale down the feature map by $\frac{1}{2} \times$

    Args:
        n_channels (int): Number of channels in the input and output.
    """

    def __init__(self, n_channels):
        super().__init__()
        self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))

    def forward(self, x: torch.Tensor):
        return self.conv(x)

FourierDownBlock ¤

Bases: nn.Module

Down block This combines FourierResidualBlock and AttentionBlock.

These are used in the first half of U-Net at each resolution.

Source code in pdearena/modules/twod_unet.py
class FourierDownBlock(nn.Module):
    """Down block This combines [`FourierResidualBlock`][pdearena.modules.twod_unet.FourierResidualBlock] and [`AttentionBlock`][pdearena.modules.twod_unet.AttentionBlock].

    These are used in the first half of U-Net at each resolution.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        modes1: int = 16,
        modes2: int = 16,
        has_attn: bool = False,
        activation: str = "gelu",
        norm: bool = False,
    ):
        super().__init__()
        self.res = FourierResidualBlock(
            in_channels,
            out_channels,
            modes1=modes1,
            modes2=modes2,
            activation=activation,
            norm=norm,
        )
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()

    def forward(self, x: torch.Tensor):
        x = self.res(x)
        x = self.attn(x)
        return x

FourierResidualBlock ¤

Bases: nn.Module

Fourier Residual Block to be used in modern Unet architectures.

Parameters:

Name Type Description Default
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
modes1 int

Number of modes in the first dimension.

16
modes2 int

Number of modes in the second dimension.

16
activation str

Activation function to use.

'gelu'
norm bool

Whether to use normalization.

False
n_groups int

Number of groups for group normalization.

1
Source code in pdearena/modules/twod_unet.py
class FourierResidualBlock(nn.Module):
    """Fourier Residual Block to be used in modern Unet architectures.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        modes1 (int): Number of modes in the first dimension.
        modes2 (int): Number of modes in the second dimension.
        activation (str): Activation function to use.
        norm (bool): Whether to use normalization.
        n_groups (int): Number of groups for group normalization.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        modes1: int = 16,
        modes2: int = 16,
        activation: str = "gelu",
        norm: bool = False,
        n_groups: int = 1,
    ):
        super().__init__()
        self.activation: nn.Module = ACTIVATION_REGISTRY.get(activation, None)
        if self.activation is None:
            raise NotImplementedError(f"Activation {activation} not implemented")

        self.modes1 = modes1
        self.modes2 = modes2

        self.fourier1 = SpectralConv2d(in_channels, out_channels, modes1=self.modes1, modes2=self.modes2)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, padding_mode="zeros")
        self.fourier2 = SpectralConv2d(out_channels, out_channels, modes1=self.modes1, modes2=self.modes2)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0, padding_mode="zeros")
        # If the number of input channels is not equal to the number of output channels we have to
        # project the shortcut connection
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
        else:
            self.shortcut = nn.Identity()

        if norm:
            self.norm1 = nn.GroupNorm(n_groups, in_channels)
            self.norm2 = nn.GroupNorm(n_groups, out_channels)
        else:
            self.norm1 = nn.Identity()
            self.norm2 = nn.Identity()

    def forward(self, x: torch.Tensor):
        # using pre-norms
        h = self.activation(self.norm1(x))
        x1 = self.fourier1(h)
        x2 = self.conv1(h)
        out = x1 + x2
        out = self.activation(self.norm2(out))
        x1 = self.fourier2(out)
        x2 = self.conv2(out)
        out = x1 + x2 + self.shortcut(x)
        return out

FourierUnet ¤

Bases: nn.Module

Unet with Fourier layers in early downsampling blocks.

Parameters:

Name Type Description Default
n_input_scalar_components int

Number of scalar components in the model

required
n_input_vector_components int

Number of vector components in the model

required
n_output_scalar_components int

Number of output scalar components in the model

required
n_output_vector_components int

Number of output vector components in the model

required
time_history int

Number of time steps in the input.

required
time_future int

Number of time steps in the output.

required
hidden_channels int

Number of channels in the first layer.

required
activation str

Activation function to use.

required
modes1 int

Number of Fourier modes to use in the first spatial dimension.

12
modes2 int

Number of Fourier modes to use in the second spatial dimension.

12
norm bool

Whether to use normalization.

False
ch_mults list

List of integers to multiply the number of channels by at each resolution.

(1, 2, 2, 4)
is_attn list

List of booleans indicating whether to use attention at each resolution.

(False, False, False, False)
mid_attn bool

Whether to use attention in the middle block.

False
n_blocks int

Number of blocks to use at each resolution.

2
n_fourier_layers int

Number of early downsampling layers to use Fourier layers in.

2
mode_scaling bool

Whether to scale the number of modes with resolution.

True
use1x1 bool

Whether to use 1x1 convolutions in the initial and final layer.

False
Source code in pdearena/modules/twod_unet.py
class FourierUnet(nn.Module):
    """Unet with Fourier layers in early downsampling blocks.

    Args:
        n_input_scalar_components (int): Number of scalar components in the model
        n_input_vector_components (int): Number of vector components in the model
        n_output_scalar_components (int): Number of output scalar components in the model
        n_output_vector_components (int): Number of output vector components in the model
        time_history (int): Number of time steps in the input.
        time_future (int): Number of time steps in the output.
        hidden_channels (int): Number of channels in the first layer.
        activation (str): Activation function to use.
        modes1 (int): Number of Fourier modes to use in the first spatial dimension.
        modes2 (int): Number of Fourier modes to use in the second spatial dimension.
        norm (bool): Whether to use normalization.
        ch_mults (list): List of integers to multiply the number of channels by at each resolution.
        is_attn (list): List of booleans indicating whether to use attention at each resolution.
        mid_attn (bool): Whether to use attention in the middle block.
        n_blocks (int): Number of blocks to use at each resolution.
        n_fourier_layers (int): Number of early downsampling layers to use Fourier layers in.
        mode_scaling (bool): Whether to scale the number of modes with resolution.
        use1x1 (bool): Whether to use 1x1 convolutions in the initial and final layer.
    """

    def __init__(
        self,
        n_input_scalar_components: int,
        n_input_vector_components: int,
        n_output_scalar_components: int,
        n_output_vector_components: int,
        time_history: int,
        time_future: int,
        hidden_channels: int,
        activation: str,
        modes1: int = 12,
        modes2: int = 12,
        norm: bool = False,
        ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
        is_attn: Union[Tuple[bool, ...], List[bool]] = (False, False, False, False),
        mid_attn: bool = False,
        n_blocks: int = 2,
        n_fourier_layers: int = 2,
        mode_scaling: bool = True,
        use1x1: bool = False,
    ) -> None:
        super().__init__()
        self.n_input_scalar_components = n_input_scalar_components
        self.n_input_vector_components = n_input_vector_components
        self.n_output_scalar_components = n_output_scalar_components
        self.n_output_vector_components = n_output_vector_components
        self.time_history = time_history
        self.time_future = time_future
        self.hidden_channels = hidden_channels
        self.activation: nn.Module = ACTIVATION_REGISTRY.get(activation, None)
        if self.activation is None:
            raise NotImplementedError(f"Activation {activation} not implemented")
        # Number of resolutions
        n_resolutions = len(ch_mults)

        insize = time_history * (self.n_input_scalar_components + self.n_input_vector_components * 2)
        n_channels = hidden_channels
        # Project image into feature map
        if use1x1:
            self.image_proj = nn.Conv2d(insize, n_channels, kernel_size=1)
        else:
            self.image_proj = nn.Conv2d(insize, n_channels, kernel_size=(3, 3), padding=(1, 1))

        # #### First half of U-Net - decreasing resolution
        down = []
        # Number of channels
        out_channels = in_channels = n_channels
        # For each resolution
        for i in range(n_resolutions):
            # Number of output channels at this resolution
            out_channels = in_channels * ch_mults[i]
            if i < n_fourier_layers:
                for _ in range(n_blocks):
                    down.append(
                        FourierDownBlock(
                            in_channels,
                            out_channels,
                            modes1=max(modes1 // 2**i, 4) if mode_scaling else modes1,
                            modes2=max(modes2 // 2**i, 4) if mode_scaling else modes2,
                            has_attn=is_attn[i],
                            activation=activation,
                            norm=norm,
                        )
                    )
                    in_channels = out_channels
            else:
                # Add `n_blocks`
                for _ in range(n_blocks):
                    down.append(
                        DownBlock(
                            in_channels,
                            out_channels,
                            has_attn=is_attn[i],
                            activation=activation,
                            norm=norm,
                        )
                    )
                    in_channels = out_channels
            # Down sample at all resolutions except the last
            if i < n_resolutions - 1:
                down.append(Downsample(in_channels))

        # Combine the set of modules
        self.down = nn.ModuleList(down)

        # Middle block
        self.middle = MiddleBlock(out_channels, has_attn=mid_attn, activation=activation, norm=norm)

        # #### Second half of U-Net - increasing resolution
        up = []
        # Number of channels
        in_channels = out_channels
        # For each resolution
        for i in reversed(range(n_resolutions)):
            # `n_blocks` at the same resolution
            out_channels = in_channels
            for _ in range(n_blocks):
                up.append(
                    UpBlock(
                        in_channels,
                        out_channels,
                        has_attn=is_attn[i],
                        activation=activation,
                        norm=norm,
                    )
                )
            # Final block to reduce the number of channels
            out_channels = in_channels // ch_mults[i]
            up.append(UpBlock(in_channels, out_channels, has_attn=is_attn[i], activation=activation, norm=norm))
            in_channels = out_channels
            # Up sample at all resolutions except last
            if i > 0:
                up.append(Upsample(in_channels))

        # Combine the set of modules
        self.up = nn.ModuleList(up)

        if norm:
            self.norm = nn.GroupNorm(8, n_channels)
        else:
            self.norm = nn.Identity()
        out_channels = time_future * (self.n_output_scalar_components + self.n_output_vector_components * 2)
        if use1x1:
            self.final = nn.Conv2d(n_channels, out_channels, kernel_size=1)
        else:
            self.final = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))

    def forward(self, x: torch.Tensor):
        assert x.dim() == 5
        orig_shape = x.shape
        x = x.reshape(x.size(0), -1, *x.shape[3:])  # collapse T,C
        x = self.image_proj(x)

        h = [x]
        for m in self.down:
            x = m(x)
            h.append(x)

        x = self.middle(x)

        for m in self.up:
            if isinstance(m, Upsample):
                x = m(x)
            else:
                # Get the skip connection from first half of U-Net and concatenate
                s = h.pop()
                x = torch.cat((x, s), dim=1)
                #
                x = m(x)

        x = self.final(self.activation(self.norm(x)))
        return x.reshape(
            orig_shape[0], -1, (self.n_output_scalar_components + self.n_output_vector_components * 2), *orig_shape[3:]
        )

FourierUpBlock ¤

Bases: nn.Module

Up block that combines FourierResidualBlock and AttentionBlock.

These are used in the second half of U-Net at each resolution.

Note

We currently don't recommend using this block.

Source code in pdearena/modules/twod_unet.py
class FourierUpBlock(nn.Module):
    """Up block that combines [`FourierResidualBlock`][pdearena.modules.twod_unet.FourierResidualBlock] and [`AttentionBlock`][pdearena.modules.twod_unet.AttentionBlock].

    These are used in the second half of U-Net at each resolution.

    Note:
        We currently don't recommend using this block.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        modes1: int = 16,
        modes2: int = 16,
        has_attn: bool = False,
        activation: str = "gelu",
        norm: bool = False,
    ):
        super().__init__()
        # The input has `in_channels + out_channels` because we concatenate the output of the same resolution
        # from the first half of the U-Net
        self.res = FourierResidualBlock(
            in_channels + out_channels,
            out_channels,
            modes1=modes1,
            modes2=modes2,
            activation=activation,
            norm=norm,
        )
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()

    def forward(self, x: torch.Tensor):
        x = self.res(x)
        x = self.attn(x)
        return x

MiddleBlock ¤

Bases: nn.Module

Middle block

It combines a ResidualBlock, AttentionBlock, followed by another ResidualBlock.

This block is applied at the lowest resolution of the U-Net.

Parameters:

Name Type Description Default
n_channels int

Number of channels in the input and output.

required
has_attn bool

Whether to use attention block. Defaults to False.

False
activation str

Activation function to use. Defaults to "gelu".

'gelu'
norm bool

Whether to use normalization. Defaults to False.

False
Source code in pdearena/modules/twod_unet.py
class MiddleBlock(nn.Module):
    """Middle block

    It combines a `ResidualBlock`, `AttentionBlock`, followed by another
    `ResidualBlock`.

    This block is applied at the lowest resolution of the U-Net.

    Args:
        n_channels (int): Number of channels in the input and output.
        has_attn (bool, optional): Whether to use attention block. Defaults to False.
        activation (str): Activation function to use. Defaults to "gelu".
        norm (bool, optional): Whether to use normalization. Defaults to False.
    """

    def __init__(self, n_channels: int, has_attn: bool = False, activation: str = "gelu", norm: bool = False):
        super().__init__()
        self.res1 = ResidualBlock(n_channels, n_channels, activation=activation, norm=norm)
        self.attn = AttentionBlock(n_channels) if has_attn else nn.Identity()
        self.res2 = ResidualBlock(n_channels, n_channels, activation=activation, norm=norm)

    def forward(self, x: torch.Tensor):
        x = self.res1(x)
        x = self.attn(x)
        x = self.res2(x)
        return x

ResidualBlock ¤

Bases: nn.Module

Wide Residual Blocks used in modern Unet architectures.

Parameters:

Name Type Description Default
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
activation str

Activation function to use.

'gelu'
norm bool

Whether to use normalization.

False
n_groups int

Number of groups for group normalization.

1
Source code in pdearena/modules/twod_unet.py
class ResidualBlock(nn.Module):
    """Wide Residual Blocks used in modern Unet architectures.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        activation (str): Activation function to use.
        norm (bool): Whether to use normalization.
        n_groups (int): Number of groups for group normalization.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        activation: str = "gelu",
        norm: bool = False,
        n_groups: int = 1,
    ):
        super().__init__()
        self.activation: nn.Module = ACTIVATION_REGISTRY.get(activation, None)
        if self.activation is None:
            raise NotImplementedError(f"Activation {activation} not implemented")
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
        # If the number of input channels is not equal to the number of output channels we have to
        # project the shortcut connection
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
        else:
            self.shortcut = nn.Identity()

        if norm:
            self.norm1 = nn.GroupNorm(n_groups, in_channels)
            self.norm2 = nn.GroupNorm(n_groups, out_channels)
        else:
            self.norm1 = nn.Identity()
            self.norm2 = nn.Identity()

    def forward(self, x: torch.Tensor):
        # First convolution layer
        h = self.conv1(self.activation(self.norm1(x)))
        # Second convolution layer
        h = self.conv2(self.activation(self.norm2(h)))
        # Add the shortcut connection and return
        return h + self.shortcut(x)

Unet ¤

Bases: nn.Module

Modern U-Net architecture

This is a modern U-Net architecture with wide-residual blocks and spatial attention blocks

Parameters:

Name Type Description Default
n_input_scalar_components int

Number of scalar components in the model

required
n_input_vector_components int

Number of vector components in the model

required
n_output_scalar_components int

Number of output scalar components in the model

required
n_output_vector_components int

Number of output vector components in the model

required
time_history int

Number of time steps in the input

required
time_future int

Number of time steps in the output

required
hidden_channels int

Number of channels in the hidden layers

required
activation str

Activation function to use

required
norm bool

Whether to use normalization

False
ch_mults list

List of channel multipliers for each resolution

(1, 2, 2, 4)
is_attn list

List of booleans indicating whether to use attention blocks

(False, False, False, False)
mid_attn bool

Whether to use attention block in the middle block

False
n_blocks int

Number of residual blocks in each resolution

2
use1x1 bool

Whether to use 1x1 convolutions in the initial and final layers

False
Source code in pdearena/modules/twod_unet.py
class Unet(nn.Module):
    """Modern U-Net architecture

    This is a modern U-Net architecture with wide-residual blocks and spatial attention blocks

    Args:
        n_input_scalar_components (int): Number of scalar components in the model
        n_input_vector_components (int): Number of vector components in the model
        n_output_scalar_components (int): Number of output scalar components in the model
        n_output_vector_components (int): Number of output vector components in the model
        time_history (int): Number of time steps in the input
        time_future (int): Number of time steps in the output
        hidden_channels (int): Number of channels in the hidden layers
        activation (str): Activation function to use
        norm (bool): Whether to use normalization
        ch_mults (list): List of channel multipliers for each resolution
        is_attn (list): List of booleans indicating whether to use attention blocks
        mid_attn (bool): Whether to use attention block in the middle block
        n_blocks (int): Number of residual blocks in each resolution
        use1x1 (bool): Whether to use 1x1 convolutions in the initial and final layers
    """

    def __init__(
        self,
        n_input_scalar_components: int,
        n_input_vector_components: int,
        n_output_scalar_components: int,
        n_output_vector_components: int,
        time_history: int,
        time_future: int,
        hidden_channels: int,
        activation: str,
        norm: bool = False,
        ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
        is_attn: Union[Tuple[bool, ...], List[bool]] = (False, False, False, False),
        mid_attn: bool = False,
        n_blocks: int = 2,
        use1x1: bool = False,
    ) -> None:
        super().__init__()
        self.n_input_scalar_components = n_input_scalar_components
        self.n_input_vector_components = n_input_vector_components
        self.n_output_scalar_components = n_output_scalar_components
        self.n_output_vector_components = n_output_vector_components
        self.time_history = time_history
        self.time_future = time_future
        self.hidden_channels = hidden_channels

        self.activation: nn.Module = ACTIVATION_REGISTRY.get(activation, None)
        if self.activation is None:
            raise NotImplementedError(f"Activation {activation} not implemented")
        # Number of resolutions
        n_resolutions = len(ch_mults)

        insize = time_history * (self.n_input_scalar_components + self.n_input_vector_components * 2)
        n_channels = hidden_channels
        # Project image into feature map
        if use1x1:
            self.image_proj = nn.Conv2d(insize, n_channels, kernel_size=1)
        else:
            self.image_proj = nn.Conv2d(insize, n_channels, kernel_size=(3, 3), padding=(1, 1))

        # #### First half of U-Net - decreasing resolution
        down = []
        # Number of channels
        out_channels = in_channels = n_channels
        # For each resolution
        for i in range(n_resolutions):
            # Number of output channels at this resolution
            out_channels = in_channels * ch_mults[i]
            # Add `n_blocks`
            for _ in range(n_blocks):
                down.append(
                    DownBlock(
                        in_channels,
                        out_channels,
                        has_attn=is_attn[i],
                        activation=activation,
                        norm=norm,
                    )
                )
                in_channels = out_channels
            # Down sample at all resolutions except the last
            if i < n_resolutions - 1:
                down.append(Downsample(in_channels))

        # Combine the set of modules
        self.down = nn.ModuleList(down)

        # Middle block
        self.middle = MiddleBlock(out_channels, has_attn=mid_attn, activation=activation, norm=norm)

        # #### Second half of U-Net - increasing resolution
        up = []
        # Number of channels
        in_channels = out_channels
        # For each resolution
        for i in reversed(range(n_resolutions)):
            # `n_blocks` at the same resolution
            out_channels = in_channels
            for _ in range(n_blocks):
                up.append(
                    UpBlock(
                        in_channels,
                        out_channels,
                        has_attn=is_attn[i],
                        activation=activation,
                        norm=norm,
                    )
                )
            # Final block to reduce the number of channels
            out_channels = in_channels // ch_mults[i]
            up.append(UpBlock(in_channels, out_channels, has_attn=is_attn[i], activation=activation, norm=norm))
            in_channels = out_channels
            # Up sample at all resolutions except last
            if i > 0:
                up.append(Upsample(in_channels))

        # Combine the set of modules
        self.up = nn.ModuleList(up)

        if norm:
            self.norm = nn.GroupNorm(8, n_channels)
        else:
            self.norm = nn.Identity()
        out_channels = time_future * (self.n_output_scalar_components + self.n_output_vector_components * 2)
        #
        if use1x1:
            self.final = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.final = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))

    def forward(self, x: torch.Tensor):
        assert x.dim() == 5
        orig_shape = x.shape
        x = x.reshape(x.size(0), -1, *x.shape[3:])  # collapse T,C
        x = self.image_proj(x)

        h = [x]
        for m in self.down:
            x = m(x)
            h.append(x)

        x = self.middle(x)

        for m in self.up:
            if isinstance(m, Upsample):
                x = m(x)
            else:
                # Get the skip connection from first half of U-Net and concatenate
                s = h.pop()
                x = torch.cat((x, s), dim=1)
                #
                x = m(x)

        x = self.final(self.activation(self.norm(x)))
        x = x.reshape(
            orig_shape[0], -1, (self.n_output_scalar_components + self.n_output_vector_components * 2), *orig_shape[3:]
        )
        return x

UpBlock ¤

Bases: nn.Module

Up block that combines ResidualBlock and AttentionBlock.

These are used in the second half of U-Net at each resolution.

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
has_attn bool

Whether to use attention block

False
activation str

Activation function

'gelu'
norm bool

Whether to use normalization

False
Source code in pdearena/modules/twod_unet.py
class UpBlock(nn.Module):
    """Up block that combines [`ResidualBlock`][pdearena.modules.twod_unet.ResidualBlock] and [`AttentionBlock`][pdearena.modules.twod_unet.AttentionBlock].

    These are used in the second half of U-Net at each resolution.

    Args:
        in_channels (int): Number of input channels
        out_channels (int): Number of output channels
        has_attn (bool): Whether to use attention block
        activation (str): Activation function
        norm (bool): Whether to use normalization
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        has_attn: bool = False,
        activation: str = "gelu",
        norm: bool = False,
    ):
        super().__init__()
        # The input has `in_channels + out_channels` because we concatenate the output of the same resolution
        # from the first half of the U-Net
        self.res = ResidualBlock(in_channels + out_channels, out_channels, activation=activation, norm=norm)
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()

    def forward(self, x: torch.Tensor):
        x = self.res(x)
        x = self.attn(x)
        return x

Upsample ¤

Bases: nn.Module

Scale up the feature map by \(2 \times\)

Parameters:

Name Type Description Default
n_channels int

Number of channels in the input and output.

required
Source code in pdearena/modules/twod_unet.py
class Upsample(nn.Module):
    r"""Scale up the feature map by $2 \times$

    Args:
        n_channels (int): Number of channels in the input and output.
    """

    def __init__(self, n_channels: int):
        super().__init__()
        self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))

    def forward(self, x: torch.Tensor):
        return self.conv(x)