Skip to content

Functions¤

clifford_batch_norm(x, n_blades, running_mean=None, running_cov=None, weight=None, bias=None, training=True, momentum=0.1, eps=1e-05) ¤

Clifford batch normalization for each channel across a batch of data.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (B, C, *D, I) where I is the blade of the algebra.

required
n_blades int

Number of blades of the Clifford algebra.

required
running_mean Tensor

The tensor with running mean statistics having shape (I, C).

None
running_cov Tensor

The tensor with running covariance statistics having shape (I, I, C).

None
weight Union[Tensor, Parameter]

Additional weight tensor which is applied post normalization, and has the shape (I, I, C).

None
bias Union[Tensor, Parameter]

Additional bias tensor which is applied post normalization, and has the shape (I, C).

None
training bool

Whether to use the running mean and variance. Defaults to True. Defaults to True.

True
momentum float

Momentum for the running mean and variance. Defaults to 0.1.

0.1
eps float

Epsilon for the running mean and variance. Defaults to 1e-05.

1e-05

Returns:

Type Description
Tensor

Normalized input of shape (B, C, *D, I)

Source code in cliffordlayers/nn/functional/batchnorm.py
def clifford_batch_norm(
    x: torch.Tensor,
    n_blades: int,
    running_mean: Optional[torch.Tensor] = None,
    running_cov: Optional[torch.Tensor] = None,
    weight: Optional[Union[torch.Tensor, nn.Parameter]] = None,
    bias: Optional[Union[torch.Tensor, nn.Parameter]] = None,
    training: bool = True,
    momentum: float = 0.1,
    eps: float = 1e-05,
) -> torch.Tensor:
    """Clifford batch normalization for each channel across a batch of data.

    Args:
        x (torch.Tensor): Input tensor of shape `(B, C, *D, I)` where I is the blade of the algebra.
        n_blades (int): Number of blades of the Clifford algebra.
        running_mean (torch.Tensor, optional): The tensor with running mean statistics having shape `(I, C)`.
        running_cov (torch.Tensor, optional): The tensor with running covariance statistics having shape `(I, I, C)`.
        weight (Union[torch.Tensor, nn.Parameter], optional): Additional weight tensor which is applied post normalization, and has the shape `(I, I, C)`.
        bias (Union[torch.Tensor, nn.Parameter], optional): Additional bias tensor which is applied post normalization, and has the shape `(I, C)`.
        training (bool, optional): Whether to use the running mean and variance. Defaults to True. Defaults to True.
        momentum (float, optional): Momentum for the running mean and variance. Defaults to 0.1.
        eps (float, optional): Epsilon for the running mean and variance. Defaults to 1e-05.

    Returns:
        (torch.Tensor): Normalized input of shape `(B, C, *D, I)`
    """

    # Check arguments.
    assert (running_mean is None and running_cov is None) or (running_mean is not None and running_cov is not None)
    assert (weight is None and bias is None) or (weight is not None and bias is not None)

    # Whiten and apply affine transformation
    _, C, *_, I = x.shape
    assert I == n_blades
    x_norm = whiten_data(
        x,
        training=training,
        running_mean=running_mean,
        running_cov=running_cov,
        momentum=momentum,
        eps=eps,
    )
    if weight is not None and bias is not None:
        # Check if weight and bias tensors are of correct shape.
        assert weight.shape == (I, I, C)
        assert bias.shape == (I, C)
        # Unsqueeze weight and bias for each dimension except the channel dimension.
        shape = 1, C, *([1] * (x.dim() - 3))
        weight = weight.reshape(I, I, *shape)
        # Apply additional affine transformation post normalization.
        weight_idx = list(range(weight.dim()))
        # TODO: weight multiplication should be changed to geometric product.
        weight = weight.permute(*weight_idx[2:], *weight_idx[:2])
        x_norm = weight.matmul(x_norm[..., None]).squeeze(-1) + bias.reshape(*shape, I)

    return x_norm

complex_batch_norm(x, running_mean=None, running_cov=None, weight=None, bias=None, training=True, momentum=0.1, eps=1e-05) ¤

Applies complex-valued Batch Normalization as described in (Trabelsi et al., 2018) for each channel across a batch of data.

Parameters:

Name Type Description Default
x Tensor

The input complex-valued data is expected to be at least 2d, with shape (B, C, *D), where B is the batch dimension, C the channels/features, and *D the remaining dimensions (if present).

required
running_mean Union[Tensor, Parameter]

The tensor with running mean statistics having shape (2, C).

None
running_cov Union[Tensor, Parameter]

The tensor with running real-imaginary covariance statistics having shape (2, 2, C).

None
weight Tensor

Additional weight tensor which is applied post normalization, and has the shape (2, 2, C).

None
bias Tensor

Additional bias tensor which is applied post normalization, and has the shape (2, C).

None
training bool

Whether to use the running mean and variance. Defaults to True.

True
momentum float

Momentum for the running mean and variance. Defaults to 0.1.

0.1
eps float

Epsilon for the running mean and variance. Defaults to 1e-05.

1e-05

Returns:

Type Description
Tensor

Normalized input as complex tensor of shape (B, C, *D).

Source code in cliffordlayers/nn/functional/batchnorm.py
def complex_batch_norm(
    x: torch.Tensor,
    running_mean: Optional[torch.Tensor] = None,
    running_cov: Optional[torch.Tensor] = None,
    weight: Optional[Union[torch.Tensor, nn.Parameter]] = None,
    bias: Optional[Union[torch.Tensor, nn.Parameter]] = None,
    training: bool = True,
    momentum: float = 0.1,
    eps: float = 1e-05,
) -> torch.Tensor:
    """Applies complex-valued Batch Normalization as described in
    (Trabelsi et al., 2018) for each channel across a batch of data.

    Args:
        x (torch.Tensor): The input complex-valued data is expected to be at least 2d, with shape `(B, C, *D)`, where `B` is the batch dimension, `C` the channels/features, and *D the remaining dimensions (if present).

        running_mean (Union[torch.Tensor, nn.Parameter], optional): The tensor with running mean statistics having shape `(2, C)`.
        running_cov (Union[torch.Tensor, nn.Parameter], optional): The tensor with running real-imaginary covariance statistics having shape `(2, 2, C)`.
        weight (torch.Tensor, optional): Additional weight tensor which is applied post normalization, and has the shape `(2, 2, C)`.
        bias (torch.Tensor, optional): Additional bias tensor which is applied post normalization, and has the shape `(2, C)`.
        training (bool, optional): Whether to use the running mean and variance. Defaults to `True`.
        momentum (float, optional): Momentum for the running mean and variance. Defaults to `0.1`.
        eps (float, optional): Epsilon for the running mean and variance. Defaults to `1e-05`.

    Returns:
        (torch.Tensor): Normalized input as complex tensor of shape `(B, C, *D)`.
    """

    # Check arguments.
    assert (running_mean is None and running_cov is None) or (running_mean is not None and running_cov is not None)
    assert (weight is None and bias is None) or (weight is not None and bias is not None)
    x = torch.view_as_real(x)
    _, C, *_, I = x.shape
    assert I == 2

    # Whiten and apply affine transformation.
    x_norm = whiten_data(
        x,
        training,
        running_mean,
        running_cov,
        momentum,
        eps,
    )
    if weight is not None and bias is not None:
        # Check if weight and bias tensors are of correct shape.
        assert weight.shape == (2, 2, C)
        assert bias.shape == (2, C)
        # Unsqueeze weight and bias for each dimension except the channel dimension.
        shape = 1, C, *([1] * (x.dim() - 3))
        weight = weight.reshape(2, 2, *shape)
        # Apply additional affine transformation post normalization.
        weight_idx = list(range(weight.dim()))
        # TODO weight multiplication should be changed to complex product.
        weight = weight.permute(*weight_idx[2:], *weight_idx[:2])
        x_norm = weight.matmul(x_norm[..., None]).squeeze(-1) + bias.reshape(*shape, 2)

    return torch.view_as_complex(x_norm)

whiten_data(x, training=True, running_mean=None, running_cov=None, momentum=0.1, eps=1e-05) ¤

Jointly whiten features in tensors (B, C, *D, I): take n_blades(I)-dim vectors and whiten individually for each channel dimension C over (B, *D). I is the number of blades in the respective Clifford algebra, e.g. I = 2 for complex numbers.

Parameters:

Name Type Description Default
x Tensor

The tensor to whiten.

required
training bool

Wheter to update the running mean and covariance. Defaults to True.

True
running_mean Tensor

The running mean of shape (I, C). Defaults toNone`.

None
running_cov Tensor

The running covariance of shape (I, I, C) Defaults to None.

None
momentum float

The momentum to use for the running mean and covariance. Defaults to 0.1.

0.1
eps float

A small number to add to the covariance. Defaults to 1e-5.

1e-05

Returns:

Type Description
Tensor

Whitened data of shape (B, C, *D, I).

Source code in cliffordlayers/nn/functional/batchnorm.py
def whiten_data(
    x: torch.Tensor,
    training: bool = True,
    running_mean: Optional[torch.Tensor] = None,
    running_cov: Optional[torch.Tensor] = None,
    momentum: float = 0.1,
    eps: float = 1e-5,
) -> torch.Tensor:
    """Jointly whiten features in tensors `(B, C, *D, I)`: take n_blades(I)-dim vectors
    and whiten individually for each channel dimension C over `(B, *D)`.
    I is the number of blades in the respective Clifford algebra, e.g. I = 2 for complex numbers.

    Args:
        x (torch.Tensor): The tensor to whiten.
        training (bool, optional): Wheter to update the running mean and covariance. Defaults to `True`.
        running_mean (torch.Tensor, optional): The running mean of shape `(I, C). Defaults to `None`.
        running_cov (torch.Tensor, optional): The running covariance of shape `(I, I, C)` Defaults to `None`.
        momentum (float, optional): The momentum to use for the running mean and covariance. Defaults to `0.1`.
        eps (float, optional): A small number to add to the covariance. Defaults to 1e-5.

    Returns:
        (torch.Tensor): Whitened data of shape `(B, C, *D, I)`.
    """

    assert x.dim() >= 3
    # Get whitening shape of [1, C, ...]
    _, C, *_, I = x.shape
    B_dim, C_dim, *D_dims, I_dim = range(len(x.shape))
    shape = 1, C, *([1] * (x.dim() - 3))

    # Get feature mean.
    if not (running_mean is None or running_mean.shape == (I, C)):
        raise ValueError(f"Running_mean expected to be none, or of shape ({I}, {C}).")
    if training or running_mean is None:
        mean = x.mean(dim=(B_dim, *D_dims))
        if running_mean is not None:
            running_mean += momentum * (mean.data.permute(1, 0) - running_mean)
    else:
        mean = running_mean.permute(1, 0)

    # Get feature covariance.
    x = x - mean.reshape(*shape, I)
    if not (running_cov is None or running_cov.shape == (I, I, C)):
        raise ValueError(f"Running_cov expected to be none, or of shape ({I}, {I}, {C}).")
    if training or running_cov is None:
        # B, C, *D, I -> C, I, B, *D
        X = x.permute(C_dim, I_dim, B_dim, *D_dims).flatten(2, -1)
        # Covariance XX^T matrix of shape C x I x I
        cov = torch.matmul(X, X.transpose(-1, -2)) / X.shape[-1]
        if running_cov is not None:
            running_cov += momentum * (cov.data.permute(1, 2, 0) - running_cov)

    else:
        cov = running_cov.permute(2, 0, 1)

    # Upper triangle Cholesky decomposition of covariance matrix: U^T U = Cov
    # eye = eps * torch.eye(I, device=cov.device, dtype=cov.dtype).unsqueeze(0)
    # Modified the scale of eps to help prevent the occurence of negative-definite matrices
    # 1e-5 may not fit the scale of matrices with large numbers
    max_values = torch.amax(cov, dim=(1, 2))
    A = torch.eye(cov.shape[-1], device=cov.device, dtype=cov.dtype)
    eye = eps * torch.einsum('ij,k->kij', A, max_values)
    U = torch.linalg.cholesky(cov + eye).mH
    # Invert Cholesky decomposition, returns tensor of shape [B, C, *D, I]
    x_whiten = torch.linalg.solve_triangular(
        U.reshape(*shape, I, I),
        x.unsqueeze(-1),
        upper=True,
    ).squeeze(-1)
    return x_whiten

clifford_group_norm(x, n_blades, num_groups=1, running_mean=None, running_cov=None, weight=None, bias=None, training=True, momentum=0.1, eps=1e-05) ¤

Clifford group normalization

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (B, C, *D, I) where I is the blade of the algebra.

required
n_blades int

Number of blades of the Clifford algebra.

required
num_groups int

Number of groups for which normalization is calculated. Defaults to 1. For num_groups == 1, it effectively applies Clifford layer normalization, for num_groups == C, it effectively applies Clifford instance normalization.

1
running_mean Tensor

The tensor with running mean statistics having shape (I, C / num_groups). Defaults to None.

None
running_cov Tensor

The tensor with running real-imaginary covariance statistics having shape (I, I, C / num_groups). Defaults to None.

None
weight Union[Tensor, Parameter]

Additional weight tensor which is applied post normalization, and has the shape (I, I, C / num_groups). Defaults to None.

None
bias Union[Tensor, Parameter]

Additional bias tensor which is applied post normalization, and has the shape (I, C / num_groups). Defaults to None.

None
training bool

Whether to use the running mean and variance. Defaults to True.

True
momentum float

Momentum for the running mean and variance. Defaults to 0.1.

0.1
eps float

Epsilon for the running mean and variance. Defaults to 1e-05.

1e-05

Returns:

Type Description
Tensor

Group normalized input of shape (B, C, *D, I).

Source code in cliffordlayers/nn/functional/groupnorm.py
def clifford_group_norm(
    x: torch.Tensor,
    n_blades: int,
    num_groups: int = 1,
    running_mean: Optional[torch.Tensor] = None,
    running_cov: Optional[torch.Tensor] = None,
    weight: Optional[Union[torch.Tensor, nn.Parameter]] = None,
    bias: Optional[Union[torch.Tensor, nn.Parameter]] = None,
    training: bool = True,
    momentum: float = 0.1,
    eps: float = 1e-05,
) -> torch.Tensor:
    """Clifford group normalization

    Args:
        x (torch.Tensor): Input tensor of shape `(B, C, *D, I)` where I is the blade of the algebra.

        n_blades (int): Number of blades of the Clifford algebra.

        num_groups (int): Number of groups for which normalization is calculated. Defaults to 1.
                          For `num_groups == 1`, it effectively applies Clifford layer normalization, for `num_groups == C`, it effectively applies Clifford instance normalization.

        running_mean (torch.Tensor, optional): The tensor with running mean statistics having shape `(I, C / num_groups)`. Defaults to None.
        running_cov (torch.Tensor, optional): The tensor with running real-imaginary covariance statistics having shape `(I, I, C / num_groups)`. Defaults to None.

        weight (Union[torch.Tensor, nn.Parameter], optional): Additional weight tensor which is applied post normalization, and has the shape `(I, I, C / num_groups)`. Defaults to None.

        bias (Union[torch.Tensor, nn.Parameter], optional): Additional bias tensor which is applied post normalization, and has the shape `(I, C / num_groups)`. Defaults to None.

        training (bool, optional): Whether to use the running mean and variance. Defaults to True.
        momentum (float, optional): Momentum for the running mean and variance. Defaults to 0.1.
        eps (float, optional): Epsilon for the running mean and variance. Defaults to 1e-05.

    Returns:
        (torch.Tensor): Group normalized input of shape `(B, C, *D, I)`.
    """

    # Check arguments.
    assert (running_mean is None and running_cov is None) or (running_mean is not None and running_cov is not None)
    assert (weight is None and bias is None) or (weight is not None and bias is not None)

    B, C, *D, I = x.shape
    assert num_groups <= C
    assert C % num_groups == 0, "Number of channels should be evenly divisible by the number of groups."
    assert I == n_blades
    if weight is not None and bias is not None:
        # Check if weight and bias tensors are of correct shape.
        assert weight.shape == (I, I, int(C / num_groups))
        assert bias.shape == (I, int(C / num_groups))
        weight = weight.repeat(1, 1, B)
        bias = bias.repeat(1, B)

    def _instance_norm(
        x,
        num_groups,
        running_mean,
        running_cov,
        weight,
        bias,
        training,
        momentum,
        eps,
    ):
        if running_mean is not None and running_cov is not None:
            assert running_mean.shape == (I, int(C / num_groups))
            running_mean_orig = running_mean
            running_mean = running_mean_orig.repeat(1, B)
            assert running_cov.shape == (I, I, int(C / num_groups))
            running_cov_orig = running_cov
            running_cov = running_cov_orig.repeat(1, 1, B)

        # Reshape such that batch normalization can be applied.
        # For num_groups == 1, it defaults to layer normalization,
        # for num_groups == C, it defaults to instance normalization.
        x_reshaped = x.reshape(1, int(B * C / num_groups), num_groups, *D, I)

        x_norm = clifford_batch_norm(
            x_reshaped,
            n_blades,
            running_mean,
            running_cov,
            weight,
            bias,
            training,
            momentum,
            eps,
        )

        # Reshape back running mean and running var.
        if running_mean is not None:
            running_mean_orig.copy_(running_mean.view(I, B, int(C / num_groups)).mean(1, keepdim=False))
        if running_cov is not None:
            running_cov_orig.copy_(running_cov.view(I, I, B, int(C / num_groups)).mean(1, keepdim=False))

        return x_norm.view(B, C, *D, I)

    return _instance_norm(
        x,
        num_groups,
        running_mean,
        running_cov,
        weight,
        bias,
        training,
        momentum,
        eps,
    )

complex_group_norm(x, num_groups=1, running_mean=None, running_cov=None, weight=None, bias=None, training=True, momentum=0.1, eps=1e-05) ¤

Group normalization for complex-valued tensors.

Parameters:

Name Type Description Default
x Tensor

The input complex-valued data is expected to be at least 2d, with shape (B, C, *D), where B is the batch dimension, C the channels/features, and *D the remaining dimensions (if present).

required
num_groups int

Number of groups for which normalization is calculated. Defaults to 1. For num_groups == 1, it effectively applies complex-valued layer normalization; for num_groups == C, it effectively applies complex-valued instance normalization.

1
running_mean Tensor

The tensor with running mean statistics having shape (2, C / num_groups). Defaults to None.

None
running_cov Tensor

The tensor with running real-imaginary covariance statistics having shape (2, 2, C / num_groups). Defaults to None.

None
weight Union[Tensor, Parameter]

Additional weight tensor which is applied post normalization, and has the shape (2, 2, C/ num_groups). Defaults to None.

None
bias Union[Tensor, Parameter]

Additional bias tensor which is applied post normalization, and has the shape (2, C / num_groups). Defaults to None.

None
training bool

Whether to use the running mean and variance. Defaults to True.

True
momentum float

Momentum for the running mean and variance. Defaults to 0.1.

0.1
eps float

Epsilon for the running mean and variance. Defaults to 1e-05.

1e-05

Returns:

Type Description
Tensor

Normalized input as complex tensor of shape (B, C, *D).

Source code in cliffordlayers/nn/functional/groupnorm.py
def complex_group_norm(
    x: torch.Tensor,
    num_groups: int = 1,
    running_mean: Optional[torch.Tensor] = None,
    running_cov: Optional[torch.Tensor] = None,
    weight: Optional[Union[torch.Tensor, nn.Parameter]] = None,
    bias: Optional[Union[torch.Tensor, nn.Parameter]] = None,
    training: bool = True,
    momentum: float = 0.1,
    eps: float = 1e-05,
):
    """Group normalization for complex-valued tensors.

    Args:
        x (torch.Tensor): The input complex-valued data is expected to be at least 2d, with
                          shape `(B, C, *D)`, where `B` is the batch dimension, `C` the
                          channels/features, and *D the remaining dimensions (if present).

        num_groups (int): Number of groups for which normalization is calculated. Defaults to 1.
                          For `num_groups == 1`, it effectively applies complex-valued layer normalization;
                          for `num_groups == C`, it effectively applies complex-valued instance normalization.

        running_mean (torch.Tensor, optional): The tensor with running mean statistics having shape `(2, C / num_groups)`. Defaults to None.
        running_cov (torch.Tensor, optional): The tensor with running real-imaginary covariance statistics having shape `(2, 2, C / num_groups)`. Defaults to None.

        weight (Union[torch.Tensor, nn.Parameter], optional): Additional weight tensor which is applied post normalization, and has the shape `(2, 2, C/ num_groups)`. Defaults to None.

        bias (Union[torch.Tensor, nn.Parameter], optional): Additional bias tensor which is applied post normalization, and has the shape `(2, C / num_groups)`. Defaults to None.

        training (bool, optional): Whether to use the running mean and variance. Defaults to True.
        momentum (float, optional): Momentum for the running mean and variance. Defaults to 0.1.
        eps (float, optional): Epsilon for the running mean and variance. Defaults to 1e-05.

    Returns:
        (torch.Tensor): Normalized input as complex tensor of shape `(B, C, *D)`.
    """

    # Check arguments.
    assert (running_mean is None and running_cov is None) or (running_mean is not None and running_cov is not None)
    assert (weight is None and bias is None) or (weight is not None and bias is not None)

    B, C, *D = x.shape
    assert C % num_groups == 0, "Number of channels should be evenly divisible by the number of groups."
    assert num_groups <= C
    if weight is not None and bias is not None:
        # Check if weight and bias tensors are of correct shape.
        assert weight.shape == (2, 2, int(C / num_groups))
        assert bias.shape == (2, int(C / num_groups))
        weight = weight.repeat(1, 1, B)
        bias = bias.repeat(1, B)

    def _instance_norm(
        x,
        num_groups,
        running_mean,
        running_cov,
        weight,
        bias,
        training,
        momentum,
        eps,
    ):
        if running_mean is not None and running_cov is not None:
            assert running_mean.shape == (2, int(C / num_groups))
            running_mean_orig = running_mean
            running_mean = running_mean_orig.repeat(1, B)
            assert running_cov.shape == (2, 2, int(C / num_groups))
            running_cov_orig = running_cov
            running_cov = running_cov_orig.repeat(1, 1, B)

        # Reshape such that batch normalization can be applied.
        # For num_groups == 1, it defaults to layer normalization,
        # for num_groups == C, it defaults to instance normalization.
        x_reshaped = x.view(1, int(B * C / num_groups), num_groups, *D)

        x_norm = complex_batch_norm(
            x_reshaped,
            running_mean,
            running_cov,
            weight=weight,
            bias=bias,
            training=training,
            momentum=momentum,
            eps=eps,
        )

        # Reshape back running mean and running var.
        if running_mean is not None:
            running_mean_orig.copy_(running_mean.view(2, B, int(C / num_groups)).mean(1, keepdim=False))
        if running_cov is not None:
            running_cov_orig.copy_(running_cov.view(2, 2, B, int(C / num_groups)).mean(2, keepdim=False))

        return x_norm.view(B, C, *D)

    return _instance_norm(
        x,
        num_groups,
        running_mean,
        running_cov,
        weight=weight,
        bias=bias,
        training=training,
        momentum=momentum,
        eps=eps,
    )