Skip to content

Clifford models¤

We provide exemplary 2D and 3D Clifford models as used in the paper.

All these modules are available for different algebras.

2D models¤

The following code snippet initializes a 2D Clifford ResNet.

import torch.nn.functional as F

from cliffordlayers.models.models_2d import (
    CliffordNet2d,
    CliffordBasicBlock2d,
)

model = CliffordNet2d(
        g = [-1, -1],
        block = CliffordBasicBlock2d,
        num_blocks = [2, 2, 2, 2],
        in_channels = in_channels,
        out_channels = out_channels,
        hidden_channels = 32,
        activation = F.gelu,
        norm = True,
        rotation = False,
    )

The following code snippet initializes a 2D rotational Clifford ResNet.

import torch.nn.functional as F

from cliffordlayers.models.models_2d import (
    CliffordNet2d,
    CliffordBasicBlock2d,
)

model = CliffordNet2d(
        g = [-1, -1],
        block = CliffordBasicBlock2d,
        num_blocks = [2, 2, 2, 2],
        in_channels = in_channels,
        out_channels = out_channels,
        hidden_channels = 32,
        activation = F.gelu,
        norm = True,
        rotation = True,
    )

The following code snippet initializes a 2D Clifford FNO.

import torch.nn.functional as F

from cliffordlayers.models.utils import partialclass
from cliffordlayers.models.models_2d import (
    CliffordNet2d,
    CliffordFourierBasicBlock2d,
)

model = CliffordNet2d(
        g = [-1, -1],
        block = partialclass(
                "CliffordFourierBasicBlock2d", CliffordFourierBasicBlock2d, modes1=32, modes2=32
            ),
        num_blocks = [1, 1, 1, 1],
        in_channels = in_channels,
        out_channels = out_channels,
        hidden_channels = 32,
        activation = F.gelu,
        norm = False,
        rotation = False,
    )

CliffordBasicBlock2d ¤

Bases: nn.Module

2D building block for Clifford ResNet architectures.

Parameters:

Name Type Description Default
g Union[tuple, list, torch.Tensor]

Signature of Clifford algebra.

required
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
activation Callable

Activation function. Defaults to F.gelu.

F.gelu
kernel_size int

Kernel size of Clifford convolution. Defaults to 3.

3
stride int

Stride of Clifford convolution. Defaults to 1.

1
padding int

Padding of Clifford convolution. Defaults to 1.

1
rotation bool

Wether to use rotational Clifford convolution. Defaults to False.

False
norm bool

Wether to use Clifford (group) normalization. Defaults to False.

False
num_groups int

Number of groups when using Clifford (group) normalization. Defaults to 1.

1
Source code in cliffordlayers/models/models_2d.py
class CliffordBasicBlock2d(nn.Module):
    """2D building block for Clifford ResNet architectures.

    Args:
        g (Union[tuple, list, torch.Tensor]): Signature of Clifford algebra.
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        activation (Callable, optional): Activation function. Defaults to F.gelu.
        kernel_size (int, optional): Kernel size of Clifford convolution. Defaults to 3.
        stride (int, optional): Stride of Clifford convolution. Defaults to 1.
        padding (int, optional): Padding of Clifford convolution. Defaults to 1.
        rotation (bool, optional): Wether to use rotational Clifford convolution. Defaults to False.
        norm (bool, optional): Wether to use Clifford (group) normalization. Defaults to False.
        num_groups (int, optional): Number of groups when using Clifford (group) normalization. Defaults to 1.
    """    
    expansion: int = 1

    def __init__(
        self,
        g: Union[tuple, list, torch.Tensor],
        in_channels: int,
        out_channels: int,
        activation: Callable = F.gelu,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 1,
        rotation: bool = False,
        norm: bool = False,
        num_groups: int = 1,
    ) -> None:   
        super().__init__()
        self.conv1 = CliffordConv2d(
            g,
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=True,
            rotation=rotation,
        )
        self.conv2 = CliffordConv2d(
            g,
            out_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=True,
            rotation=rotation,
        )
        self.norm1 = CliffordGroupNorm2d(g, num_groups, in_channels) if norm else nn.Identity()
        self.norm2 = CliffordGroupNorm2d(g, num_groups, out_channels) if norm else nn.Identity()
        self.activation = activation

    def __repr__(self):
        return "CliffordBasicBlock2d"

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.conv1(self.activation(self.norm1(x)))
        out = self.conv2(self.activation(self.norm2(out)))
        return out + x

CliffordFourierBasicBlock2d ¤

Bases: nn.Module

2D building block for Clifford FNO architectures.

Parameters:

Name Type Description Default
g Union[tuple, list, torch.Tensor]

Signature of Clifford algebra.

required
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
activation Callable

Activation function. Defaults to F.gelu.

F.gelu
kernel_size int

Kernel size of Clifford convolution. Defaults to 3.

1
stride int

Stride of Clifford convolution. Defaults to 1.

1
padding int

Padding of Clifford convolution. Defaults to 1.

0
rotation bool

Wether to use rotational Clifford convolution. Defaults to False.

False
norm bool

Wether to use Clifford (group) normalization. Defaults to False.

False
num_groups int

Number of groups when using Clifford (group) normalization. Defaults to 1.

1
modes1 int

Number of Fourier modes in the first dimension. Defaults to 16.

16
modes2 int

Number of Fourier modes in the second dimension. Defaults to 16.

16
Source code in cliffordlayers/models/models_2d.py
class CliffordFourierBasicBlock2d(nn.Module):
    """2D building block for Clifford FNO architectures.

    Args:
        g (Union[tuple, list, torch.Tensor]): Signature of Clifford algebra.
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        activation (Callable, optional): Activation function. Defaults to F.gelu.
        kernel_size (int, optional): Kernel size of Clifford convolution. Defaults to 3.
        stride (int, optional): Stride of Clifford convolution. Defaults to 1.
        padding (int, optional): Padding of Clifford convolution. Defaults to 1.
        rotation (bool, optional): Wether to use rotational Clifford convolution. Defaults to False.
        norm (bool, optional): Wether to use Clifford (group) normalization. Defaults to False.
        num_groups (int, optional): Number of groups when using Clifford (group) normalization. Defaults to 1.
        modes1 (int, optional): Number of Fourier modes in the first dimension. Defaults to 16.
        modes2 (int, optional): Number of Fourier modes in the second dimension. Defaults to 16.
    """   
    expansion: int = 1

    def __init__(
        self,
        g: Union[tuple, list, torch.Tensor],
        in_channels: int,
        out_channels: int,
        activation: Callable = F.gelu,
        kernel_size: int = 1,
        stride: int = 1,
        padding: int = 0,
        rotation: bool = False,
        norm: bool = False,
        num_groups: int = 1,
        modes1: int = 16,
        modes2: int = 16,
    ):   
        super().__init__()
        self.fourier = CliffordSpectralConv2d(
            g,
            in_channels,
            out_channels,
            modes1=modes1,
            modes2=modes2,
        )
        self.conv = CliffordConv2d(
            g,
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=True,
            rotation=rotation,
        )
        self.norm = CliffordGroupNorm2d(g, num_groups, out_channels) if norm else nn.Identity()
        self.activation = activation

    def __repr__(self):
        return "CliffordFourierBasicBlock2d"

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = self.fourier(x)
        x2 = self.conv(x)
        return self.activation(self.norm(x1 + x2))

CliffordNet2d ¤

Bases: nn.Module

2D building block for Clifford architectures with ResNet backbone network.

The backbone networks follows these three steps
  1. Clifford encoding.
  2. Basic blocks as provided.
  3. Decoding.

Parameters:

Name Type Description Default
g Union[tuple, list, torch.Tensor]

Signature of Clifford algebra.

required
block nn.Module

Choice of basic blocks.

required
num_blocks list

List of basic blocks in each residual block.

required
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
activation Callable

Activation function. Defaults to F.gelu.

required
rotation bool

Wether to use rotational Clifford convolution. Defaults to False.

required
norm bool

Wether to use Clifford (group) normalization. Defaults to False.

False
num_groups int

Number of groups when using Clifford (group) normalization. Defaults to 1.

1
Source code in cliffordlayers/models/models_2d.py
class CliffordNet2d(nn.Module):
    """2D building block for Clifford architectures with ResNet backbone network.
    The backbone networks follows these three steps:
        1. Clifford encoding.
        2. Basic blocks as provided.
        3. Decoding.

    Args:
        g (Union[tuple, list, torch.Tensor]): Signature of Clifford algebra.
        block (nn.Module): Choice of basic blocks.
        num_blocks (list): List of basic blocks in each residual block.
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        activation (Callable, optional): Activation function. Defaults to F.gelu.
        rotation (bool, optional): Wether to use rotational Clifford convolution. Defaults to False.
        norm (bool, optional): Wether to use Clifford (group) normalization. Defaults to False.
        num_groups (int, optional): Number of groups when using Clifford (group) normalization. Defaults to 1.
    """   
    # For periodic boundary conditions, set padding = 0.
    padding = 9

    def __init__(
        self,
        g: Union[tuple, list, torch.Tensor],
        block: nn.Module,
        num_blocks: list,
        in_channels: int,
        out_channels: int,
        hidden_channels: int,
        activation: Callable,
        rotation: False,
        norm: bool = False,
        num_groups: int = 1,
    ):   
        super().__init__()

        self.activation = activation
        # Encoding and decoding layers
        self.encoder = CliffordConv2dEncoder(
            g,
            in_channels=in_channels,
            out_channels=hidden_channels,
            kernel_size=1,
            padding=0,
            rotation=rotation,
        )
        self.decoder = CliffordConv2dDecoder(
            g,
            in_channels=hidden_channels,
            out_channels=out_channels,
            kernel_size=1,
            padding=0,
            rotation=rotation,
        )

        # Residual blocks
        self.layers = nn.ModuleList(
            [
                self._make_basic_block(
                    g,
                    block,
                    hidden_channels,
                    num_blocks[i],
                    activation=activation,
                    rotation=rotation,
                    norm=norm,
                    num_groups=num_groups,
                )
                for i in range(len(num_blocks))
            ]
        )

    def _make_basic_block(
        self,
        g,
        block: nn.Module,
        hidden_channels: int,
        num_blocks: int,
        activation: Callable,
        rotation: bool,
        norm: bool,
        num_groups: int,
    ) -> nn.Sequential:
        blocks = []
        for _ in range(num_blocks):
            blocks.append(
                block(
                    g,
                    hidden_channels,
                    hidden_channels,
                    activation=activation,
                    rotation=rotation,
                    norm=norm,
                    num_groups=num_groups,
                )
            )
        return nn.Sequential(*blocks)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.dim() == 5

        # Encoding layer
        x = self.encoder(self.activation(x))

        # Embed for non-periodic boundaries
        if self.padding > 0:
            B_dim, C_dim, *D_dims, I_dim = range(len(x.shape))
            x = x.permute(B_dim, I_dim, C_dim, *D_dims)
            x = F.pad(x, [0, self.padding, 0, self.padding])
            B_dim, I_dim, C_dim, *D_dims = range(len(x.shape))
            x = x.permute(B_dim, C_dim, *D_dims, I_dim)

        # Apply residual layers
        for layer in self.layers:
            x = layer(x)

        # Decoding layer
        if self.padding > 0:
            B_dim, C_dim, *D_dims, I_dim = range(len(x.shape))
            x = x.permute(B_dim, I_dim, C_dim, *D_dims)
            x = x[..., : -self.padding, : -self.padding]
            B_dim, I_dim, C_dim, *D_dims = range(len(x.shape))
            x = x.permute(B_dim, C_dim, *D_dims, I_dim)

        # Output layer
        x = self.decoder(x)
        return x

3D models¤

The following code snippet initializes a 3D Clifford FNO.

import torch.nn.functional as F

from cliffordlayers.models.models_3d import (
    CliffordNet3d,
    CliffordFourierBasicBlock3d,
)
model = CliffordNet3d(
        g = [1, 1, 1],
        block = CliffordFourierBasicBlock3d,
        num_blocks = [1, 1, 1, 1],
        in_channels = 4,
        out_channels = 1,
        hidden_channels = 16,
        activation = F.gelu,
        norm = False,
    )

CliffordFourierBasicBlock3d ¤

Bases: nn.Module

2D building block for Clifford FNO architectures.

Parameters:

Name Type Description Default
g Union[tuple, list, torch.Tensor]

Signature of Clifford algebra.

required
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
activation Callable

Activation function. Defaults to F.gelu.

F.gelu
kernel_size int

Kernel size of Clifford convolution. Defaults to 3.

1
stride int

Stride of Clifford convolution. Defaults to 1.

1
padding int

Padding of Clifford convolution. Defaults to 1.

0
norm bool

Wether to use Clifford (group) normalization. Defaults to False.

False
num_groups int

Number of groups when using Clifford (group) normalization. Defaults to 1.

1
modes1 int

Number of Fourier modes in the first dimension. Defaults to 8.

8
modes2 int

Number of Fourier modes in the second dimension. Defaults to 8.

8
modes3 int

Number of Fourier modes in the third dimension. Defaults to 8.

8
Source code in cliffordlayers/models/models_3d.py
class CliffordFourierBasicBlock3d(nn.Module):
    """2D building block for Clifford FNO architectures.

        Args:
            g (Union[tuple, list, torch.Tensor]): Signature of Clifford algebra.
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            activation (Callable, optional): Activation function. Defaults to F.gelu.
            kernel_size (int, optional): Kernel size of Clifford convolution. Defaults to 3.
            stride (int, optional): Stride of Clifford convolution. Defaults to 1.
            padding (int, optional): Padding of Clifford convolution. Defaults to 1.
            norm (bool, optional): Wether to use Clifford (group) normalization. Defaults to False.
            num_groups (int, optional): Number of groups when using Clifford (group) normalization. Defaults to 1.
            modes1 (int, optional): Number of Fourier modes in the first dimension. Defaults to 8.
            modes2 (int, optional): Number of Fourier modes in the second dimension. Defaults to 8.
            modes3 (int, optional): Number of Fourier modes in the third dimension. Defaults to 8.
        """    
    expansion: int = 1

    def __init__(
        self,
        g: Union[tuple, list, torch.Tensor],
        in_channels: int,
        out_channels: int,
        activation: Callable = F.gelu,
        kernel_size: int = 1,
        stride: int = 1,
        padding: int = 0,
        norm: bool = False,
        num_groups: int = 1,
        modes1: int = 8,
        modes2: int = 8,
        modes3: int = 8,
    ):    
        super().__init__()
        self.fourier = CliffordSpectralConv3d(
            g,
            in_channels,
            out_channels,
            modes1=modes1,
            modes2=modes2,
            modes3=modes3,
        )
        self.conv = CliffordConv3d(
            g,
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=True,
        )
        self.norm = CliffordGroupNorm3d(g, num_groups, in_channels) if norm else nn.Identity()
        self.activation = activation

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = self.fourier(x)
        x2 = self.conv(x)
        return self.activation(self.norm(x1 + x2))

CliffordNet3d ¤

Bases: nn.Module

3D building block for Clifford architectures with ResNet backbone network.

The backbone networks follows these three steps
  1. Clifford encoding.
  2. Basic blocks as provided.
  3. Decoding.

Parameters:

Name Type Description Default
g Union[tuple, list, torch.Tensor]

Signature of Clifford algebra.

required
block nn.Module

Choice of basic blocks.

required
num_blocks list

List of basic blocks in each residual block.

required
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
activation Callable

Activation function. Defaults to F.gelu.

required
norm bool

Wether to use Clifford (group) normalization. Defaults to False.

False
num_groups int

Number of groups when using Clifford (group) normalization. Defaults to 1.

1
Source code in cliffordlayers/models/models_3d.py
class CliffordNet3d(nn.Module):
    """3D building block for Clifford architectures with ResNet backbone network.
    The backbone networks follows these three steps:
        1. Clifford encoding.
        2. Basic blocks as provided.
        3. Decoding.

    Args:
        g (Union[tuple, list, torch.Tensor]): Signature of Clifford algebra.
        block (nn.Module): Choice of basic blocks.
        num_blocks (list): List of basic blocks in each residual block.
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        activation (Callable, optional): Activation function. Defaults to F.gelu.
        norm (bool, optional): Wether to use Clifford (group) normalization. Defaults to False.
        num_groups (int, optional): Number of groups when using Clifford (group) normalization. Defaults to 1.
    """     
    # For periodic boundary conditions, set padding = 0.
    padding = 2

    def __init__(
        self,
        g: Union[tuple, list, torch.Tensor],
        block: nn.Module,
        num_blocks: list,
        in_channels: int,
        out_channels: int,
        hidden_channels: int,
        activation: Callable,
        norm: bool = False,
        num_groups: int = 1,
    ):
        super().__init__()

        self.activation = activation
        # Encoding and decoding layers.
        self.encoder = CliffordConv3dEncoder(
            g,
            in_channels=in_channels,
            out_channels=hidden_channels,
            kernel_size=1,
            padding=0,
        )
        self.decoder = CliffordConv3dDecoder(
            g,
            in_channels=hidden_channels,
            out_channels=out_channels,
            kernel_size=1,
            padding=0,
        )

        # Residual blocks.
        self.layers = nn.ModuleList(
            [
                self._make_basic_block(
                    g,
                    block,
                    hidden_channels,
                    num_blocks[i],
                    activation=activation,
                    norm=norm,
                    num_groups=num_groups,
                )
                for i in range(len(num_blocks))
            ]
        )

    def _make_basic_block(
        self,
        g,
        block: nn.Module,
        hidden_channels: int,
        num_blocks: int,
        activation: Callable,
        norm: bool,
        num_groups: int,
    ) -> nn.Sequential:
        blocks = []
        for _ in range(num_blocks):
            blocks.append(
                block(
                    g,
                    hidden_channels,
                    hidden_channels,
                    activation=activation,
                    norm=norm,
                    num_groups=num_groups,
                )
            )
        return nn.Sequential(*blocks)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.dim() == 6

        # Encoding layer.
        x = self.encoder(self.activation(x))

        # Embed for non-periodic boundaries.
        if self.padding > 0:
            B_dim, C_dim, *D_dims, I_dim = range(len(x.shape))
            x = x.permute(B_dim, I_dim, C_dim, *D_dims)
            x = F.pad(x, [0, self.padding, 0, self.padding, 0, self.padding])
            B_dim, I_dim, C_dim, *D_dims = range(len(x.shape))
            x = x.permute(B_dim, C_dim, *D_dims, I_dim)

        # Apply residual layers.
        for layer in self.layers:
            x = layer(x)

        # Decoding layer.
        if self.padding > 0:
            B_dim, C_dim, *D_dims, I_dim = range(len(x.shape))
            x = x.permute(B_dim, I_dim, C_dim, *D_dims)
            x = x[..., : -self.padding, : -self.padding, : -self.padding]
            B_dim, I_dim, C_dim, *D_dims = range(len(x.shape))
            x = x.permute(B_dim, C_dim, *D_dims, I_dim)

        # Output layer.
        x = self.decoder(x)
        return x