Skip to content

Modules¤

We provide linear Clifford layers; 1D, 2D, 3D Clifford convolution layers, and 2D, 3D Clifford Fourier transform layers. Additionally, Clifford normalization schemes are provided.

All these modules are available for different algebras.

CliffordLinear ¤

Bases: nn.Module

Clifford linear layer.

Parameters:

Name Type Description Default
g Union[List, Tuple]

Clifford signature tensor.

required
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
bias bool

If True, adds a learnable bias to the output. Defaults to True.

True
Source code in cliffordlayers/nn/modules/cliffordlinear.py
class CliffordLinear(nn.Module):
    """Clifford linear layer.

    Args:
        g (Union[List, Tuple]): Clifford signature tensor.
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        bias (bool, optional): If True, adds a learnable bias to the output. Defaults to True.

    """

    def __init__(
        self,
        g,
        in_channels: int,
        out_channels: int,
        bias: bool = True,
    ) -> None:
        super().__init__()
        sig = CliffordSignature(g)

        self.register_buffer("g", sig.g)
        self.dim = sig.dim
        self.n_blades = sig.n_blades

        if self.dim == 1:
            self._get_kernel = get_1d_clifford_kernel
        elif self.dim == 2:
            self._get_kernel = get_2d_clifford_kernel
        elif self.dim == 3:
            self._get_kernel = get_3d_clifford_kernel
        else:
            raise NotImplementedError(
                f"Clifford linear layers are not implemented for {self.dim} dimensions. Wrong Clifford signature."
            )

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.weight = nn.Parameter(torch.empty(self.n_blades, out_channels, in_channels))

        if bias:
            self.bias = nn.Parameter(torch.empty(self.n_blades, out_channels))
        else:
            self.register_parameter("bias", None)

        self.reset_parameters()

    def reset_parameters(self):
        # Initialization of the Clifford linear weight and bias tensors.
        # The number of blades is taken into account when calculated the bounds of Kaiming uniform.
        nn.init.kaiming_uniform_(
            self.weight.view(self.out_channels, self.in_channels * self.n_blades),
            a=math.sqrt(5),
        )
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
                self.weight.view(self.out_channels, self.in_channels * self.n_blades)
            )
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Reshape x such that the Clifford kernel can be applied.
        B, _, I = x.shape
        if not (I == self.n_blades):
            raise ValueError(f"Input has {I} blades, but Clifford layer expects {self.n_blades}.")
        B_dim, C_dim, I_dim = range(len(x.shape))
        x = x.permute(B_dim, -1, C_dim)
        x = x.reshape(B, -1)
        # Get Clifford kernel, apply it.
        _, weight = self._get_kernel(self.weight, self.g)
        output = F.linear(x, weight, self.bias.view(-1))
        # Reshape back.
        output = output.view(B, I, -1)
        B_dim, I_dim, C_dim = range(len(output.shape))
        output = output.permute(B_dim, C_dim, I_dim)
        return output

CliffordConv1d ¤

Bases: _CliffordConvNd

1d Clifford convolution.

Parameters:

Name Type Description Default
g Union[tuple, list, torch.Tensor]

Clifford signature.

required
in_channels int

Number of channels in the input tensor.

required
out_channels int

Number of channels produced by the convolution.

required
kernel_size int

Size of the convolving kernel.

3
stride int

Stride of the convolution.

1
padding int

padding added to both sides of the input.

0
dilation int

Spacing between kernel elements.

1
groups int

Number of blocked connections from input channels to output channels.

1
bias bool

If True, adds a learnable bias to the output.

True
padding_mode str

Padding to use.

'zeros'
Source code in cliffordlayers/nn/modules/cliffordconv.py
class CliffordConv1d(_CliffordConvNd):
    """1d Clifford convolution.

    Args:
        g (Union[tuple, list, torch.Tensor]): Clifford signature.
        in_channels (int): Number of channels in the input tensor.
        out_channels (int): Number of channels produced by the convolution.
        kernel_size (int): Size of the convolving kernel.
        stride (int): Stride of the convolution.
        padding (int): padding added to both sides of the input.
        dilation (int): Spacing between kernel elements.
        groups (int): Number of blocked connections from input channels to output channels.
        bias (bool): If True, adds a learnable bias to the output.
        padding_mode (str): Padding to use.
    """

    def __init__(
        self,
        g: Union[tuple, list, torch.Tensor],
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
    ) -> None:

        kernel_size_ = _single(kernel_size)
        stride_ = _single(stride)
        padding_ = _single(padding)
        dilation_ = _single(dilation)

        super().__init__(
            g,
            in_channels,
            out_channels,
            kernel_size_,
            stride_,
            padding_,
            dilation_,
            groups,
            bias,
            padding_mode,
        )
        if not self.dim == 1:
            raise NotImplementedError(f"Wrong Clifford signature for CliffordConv1d.")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        *_, I = x.shape
        if not (I == self.n_blades):
            raise ValueError(f"Input has {I} blades, but Clifford layer expects {self.n_blades}.")
        return super().forward(x, F.conv1d)

CliffordConv2d ¤

Bases: _CliffordConvNd

2d Clifford convolution.

Parameters:

Name Type Description Default
g Union[tuple, list, torch.Tensor]

Clifford signature.

required
in_channels int

Number of channels in the input tensor.

required
out_channels int

Number of channels produced by the convolution.

required
kernel_size Union[int, Tuple[int, int]]

Size of the convolving kernel.

3
stride Union[int, Tuple[int, int]]

Stride of the convolution.

1
padding Union[int, Tuple[int, int]]

padding added to both sides of the input.

0
dilation Union[int, Tuple[int, int]]

Spacing between kernel elements.

1
groups int

Number of blocked connections from input channels to output channels.

1
bias bool

If True, adds a learnable bias to the output.

True
padding_mode str

Padding to use.

'zeros'
rotation bool

If True, enables the rotation kernel for Clifford convolution.

False
Source code in cliffordlayers/nn/modules/cliffordconv.py
class CliffordConv2d(_CliffordConvNd):
    """2d Clifford convolution.

    Args:
        g (Union[tuple, list, torch.Tensor]): Clifford signature.
        in_channels (int): Number of channels in the input tensor.
        out_channels (int): Number of channels produced by the convolution.
        kernel_size (Union[int, Tuple[int, int]]): Size of the convolving kernel.
        stride (Union[int, Tuple[int, int]]): Stride of the convolution.
        padding (Union[int, Tuple[int, int]]): padding added to both sides of the input.
        dilation (Union[int, Tuple[int, int]]): Spacing between kernel elements.
        groups (int): Number of blocked connections from input channels to output channels.
        bias (bool): If True, adds a learnable bias to the output.
        padding_mode (str): Padding to use.
        rotation (bool): If True, enables the rotation kernel for Clifford convolution.
    """

    def __init__(
        self,
        g: Union[tuple, list, torch.Tensor],
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
        rotation: bool = False,
    ):

        kernel_size_ = _pair(kernel_size)
        stride_ = _pair(stride)
        padding_ = padding if isinstance(padding, str) else _pair(padding)
        dilation_ = _pair(dilation)

        super().__init__(
            g,
            in_channels,
            out_channels,
            kernel_size_,
            stride_,
            padding_,
            dilation_,
            groups,
            bias,
            padding_mode,
            rotation,
        )
        if not self.dim == 2:
            raise NotImplementedError(f"Wrong Clifford signature for CliffordConv2d.")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        *_, I = x.shape
        if not (I == self.n_blades):
            raise ValueError(f"Input has {I} blades, but Clifford layer expects {self.n_blades}.")
        return super().forward(x, F.conv2d)

CliffordConv3d ¤

Bases: _CliffordConvNd

3d Clifford convolution.

Parameters:

Name Type Description Default
g Union[tuple, list, torch.Tensor]

Clifford signature.

required
in_channels int

Number of channels in the input tensor.

required
out_channels int

Number of channels produced by the convolution.

required
kernel_size Union[int, Tuple[int, int, int]]

Size of the convolving kernel.

3
stride Union[int, Tuple[int, int, int]]

Stride of the convolution.

1
padding Union[int, Tuple[int, int, int]]

padding added to all sides of the input.

0
dilation Union[int, Tuple[int, int, int]]

Spacing between kernel elements.

1
groups int

Number of blocked connections from input channels to output channels.

1
bias bool

If True, adds a learnable bias to the output.

True
padding_mode str

Padding to use.

'zeros'
Source code in cliffordlayers/nn/modules/cliffordconv.py
class CliffordConv3d(_CliffordConvNd):
    """3d Clifford convolution.

    Args:
        g (Union[tuple, list, torch.Tensor]): Clifford signature.
        in_channels (int): Number of channels in the input tensor.
        out_channels (int): Number of channels produced by the convolution.
        kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolving kernel.
        stride (Union[int, Tuple[int, int, int]]): Stride of the convolution.
        padding (Union[int, Tuple[int, int, int]]): padding added to all sides of the input.
        dilation (Union[int, Tuple[int, int, int]]): Spacing between kernel elements.
        groups (int): Number of blocked connections from input channels to output channels.
        bias (bool): If True, adds a learnable bias to the output.
        padding_mode (str): Padding to use.
    """

    def __init__(
        self,
        g: Union[tuple, list, torch.Tensor],
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
    ):

        kernel_size_ = _triple(kernel_size)
        stride_ = _triple(stride)
        padding_ = padding if isinstance(padding, str) else _triple(padding)
        dilation_ = _triple(dilation)

        super().__init__(
            g,
            in_channels,
            out_channels,
            kernel_size_,
            stride_,
            padding_,
            dilation_,
            groups,
            bias,
            padding_mode,
        )
        if not self.dim == 3:
            raise NotImplementedError(f"Wrong Clifford signature for CliffordConv3d.")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        *_, I = x.shape
        if not (I == self.n_blades):
            raise ValueError(f"Input has {I} blades, but Clifford layer expects {self.n_blades}.")
        return super().forward(x, F.conv3d)

CliffordSpectralConv2d ¤

Bases: nn.Module

2d Clifford Fourier layer.

Performs following three steps
  1. Clifford Fourier transform over the multivector of 2d Clifford algebras, based on complex Fourier transforms using pytorch.fft.fft2.
  2. Weight multiplication in the Clifford Fourier space using the geometric product.
  3. Inverse Clifford Fourier transform, based on inverse complex Fourier transforms using pytorch.fft.ifft2.

Parameters:

Name Type Description Default
g Union[tuple, list, torch.Tensor]

Signature of Clifford algebra.

required
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
modes1 int

Number of non-zero Fourier modes in the first dimension.

required
modes2 int

Number of non-zero Fourier modes in the second dimension.

required
multiply bool

Multipliation in the Fourier space. If set to False this class only crops high-frequency modes.

True
Source code in cliffordlayers/nn/modules/cliffordfourier.py
class CliffordSpectralConv2d(nn.Module):
    """2d Clifford Fourier layer.
    Performs following three steps:
        1. Clifford Fourier transform over the multivector of 2d Clifford algebras, based on complex Fourier transforms using [pytorch.fft.fft2](https://pytorch.org/docs/stable/generated/torch.fft.fft2.html#torch.fft.fft2).
        2. Weight multiplication in the Clifford Fourier space using the geometric product.
        3. Inverse Clifford Fourier transform, based on inverse complex Fourier transforms using [pytorch.fft.ifft2](https://pytorch.org/docs/stable/generated/torch.fft.ifft2.html#torch.fft.ifft2).

    Args:
        g ((Union[tuple, list, torch.Tensor]): Signature of Clifford algebra.
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        modes1 (int): Number of non-zero Fourier modes in the first dimension.
        modes2 (int): Number of non-zero Fourier modes in the second dimension.
        multiply (bool): Multipliation in the Fourier space. If set to False this class only crops high-frequency modes.

    """

    def __init__(
        self,
        g: Union[tuple, list, torch.Tensor],
        in_channels: int,
        out_channels: int,
        modes1: int,
        modes2: int,
        multiply: bool = True,
    ) -> None:
        super().__init__()
        sig = CliffordSignature(g)
        # To allow move to same device as module.
        self.register_buffer("g", sig.g)
        self.dim = sig.dim
        if self.dim != 2:
            raise ValueError("g must be a 2D Clifford algebra")

        self.n_blades = sig.n_blades
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1
        self.modes2 = modes2
        self.multiply = multiply

        # Initialize weight parameters.
        if multiply:
            scale = 1 / (in_channels * out_channels)
            self.weights = nn.Parameter(
                scale * torch.rand(4, out_channels, in_channels, self.modes1 * 2, self.modes2 * 2, dtype=torch.float32)
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Reshape x such that FFT can be applied to dual pairs.
        B, _, *D, I = x.shape
        *_, I = x.shape
        if not (I == self.n_blades):
            raise ValueError(f"Input has {I} blades, but Clifford layer expects {self.n_blades}.")

        dual_1 = torch.view_as_complex(torch.stack((x[..., 0], x[..., 3]), dim=-1))
        dual_2 = torch.view_as_complex(torch.stack((x[..., 1], x[..., 2]), dim=-1))
        dual_1_ft = torch.fft.fft2(dual_1)
        dual_2_ft = torch.fft.fft2(dual_2)

        # Add dual pairs again to multivector in the Fourier space.
        multivector_ft = torch.cat(
            (
                dual_1_ft.real,
                dual_2_ft.real,
                dual_2_ft.imag,
                dual_1_ft.imag,
            ),
            dim=1,
        )

        # Reserve Cifford output Fourier modes.
        out_ft = torch.zeros(
            B,
            self.out_channels * self.n_blades,
            *D,
            dtype=torch.float,
            device=multivector_ft.device,
        )

        # Concatenate positive and negative modes, such that the geometric product can be applied in one go.
        input_mul = torch.cat(
            (
                torch.cat(
                    (
                        multivector_ft[:, :, : self.modes1, : self.modes2],
                        multivector_ft[:, :, : self.modes1, -self.modes2 :],
                    ),
                    -1,
                ),
                torch.cat(
                    (
                        multivector_ft[:, :, -self.modes1 :, : self.modes2],
                        multivector_ft[:, :, -self.modes1 :, -self.modes2 :],
                    ),
                    -1,
                ),
            ),
            -2,
        )

        # Get Clifford weight tensor and apply the geometric product in the Fourier space.
        if self.multiply:
            _, kernel = get_2d_clifford_kernel(self.weights, self.g)
            output_mul = batchmul2d(input_mul, kernel)
        else:
            output_mul = input_mul

        # Fill the output modes, i.e. cut away high-frequency modes.
        out_ft[:, :, : self.modes1, : self.modes2] = output_mul[:, :, : self.modes1, : self.modes2]
        out_ft[:, :, -self.modes1 :, : self.modes2] = output_mul[:, :, -self.modes1 :, : self.modes2]
        out_ft[:, :, : self.modes1, -self.modes2 :] = output_mul[:, :, : self.modes1, -self.modes2 :]
        out_ft[:, :, -self.modes1 :, -self.modes2 :] = output_mul[:, :, -self.modes1 :, -self.modes2 :]

        # Reshape output such that inverse FFTs can be applied to the dual pairs.
        out_ft = out_ft.reshape(B, I, -1, *out_ft.shape[-2:])
        B_dim, I_dim, C_dim, *D_dims = range(len(out_ft.shape))
        out_ft = out_ft.permute(B_dim, C_dim, *D_dims, I_dim)
        out_dual_1 = torch.view_as_complex(torch.stack((out_ft[..., 0], out_ft[..., 3]), dim=-1))
        out_dual_2 = torch.view_as_complex(torch.stack((out_ft[..., 1], out_ft[..., 2]), dim=-1))
        dual_1_ifft = torch.fft.ifft2(out_dual_1, s=(out_dual_1.size(-2), out_dual_1.size(-1)))
        dual_2_ifft = torch.fft.ifft2(out_dual_2, s=(out_dual_2.size(-2), out_dual_2.size(-1)))

        # Finally, return to the multivector in the spatial domain.
        output = torch.stack(
            (
                dual_1_ifft.real,
                dual_2_ifft.real,
                dual_2_ifft.imag,
                dual_1_ifft.imag,
            ),
            dim=-1,
        )

        return output

CliffordSpectralConv3d ¤

Bases: nn.Module

3d Clifford Fourier layer.

Performs following three steps
  1. Clifford Fourier transform over the multivector of 3d Clifford algebras, based on complex Fourier transforms using pytorch.fft.fftn.
  2. Weight multiplication in the Clifford Fourier space using the geometric product.
  3. Inverse Clifford Fourier transform, based on inverse complex Fourier transforms using pytorch.fft.ifftn.

Parameters:

Name Type Description Default
g Union[tuple, list, torch.Tensor]

Signature of Clifford algebra.

required
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
modes1 int

Number of non-zero Fourier modes in the first dimension.

required
modes2 int

Number of non-zero Fourier modes in the second dimension.

required
modes3 int

Number of non-zero Fourier modes in the second dimension.

required
multiply bool

Multipliation in the Fourier space. If set to False this class only crops high-frequency modes.

True
Source code in cliffordlayers/nn/modules/cliffordfourier.py
class CliffordSpectralConv3d(nn.Module):
    """3d Clifford Fourier layer.
    Performs following three steps:
        1. Clifford Fourier transform over the multivector of 3d Clifford algebras, based on complex Fourier transforms using [pytorch.fft.fftn](https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn).
        2. Weight multiplication in the Clifford Fourier space using the geometric product.
        3. Inverse Clifford Fourier transform, based on inverse complex Fourier transforms using [pytorch.fft.ifftn](https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.ifftn).

    Args:
        g ((Union[tuple, list, torch.Tensor]): Signature of Clifford algebra.
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        modes1 (int): Number of non-zero Fourier modes in the first dimension.
        modes2 (int): Number of non-zero Fourier modes in the second dimension.
        modes3 (int): Number of non-zero Fourier modes in the second dimension.
        multiply (bool): Multipliation in the Fourier space. If set to False this class only crops high-frequency modes.

    """

    def __init__(
        self,
        g: Union[tuple, list, torch.Tensor],
        in_channels: int,
        out_channels: int,
        modes1: int,
        modes2: int,
        modes3: int,
        multiply: bool = True,
    ) -> None:
        super().__init__()
        sig = CliffordSignature(g)
        self.g = sig.g
        self.dim = sig.dim
        if self.dim != 3:
            raise ValueError("g must be a 3D Clifford algebra")
        self.n_blades = sig.n_blades

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1
        self.modes2 = modes2
        self.modes3 = modes3
        self.multiply = multiply

        # Initialize weight parameters.
        if self.multiply:
            scale = 1 / (in_channels * out_channels)
            self.weights = nn.Parameter(
                scale
                * torch.rand(
                    8,
                    out_channels,
                    in_channels,
                    self.modes1 * 2,
                    self.modes2 * 2,
                    self.modes3 * 2,
                    dtype=torch.float32,
                )
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Reshape x such that FFT can be applied to dual pairs.
        B, _, *D, I = x.shape
        *_, I = x.shape
        if not (I == self.n_blades):
            raise ValueError(f"Input has {I} blades, but Clifford layer expects {self.n_blades}.")

        dual_1 = torch.view_as_complex(torch.stack((x[..., 0], x[..., 7]), dim=-1))
        dual_2 = torch.view_as_complex(torch.stack((x[..., 1], x[..., 6]), dim=-1))
        dual_3 = torch.view_as_complex(torch.stack((x[..., 2], x[..., 5]), dim=-1))
        dual_4 = torch.view_as_complex(torch.stack((x[..., 3], x[..., 4]), dim=-1))
        dual_1_ft = torch.fft.fftn(dual_1, dim=[-3, -2, -1])
        dual_2_ft = torch.fft.fftn(dual_2, dim=[-3, -2, -1])
        dual_3_ft = torch.fft.fftn(dual_3, dim=[-3, -2, -1])
        dual_4_ft = torch.fft.fftn(dual_4, dim=[-3, -2, -1])

        # Add dual pairs again to multivector in the Fourier space.
        multivector_ft = torch.cat(
            (
                dual_1_ft.real,
                dual_2_ft.real,
                dual_3_ft.real,
                dual_4_ft.real,
                dual_4_ft.imag,
                dual_3_ft.imag,
                dual_2_ft.imag,
                dual_1_ft.imag,
            ),
            dim=1,
        )

        # Reserve Cifford output Fourier modes.
        out_ft = torch.zeros(
            B,
            self.out_channels * self.n_blades,
            *D,
            dtype=torch.float,
            device=multivector_ft.device,
        )

        # Concatenate positive and negative modes, such that the geometric product can be applied in one go.
        input_mul = torch.cat(
            (
                torch.cat(
                    (
                        torch.cat(
                            (
                                multivector_ft[:, :, : self.modes1, : self.modes2, : self.modes3],
                                multivector_ft[:, :, : self.modes1, : self.modes2, -self.modes3 :],
                            ),
                            -1,
                        ),
                        torch.cat(
                            (
                                multivector_ft[:, :, : self.modes1, -self.modes2 :, : self.modes3],
                                multivector_ft[:, :, : self.modes1, -self.modes2 :, -self.modes3 :],
                            ),
                            -1,
                        ),
                    ),
                    -2,
                ),
                torch.cat(
                    (
                        torch.cat(
                            (
                                multivector_ft[:, :, -self.modes1 :, : self.modes2, : self.modes3],
                                multivector_ft[:, :, -self.modes1 :, : self.modes2, -self.modes3 :],
                            ),
                            -1,
                        ),
                        torch.cat(
                            (
                                multivector_ft[:, :, -self.modes1 :, -self.modes2 :, : self.modes3],
                                multivector_ft[:, :, -self.modes1 :, -self.modes2 :, -self.modes3 :],
                            ),
                            -1,
                        ),
                    ),
                    -2,
                ),
            ),
            -3,
        )

        # Get Clifford weight tensor and apply the geometric product in the Fourier space.
        if self.multiply:
            _, kernel = get_3d_clifford_kernel(self.weights, self.g)
            output_mul = batchmul3d(input_mul, kernel)
        else:
            output_mul = input_mul

        # Fill the output modes, i.e. cut away high-frequency modes.
        out_ft[:, :, : self.modes1, : self.modes2, : self.modes3] = output_mul[
            :, :, : self.modes1, : self.modes2, : self.modes3
        ]
        out_ft[:, :, : self.modes1, : self.modes2, -self.modes3 :] = output_mul[
            :, :, : self.modes1, : self.modes2, -self.modes3 :
        ]
        out_ft[:, :, : self.modes1, -self.modes2 :, : self.modes3] = output_mul[
            :, :, : self.modes1, -self.modes2 :, : self.modes3
        ]
        out_ft[:, :, : self.modes1, -self.modes2 :, -self.modes3 :] = output_mul[
            :, :, : self.modes1, -self.modes2 :, -self.modes3 :
        ]
        out_ft[:, :, -self.modes1 :, : self.modes2, : self.modes3] = output_mul[
            :, :, -self.modes1 :, : self.modes2, : self.modes3
        ]
        out_ft[:, :, -self.modes1 :, : self.modes2, -self.modes3 :] = output_mul[
            :, :, : -self.modes1 :, : self.modes2, -self.modes3 :
        ]
        out_ft[:, :, -self.modes1 :, -self.modes2 :, : self.modes3] = output_mul[
            :, :, -self.modes1 :, -self.modes2 :, : self.modes3
        ]
        out_ft[:, :, -self.modes1 :, -self.modes2 :, -self.modes3 :] = output_mul[
            :, :, -self.modes1 :, -self.modes2 :, -self.modes3 :
        ]

        # Reshape output such that inverse FFTs can be applied to the dual pairs.
        out_ft = out_ft.reshape(B, I, -1, *out_ft.shape[-3:])
        B_dim, I_dim, C_dim, *D_dims = range(len(out_ft.shape))
        out_ft = out_ft.permute(B_dim, C_dim, *D_dims, I_dim)

        out_dual_1 = torch.view_as_complex(torch.stack((out_ft[..., 0], out_ft[..., 7]), dim=-1))
        out_dual_2 = torch.view_as_complex(torch.stack((out_ft[..., 1], out_ft[..., 6]), dim=-1))
        out_dual_3 = torch.view_as_complex(torch.stack((out_ft[..., 2], out_ft[..., 5]), dim=-1))
        out_dual_4 = torch.view_as_complex(torch.stack((out_ft[..., 3], out_ft[..., 4]), dim=-1))
        dual_1_ifft = torch.fft.ifftn(out_dual_1, s=(out_dual_1.size(-3), out_dual_1.size(-2), out_dual_1.size(-1)))
        dual_2_ifft = torch.fft.ifftn(out_dual_2, s=(out_dual_2.size(-3), out_dual_2.size(-2), out_dual_2.size(-1)))
        dual_3_ifft = torch.fft.ifftn(out_dual_3, s=(out_dual_3.size(-3), out_dual_3.size(-2), out_dual_3.size(-1)))
        dual_4_ifft = torch.fft.ifftn(out_dual_4, s=(out_dual_4.size(-3), out_dual_4.size(-2), out_dual_4.size(-1)))

        # Finally, return to the multivector in the spatial domain.
        output = torch.stack(
            (
                dual_1_ifft.real,
                dual_2_ifft.real,
                dual_3_ifft.real,
                dual_4_ifft.real,
                dual_4_ifft.imag,
                dual_3_ifft.imag,
                dual_2_ifft.imag,
                dual_1_ifft.imag,
            ),
            dim=-1,
        )

        return output

CliffordBatchNorm1d ¤

Bases: _CliffordBatchNorm

Clifford batch normalization for 2D or 3D data.

The input data is expected to be at least 3d, with shape (B, C, D, I), where B is the batch dimension, C the channels/features, and D the remaining dimension (if present). See [torch.nn.BatchNorm1d] for details.

Source code in cliffordlayers/nn/modules/batchnorm.py
class CliffordBatchNorm1d(_CliffordBatchNorm):
    """Clifford batch normalization for 2D or 3D data.

    The input data is expected to be at least 3d, with shape `(B, C, D, I)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining dimension (if present).
    See [torch.nn.BatchNorm1d] for details.
    """

    def _check_input_dim(self, x):
        *_, I = x.shape
        if not I == self.n_blades:
            raise ValueError(f"Wrong number of Clifford blades. Expected {self.n_blades} blades, but {I} were given.")
        if x.dim() != 3 and x.dim() != 4:
            raise ValueError(f"Expected 3D or 4D input (got {x.dim()}D input).")

CliffordBatchNorm2d ¤

Bases: _CliffordBatchNorm

Clifford batch normalization for 4D data.

The input data is expected to be 4d, with shape (B, C, *D, I), where B is the batch dimension, C the channels/features, and D the remaining dimension 2 dimensions. See torch.nn.BatchNorm2d for details.

Source code in cliffordlayers/nn/modules/batchnorm.py
class CliffordBatchNorm2d(_CliffordBatchNorm):
    """Clifford batch normalization for 4D data.

    The input data is expected to be 4d, with shape `(B, C, *D, I)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining dimension 2 dimensions.
    See [torch.nn.BatchNorm2d][] for details.
    """

    def _check_input_dim(self, x):
        *_, I = x.shape
        if not I == self.n_blades:
            raise ValueError(f"Wrong number of Clifford blades. Expected {self.n_blades} blades, but {I} were given.")
        if x.dim() != 5:
            raise ValueError(f"Expected 3D or 4D input (got {x.dim()}D input).")

CliffordBatchNorm3d ¤

Bases: _CliffordBatchNorm

Clifford batch normalization for 5D data. The input data is expected to be 5d, with shape (B, C, *D, I), where B is the batch dimension, C the channels/features, and D the remaining dimension 3 dimensions. See torch.nn.BatchNorm2d for details.

Source code in cliffordlayers/nn/modules/batchnorm.py
class CliffordBatchNorm3d(_CliffordBatchNorm):
    """Clifford batch normalization for 5D data.
    The input data is expected to be 5d, with shape `(B, C, *D, I)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining dimension 3 dimensions.
    See [torch.nn.BatchNorm2d][] for details.
    """

    def _check_input_dim(self, x):
        *_, I = x.shape
        if not I == self.n_blades:
            raise ValueError(f"Wrong number of Clifford blades. Expected {self.n_blades} blades, but {I} were given.")
        if x.dim() != 6:
            raise ValueError(f"Expected 3D or 4D input (got {x.dim()}D input).")

ComplexBatchNorm1d ¤

Bases: _ComplexBatchNorm

Complex-valued batch normalization for 2D or 3D data.

The input complex-valued data is expected to be at least 2d, with shape (B, C, D), where B is the batch dimension, C the channels/features, and D the remaining dimension (if present). See torch.nn.BatchNorm1d for details.

Source code in cliffordlayers/nn/modules/batchnorm.py
class ComplexBatchNorm1d(_ComplexBatchNorm):
    """Complex-valued batch normalization for 2D or 3D data.

    The input complex-valued data is expected to be at least 2d, with shape `(B, C, D)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining dimension (if present).
    See [torch.nn.BatchNorm1d][] for details.
    """

    def _check_input_dim(self, x):
        if x.dim() != 2 and x.dim() != 3:
            raise ValueError(f"Expected 2D or 3D input (got {x.dim()}D input).")

ComplexBatchNorm2d ¤

Bases: _ComplexBatchNorm

Complex-valued batch normalization for 4D data.

The input complex-valued data is expected to be 4d, with shape (B, C, *D), where B is the batch dimension, C the channels/features, and D the remaining 2 dimensions. See torch.nn.BatchNorm2d for details.

Source code in cliffordlayers/nn/modules/batchnorm.py
class ComplexBatchNorm2d(_ComplexBatchNorm):
    """Complex-valued batch normalization for 4D data.

    The input complex-valued data is expected to be 4d, with shape `(B, C, *D)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining 2 dimensions.
    See [torch.nn.BatchNorm2d][] for details.
    """

    def _check_input_dim(self, x):
        if x.dim() != 4:
            raise ValueError(f"Expected 4D input (got {x.dim()}D input).")

ComplexBatchNorm3d ¤

Bases: _ComplexBatchNorm

Complex-valued batch normalization for 5D data.

The input complex-valued data is expected to be 5d, with shape (B, C, *D), where B is the batch dimension, C the channels/features, and D the remaining 3 dimensions. See torch.nn.BatchNorm3d for details.

Source code in cliffordlayers/nn/modules/batchnorm.py
class ComplexBatchNorm3d(_ComplexBatchNorm):
    """Complex-valued batch normalization for 5D data.

    The input complex-valued data is expected to be 5d, with shape `(B, C, *D)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining 3 dimensions.
    See [torch.nn.BatchNorm3d][] for details.
    """

    def _check_input_dim(self, x):
        if x.dim() != 5:
            raise ValueError(f"Expected 5D input (got {x.dim()}D input).")

CliffordGroupNorm1d ¤

Bases: _CliffordGroupNorm

Clifford group normalization for 2D or 3D data.

The input data is expected to be at least 3d, with shape (B, C, D, I), where B is the batch dimension, C the channels/features, and D the remaining dimension (if present).

Source code in cliffordlayers/nn/modules/groupnorm.py
class CliffordGroupNorm1d(_CliffordGroupNorm):
    """Clifford group normalization for 2D or 3D data.

    The input data is expected to be at least 3d, with shape `(B, C, D, I)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining dimension (if present).
    """

    def _check_input_dim(self, x):
        *_, I = x.shape
        if not I == self.n_blades:
            raise ValueError(f"Wrong number of Clifford blades. Expected {self.n_blades} blades, but {I} were given.")
        if x.dim() != 3 and x.dim() != 4:
            raise ValueError(f"Expected 3D or 4D input (got {x.dim()}D input).")

CliffordGroupNorm2d ¤

Bases: _CliffordGroupNorm

Clifford group normalization for 4D data.

The input data is expected to be 4D, with shape (B, C, *D, I), where B is the batch dimension, C the channels/features, and D the remaining 2 dimensions.

Source code in cliffordlayers/nn/modules/groupnorm.py
class CliffordGroupNorm2d(_CliffordGroupNorm):
    """Clifford group normalization for 4D data.

    The input data is expected to be 4D, with shape `(B, C, *D, I)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining 2 dimensions.
    """

    def _check_input_dim(self, x):
        *_, I = x.shape
        if not I == self.n_blades:
            raise ValueError(f"Wrong number of Clifford blades. Expected {self.n_blades} blades, but {I} were given.")
        if x.dim() != 5:
            raise ValueError(f"Expected 3D or 4D input (got {x.dim()}D input).")

CliffordGroupNorm3d ¤

Bases: _CliffordGroupNorm

Clifford group normalization for 4D data.

The input data is expected to be 5D, with shape (B, C, *D, I), where B is the batch dimension, C the channels/features, and D the remaining 3 dimensions.

Source code in cliffordlayers/nn/modules/groupnorm.py
class CliffordGroupNorm3d(_CliffordGroupNorm):
    """Clifford group normalization for 4D data.

    The input data is expected to be 5D, with shape `(B, C, *D, I)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining 3 dimensions.
    """

    def _check_input_dim(self, x):
        *_, I = x.shape
        if not I == self.n_blades:
            raise ValueError(f"Wrong number of Clifford blades. Expected {self.n_blades} blades, but {I} were given.")
        if x.dim() != 6:
            raise ValueError(f"Expected 3D or 4D input (got {x.dim()}D input).")

ComplexGroupNorm1d ¤

Bases: _ComplexGroupNorm

Complex-valued group normalization for 2D or 3D data.

The input complex-valued data is expected to be at least 2d, with shape (B, C, D), where B is the batch dimension, C the channels/features, and D the remaining dimension (if present).

Source code in cliffordlayers/nn/modules/groupnorm.py
class ComplexGroupNorm1d(_ComplexGroupNorm):
    """Complex-valued group normalization for 2D or 3D data.

    The input complex-valued data is expected to be at least 2d, with shape `(B, C, D)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining dimension (if present).
    """

    def _check_input_dim(self, x):
        if x.dim() != 2 and x.dim() != 3:
            raise ValueError(f"Expected 2D or 3D input (got {x.dim()}D input).")

ComplexGroupNorm2d ¤

Bases: _ComplexGroupNorm

Complex-valued group normalization for 4 data.

The input complex-valued data is expected to be 4d, with shape (B, C, *D), where B is the batch dimension, C the channels/features, and D the remaining 2 dimensions.

Source code in cliffordlayers/nn/modules/groupnorm.py
class ComplexGroupNorm2d(_ComplexGroupNorm):
    """Complex-valued group normalization for 4 data.

    The input complex-valued data is expected to be 4d, with shape `(B, C, *D)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining 2 dimensions.
    """

    def _check_input_dim(self, x):
        if x.dim() != 4:
            raise ValueError(f"Expected 4D input (got {x.dim()}D input).")

ComplexGroupNorm3d ¤

Bases: _ComplexGroupNorm

Complex-valued group normalization for 5 data.

The input complex-valued data is expected to be 5d, with shape (B, C, *D), where B is the batch dimension, C the channels/features, and D the remaining 3 dimensions.

Source code in cliffordlayers/nn/modules/groupnorm.py
class ComplexGroupNorm3d(_ComplexGroupNorm):
    """Complex-valued group normalization for 5 data.

    The input complex-valued data is expected to be 5d, with shape `(B, C, *D)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining 3 dimensions.
    """

    def _check_input_dim(self, x):
        if x.dim() != 5:
            raise ValueError(f"Expected 4D input (got {x.dim()}D input).")