Skip to content

Modules¤

We provide linear Clifford layers; 1D, 2D, 3D Clifford convolution layers, and 2D, 3D Clifford Fourier transform layers. Additionally, Clifford normalization schemes are provided.

All these modules are available for different algebras.

CliffordLinear ¤

Bases: Module

Clifford linear layer.

Parameters:

Name Type Description Default
g Union[List, Tuple]

Clifford signature tensor.

required
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
bias bool

If True, adds a learnable bias to the output. Defaults to True.

True
Source code in cliffordlayers/nn/modules/cliffordlinear.py
class CliffordLinear(nn.Module):
    """Clifford linear layer.

    Args:
        g (Union[List, Tuple]): Clifford signature tensor.
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        bias (bool, optional): If True, adds a learnable bias to the output. Defaults to True.

    """

    def __init__(
        self,
        g,
        in_channels: int,
        out_channels: int,
        bias: bool = True,
    ) -> None:
        super().__init__()
        sig = CliffordSignature(g)

        self.register_buffer("g", sig.g)
        self.dim = sig.dim
        self.n_blades = sig.n_blades

        if self.dim == 1:
            self._get_kernel = get_1d_clifford_kernel
        elif self.dim == 2:
            self._get_kernel = get_2d_clifford_kernel
        elif self.dim == 3:
            self._get_kernel = get_3d_clifford_kernel
        else:
            raise NotImplementedError(
                f"Clifford linear layers are not implemented for {self.dim} dimensions. Wrong Clifford signature."
            )

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.weight = nn.Parameter(torch.empty(self.n_blades, out_channels, in_channels))

        if bias:
            self.bias = nn.Parameter(torch.empty(self.n_blades, out_channels))
        else:
            self.register_parameter("bias", None)

        self.reset_parameters()

    def reset_parameters(self):
        # Initialization of the Clifford linear weight and bias tensors.
        # The number of blades is taken into account when calculated the bounds of Kaiming uniform.
        nn.init.kaiming_uniform_(
            self.weight.view(self.out_channels, self.in_channels * self.n_blades),
            a=math.sqrt(5),
        )
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
                self.weight.view(self.out_channels, self.in_channels * self.n_blades)
            )
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Reshape x such that the Clifford kernel can be applied.
        B, _, I = x.shape
        if not (I == self.n_blades):
            raise ValueError(f"Input has {I} blades, but Clifford layer expects {self.n_blades}.")
        B_dim, C_dim, I_dim = range(len(x.shape))
        x = x.permute(B_dim, -1, C_dim)
        x = x.reshape(B, -1)
        # Get Clifford kernel, apply it.
        _, weight = self._get_kernel(self.weight, self.g)
        output = F.linear(x, weight, self.bias.view(-1))
        # Reshape back.
        output = output.view(B, I, -1)
        B_dim, I_dim, C_dim = range(len(output.shape))
        output = output.permute(B_dim, C_dim, I_dim)
        return output

CliffordConv1d ¤

Bases: _CliffordConvNd

1d Clifford convolution.

Parameters:

Name Type Description Default
g Union[tuple, list, Tensor]

Clifford signature.

required
in_channels int

Number of channels in the input tensor.

required
out_channels int

Number of channels produced by the convolution.

required
kernel_size int

Size of the convolving kernel.

3
stride int

Stride of the convolution.

1
padding int

padding added to both sides of the input.

0
dilation int

Spacing between kernel elements.

1
groups int

Number of blocked connections from input channels to output channels.

1
bias bool

If True, adds a learnable bias to the output.

True
padding_mode str

Padding to use.

'zeros'
Source code in cliffordlayers/nn/modules/cliffordconv.py
class CliffordConv1d(_CliffordConvNd):
    """1d Clifford convolution.

    Args:
        g (Union[tuple, list, torch.Tensor]): Clifford signature.
        in_channels (int): Number of channels in the input tensor.
        out_channels (int): Number of channels produced by the convolution.
        kernel_size (int): Size of the convolving kernel.
        stride (int): Stride of the convolution.
        padding (int): padding added to both sides of the input.
        dilation (int): Spacing between kernel elements.
        groups (int): Number of blocked connections from input channels to output channels.
        bias (bool): If True, adds a learnable bias to the output.
        padding_mode (str): Padding to use.
    """

    def __init__(
        self,
        g: Union[tuple, list, torch.Tensor],
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
    ) -> None:
        kernel_size_ = _single(kernel_size)
        stride_ = _single(stride)
        padding_ = _single(padding)
        dilation_ = _single(dilation)

        super().__init__(
            g,
            in_channels,
            out_channels,
            kernel_size_,
            stride_,
            padding_,
            dilation_,
            groups,
            bias,
            padding_mode,
        )
        if not self.dim == 1:
            raise NotImplementedError("Wrong Clifford signature for CliffordConv1d.")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        *_, I = x.shape
        if not (I == self.n_blades):
            raise ValueError(f"Input has {I} blades, but Clifford layer expects {self.n_blades}.")
        return super().forward(x, F.conv1d)

CliffordConv2d ¤

Bases: _CliffordConvNd

2d Clifford convolution.

Parameters:

Name Type Description Default
g Union[tuple, list, Tensor]

Clifford signature.

required
in_channels int

Number of channels in the input tensor.

required
out_channels int

Number of channels produced by the convolution.

required
kernel_size Union[int, Tuple[int, int]]

Size of the convolving kernel.

3
stride Union[int, Tuple[int, int]]

Stride of the convolution.

1
padding Union[int, Tuple[int, int]]

padding added to both sides of the input.

0
dilation Union[int, Tuple[int, int]]

Spacing between kernel elements.

1
groups int

Number of blocked connections from input channels to output channels.

1
bias bool

If True, adds a learnable bias to the output.

True
padding_mode str

Padding to use.

'zeros'
rotation bool

If True, enables the rotation kernel for Clifford convolution.

False
Source code in cliffordlayers/nn/modules/cliffordconv.py
class CliffordConv2d(_CliffordConvNd):
    """2d Clifford convolution.

    Args:
        g (Union[tuple, list, torch.Tensor]): Clifford signature.
        in_channels (int): Number of channels in the input tensor.
        out_channels (int): Number of channels produced by the convolution.
        kernel_size (Union[int, Tuple[int, int]]): Size of the convolving kernel.
        stride (Union[int, Tuple[int, int]]): Stride of the convolution.
        padding (Union[int, Tuple[int, int]]): padding added to both sides of the input.
        dilation (Union[int, Tuple[int, int]]): Spacing between kernel elements.
        groups (int): Number of blocked connections from input channels to output channels.
        bias (bool): If True, adds a learnable bias to the output.
        padding_mode (str): Padding to use.
        rotation (bool): If True, enables the rotation kernel for Clifford convolution.
    """

    def __init__(
        self,
        g: Union[tuple, list, torch.Tensor],
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
        rotation: bool = False,
    ):
        kernel_size_ = _pair(kernel_size)
        stride_ = _pair(stride)
        padding_ = padding if isinstance(padding, str) else _pair(padding)
        dilation_ = _pair(dilation)

        super().__init__(
            g,
            in_channels,
            out_channels,
            kernel_size_,
            stride_,
            padding_,
            dilation_,
            groups,
            bias,
            padding_mode,
            rotation,
        )
        if not self.dim == 2:
            raise NotImplementedError("Wrong Clifford signature for CliffordConv2d.")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        *_, I = x.shape
        if not (I == self.n_blades):
            raise ValueError(f"Input has {I} blades, but Clifford layer expects {self.n_blades}.")
        return super().forward(x, F.conv2d)

CliffordConv3d ¤

Bases: _CliffordConvNd

3d Clifford convolution.

Parameters:

Name Type Description Default
g Union[tuple, list, Tensor]

Clifford signature.

required
in_channels int

Number of channels in the input tensor.

required
out_channels int

Number of channels produced by the convolution.

required
kernel_size Union[int, Tuple[int, int, int]]

Size of the convolving kernel.

3
stride Union[int, Tuple[int, int, int]]

Stride of the convolution.

1
padding Union[int, Tuple[int, int, int]]

padding added to all sides of the input.

0
dilation Union[int, Tuple[int, int, int]]

Spacing between kernel elements.

1
groups int

Number of blocked connections from input channels to output channels.

1
bias bool

If True, adds a learnable bias to the output.

True
padding_mode str

Padding to use.

'zeros'
Source code in cliffordlayers/nn/modules/cliffordconv.py
class CliffordConv3d(_CliffordConvNd):
    """3d Clifford convolution.

    Args:
        g (Union[tuple, list, torch.Tensor]): Clifford signature.
        in_channels (int): Number of channels in the input tensor.
        out_channels (int): Number of channels produced by the convolution.
        kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolving kernel.
        stride (Union[int, Tuple[int, int, int]]): Stride of the convolution.
        padding (Union[int, Tuple[int, int, int]]): padding added to all sides of the input.
        dilation (Union[int, Tuple[int, int, int]]): Spacing between kernel elements.
        groups (int): Number of blocked connections from input channels to output channels.
        bias (bool): If True, adds a learnable bias to the output.
        padding_mode (str): Padding to use.
    """

    def __init__(
        self,
        g: Union[tuple, list, torch.Tensor],
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
    ):
        kernel_size_ = _triple(kernel_size)
        stride_ = _triple(stride)
        padding_ = padding if isinstance(padding, str) else _triple(padding)
        dilation_ = _triple(dilation)

        super().__init__(
            g,
            in_channels,
            out_channels,
            kernel_size_,
            stride_,
            padding_,
            dilation_,
            groups,
            bias,
            padding_mode,
        )
        if not self.dim == 3:
            raise NotImplementedError("Wrong Clifford signature for CliffordConv3d.")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        *_, I = x.shape
        if not (I == self.n_blades):
            raise ValueError(f"Input has {I} blades, but Clifford layer expects {self.n_blades}.")
        return super().forward(x, F.conv3d)

CliffordSpectralConv2d ¤

Bases: Module

2d Clifford Fourier layer. Performs following three steps: 1. Clifford Fourier transform over the multivector of 2d Clifford algebras, based on complex Fourier transforms using pytorch.fft.fft2. 2. Weight multiplication in the Clifford Fourier space using the geometric product. 3. Inverse Clifford Fourier transform, based on inverse complex Fourier transforms using pytorch.fft.ifft2.

Parameters:

Name Type Description Default
g Union[tuple, list, Tensor]

Signature of Clifford algebra.

required
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
modes1 int

Number of non-zero Fourier modes in the first dimension.

required
modes2 int

Number of non-zero Fourier modes in the second dimension.

required
multiply bool

Multipliation in the Fourier space. If set to False this class only crops high-frequency modes.

True
Source code in cliffordlayers/nn/modules/cliffordfourier.py
class CliffordSpectralConv2d(nn.Module):
    """2d Clifford Fourier layer.
    Performs following three steps:
        1. Clifford Fourier transform over the multivector of 2d Clifford algebras, based on complex Fourier transforms using [pytorch.fft.fft2](https://pytorch.org/docs/stable/generated/torch.fft.fft2.html#torch.fft.fft2).
        2. Weight multiplication in the Clifford Fourier space using the geometric product.
        3. Inverse Clifford Fourier transform, based on inverse complex Fourier transforms using [pytorch.fft.ifft2](https://pytorch.org/docs/stable/generated/torch.fft.ifft2.html#torch.fft.ifft2).

    Args:
        g ((Union[tuple, list, torch.Tensor]): Signature of Clifford algebra.
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        modes1 (int): Number of non-zero Fourier modes in the first dimension.
        modes2 (int): Number of non-zero Fourier modes in the second dimension.
        multiply (bool): Multipliation in the Fourier space. If set to False this class only crops high-frequency modes.

    """

    def __init__(
        self,
        g: Union[tuple, list, torch.Tensor],
        in_channels: int,
        out_channels: int,
        modes1: int,
        modes2: int,
        multiply: bool = True,
    ) -> None:
        super().__init__()
        sig = CliffordSignature(g)
        # To allow move to same device as module.
        self.register_buffer("g", sig.g)
        self.dim = sig.dim
        if self.dim != 2:
            raise ValueError("g must be a 2D Clifford algebra")

        self.n_blades = sig.n_blades
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1
        self.modes2 = modes2
        self.multiply = multiply

        # Initialize weight parameters.
        if multiply:
            scale = 1 / (in_channels * out_channels)
            self.weights = nn.Parameter(
                scale * torch.rand(4, out_channels, in_channels, self.modes1 * 2, self.modes2 * 2, dtype=torch.float32)
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Reshape x such that FFT can be applied to dual pairs.
        B, _, *D, I = x.shape
        *_, I = x.shape
        if not (I == self.n_blades):
            raise ValueError(f"Input has {I} blades, but Clifford layer expects {self.n_blades}.")

        dual_1 = torch.view_as_complex(torch.stack((x[..., 0], x[..., 3]), dim=-1))
        dual_2 = torch.view_as_complex(torch.stack((x[..., 1], x[..., 2]), dim=-1))
        dual_1_ft = torch.fft.fft2(dual_1)
        dual_2_ft = torch.fft.fft2(dual_2)

        # Add dual pairs again to multivector in the Fourier space.
        multivector_ft = torch.cat(
            (
                dual_1_ft.real,
                dual_2_ft.real,
                dual_2_ft.imag,
                dual_1_ft.imag,
            ),
            dim=1,
        )

        # Reserve Cifford output Fourier modes.
        out_ft = torch.zeros(
            B,
            self.out_channels * self.n_blades,
            *D,
            dtype=torch.float,
            device=multivector_ft.device,
        )

        # Concatenate positive and negative modes, such that the geometric product can be applied in one go.
        input_mul = torch.cat(
            (
                torch.cat(
                    (
                        multivector_ft[:, :, : self.modes1, : self.modes2],
                        multivector_ft[:, :, : self.modes1, -self.modes2 :],
                    ),
                    -1,
                ),
                torch.cat(
                    (
                        multivector_ft[:, :, -self.modes1 :, : self.modes2],
                        multivector_ft[:, :, -self.modes1 :, -self.modes2 :],
                    ),
                    -1,
                ),
            ),
            -2,
        )

        # Get Clifford weight tensor and apply the geometric product in the Fourier space.
        if self.multiply:
            _, kernel = get_2d_clifford_kernel(self.weights, self.g)
            output_mul = batchmul2d(input_mul, kernel)
        else:
            output_mul = input_mul

        # Fill the output modes, i.e. cut away high-frequency modes.
        out_ft[:, :, : self.modes1, : self.modes2] = output_mul[:, :, : self.modes1, : self.modes2]
        out_ft[:, :, -self.modes1 :, : self.modes2] = output_mul[:, :, -self.modes1 :, : self.modes2]
        out_ft[:, :, : self.modes1, -self.modes2 :] = output_mul[:, :, : self.modes1, -self.modes2 :]
        out_ft[:, :, -self.modes1 :, -self.modes2 :] = output_mul[:, :, -self.modes1 :, -self.modes2 :]

        # Reshape output such that inverse FFTs can be applied to the dual pairs.
        out_ft = out_ft.reshape(B, I, -1, *out_ft.shape[-2:])
        B_dim, I_dim, C_dim, *D_dims = range(len(out_ft.shape))
        out_ft = out_ft.permute(B_dim, C_dim, *D_dims, I_dim)
        out_dual_1 = torch.view_as_complex(torch.stack((out_ft[..., 0], out_ft[..., 3]), dim=-1))
        out_dual_2 = torch.view_as_complex(torch.stack((out_ft[..., 1], out_ft[..., 2]), dim=-1))
        dual_1_ifft = torch.fft.ifft2(out_dual_1, s=(out_dual_1.size(-2), out_dual_1.size(-1)))
        dual_2_ifft = torch.fft.ifft2(out_dual_2, s=(out_dual_2.size(-2), out_dual_2.size(-1)))

        # Finally, return to the multivector in the spatial domain.
        output = torch.stack(
            (
                dual_1_ifft.real,
                dual_2_ifft.real,
                dual_2_ifft.imag,
                dual_1_ifft.imag,
            ),
            dim=-1,
        )

        return output

CliffordSpectralConv3d ¤

Bases: Module

3d Clifford Fourier layer. Performs following three steps: 1. Clifford Fourier transform over the multivector of 3d Clifford algebras, based on complex Fourier transforms using pytorch.fft.fftn. 2. Weight multiplication in the Clifford Fourier space using the geometric product. 3. Inverse Clifford Fourier transform, based on inverse complex Fourier transforms using pytorch.fft.ifftn.

Parameters:

Name Type Description Default
g Union[tuple, list, Tensor]

Signature of Clifford algebra.

required
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
modes1 int

Number of non-zero Fourier modes in the first dimension.

required
modes2 int

Number of non-zero Fourier modes in the second dimension.

required
modes3 int

Number of non-zero Fourier modes in the second dimension.

required
multiply bool

Multipliation in the Fourier space. If set to False this class only crops high-frequency modes.

True
Source code in cliffordlayers/nn/modules/cliffordfourier.py
class CliffordSpectralConv3d(nn.Module):
    """3d Clifford Fourier layer.
    Performs following three steps:
        1. Clifford Fourier transform over the multivector of 3d Clifford algebras, based on complex Fourier transforms using [pytorch.fft.fftn](https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn).
        2. Weight multiplication in the Clifford Fourier space using the geometric product.
        3. Inverse Clifford Fourier transform, based on inverse complex Fourier transforms using [pytorch.fft.ifftn](https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.ifftn).

    Args:
        g ((Union[tuple, list, torch.Tensor]): Signature of Clifford algebra.
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        modes1 (int): Number of non-zero Fourier modes in the first dimension.
        modes2 (int): Number of non-zero Fourier modes in the second dimension.
        modes3 (int): Number of non-zero Fourier modes in the second dimension.
        multiply (bool): Multipliation in the Fourier space. If set to False this class only crops high-frequency modes.

    """

    def __init__(
        self,
        g: Union[tuple, list, torch.Tensor],
        in_channels: int,
        out_channels: int,
        modes1: int,
        modes2: int,
        modes3: int,
        multiply: bool = True,
    ) -> None:
        super().__init__()
        sig = CliffordSignature(g)
        self.g = sig.g
        self.dim = sig.dim
        if self.dim != 3:
            raise ValueError("g must be a 3D Clifford algebra")
        self.n_blades = sig.n_blades

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1
        self.modes2 = modes2
        self.modes3 = modes3
        self.multiply = multiply

        # Initialize weight parameters.
        if self.multiply:
            scale = 1 / (in_channels * out_channels)
            self.weights = nn.Parameter(
                scale
                * torch.rand(
                    8,
                    out_channels,
                    in_channels,
                    self.modes1 * 2,
                    self.modes2 * 2,
                    self.modes3 * 2,
                    dtype=torch.float32,
                )
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Reshape x such that FFT can be applied to dual pairs.
        B, _, *D, I = x.shape
        *_, I = x.shape
        if not (I == self.n_blades):
            raise ValueError(f"Input has {I} blades, but Clifford layer expects {self.n_blades}.")

        dual_1 = torch.view_as_complex(torch.stack((x[..., 0], x[..., 7]), dim=-1))
        dual_2 = torch.view_as_complex(torch.stack((x[..., 1], x[..., 6]), dim=-1))
        dual_3 = torch.view_as_complex(torch.stack((x[..., 2], x[..., 5]), dim=-1))
        dual_4 = torch.view_as_complex(torch.stack((x[..., 3], x[..., 4]), dim=-1))
        dual_1_ft = torch.fft.fftn(dual_1, dim=[-3, -2, -1])
        dual_2_ft = torch.fft.fftn(dual_2, dim=[-3, -2, -1])
        dual_3_ft = torch.fft.fftn(dual_3, dim=[-3, -2, -1])
        dual_4_ft = torch.fft.fftn(dual_4, dim=[-3, -2, -1])

        # Add dual pairs again to multivector in the Fourier space.
        multivector_ft = torch.cat(
            (
                dual_1_ft.real,
                dual_2_ft.real,
                dual_3_ft.real,
                dual_4_ft.real,
                dual_4_ft.imag,
                dual_3_ft.imag,
                dual_2_ft.imag,
                dual_1_ft.imag,
            ),
            dim=1,
        )

        # Reserve Cifford output Fourier modes.
        out_ft = torch.zeros(
            B,
            self.out_channels * self.n_blades,
            *D,
            dtype=torch.float,
            device=multivector_ft.device,
        )

        # Concatenate positive and negative modes, such that the geometric product can be applied in one go.
        input_mul = torch.cat(
            (
                torch.cat(
                    (
                        torch.cat(
                            (
                                multivector_ft[:, :, : self.modes1, : self.modes2, : self.modes3],
                                multivector_ft[:, :, : self.modes1, : self.modes2, -self.modes3 :],
                            ),
                            -1,
                        ),
                        torch.cat(
                            (
                                multivector_ft[:, :, : self.modes1, -self.modes2 :, : self.modes3],
                                multivector_ft[:, :, : self.modes1, -self.modes2 :, -self.modes3 :],
                            ),
                            -1,
                        ),
                    ),
                    -2,
                ),
                torch.cat(
                    (
                        torch.cat(
                            (
                                multivector_ft[:, :, -self.modes1 :, : self.modes2, : self.modes3],
                                multivector_ft[:, :, -self.modes1 :, : self.modes2, -self.modes3 :],
                            ),
                            -1,
                        ),
                        torch.cat(
                            (
                                multivector_ft[:, :, -self.modes1 :, -self.modes2 :, : self.modes3],
                                multivector_ft[:, :, -self.modes1 :, -self.modes2 :, -self.modes3 :],
                            ),
                            -1,
                        ),
                    ),
                    -2,
                ),
            ),
            -3,
        )

        # Get Clifford weight tensor and apply the geometric product in the Fourier space.
        if self.multiply:
            _, kernel = get_3d_clifford_kernel(self.weights, self.g)
            output_mul = batchmul3d(input_mul, kernel)
        else:
            output_mul = input_mul

        # Fill the output modes, i.e. cut away high-frequency modes.
        out_ft[:, :, : self.modes1, : self.modes2, : self.modes3] = output_mul[
            :, :, : self.modes1, : self.modes2, : self.modes3
        ]
        out_ft[:, :, : self.modes1, : self.modes2, -self.modes3 :] = output_mul[
            :, :, : self.modes1, : self.modes2, -self.modes3 :
        ]
        out_ft[:, :, : self.modes1, -self.modes2 :, : self.modes3] = output_mul[
            :, :, : self.modes1, -self.modes2 :, : self.modes3
        ]
        out_ft[:, :, : self.modes1, -self.modes2 :, -self.modes3 :] = output_mul[
            :, :, : self.modes1, -self.modes2 :, -self.modes3 :
        ]
        out_ft[:, :, -self.modes1 :, : self.modes2, : self.modes3] = output_mul[
            :, :, -self.modes1 :, : self.modes2, : self.modes3
        ]
        out_ft[:, :, -self.modes1 :, : self.modes2, -self.modes3 :] = output_mul[
            :, :, : -self.modes1 :, : self.modes2, -self.modes3 :
        ]
        out_ft[:, :, -self.modes1 :, -self.modes2 :, : self.modes3] = output_mul[
            :, :, -self.modes1 :, -self.modes2 :, : self.modes3
        ]
        out_ft[:, :, -self.modes1 :, -self.modes2 :, -self.modes3 :] = output_mul[
            :, :, -self.modes1 :, -self.modes2 :, -self.modes3 :
        ]

        # Reshape output such that inverse FFTs can be applied to the dual pairs.
        out_ft = out_ft.reshape(B, I, -1, *out_ft.shape[-3:])
        B_dim, I_dim, C_dim, *D_dims = range(len(out_ft.shape))
        out_ft = out_ft.permute(B_dim, C_dim, *D_dims, I_dim)

        out_dual_1 = torch.view_as_complex(torch.stack((out_ft[..., 0], out_ft[..., 7]), dim=-1))
        out_dual_2 = torch.view_as_complex(torch.stack((out_ft[..., 1], out_ft[..., 6]), dim=-1))
        out_dual_3 = torch.view_as_complex(torch.stack((out_ft[..., 2], out_ft[..., 5]), dim=-1))
        out_dual_4 = torch.view_as_complex(torch.stack((out_ft[..., 3], out_ft[..., 4]), dim=-1))
        dual_1_ifft = torch.fft.ifftn(out_dual_1, s=(out_dual_1.size(-3), out_dual_1.size(-2), out_dual_1.size(-1)))
        dual_2_ifft = torch.fft.ifftn(out_dual_2, s=(out_dual_2.size(-3), out_dual_2.size(-2), out_dual_2.size(-1)))
        dual_3_ifft = torch.fft.ifftn(out_dual_3, s=(out_dual_3.size(-3), out_dual_3.size(-2), out_dual_3.size(-1)))
        dual_4_ifft = torch.fft.ifftn(out_dual_4, s=(out_dual_4.size(-3), out_dual_4.size(-2), out_dual_4.size(-1)))

        # Finally, return to the multivector in the spatial domain.
        output = torch.stack(
            (
                dual_1_ifft.real,
                dual_2_ifft.real,
                dual_3_ifft.real,
                dual_4_ifft.real,
                dual_4_ifft.imag,
                dual_3_ifft.imag,
                dual_2_ifft.imag,
                dual_1_ifft.imag,
            ),
            dim=-1,
        )

        return output

CliffordG3Conv2d ¤

Bases: _CliffordG3ConvNd

2D convolutional layer where the features are vectors in G3.

Parameters:

Name Type Description Default
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
kernel_size int

Size of the convolutional kernel. Defaults to 1.

1
stride int

Stride of the convolution operation. Defaults to 1.

1
padding int or str

Padding added to both sides of the input or padding mode. Defaults to 0.

0
dilation int

Dilation rate of the kernel. Defaults to 1.

1
groups int

Number of blocked connections from input channels to output channels. Defaults to 1.

1
bias bool

If True, adds a bias term to the output. Defaults to False.

False
Source code in cliffordlayers/nn/modules/gcan.py
class CliffordG3Conv2d(_CliffordG3ConvNd):
    """
    2D convolutional layer where the features are vectors in G3.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        kernel_size (int, optional): Size of the convolutional kernel. Defaults to 1.
        stride (int, optional): Stride of the convolution operation. Defaults to 1.
        padding (int or str, optional): Padding added to both sides of the input or padding mode. Defaults to 0.
        dilation (int, optional): Dilation rate of the kernel. Defaults to 1.
        groups (int, optional): Number of blocked connections from input channels to output channels. Defaults to 1.
        bias (bool, optional): If True, adds a bias term to the output. Defaults to False.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 1,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = False,
    ):
        kernel_size_ = _pair(kernel_size)
        stride_ = _pair(stride)
        padding_ = padding if isinstance(padding, str) else _pair(padding)
        dilation_ = _pair(dilation)
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size_,
            stride=stride_,
            padding=padding_,
            dilation=dilation_,
            groups=groups,
            transposed=False,
            bias=bias,
        )

    def forward(self, input):
        x = torch.cat([input[..., 0], input[..., 1], input[..., 2]], dim=1)

        x = clifford_g3convnd(
            x,
            self.weights,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )

        e_0 = x[:, : self.out_channels, :, :]
        e_1 = x[:, self.out_channels : self.out_channels * 2, :, :]
        e_2 = x[:, self.out_channels * 2 : self.out_channels * 3, :, :]

        return torch.stack([e_0, e_1, e_2], dim=-1)

CliffordG3ConvTranspose2d ¤

Bases: _CliffordG3ConvNd

2D transposed convolutional layer where the features are vectors in G3.

Parameters:

Name Type Description Default
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
kernel_size int

Size of the convolutional kernel. Defaults to 1.

1
stride int

Stride of the convolution operation. Defaults to 1.

1
padding int or str

Padding added to both sides of the input or padding mode. Defaults to 0.

0
dilation int

Dilation rate of the kernel. Defaults to 1.

1
groups int

Number of blocked connections from input channels to output channels. Defaults to 1.

1
bias bool

If True, adds a bias term to the output. Defaults to False.

False
Source code in cliffordlayers/nn/modules/gcan.py
class CliffordG3ConvTranspose2d(_CliffordG3ConvNd):
    """
    2D transposed convolutional layer where the features are vectors in G3.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        kernel_size (int, optional): Size of the convolutional kernel. Defaults to 1.
        stride (int, optional): Stride of the convolution operation. Defaults to 1.
        padding (int or str, optional): Padding added to both sides of the input or padding mode. Defaults to 0.
        dilation (int, optional): Dilation rate of the kernel. Defaults to 1.
        groups (int, optional): Number of blocked connections from input channels to output channels. Defaults to 1.
        bias (bool, optional): If True, adds a bias term to the output. Defaults to False.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 1,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = False,
    ):
        kernel_size_ = _pair(kernel_size)
        stride_ = _pair(stride)
        padding_ = padding if isinstance(padding, str) else _pair(padding)
        dilation_ = _pair(dilation)
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size_,
            stride=stride_,
            padding=padding_,
            dilation=dilation_,
            groups=groups,
            transposed=True,
            bias=bias,
        )

    def forward(self, input):
        x = torch.cat([input[..., 0], input[..., 1], input[..., 2]], dim=1)

        x = clifford_g3convnd(
            x,
            self.weights,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
            transposed=True,
        )
        e_0 = x[:, : self.out_channels, :, :]
        e_1 = x[:, self.out_channels : self.out_channels * 2, :, :]
        e_2 = x[:, self.out_channels * 2 : self.out_channels * 3, :, :]

        return torch.stack([e_0, e_1, e_2], dim=-1)

CliffordG3GroupNorm ¤

Bases: Module

A module that applies group normalization to vectors in G3.

Parameters:

Name Type Description Default
num_groups int

Number of groups to normalize over.

required
num_features int

Number of features in the input.

required
num_blades int

Number of blades in the input.

required
scale_norm bool

If True, the output is scaled by the norm of the input. Defaults to False.

False
Source code in cliffordlayers/nn/modules/gcan.py
class CliffordG3GroupNorm(nn.Module):
    """
    A module that applies group normalization to vectors in G3.

    Args:
        num_groups (int): Number of groups to normalize over.
        num_features (int): Number of features in the input.
        num_blades (int): Number of blades in the input.
        scale_norm (bool, optional): If True, the output is scaled by the norm of the input. Defaults to False.
    """

    def __init__(self, num_groups, num_features, num_blades, scale_norm=False):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features, num_blades))
        self.num_groups = num_groups
        self.scale_norm = scale_norm
        self.num_blades = num_blades
        self.num_features = num_features

    def forward(self, x):
        N, C, *D, I = x.size()
        G = self.num_groups
        assert C % G == 0

        x = x.view(N, G, -1, I)
        mean = x.mean(-2, keepdim=True)
        x = x - mean
        if self.scale_norm:
            norm = x.norm(dim=-1, keepdim=True).mean(dim=-2, keepdims=True)
            x = x / norm

        x = x.view(len(x), self.num_features, -1, self.num_blades)

        return (x * self.weight[None, :, None, None] + self.bias[None, :, None]).view(N, C, *D, I)

CliffordG3LinearVSiLU ¤

Bases: Module

A module that applies the vector SiLU using a linear combination to vectors in G3.

Parameters:

Name Type Description Default
channels int

Number of channels in the input.

required
Source code in cliffordlayers/nn/modules/gcan.py
class CliffordG3LinearVSiLU(nn.Module):
    """
    A module that applies the vector SiLU using a linear combination to vectors in G3.

    Args:
        channels (int): Number of channels in the input.
    """

    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv3d(channels, channels, (1, 1, 3), groups=channels)

    def forward(self, input):
        return input * torch.sigmoid(self.conv(input))

CliffordG3MeanVSiLU ¤

Bases: Module

A module that applies the vector SiLU using vector mean to vectors in G3.

Source code in cliffordlayers/nn/modules/gcan.py
class CliffordG3MeanVSiLU(nn.Module):
    """
    A module that applies the vector SiLU using vector mean to vectors in G3.
    """

    def __init__(self):
        super().__init__()

    def forward(self, input):
        return torch.sigmoid(input.mean(-1, keepdim=True)) * input

CliffordG3SumVSiLU ¤

Bases: Module

A module that applies the vector SiLU using vector sum to vectors in G3.

Source code in cliffordlayers/nn/modules/gcan.py
class CliffordG3SumVSiLU(nn.Module):
    """
    A module that applies the vector SiLU using vector sum to vectors in G3.
    """

    def __init__(self):
        super().__init__()

    def forward(self, input):
        return torch.sigmoid(input.sum(-1, keepdim=True)) * input

MultiVectorAct ¤

Bases: Module

A module to apply multivector activations to the input.

Parameters:

Name Type Description Default
channels int

Number of channels in the input.

required
algebra

The algebra object that defines the geometric product.

required
input_blades (list, tuple)

The nonnegative input blades.

required
kernel_blades (list, tuple)

The blades that will be used to compute the activation. Defaults to all input blades.

None
agg str

The aggregation method to be used. Options include "linear", "sum", and "mean". Defaults to "linear".

'linear'
Source code in cliffordlayers/nn/modules/gcan.py
class MultiVectorAct(nn.Module):
    """
    A module to apply multivector activations to the input.

    Args:
        channels (int): Number of channels in the input.
        algebra: The algebra object that defines the geometric product.
        input_blades (list, tuple): The nonnegative input blades.
        kernel_blades (list, tuple, optional): The blades that will be used to compute the activation. Defaults to all input blades.
        agg (str, optional): The aggregation method to be used. Options include "linear", "sum", and "mean". Defaults to "linear".
    """

    def __init__(self, channels, algebra, input_blades, kernel_blades=None, agg="linear"):
        super().__init__()
        self.algebra = algebra
        self.input_blades = tuple(input_blades)
        if kernel_blades is not None:
            self.kernel_blades = tuple(kernel_blades)
        else:
            self.kernel_blades = self.input_blades

        if agg == "linear":
            self.conv = nn.Conv1d(channels, channels, kernel_size=len(self.kernel_blades), groups=channels)
        self.agg = agg

    def forward(self, input):
        v = self.algebra.embed(input, self.input_blades)
        if self.agg == "linear":
            v = v * torch.sigmoid(self.conv(v[..., self.kernel_blades]))
        elif self.agg == "sum":
            v = v * torch.sigmoid(v[..., self.kernel_blades].sum(dim=-1, keepdim=True))
        elif self.agg == "mean":
            v = v * torch.sigmoid(v[..., self.kernel_blades].mean(dim=-1, keepdim=True))
        else:
            raise ValueError(f"Aggregation {self.agg} not implemented.")
        v = self.algebra.get(v, self.input_blades)
        return v

PGAConjugateLinear ¤

Bases: Module

Linear layer that applies the PGA conjugation to the input.

Parameters:

Name Type Description Default
in_features int

Number of input features.

required
out_features int

Number of output features.

required
algebra Algebra

Algebra object that defines the geometric product.

required
input_blades tuple

Nonnegative blades of the input multivectors.

required
action_blades tuple

Blades of the action. Defaults to (0, 5, 6, 7, 8, 9, 10, 15), which encodes rotation and translation.

(0, 5, 6, 7, 8, 9, 10, 15)
Source code in cliffordlayers/nn/modules/gcan.py
class PGAConjugateLinear(nn.Module):
    """
    Linear layer that applies the PGA conjugation to the input.

    Args:
        in_features (int): Number of input features.
        out_features (int): Number of output features.
        algebra (Algebra): Algebra object that defines the geometric product.
        input_blades (tuple): Nonnegative blades of the input multivectors.
        action_blades (tuple, optional): Blades of the action. Defaults to (0, 5, 6, 7, 8, 9, 10, 15),
                                         which encodes rotation and translation.
    """

    def __init__(
        self,
        in_features,
        out_features,
        algebra,
        input_blades,
        action_blades=(0, 5, 6, 7, 8, 9, 10, 15),
    ):
        super().__init__()
        assert torch.all(algebra.metric == torch.tensor([0, 1, 1, 1]))
        self.input_blades = input_blades
        self.in_features = in_features
        self.out_features = out_features
        self.algebra = algebra
        self.action_blades = action_blades
        self.n_action_blades = len(action_blades)
        self._action = nn.Parameter(torch.empty(out_features, in_features, self.n_action_blades))
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        self.embed_e0 = nn.Parameter(torch.zeros(in_features, 1))

        self.inverse = algebra.reverse

        self.reset_parameters()

    def reset_parameters(self):
        # Init the rotation parts uniformly.
        torch.nn.init.uniform_(self._action[..., 0], -1, 1)
        torch.nn.init.uniform_(self._action[..., 4:7], -1, 1)

        # Init the translation parts with zeros.
        torch.nn.init.zeros_(self._action[..., 1:4])
        torch.nn.init.zeros_(self._action[..., 7])

        norm = self.algebra.norm(self.algebra.embed(self._action.data, self.action_blades))
        assert torch.allclose(norm[..., 1:], torch.tensor(0.0), atol=1e-3)
        norm = norm[..., :1]
        self._action.data = self._action.data / norm

        torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    @property
    def action(self):
        return self.algebra.embed(self._action, self.action_blades)

    def forward(self, input):
        M = self.algebra.cayley
        k = self.action
        k_ = self.inverse(k)
        x = self.algebra.embed(input, self.input_blades)
        x[..., 14:15] = self.embed_e0
        # x[..., 14:15] = 1

        k_l = get_clifford_left_kernel(M, k, flatten=False)
        k_r = get_clifford_right_kernel(M, k_, flatten=False)

        x = torch.einsum("oi,poqi,qori,bir->bop", self.weight, k_r, k_l, x)

        x = self.algebra.get(x, self.input_blades)

        return x

get_clifford_left_kernel(M, w, flatten=True) ¤

Obtains the matrix that computes the geometric product from the left. When the output is flattened, it can be used to apply a fully connected layer on the multivectors.

Parameters:

Name Type Description Default
M Tensor

Cayley table that defines the geometric relation.

required
w Tensor

Input tensor with shape (o, i, c) where o is the number of output channels, i is the number of input channels, and c is the number of blades.

required
flatten bool

If True, the resulting matrix will be reshaped for subsequent fully connected operations. Defaults to True.

True
Source code in cliffordlayers/nn/modules/gcan.py
def get_clifford_left_kernel(M, w, flatten=True):
    """
    Obtains the matrix that computes the geometric product from the left.
    When the output is flattened, it can be used to apply a fully connected
    layer on the multivectors.

    Args:
        M (Tensor): Cayley table that defines the geometric relation.
        w (Tensor): Input tensor with shape (o, i, c) where o is the number of output channels,
                    i is the number of input channels, and c is the number of blades.
        flatten (bool, optional): If True, the resulting matrix will be reshaped for subsequent
                                  fully connected operations. Defaults to True.

    """
    o, i, c = w.size()
    k = torch.einsum("ijk, pqi->jpkq", M, w)
    if flatten:
        k = k.reshape(o * c, i * c)
    return k

get_clifford_right_kernel(M, w, flatten=True) ¤

Obtains the matrix that computes the geometric product from the right. When the output is flattened, it can be used to apply a fully connected layer on the multivectors.

Parameters:

Name Type Description Default
M Tensor

Cayley table that defines the geometric relation.

required
w Tensor

Input tensor with shape (o, i, c) where o is the number of output channels, i is the number of input channels, and c is the number of blades.

required
flatten bool

If True, the resulting matrix will be reshaped for subsequent fully connected operations. Defaults to True.

True
Source code in cliffordlayers/nn/modules/gcan.py
def get_clifford_right_kernel(M, w, flatten=True):
    """
    Obtains the matrix that computes the geometric product from the right.
    When the output is flattened, it can be used to apply a fully connected
    layer on the multivectors.

    Args:
        M (Tensor): Cayley table that defines the geometric relation.
        w (Tensor): Input tensor with shape (o, i, c) where o is the number of output channels,
                    i is the number of input channels, and c is the number of blades.
        flatten (bool, optional): If True, the resulting matrix will be reshaped for subsequent
                                    fully connected operations. Defaults to True.
    """
    o, i, c = w.size()
    k = torch.einsum("ijk, pqk->jpiq", M, w)
    if flatten:
        k = k.reshape(o * c, i * c)
    return k

CliffordBatchNorm1d ¤

Bases: _CliffordBatchNorm

Clifford batch normalization for 2D or 3D data.

The input data is expected to be at least 3d, with shape (B, C, D, I), where B is the batch dimension, C the channels/features, and D the remaining dimension (if present). See [torch.nn.BatchNorm1d] for details.

Source code in cliffordlayers/nn/modules/batchnorm.py
class CliffordBatchNorm1d(_CliffordBatchNorm):
    """Clifford batch normalization for 2D or 3D data.

    The input data is expected to be at least 3d, with shape `(B, C, D, I)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining dimension (if present).
    See [torch.nn.BatchNorm1d] for details.
    """

    def _check_input_dim(self, x):
        *_, I = x.shape
        if not I == self.n_blades:
            raise ValueError(f"Wrong number of Clifford blades. Expected {self.n_blades} blades, but {I} were given.")
        if x.dim() != 3 and x.dim() != 4:
            raise ValueError(f"Expected 3D or 4D input (got {x.dim()}D input).")

CliffordBatchNorm2d ¤

Bases: _CliffordBatchNorm

Clifford batch normalization for 4D data.

The input data is expected to be 4d, with shape (B, C, *D, I), where B is the batch dimension, C the channels/features, and D the remaining dimension 2 dimensions. See torch.nn.BatchNorm2d for details.

Source code in cliffordlayers/nn/modules/batchnorm.py
class CliffordBatchNorm2d(_CliffordBatchNorm):
    """Clifford batch normalization for 4D data.

    The input data is expected to be 4d, with shape `(B, C, *D, I)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining dimension 2 dimensions.
    See [torch.nn.BatchNorm2d][] for details.
    """

    def _check_input_dim(self, x):
        *_, I = x.shape
        if not I == self.n_blades:
            raise ValueError(f"Wrong number of Clifford blades. Expected {self.n_blades} blades, but {I} were given.")
        if x.dim() != 5:
            raise ValueError(f"Expected 3D or 4D input (got {x.dim()}D input).")

CliffordBatchNorm3d ¤

Bases: _CliffordBatchNorm

Clifford batch normalization for 5D data. The input data is expected to be 5d, with shape (B, C, *D, I), where B is the batch dimension, C the channels/features, and D the remaining dimension 3 dimensions. See torch.nn.BatchNorm2d for details.

Source code in cliffordlayers/nn/modules/batchnorm.py
class CliffordBatchNorm3d(_CliffordBatchNorm):
    """Clifford batch normalization for 5D data.
    The input data is expected to be 5d, with shape `(B, C, *D, I)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining dimension 3 dimensions.
    See [torch.nn.BatchNorm2d][] for details.
    """

    def _check_input_dim(self, x):
        *_, I = x.shape
        if not I == self.n_blades:
            raise ValueError(f"Wrong number of Clifford blades. Expected {self.n_blades} blades, but {I} were given.")
        if x.dim() != 6:
            raise ValueError(f"Expected 3D or 4D input (got {x.dim()}D input).")

ComplexBatchNorm1d ¤

Bases: _ComplexBatchNorm

Complex-valued batch normalization for 2D or 3D data.

The input complex-valued data is expected to be at least 2d, with shape (B, C, D), where B is the batch dimension, C the channels/features, and D the remaining dimension (if present). See torch.nn.BatchNorm1d for details.

Source code in cliffordlayers/nn/modules/batchnorm.py
class ComplexBatchNorm1d(_ComplexBatchNorm):
    """Complex-valued batch normalization for 2D or 3D data.

    The input complex-valued data is expected to be at least 2d, with shape `(B, C, D)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining dimension (if present).
    See [torch.nn.BatchNorm1d][] for details.
    """

    def _check_input_dim(self, x):
        if x.dim() != 2 and x.dim() != 3:
            raise ValueError(f"Expected 2D or 3D input (got {x.dim()}D input).")

ComplexBatchNorm2d ¤

Bases: _ComplexBatchNorm

Complex-valued batch normalization for 4D data.

The input complex-valued data is expected to be 4d, with shape (B, C, *D), where B is the batch dimension, C the channels/features, and D the remaining 2 dimensions. See torch.nn.BatchNorm2d for details.

Source code in cliffordlayers/nn/modules/batchnorm.py
class ComplexBatchNorm2d(_ComplexBatchNorm):
    """Complex-valued batch normalization for 4D data.

    The input complex-valued data is expected to be 4d, with shape `(B, C, *D)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining 2 dimensions.
    See [torch.nn.BatchNorm2d][] for details.
    """

    def _check_input_dim(self, x):
        if x.dim() != 4:
            raise ValueError(f"Expected 4D input (got {x.dim()}D input).")

ComplexBatchNorm3d ¤

Bases: _ComplexBatchNorm

Complex-valued batch normalization for 5D data.

The input complex-valued data is expected to be 5d, with shape (B, C, *D), where B is the batch dimension, C the channels/features, and D the remaining 3 dimensions. See torch.nn.BatchNorm3d for details.

Source code in cliffordlayers/nn/modules/batchnorm.py
class ComplexBatchNorm3d(_ComplexBatchNorm):
    """Complex-valued batch normalization for 5D data.

    The input complex-valued data is expected to be 5d, with shape `(B, C, *D)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining 3 dimensions.
    See [torch.nn.BatchNorm3d][] for details.
    """

    def _check_input_dim(self, x):
        if x.dim() != 5:
            raise ValueError(f"Expected 5D input (got {x.dim()}D input).")

CliffordGroupNorm1d ¤

Bases: _CliffordGroupNorm

Clifford group normalization for 2D or 3D data.

The input data is expected to be at least 3d, with shape (B, C, D, I), where B is the batch dimension, C the channels/features, and D the remaining dimension (if present).

Source code in cliffordlayers/nn/modules/groupnorm.py
class CliffordGroupNorm1d(_CliffordGroupNorm):
    """Clifford group normalization for 2D or 3D data.

    The input data is expected to be at least 3d, with shape `(B, C, D, I)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining dimension (if present).
    """

    def _check_input_dim(self, x):
        *_, I = x.shape
        if not I == self.n_blades:
            raise ValueError(f"Wrong number of Clifford blades. Expected {self.n_blades} blades, but {I} were given.")
        if x.dim() != 3 and x.dim() != 4:
            raise ValueError(f"Expected 3D or 4D input (got {x.dim()}D input).")

CliffordGroupNorm2d ¤

Bases: _CliffordGroupNorm

Clifford group normalization for 4D data.

The input data is expected to be 4D, with shape (B, C, *D, I), where B is the batch dimension, C the channels/features, and D the remaining 2 dimensions.

Source code in cliffordlayers/nn/modules/groupnorm.py
class CliffordGroupNorm2d(_CliffordGroupNorm):
    """Clifford group normalization for 4D data.

    The input data is expected to be 4D, with shape `(B, C, *D, I)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining 2 dimensions.
    """

    def _check_input_dim(self, x):
        *_, I = x.shape
        if not I == self.n_blades:
            raise ValueError(f"Wrong number of Clifford blades. Expected {self.n_blades} blades, but {I} were given.")
        if x.dim() != 5:
            raise ValueError(f"Expected 3D or 4D input (got {x.dim()}D input).")

CliffordGroupNorm3d ¤

Bases: _CliffordGroupNorm

Clifford group normalization for 4D data.

The input data is expected to be 5D, with shape (B, C, *D, I), where B is the batch dimension, C the channels/features, and D the remaining 3 dimensions.

Source code in cliffordlayers/nn/modules/groupnorm.py
class CliffordGroupNorm3d(_CliffordGroupNorm):
    """Clifford group normalization for 4D data.

    The input data is expected to be 5D, with shape `(B, C, *D, I)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining 3 dimensions.
    """

    def _check_input_dim(self, x):
        *_, I = x.shape
        if not I == self.n_blades:
            raise ValueError(f"Wrong number of Clifford blades. Expected {self.n_blades} blades, but {I} were given.")
        if x.dim() != 6:
            raise ValueError(f"Expected 3D or 4D input (got {x.dim()}D input).")

ComplexGroupNorm1d ¤

Bases: _ComplexGroupNorm

Complex-valued group normalization for 2D or 3D data.

The input complex-valued data is expected to be at least 2d, with shape (B, C, D), where B is the batch dimension, C the channels/features, and D the remaining dimension (if present).

Source code in cliffordlayers/nn/modules/groupnorm.py
class ComplexGroupNorm1d(_ComplexGroupNorm):
    """Complex-valued group normalization for 2D or 3D data.

    The input complex-valued data is expected to be at least 2d, with shape `(B, C, D)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining dimension (if present).
    """

    def _check_input_dim(self, x):
        if x.dim() != 2 and x.dim() != 3:
            raise ValueError(f"Expected 2D or 3D input (got {x.dim()}D input).")

ComplexGroupNorm2d ¤

Bases: _ComplexGroupNorm

Complex-valued group normalization for 4 data.

The input complex-valued data is expected to be 4d, with shape (B, C, *D), where B is the batch dimension, C the channels/features, and D the remaining 2 dimensions.

Source code in cliffordlayers/nn/modules/groupnorm.py
class ComplexGroupNorm2d(_ComplexGroupNorm):
    """Complex-valued group normalization for 4 data.

    The input complex-valued data is expected to be 4d, with shape `(B, C, *D)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining 2 dimensions.
    """

    def _check_input_dim(self, x):
        if x.dim() != 4:
            raise ValueError(f"Expected 4D input (got {x.dim()}D input).")

ComplexGroupNorm3d ¤

Bases: _ComplexGroupNorm

Complex-valued group normalization for 5 data.

The input complex-valued data is expected to be 5d, with shape (B, C, *D), where B is the batch dimension, C the channels/features, and D the remaining 3 dimensions.

Source code in cliffordlayers/nn/modules/groupnorm.py
class ComplexGroupNorm3d(_ComplexGroupNorm):
    """Complex-valued group normalization for 5 data.

    The input complex-valued data is expected to be 5d, with shape `(B, C, *D)`,
    where `B` is the batch dimension, `C` the channels/features, and D the remaining 3 dimensions.
    """

    def _check_input_dim(self, x):
        if x.dim() != 5:
            raise ValueError(f"Expected 4D input (got {x.dim()}D input).")