# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
r"""Parameter transformation components for SciStanPy models.
This module provides a comprehensive library of mathematical transformations that
can be applied to model parameters. These transformations enable complex model
construction through composition of simple mathematical operations while maintaining
automatic differentiation capabilities and Stan code generation.
"""
from __future__ import annotations
from abc import abstractmethod
from typing import Callable, Optional, overload, TYPE_CHECKING, Union
import numpy as np
import numpy.typing as npt
import scipy.special as sp
import torch
import torch.nn.functional as F
from scistanpy import utils
from scistanpy.model.components import abstract_model_component, constants
if TYPE_CHECKING:
from scistanpy import custom_types
# pylint: disable=too-many-lines, line-too-long
# pylint: enable=arguments-differ
# pylint: enable=arguments-differ
# Basic arithmetic operations
[docs]
class AddParameter(BinaryTransformedParameter):
"""Addition transformation for two parameters.
Implements element-wise addition of two parameters: ``result = dist1 + dist2``
Example:
.. code-block:: python
# Binary operations automatically handle broadcasting
x = ssp.parameters.Normal(mu=0, sigma=1, shape=(5,))
y = ssp.parameters.Normal(mu=0, sigma=1, shape=(3, 1))
# Result has shape (3, 5) through broadcasting
combined = x + y
"""
STAN_OPERATOR: str = "+"
def run_np_torch_op(self, dist1, dist2):
"""Perform element-wise addition.
:param dist1: First addend
:param dist2: Second addend
:returns: Sum of the two inputs
"""
return dist1 + dist2
[docs]
class SubtractParameter(BinaryTransformedParameter):
"""Subtraction transformation for two parameters.
Implements element-wise subtraction of two parameters: ``result = dist1 - dist2``
Example:
.. code-block:: python
# Binary operations automatically handle broadcasting
x = ssp.parameters.Normal(mu=0, sigma=1, shape=(5,))
y = ssp.parameters.Normal(mu=0, sigma=1, shape=(3, 1))
# Result has shape (3, 5) through broadcasting
difference = x - y
"""
STAN_OPERATOR: str = "-"
def run_np_torch_op(self, dist1, dist2):
"""Perform element-wise subtraction.
:param dist1: Minuend
:param dist2: Subtrahend
:returns: Difference of the two inputs
"""
return dist1 - dist2
[docs]
class MultiplyParameter(BinaryTransformedParameter):
"""Element-wise multiplication transformation for two parameters.
Implements element-wise multiplication of two parameters: ``result = dist1 * dist2``
This transformation is fundamental for scaling, interaction effects, and
product relationships between parameters.
Example:
.. code-block:: python
# Binary operations automatically handle broadcasting
x = ssp.parameters.Normal(mu=0, sigma=1, shape=(5,))
y = ssp.parameters.Normal(mu=0, sigma=1, shape=(3, 1))
# Result has shape (3, 5) through broadcasting
product = x * y
"""
STAN_OPERATOR: str = ".*"
def run_np_torch_op(self, dist1, dist2):
"""Perform element-wise multiplication.
:param dist1: First factor
:param dist2: Second factor
:returns: Product of the two inputs
"""
return dist1 * dist2
[docs]
class DivideParameter(BinaryTransformedParameter):
"""Element-wise division transformation for two parameters.
Implements element-wise division of two parameters: ``result = dist1 / dist2``
This transformation is used for ratios, rates, normalized quantities,
and relative measures. Care must be taken to avoid division by zero.
Example:
.. code-block:: python
# Binary operations automatically handle broadcasting
x = ssp.parameters.Normal(mu=0, sigma=1, shape=(5,))
y = ssp.parameters.Normal(mu=1, sigma=0.5, shape=(3, 1))
# Result has shape (3, 5) through broadcasting
ratio = x / y
"""
STAN_OPERATOR: str = "./"
def run_np_torch_op(self, dist1, dist2):
"""Perform element-wise division.
:param dist1: Dividend
:param dist2: Divisor
:returns: Quotient of the two inputs
"""
return dist1 / dist2
[docs]
class PowerParameter(BinaryTransformedParameter):
"""Element-wise exponentiation transformation for two parameters.
Implements element-wise exponentiation: ``result = dist1 ^ dist2``
This transformation is used for power relationships, polynomial terms,
and exponential scaling effects.
Example:
.. code-block:: python
# Binary operations automatically handle broadcasting
x = ssp.parameters.Normal(mu=2, sigma=1, shape=(5,))
y = ssp.parameters.Normal(mu=3, sigma=0.5, shape=(3, 1))
# Result has shape (3, 5) through broadcasting
exponentiated = x ** y
"""
STAN_OPERATOR: str = ".^"
def run_np_torch_op(self, dist1, dist2):
"""Perform element-wise exponentiation.
:param dist1: Base
:param dist2: Exponent
:returns: dist1 raised to the power of dist2
"""
return dist1**dist2
[docs]
class NegateParameter(UnaryTransformedParameter):
"""Unary negation transformation for parameters.
Implements element-wise negation of a parameter: ``result = -dist1``
Example:
.. code-block:: python
x = ssp.parameters.Normal(mu=0, sigma=1, shape=(5,))
negated = -x
"""
STAN_OPERATOR: str = "-"
def run_np_torch_op(self, dist1):
"""Perform element-wise negation.
:param dist1: Input parameter
:returns: Negated input
"""
return -dist1
[docs]
class AbsParameter(UnaryTransformedParameter):
r"""Absolute value transformation for parameters.
Implements element-wise absolute value: ``result = \|dist1\|``
This transformation ensures non-negative values and is commonly used
for magnitudes, distances, and ensuring positive parameters.
This transformation is accessed through the :py:func:`~scistanpy.operations.abs_`
function.
Example:
.. code-block:: python
x = ssp.parameters.Normal(mu=0, sigma=1, shape=(5,))
abs_x = ssp.operations.abs_(x) # Effectively half-normal distribution
"""
LOWER_BOUND: "custom_types.Float" = 0.0
def run_np_torch_op(self, dist1):
"""Compute element-wise absolute value.
:param dist1: Input parameter
:returns: Absolute value of input
"""
return utils.choose_module(dist1).abs(dist1)
def write_stan_operation(self, dist1: str) -> str:
"""Generate Stan absolute value function call.
:param dist1: Formatted parameter string
:type dist1: str
:returns: Stan ``abs()`` function call
:rtype: str
"""
return f"abs({dist1})"
[docs]
class LogParameter(UnaryTransformedParameter):
"""Natural logarithm transformation for parameters.
Implements element-wise natural logarithm: ``result = ln(dist1)``
This transformation is fundamental for log-scale modeling, multiplicative
effects on additive scales, and ensuring positive-valued parameters.
This transformation is accessed through the :py:func:`~scistanpy.operations.log`
function.
Example:
.. code-block:: python
x = ssp.parameters.LogNormal(mu=0, sigma=1, shape=(5,))
log_x = ssp.operations.log(x) # Normally distributed
"""
POSITIVE_PARAMS = {"dist1"}
def run_np_torch_op(self, dist1):
"""Compute element-wise natural logarithm.
:param dist1: Input parameter (must be positive)
:returns: Natural logarithm of input
"""
return utils.choose_module(dist1).log(dist1)
def write_stan_operation(self, dist1: str) -> str:
"""Generate Stan logarithm function call.
:param dist1: Formatted parameter string
:type dist1: str
:returns: Stan ``log()`` function call
:rtype: str
"""
return f"log({dist1})"
[docs]
class ExpParameter(UnaryTransformedParameter):
"""Exponential transformation for parameters.
Implements element-wise exponential function: ``result = exp(dist1)``
This transformation is used for ensuring positive values, exponential
growth models, and converting from log-scale to natural scale.
This transformation is accessed through the :py:func:`~scistanpy.operations.exp`
function.
Example:
.. code-block:: python
x = ssp.parameters.Normal(mu=0, sigma=1, shape=(5,))
exp_x = ssp.operations.exp(x) # Log-normally distributed
"""
LOWER_BOUND: "custom_types.Float" = 0.0
def run_np_torch_op(self, dist1):
return utils.choose_module(dist1).exp(dist1)
def write_stan_operation(self, dist1: str) -> str:
"""Generate Stan exponential function call.
:param dist1: Formatted parameter string
:type dist1: str
:returns: Stan exp() function call
:rtype: str
"""
return f"exp({dist1})"
[docs]
class NormalizeParameter(UnaryTransformedParameter):
"""Normalization transformation that scales values to sum to 1 in the last dimension
of the input.
Implements element-wise normalization where each vector is divided by its sum,
creating probability vectors or normalized weights that sum to unity.
This transformation is essential for creating probability vectors from non-negative
weights.
This transformation is accessed through the :py:func:`~scistanpy.operations.normalize`
function.
.. important::
The normalization is applied along the last dimension only. This cannot
be changed to other dimensions in the current implementation.
Example:
.. code-block:: python
x = ssp.parameters.Exponential(rate=1.0, shape=(4, 5))
# Normalize along last dimension (size 5)
normalized_x = ssp.operations.normalize(x)
# Each of the 4 vectors of length 5 sums to 1
"""
LOWER_BOUND: "custom_types.Float" = 0.0
UPPER_BOUND: "custom_types.Float" = 1.0
def run_np_torch_op(self, dist1):
"""Compute normalization by dividing by sum along last dimension.
:param dist1: Input parameter (must be non-negative)
:returns: Normalized values that sum to 1 along last dimension
"""
if isinstance(dist1, np.ndarray):
return dist1 / np.sum(dist1, keepdims=True, axis=-1)
elif isinstance(dist1, torch.Tensor):
return dist1 / dist1.sum(dim=-1, keepdim=True)
# Error if the type is not supported
else:
raise TypeError(
"Unsupported type for dist1. Expected torch.Tensor or np.ndarray."
)
def write_stan_operation(self, dist1: str) -> str:
"""Generate Stan normalization code.
:param dist1: Formatted parameter string
:type dist1: str
:returns: Stan code dividing by sum
:rtype: str
"""
return f"{dist1} / sum({dist1})"
[docs]
class NormalizeLogParameter(UnaryTransformedParameter):
"""Log-space normalization transformation for log-probability vectors over the
last dimension of the input.
Implements normalization in log-space where log-probabilities are adjusted
so that their exponentiated values sum to 1. This is equivalent to
subtracting the log-sum-exp from each element.
This transformation is crucial for:
- Normalizing log-probabilities without exponentiation
- Stable computation with very small probabilities
- Log-space categorical distributions
The log-sum-exp operation provides numerical stability by avoiding
overflow/underflow issues common with direct exponentiation.
This transformation is accessed through the :py:func:`~scistanpy.operations.normalize_log`
function.
.. important::
As with
:py:class:`~scistanpy.model.components.transformations.transformed_parameters.NormalizeParameter`,
this is performed over the last dimension only.
Example:
.. code-block:: python
x = ssp.parameters.Normal(mu=0, sigma=1, shape=(4, 5))
log_probs = ssp.operations.normalize_log(x)
# Each of the exponentials of the 4 vectors of length 5 sums to 1 in
# probability space
"""
UPPER_BOUND: "custom_types.Float" = 0.0
def run_np_torch_op(self, dist1):
"""Compute log-space normalization using log-sum-exp.
:param dist1: Input log-probabilities
:returns: Normalized log-probabilities
"""
if isinstance(dist1, torch.Tensor):
return dist1 - torch.logsumexp(dist1, keepdims=True, dim=-1)
elif isinstance(dist1, np.ndarray):
return dist1 - sp.logsumexp(dist1, keepdims=True, axis=-1)
else:
raise TypeError(
"Unsupported type for dist1. Expected torch.Tensor or np.ndarray."
)
def write_stan_operation(self, dist1: str) -> str:
"""Generate Stan log-space normalization code.
:param dist1: Formatted parameter string
:type dist1: str
:returns: Stan code using log_sum_exp function
:rtype: str
"""
return f"{dist1} - log_sum_exp({dist1})"
[docs]
class Reduction(UnaryTransformedParameter):
"""Base class for operations that reduce dimensionality.
This abstract base class provides infrastructure for transformations that
reduce the size of the last dimension through operations like ``sum`` or ``log-sum-exp``.
It handles shape management and provides specialized indexing behavior for
reductions.
:param dist1: Parameter to reduce
:type dist1: custom_types.CombinableParameterType
:param keepdims: Whether to keep the reduced dimension as size 1. Defaults to False.
:type keepdims: bool
:param kwargs: Additional arguments passed to parent class
.. important::
Currently, all reductions in SciStanPy are performed along the last dimension
of the input parameter. This cannot be changed to other dimensions in the
current implementation.
"""
SHAPE_CHECK = False
TORCH_FUNC: Callable[[npt.NDArray], npt.NDArray]
"""The PyTorch function to use for the reduction operation."""
NP_FUNC: Callable[[npt.NDArray], npt.NDArray]
"""The NumPy function to use for the reduction operation."""
def __init__(
self,
dist1: "custom_types.CombinableParameterType",
keepdims: bool = False,
**kwargs,
):
"""Initialize reduction with automatic shape calculation. Reduction is always
over the last dimension.
:param dist1: Parameter to reduce
:type dist1: custom_types.CombinableParameterType
:param keepdims: Whether to keep reduced dimension. Defaults to False.
:type keepdims: bool
:param kwargs: Additional arguments for parent initialization
The initialization automatically calculates the output shape by
removing the last dimension (or keeping it as size 1 if ``keepdims=True``).
"""
# Record whether to keep the last dimension
self.keepdims = keepdims
# The shape is the leading dimensions of the input parameter plus a singleton
# dimension if keepdims is True.
if "shape" not in kwargs:
shape = dist1.shape[:-1]
if keepdims:
shape += (1,)
kwargs["shape"] = shape
# Init as normal
super().__init__(dist1=dist1, **kwargs)
def run_np_torch_op(self, dist1, keepdim: bool | None = None):
"""Execute reduction operation with backend-appropriate function.
:param dist1: Input parameter to reduce
:param keepdim: Whether to keep dimensions (static method only)
:type keepdim: Optional[bool]
:returns: Reduced parameter values
:raises ValueError: If keepdim is provided for instance method calls
The method automatically selects between PyTorch and NumPy reduction
functions and applies them along the last dimension.
"""
# Keepdim can only be provided if called as a static method
if self is None:
keepdim = bool(keepdim)
elif keepdim is not None:
raise ValueError(
"The `keepdim` argument can only be provided when calling this method "
"as a static method."
)
else:
keepdim = self.keepdims
if isinstance(dist1, torch.Tensor):
return self.__class__.TORCH_FUNC(dist1, keepdim=keepdim, dim=-1)
else:
return self.__class__.NP_FUNC(dist1, keepdims=keepdim, axis=-1)
[docs]
def get_index_offset(
self,
query: Union[str, "abstract_model_component.AbstractModelComponent"],
offset_adjustment: int = 0,
) -> int:
"""Return zero offset for all reduction operations.
:param query: Component or parameter name (ignored)
:param offset_adjustment: Offset adjustment (ignored)
:returns: Always returns 0
:rtype: int
Reductions always return zero offset because they operate on the
last dimension and don't require complex indexing adjustments.
"""
return 0
[docs]
def get_assign_depth(self) -> int:
"""Reductions require one additional level of loop nesting to properly
iterate over the dimension being reduced. Thus, the assignment depth
is always one greater than the parent class.
"""
return super().get_assign_depth() + 1
[docs]
class LogSumExpParameter(Reduction):
"""Log-sum-exp reduction transformation.
Computes the logarithm of the sum of exponentials along the last dimension.
This transformation is fundamental for:
- Normalizing log-probabilities
- Computing partition functions
- Stable softmax computations
- Log-space mixture models
This transformation is accessed through the :py:func:`~scistanpy.operations.logsumexp`
function.
.. important::
The log-sum-exp is performed over the last dimension only. This cannot
be changed to other dimensions in the current implementation.
Example:
.. code-block:: python
weights = ssp.parameters.Exponential(rate=1.0, shape=(10, 5))
log_weights = ssp.operations.log(weights) # Log-space weights
log_partition = ssp.operations.logsumexp(log_weights) # Shape (10,)
log_partition2 = ssp.operations.logsumexp(log_weights, keepdims=True) # Shape (10, 1)
"""
TORCH_FUNC = torch.logsumexp
NP_FUNC = sp.logsumexp
def write_stan_operation(self, dist1: str) -> str:
"""Generate Stan log_sum_exp function call.
:param dist1: Formatted parameter string
:type dist1: str
:returns: Stan log_sum_exp function call
:rtype: str
"""
return f"log_sum_exp({dist1})"
[docs]
class SumParameter(Reduction):
"""Sum reduction transformation.
Computes the sum of values along the last dimension. This is accessed through
the :py:func:`~scistanpy.operations.sum` function.
Example:
.. code-block:: python
rates = ssp.parameters.Exponential(rate=1.0, shape=(10, 5))
summed = ssp.operations.sum(rates, keepdims=False) # Shape (10,)
summed2 = ssp.operations.sum(rates, keepdims=True) # Shape (10, 1)
"""
TORCH_FUNC = torch.sum
NP_FUNC = np.sum
def write_stan_operation(self, dist1: str) -> str:
"""Generate Stan sum function call.
:param dist1: Formatted parameter string
:type dist1: str
:returns: Stan sum function call
:rtype: str
"""
return f"sum({dist1})"
[docs]
class Log1pExpParameter(UnaryTransformedParameter):
r"""Calculates :math:`\log(1 + \exp(x))` transformation in a numerically stable
way.
This transformation is accessed through the :py:func:`~scistanpy.operations.log1p_exp`
function.
"""
def run_np_torch_op(self, dist1):
"""Compute log(1 + exp(x)) with numerical stability.
:param dist1: Input parameter
:returns: log(1 + exp(dist1))
Uses logaddexp(0, x) for numerical stability, which handles
both overflow (large positive x) and underflow (large negative x).
"""
# If using torch, use the logaddexp function directly.
if isinstance(dist1, torch.Tensor):
return torch.logaddexp(torch.tensor(0.0, device=dist1.device), dist1)
# If using numpy, use the logaddexp function from scipy.
elif isinstance(dist1, np.ndarray):
return np.logaddexp(0.0, dist1)
# Error if the type is not supported
else:
raise TypeError(
"Unsupported type for dist1. Expected torch.Tensor or np.ndarray."
)
def write_stan_operation(self, dist1: str) -> str:
"""Generate Stan log1p_exp function call.
:param dist1: Formatted parameter string
:type dist1: str
:returns: Stan ``log1p_exp`` function call
:rtype: str
"""
return f"log1p_exp({dist1})"
[docs]
class SigmoidParameter(UnaryTransformedParameter):
r"""Sigmoid (logistic) transformation for parameters.
Implements the sigmoid function: :math:`result = \frac{1}{1 + \exp(-dist1)}`
The sigmoid function is essential for:
- Converting unbounded values to probabilities
- Logistic regression and classification
- Smooth transitions between bounds
- Activation functions in neural networks
This implementation uses numerically stable computation methods to
avoid overflow/underflow issues.
This is accessed through the :py:func:`~scistanpy.operations.sigmoid` function.
.. warning::
It is easy to define a model that is not identifiable when using
the sigmoid transformation. Make sure that there are sufficient constraints
on the input parameter to ensure a well-defined posterior distribution.
Example:
>>> logits = Normal(mu=0, sigma=1)
>>> probabilities = SigmoidParameter(logits)
"""
UPPER_BOUND: "custom_types.Float" = 1.0
LOWER_BOUND: "custom_types.Float" = 0.0
def run_np_torch_op(self, dist1):
"""Compute sigmoid function with numerical stability.
:param dist1: Input parameter (logits)
:returns: Sigmoid-transformed values in (0, 1)
:rtype: Union[torch.Tensor, np.ndarray]
:raises TypeError: If input type is not supported
Uses numerically stable implementations:
- PyTorch: Built-in torch.sigmoid function
- NumPy: Custom stable implementation from utils
"""
# If using torch, use the sigmoid function directly.
if isinstance(dist1, torch.Tensor):
return torch.sigmoid(dist1)
# If using numpy, we manually calculate the sigmoid function using a more
# numerically stable approach.
elif isinstance(dist1, np.ndarray):
return utils.stable_sigmoid(dist1)
# If using a different type, raise an error.
else:
raise TypeError(
"Unsupported type for dist1. Expected torch.Tensor or np.ndarray."
)
def write_stan_operation(self, dist1: str) -> str:
"""Generate Stan inverse logit function call.
:param dist1: Formatted parameter string
:type dist1: str
:returns: Stan ``inv_logit()`` function call
:rtype: str
"""
return f"inv_logit({dist1})"
[docs]
class LogSigmoidParameter(UnaryTransformedParameter):
r"""Performs the sigmoid transformation followed by the natural logarithm in
a numerically stable manner: :math:`\log(\text{sigmoid}(x)) = -\log(1 + \exp(-x))`.
This transformation is useful for converting unbounded values to log-probabilities
in the range (-∞, 0). It is commonly used in logistic regression, binary
classification, and scenarios where log-probabilities are required.
This transformation is accessed through the :py:func:`~scistanpy.operations.logsigmoid`
function.
"""
UPPER_BOUND: "custom_types.Float" = 0.0
def run_np_torch_op(self, dist1):
if isinstance(dist1, torch.Tensor):
return F.logsigmoid(dist1) # pylint: disable=not-callable
elif isinstance(dist1, np.ndarray):
return np.log(utils.stable_sigmoid(dist1))
else:
raise TypeError(
"Unsupported type for dist1. Expected torch.Tensor or np.ndarray."
)
def write_stan_operation(self, dist1: str) -> str:
return f"log_inv_logit({dist1})"
[docs]
class ExponentialGrowth(ExpParameter):
r"""
A transformed parameter that models exponential growth. Specifically, parameters
``t``, ``A``, and ``r`` are used to calculate the exponential growth model as follows:
.. math::
x = A \text{e}^{rt}
This transformation is accessed through the :py:func:`~scistanpy.operations.exponential_growth`
function.
Example:
>>> time = np.array([[0], [1], [2], [3], [4]]) # Shape (5, 1)
>>> baseline = Dirichlet(alpha = 1.0, shape = (10,))
>>> rate = Exponential(beta = 1.0)
>>> growth = ssp.operations.exponential_growth(t=time, A=baseline, r=rate)
"""
def __init__(
self,
*,
t: "custom_types.CombinableParameterType",
A: "custom_types.CombinableParameterType",
r: "custom_types.CombinableParameterType",
**kwargs,
):
"""Initialize exponential growth model.
:param t: Time parameter
:param A: Amplitude parameter
:param r: Rate parameter
:param kwargs: Additional arguments
"""
super(UnaryTransformedParameter, self).__init__(t=t, A=A, r=r, **kwargs)
# pylint: disable=arguments-differ
@overload
def run_np_torch_op(
self, t: torch.Tensor, A: torch.Tensor, r: torch.Tensor
) -> torch.Tensor: ...
@overload
def run_np_torch_op(
self,
t: "custom_types.SampleType",
A: "custom_types.SampleType",
r: "custom_types.SampleType",
) -> npt.NDArray: ...
def run_np_torch_op(self, *, t, A, r):
r"""Compute exponential growth: :math:`A\text{e}^{rt}`.
:param t: Time values
:param A: Amplitude values
:param r: Rate values
:returns: Exponential growth values
"""
return A * ExpParameter.run_np_torch_op(self, r * t)
# pylint: enable=arguments-differ
def write_stan_operation( # pylint: disable=arguments-differ
self, t: str, A: str, r: str
) -> str:
par_string = super().write_stan_operation(f"{r} .* {t}")
return f"{A} .* {par_string}"
[docs]
class BinaryExponentialGrowth(ExpParameter):
r"""Binary exponential growth for two time points, which is a special case of
exponential growth for modeling with only two time points assuming
:math:`t_0 = 0` and :math:`t_1 = 1`. This reduces to
.. math::
x = A\text{e}^r
This transformation is accessed through the
:py:func:`~scistanpy.operations.binary_exponential_growth` function.
"""
def __init__(
self,
A: "custom_types.CombinableParameterType",
r: "custom_types.CombinableParameterType",
**kwargs,
):
"""Initialize binary exponential growth.
:param A: Amplitude parameter
:param r: Rate parameter
:param kwargs: Additional arguments
"""
super(UnaryTransformedParameter, self).__init__(A=A, r=r, **kwargs)
# pylint: disable=arguments-differ
@overload
def run_np_torch_op(self, A: torch.Tensor, r: torch.Tensor) -> torch.Tensor: ...
@overload
def run_np_torch_op(
self,
A: "custom_types.SampleType",
r: "custom_types.SampleType",
) -> npt.NDArray: ...
def run_np_torch_op(self, *, A, r):
"""Compute binary exponential growth: A * exp(r).
:param A: Amplitude values
:param r: Rate values
:returns: A * exp(r)
"""
return A * ExpParameter.run_np_torch_op(self, r)
def write_stan_operation(self, A: str, r: str) -> str:
"""Generate Stan code for binary exponential growth.
:param A: Formatted amplitude parameter
:type A: str
:param r: Formatted rate parameter
:type r: str
:returns: Stan code for A .* exp(r)
:rtype: str
"""
return f"{A} .* {super().write_stan_operation(r)}"
[docs]
class LogExponentialGrowth(TransformedParameter):
r"""Log-scale exponential growth model transformation.
Implements the logarithm of exponential growth:
.. math::
\log(x) = logA + r * t
:param t: Time parameter
:type t: custom_types.CombinableParameterType
:param log_A: Log-amplitude parameter (log of initial value)
:type log_A: custom_types.CombinableParameterType
:param r: Growth rate parameter
:type r: custom_types.CombinableParameterType
:param kwargs: Additional arguments passed to parent class
This transformation is particularly useful for:
- Population modeling where values must be positive
- Multiplicative growth processes
- Log-scale regression models
- Ensuring positive-valued outcomes
The log-scale parameterization avoids issues with negative values
and provides numerical stability for extreme growth rates.
This transformation is accessed through the
:py:func:`~scistanpy.operations.log_exponential_growth` function.
"""
def __init__(
self,
*,
t: "custom_types.CombinableParameterType",
log_A: "custom_types.CombinableParameterType",
r: "custom_types.CombinableParameterType",
**kwargs,
):
"""Initialize log-exponential growth model.
:param t: Time parameter
:param log_A: Log-amplitude parameter
:param r: Rate parameter
:param kwargs: Additional arguments
"""
super().__init__(t=t, log_A=log_A, r=r, **kwargs)
# pylint: disable=arguments-differ
@overload
def run_np_torch_op(
self, t: torch.Tensor, log_A: torch.Tensor, r: torch.Tensor
) -> torch.Tensor: ...
@overload
def run_np_torch_op(
self,
t: "custom_types.SampleType",
log_A: "custom_types.SampleType",
r: "custom_types.SampleType",
) -> npt.NDArray: ...
def run_np_torch_op(self, *, t, log_A, r):
"""Compute log-exponential growth: log_A + r * t.
:param t: Time values
:param log_A: Log-amplitude values
:param r: Rate values
:returns: Log-exponential growth values
"""
return log_A + r * t
# pylint: enable=arguments-differ
def write_stan_operation( # pylint: disable=arguments-differ
self, t: str, log_A: str, r: str
) -> str:
return f"{log_A} + {r} .* {t}"
[docs]
class BinaryLogExponentialGrowth(TransformedParameter):
r"""Binary log-exponential growth for two time points.
This is identical to
:py:class:`~scistanpy.model.components.transformations.transformed_parameters.BinaryExponentialGrowth`,
but operates in log-space. Mathematically,
.. math::
\log(x) = logA + r
:param log_A: Log-amplitude parameter (log of initial value)
:type log_A: custom_types.CombinableParameterType
:param r: Growth rate parameter
:type r: custom_types.CombinableParameterType
:param kwargs: Additional arguments passed to parent class
This transformation is useful for:
- Modeling growth between two time points
- Ensuring positive-valued outcomes
- Log-scale regression models
- Multiplicative growth processes
The log-scale parameterization avoids issues with negative values
and provides numerical stability for extreme growth rates.
This transformation is accessed through the
:py:func:`~scistanpy.operations.binary_log_exponential_growth` function.
"""
def __init__(
self,
log_A: "custom_types.CombinableParameterType",
r: "custom_types.CombinableParameterType",
**kwargs,
):
"""Initialize binary log-exponential growth.
:param log_A: Log-amplitude parameter
:param r: Rate parameter
:param kwargs: Additional arguments
"""
super().__init__(log_A=log_A, r=r, **kwargs)
# pylint: disable=arguments-differ
@overload
def run_np_torch_op(self, log_A: torch.Tensor, r: torch.Tensor) -> torch.Tensor: ...
@overload
def run_np_torch_op(
self,
log_A: "custom_types.SampleType",
r: "custom_types.SampleType",
) -> npt.NDArray: ...
def run_np_torch_op(self, *, log_A, r):
"""Compute binary log-exponential growth: log_A + r.
:param log_A: Log-amplitude values
:param r: Rate values
:returns: log_A + r
"""
return log_A + r
def write_stan_operation(self, log_A: str, r: str) -> str:
"""Generate Stan code for binary log-exponential growth.
:param log_A: Formatted log-amplitude parameter
:type log_A: str
:param r: Formatted rate parameter
:type r: str
:returns: Stan code for log_A + r
:rtype: str
"""
return f"{log_A} + {r}"
[docs]
class SigmoidGrowth(SigmoidParameter):
r"""Sigmoid growth model transformation.
Implements sigmoid growth:
.. math::
x = \frac{A}{1 + \exp(-r(t - c))}
:param t: Time parameter
:type t: custom_types.CombinableParameterType
:param A: Amplitude parameter (carrying capacity)
:type A: custom_types.CombinableParameterType
:param r: Growth rate parameter
:type r: custom_types.CombinableParameterType
:param c: Inflection point parameter (time of fastest growth)
:type c: custom_types.CombinableParameterType
:param kwargs: Additional arguments passed to parent class
This transformation is essential for:
- Population growth with carrying capacity
- Any growth process with saturation
This transformation is accessed through the
:py:func:`~scistanpy.operations.sigmoid_growth` function.
"""
LOWER_BOUND: "custom_types.Float" = 0.0
UPPER_BOUND: None = None
def __init__(
self,
*,
t: "custom_types.CombinableParameterType",
A: "custom_types.CombinableParameterType",
r: "custom_types.CombinableParameterType",
c: "custom_types.CombinableParameterType",
**kwargs,
):
"""Initialize sigmoid growth model.
:param t: Time parameter
:param A: Amplitude/carrying capacity parameter
:param r: Rate parameter
:param c: Inflection point parameter
:param kwargs: Additional arguments
"""
super(UnaryTransformedParameter, self).__init__(t=t, A=A, r=r, c=c, **kwargs)
# pylint: disable=arguments-differ
@overload
def run_np_torch_op(
self, t: torch.Tensor, A: torch.Tensor, r: torch.Tensor, c: torch.Tensor
) -> torch.Tensor: ...
@overload
def run_np_torch_op(
self,
t: "custom_types.SampleType",
A: "custom_types.SampleType",
r: "custom_types.SampleType",
c: "custom_types.SampleType",
) -> npt.NDArray: ...
def run_np_torch_op(self, *, t, A, r, c):
"""Compute sigmoid growth: A * sigmoid(r * (t - c)).
:param t: Time values
:param A: Amplitude values
:param r: Rate values
:param c: Inflection point values
:returns: Sigmoid growth values
"""
return A * SigmoidParameter.run_np_torch_op(self, r * (t - c))
# pylint: enable=arguments-differ
def write_stan_operation( # pylint: disable=arguments-differ
self, t: str, A: str, r: str, c: str
) -> str:
par_string = super().write_stan_operation(f"{r} .* ({t} - {c})")
return f"{A} .* {par_string}"
[docs]
class LogSigmoidGrowth(LogSigmoidParameter):
r"""Log-scale sigmoid growth model transformation.
Implements the logarithm of sigmoid growth:
.. math::
\log(x) = logA - \log(1 + \exp(-r(t - c)))
Under the hood, this uses the numerically stable log-sigmoid to calculate the
``1 + exp(-...)`` term.
:param t: Time parameter
:type t: custom_types.CombinableParameterType
:param log_A: Log-amplitude parameter (log of carrying capacity)
:type log_A: custom_types.CombinableParameterType
:param r: Growth rate parameter
:type r: custom_types.CombinableParameterType
:param c: Inflection point parameter
:type c: custom_types.CombinableParameterType
:param kwargs: Additional arguments passed to parent class
This parameterization is ideal for:
- Extreme parameter regimes
- Log-scale statistical modeling
- When initial conditions are naturally in log-space
- Maximum numerical precision requirements
This transformation is accessed through the
:py:func:`~scistanpy.operations.log_sigmoid_growth` function.
"""
LOWER_BOUND: None = None
UPPER_BOUND: None = None
def __init__(
self,
*,
t: "custom_types.CombinableParameterType",
log_A: "custom_types.CombinableParameterType",
r: "custom_types.CombinableParameterType",
c: "custom_types.CombinableParameterType",
**kwargs,
):
"""Initialize log-sigmoid growth model.
:param t: Time parameter
:param log_A: Log-amplitude parameter
:param r: Rate parameter
:param c: Inflection point parameter
:param kwargs: Additional arguments
"""
super(UnaryTransformedParameter, self).__init__(
t=t, log_A=log_A, r=r, c=c, **kwargs
)
# pylint: disable=arguments-differ
@overload
def run_np_torch_op(
self, t: torch.Tensor, log_A: torch.Tensor, r: torch.Tensor, c: torch.Tensor
) -> torch.Tensor: ...
@overload
def run_np_torch_op(
self,
t: "custom_types.SampleType",
log_A: "custom_types.SampleType",
r: "custom_types.SampleType",
c: "custom_types.SampleType",
) -> npt.NDArray: ...
def run_np_torch_op(self, *, t, log_A, r, c):
"""Compute log-sigmoid growth: log_A + log_sigmoid(r * (t - c)).
:param t: Time values
:param log_A: Log-amplitude values
:param r: Rate values
:param c: Inflection point values
:returns: Log-sigmoid growth values
"""
return log_A + LogSigmoidParameter.run_np_torch_op(self, r * (t - c))
# pylint: enable=arguments-differ
def write_stan_operation( # pylint: disable=arguments-differ
self, t: str, log_A: str, r: str, c: str
) -> str:
par_string = super().write_stan_operation(f"{r} .* ({t} - {c})")
return f"{log_A} + {par_string}"
[docs]
class SigmoidGrowthInitParametrization(TransformedParameter):
"""Sigmoid growth with initial value parameterization.
Alternative parameterization of sigmoid growth in terms of initial abundances
rather than carrying capacity.
:param t: Time parameter
:type t: custom_types.CombinableParameterType
:param x0: Initial abundance parameter
:type x0: custom_types.CombinableParameterType
:param r: Growth rate parameter
:type r: custom_types.CombinableParameterType
:param c: Offset parameter (related to carrying capacity)
:type c: custom_types.CombinableParameterType
:param kwargs: Additional arguments passed to parent class
Mathematical Properties:
- Parameterizes sigmoid growth by initial value x0
- Uses log-add-exp trick for numerical stability
- Avoids direct computation of large exponentials
- Maintains sigmoid growth dynamics
This parameterization is useful when:
- Initial conditions are better known than carrying capacity (e.g., biological systems)
- Numerical stability is crucial
- Working with extreme parameter values
- Modeling relative growth from baseline
This transformation is accessed through the
:py:func:`~scistanpy.operations.sigmoid_growth_init_param` function.
"""
LOWER_BOUND: "custom_types.Float" = 0.0
UPPER_BOUND: None = None
def __init__(
self,
*,
t: "custom_types.CombinableParameterType",
x0: "custom_types.CombinableParameterType",
r: "custom_types.CombinableParameterType",
c: "custom_types.CombinableParameterType",
**kwargs,
):
"""Initialize sigmoid growth with initial value parameterization.
:param t: Time parameter
:param x0: Initial abundance parameter
:param r: Growth rate parameter
:param c: Offset parameter
:param kwargs: Additional arguments
"""
super().__init__(t=t, x0=x0, r=r, c=c, **kwargs)
# pylint: disable=arguments-renamed, arguments-differ
@overload
def run_np_torch_op(
self, t: torch.Tensor, x0: torch.Tensor, r: torch.Tensor, c: torch.Tensor
) -> torch.Tensor: ...
@overload
def run_np_torch_op(
self,
t: "custom_types.SampleType",
x0: "custom_types.SampleType",
r: "custom_types.SampleType",
c: "custom_types.SampleType",
) -> npt.NDArray: ...
def run_np_torch_op(self, t, x0, r, c):
"""Compute sigmoid growth with initial parameterization using log-add-exp.
:param t: Time values
:param x0: Initial abundance values
:param r: Rate values
:param c: Offset values
:returns: Sigmoid growth values
Uses log-add-exp for numerical stability in computing:
x0 * exp(log(1+exp(r*c)) - log(1+exp(r*(c-t))))
"""
# Get the module
mod = utils.choose_module(x0)
# Get the fold-change. We use the log-add-exp function to calculate this
# in a more numerically stable way
zero = 0.0 if mod is np else torch.tensor(0.0, device=x0.device)
foldchange = mod.exp(
mod.logaddexp(zero, r * c) - mod.logaddexp(zero, r * (c - t))
)
# Calculate the abundance
return x0 * foldchange
def write_stan_operation(
self, t: str, x0: str, r: str, c: str # pylint: disable=invalid-name
) -> str:
"""Calculate using Stan's log1p_exp function."""
return f"{x0} .* exp(log1p_exp({r} .* {c}) - log1p_exp({r} .* ({c} - {t})))"
# pylint: enable=arguments-renamed, arguments-differ
[docs]
class LogSigmoidGrowthInitParametrization(TransformedParameter):
"""Log-scale sigmoid growth with initial value parameterization.
Log-space version of sigmoid growth parameterized by initial values,
providing numerical stability and guaranteed positive outputs.
:param t: Time parameter
:type t: custom_types.CombinableParameterType
:param log_x0: Log of initial abundance parameter
:type log_x0: custom_types.CombinableParameterType
:param r: Growth rate parameter
:type r: custom_types.CombinableParameterType
:param c: Offset parameter
:type c: custom_types.CombinableParameterType
:param kwargs: Additional arguments passed to parent class
Mathematical Properties:
- Fully operates in log-space for numerical stability
- Parameterized by log of initial conditions
- Maintains sigmoid growth dynamics
This parameterization is ideal for:
- Extreme parameter regimes
- Log-scale statistical modeling
- When initial conditions are naturally in log-space
- Maximum numerical precision requirements
"""
LOWER_BOUND: None = None
UPPER_BOUND: None = None
def __init__(
self,
*,
t: "custom_types.CombinableParameterType",
log_x0: "custom_types.CombinableParameterType",
r: "custom_types.CombinableParameterType",
c: "custom_types.CombinableParameterType",
**kwargs,
):
"""Initialize log-sigmoid growth with initial value parameterization.
:param t: Time parameter
:param log_x0: Log of initial abundance parameter
:param r: Growth rate parameter
:param c: Offset parameter
:param kwargs: Additional arguments
"""
super().__init__(t=t, log_x0=log_x0, r=r, c=c, **kwargs)
# pylint: disable=arguments-renamed, arguments-differ
@overload
def run_np_torch_op(
self,
t: torch.Tensor,
log_x0: torch.Tensor,
r: torch.Tensor,
c: torch.Tensor,
) -> torch.Tensor: ...
@overload
def run_np_torch_op(
self,
t: "custom_types.SampleType",
log_x0: "custom_types.SampleType",
r: "custom_types.SampleType",
c: "custom_types.SampleType",
) -> npt.NDArray: ...
def run_np_torch_op(self, t, log_x0, r, c):
"""Compute log-sigmoid growth with initial parameterization.
:param t: Time values
:param log_x0: Log initial abundance values
:param r: Rate values
:param c: Offset values
:returns: Log-sigmoid growth values
Computes: log_x0 + log(1+exp(r*c)) - log(1+exp(r*(c-t)))
using numerically stable log-add-exp operations.
"""
# Get the module
mod = utils.choose_module(log_x0)
# Define zero
zero = 0.0 if mod is np else torch.tensor(0.0, device=log_x0.device)
# Calculate
return log_x0 + mod.logaddexp(zero, r * c) - mod.logaddexp(zero, r * (c - t))
def write_stan_operation(
self, t: str, log_x0: str, r: str, c: str # pylint: disable=invalid-name
) -> str:
"""Calculate using Stan's log1p_exp function."""
return f"{log_x0} + log1p_exp({r} .* {c}) - log1p_exp({r} .* ({c} - {t}))"
[docs]
class ConvolveSequence(TransformedParameter):
"""Sequence convolution transformation using weight matrices.
Performs convolution operation on ordinally-encoded sequences using provided
weight matrices. This is commonly used for sequence modeling and pattern
recognition in biological sequences or text data.
:param weights: Weight matrix for convolution (at least 2D)
:type weights: custom_types.CombinableParameterType
:param ordinals: Ordinally-encoded sequence array (at least 1D)
:type ordinals: custom_types.CombinableParameterType
:param kwargs: Additional arguments passed to parent class
:raises ValueError: If weights is not at least 2D
:raises ValueError: If ordinals is not at least 1D
:raises ValueError: If shapes are incompatible for broadcasting
Shape Requirements:
- Weights: (..., kernel_size, alphabet_size)
- Ordinals: (..., sequence_length)
- Output: (..., sequence_length - kernel_size + 1)
The transformation applies convolution by:
1. Sliding a kernel of size kernel_size over the sequence
2. Using ordinal values to index into the weight matrix
3. Summing weighted values for each position
This is commonly used for:
- DNA/RNA sequence analysis
- Protein sequence modeling
- Text processing with character-level models
- Pattern recognition in discrete sequences
This transformation is accessed through the
:py:func:`~scistanpy.operations.convolve_sequence` function.
Example:
>>> # DNA sequence convolution
>>> weights = Normal(mu=0, sigma=1, shape=(motif_length, 4)) # 4 nucleotides
>>> dna_sequence = Constant(encoded_dna) # 0,1,2,3 for A,C,G,T
>>> motif_scores = ConvolveSequence(weights=weights, ordinals=dna_sequence)
"""
SHAPE_CHECK = False
FORCE_LOOP_RESET = True
FORCE_PARENT_NAME = True
def __init__(
self,
*,
weights: "custom_types.CombinableParameterType",
ordinals: "custom_types.CombinableParameterType",
**kwargs,
):
"""Initialize sequence convolution with shape validation."""
# Weights must be at least 2D.
if weights.ndim < 2:
raise ValueError("Weights must be at least a 2D parameter.")
# Sequence must be at least 1D
if ordinals.ndim < 1:
raise ValueError("Sequence must be at least a 1D parameter.")
# Note features of the weights. This is the last two dimensions.
self.kernel_size, self.alphabet_size = weights.shape[-2:]
# The first N - 2 dimensions of the weights must align with the first
# N - 1 dimensions of the ordinals
try:
batch_dims = np.broadcast_shapes(weights.shape[:-2], ordinals.shape[:-1])
except ValueError as err:
raise ValueError(
"Incompatible shapes between weights and ordinals. The shapes must "
"be broadcastable in the batch dimensions (all but last two for "
"the weights and all but the last for the ordinals). Got "
f"weights: {weights.shape}, ordinals: {ordinals.shape}"
) from err
# The final dimension has the size of the sequence length adjusted by the
# convolution
shape = batch_dims + (ordinals.shape[-1] - self.kernel_size + 1,)
# Init using inherited method.
super().__init__(weights=weights, ordinals=ordinals, shape=shape, **kwargs)
def run_np_torch_op(self, weights, ordinals): # pylint: disable=arguments-differ
"""Performs the convolution"""
# If numpy, loop over the leading dimension
assert weights.shape == self.weights.shape
assert ordinals.shape == self.ordinals.shape
# Decide on the module for the operation
module = utils.choose_module(weights)
# Determine the number of dimensions to prepend to each array
weights_n_prepends = len(self.shape[:-1]) - len(self.weights.shape[:-2])
ordinal_n_prepends = len(self.shape[:-1]) - len(self.ordinals.shape[:-1])
# Get the padded shapes. This is just aligning the shapes for broadcasting.
padded_weights_shape = (None,) * weights_n_prepends + self.weights.shape[:-2]
padded_ordinals_shape = (None,) * ordinal_n_prepends + self.ordinals.shape[:-1]
assert len(padded_weights_shape) == len(padded_ordinals_shape)
# Set output array and build a set of filter indices
output_arr = module.full(self.shape, np.nan)
filter_indices = module.arange(self.kernel_size)
# If torch, send arrays to appropriate device
if module is torch:
filter_indices = filter_indices.to(weights.device)
output_arr = output_arr.to(weights.device)
# Loop over the different weights
for weights_inds in np.ndindex(weights.shape[:-2]):
# Prepend `None` to the weight indices if needed
weights_inds = (None,) * weights_n_prepends + weights_inds
# Determine the ordinal and output indices. If weights or ordinals
# are a singleton, slice all for the ordinal indices.
ordinal_inds = []
output_inds = []
for dim, (weight_dim_size, ord_dim_size) in enumerate(
zip(padded_weights_shape, padded_ordinals_shape)
):
# We can never have both weight and ord dim sizes be `None`
assert not (weight_dim_size is None and ord_dim_size is None)
# If the ordinal dimension is `None`, then the output dimension is whatever
# the weight dimension is. We do not record an ordinal index.
weight_ind = weights_inds[dim]
if ord_dim_size is None:
output_inds.append(weight_ind)
continue
# If the weight dimension is a singleton we slice all for the ordinal and
# the output
if weight_dim_size == 1 or weight_dim_size is None:
ordinal_inds.append(slice(None))
output_inds.append(slice(None))
# If the ordinal dimension is a singleton, add "0" to the indices for the
# ordinals and the weights ind for the output
elif ord_dim_size == 1:
ordinal_inds.append(0)
output_inds.append(weight_ind)
# Otherwise, identical index to the weights for both
else:
ordinal_inds.append(weight_ind)
output_inds.append(weight_ind)
# Convert indices to tuples
ordinal_inds = tuple(ordinal_inds)
output_inds = tuple(output_inds)
assert len(output_inds) == len(self.shape) - 1
# Get the matrix and set of sequences to which it will be applied
weights_matrix = weights[weights_inds]
ordinal_matrix = ordinals[ordinal_inds]
assert weights_matrix.ndim == 2
# Run convolution for this batch by sliding over the sequence length
for convind, upper_slice in enumerate(
range(self.kernel_size, ordinal_matrix.shape[-1] + 1)
):
# Get the lower bound
lower = upper_slice - self.kernel_size
# Slice the sequence and pull the appropriate weights. Sum the weights.
output_arr[output_inds + (convind,)] = weights_matrix[
filter_indices, ordinal_matrix[..., lower:upper_slice]
].sum(**{"dim" if module is torch else "axis": -1})
# No Nan's in output
assert not module.any(module.isnan(output_arr))
return output_arr
def get_index_offset(
self,
query: Union[str, "abstract_model_component.AbstractModelComponent"],
offset_adjustment: int = 0,
) -> int:
"""Calculate index offset with special handling for weights.
:param query: Component or parameter name to query
:param offset_adjustment: Base offset adjustment
:returns: Index offset (adjusted +1 for weights parameter)
:rtype: int
The weights parameter requires special offset handling because
its last two dimensions are used directly in the convolution
rather than being broadcast.
"""
# Run the inherited method
offset = super().get_index_offset(query, offset_adjustment)
# Adjust if needed
if query == "weights" or query is self.weights:
offset += 1
return offset
def get_right_side(
self,
index_opts: tuple[str, ...] | None,
start_dims: dict[str, "custom_types.Integer"] | None = None,
end_dims: dict[str, "custom_types.Integer"] | None = None,
offset_adjustment: int = 0,
) -> str:
"""Generate right-side code with proper dimension handling.
Sets default end_dims to exclude the last weight dimension (-2),
ensuring both kernel_size and alphabet_size dimensions are preserved.
"""
end_dims = end_dims or {"weights": -2}
# Run the AbstractModelParameter version of the method to get each model
# component formatted appropriately. Note that we ignore the last dimension
# of the weights. This is because we need both dimensions in the Stan code.
return super().get_right_side(
index_opts, end_dims=end_dims, offset_adjustment=offset_adjustment
)
def write_stan_operation( # pylint: disable=arguments-differ
self, weights: str, ordinals: str
) -> str:
"""Generate Stan convolution function call.
:param weights: Formatted weights parameter name
:type weights: str
:param ordinals: Formatted ordinals parameter name
:type ordinals: str
:returns: Stan function call for sequence convolution
:rtype: str
"""
# This runs a custom function
return f"convolve_sequence({weights}, {ordinals})"
def get_supporting_functions(self) -> list[str]:
"""Return required Stan function includes.
:returns: List including pssm.stanfunctions for convolution support
:rtype: list[str]
"""
return super().get_supporting_functions() + ["#include pssm.stanfunctions"]
[docs]
class IndexParameter(TransformedParameter):
"""Array indexing transformation with NumPy-compatible semantics.
Creates indexed subsets of parameters using slicing, scalar indexing,
and array indexing. Follows NumPy indexing conventions rather than
Stan conventions for consistency with Python data manipulation.
:param dist: Parameter to index
:type dist: custom_types.CombinableParameterType
:param indices: Indexing specifications (slices, integers, arrays)
:type indices: custom_types.IndexType
Supported Index Types:
- **slice**: Standard Python slicing with start:stop (step=1 only)
- **int**: Single element selection
- **np.ndarray**: Advanced indexing with integer arrays. 1D only.
Follows numpy convention.
- **Ellipsis**: Automatic dimension filling
- **None**: New axis insertion
Important Differences from Stan:
- Uses 0-based indexing (Python convention)
- Advanced indexing follows NumPy broadcasting rules, not Stan rules
- Negative indices are supported and converted appropriately
This transformation is never applied directly. Instead index a parameter as
in normal Python/NumPy:
Example:
.. code-block:: python
# Define any parameter
param = Normal(mu=0, sigma=1, shape=(10, 5))
# Standard indexing (last element of first dimension)
last_elem = param[-1]
# Slice first 5 rows
subset = param[:5]
# Select specific elements with NumPy-style advanced indexing
selected = param[np.array([0, 2, 4])]
# Use Ellipsis to fill in dimensions
first_col = param[..., 0] # All leading dimensions, first of last
# Insert new axis
new_axis = param[:, None, :]
"""
SHAPE_CHECK = False
FORCE_PARENT_NAME = True
def __init__(
self,
dist: "custom_types.CombinableParameterType",
*indices: "custom_types.IndexType",
):
"""Initialize indexing transformation.
:param dist: Parameter to index
:param indices: Indexing specifications
The initialization processes all index types, converts negative
indices to positive, validates array dimensions, and creates
appropriate constant parameters for array indices.
"""
# We need the shape of what we're indexing to prep for parent init
self._dist_shape = dist.shape
# We need the input indices for torch and numpy operations
self._python_indices = indices
# Process and unify the different index types
shape, self._stan_indices, parents = self._process_indices(indices)
# Init using parent method. Provide the shape with `None` values removed --
# these are the dimensions that are removed by indexing
super().__init__(dist=dist, shape=shape, **parents)
@overload
def neg_to_pos(
self, neg_ind: "custom_types.Integer", dim: "custom_types.Integer"
) -> "custom_types.Integer": ...
@overload
def neg_to_pos(
self, neg_ind: npt.NDArray[np.int64], dim: "custom_types.Integer"
) -> npt.NDArray[np.int64]: ...
[docs]
def neg_to_pos(self, neg_ind, dim):
"""Convert negative indices to positive indices.
:param neg_ind: Negative index or array of indices
:param dim: Dimension size for conversion
:returns: Positive indices
:raises ValueError: If indices are out of bounds
Handles both scalar and array indices, performing bounds checking
and conversion from Python's negative indexing convention.
"""
# If a numpy array, we update negative positions only
if isinstance(neg_ind, np.ndarray):
out = neg_ind.copy()
out[out < 0] += self._dist_shape[dim]
# There should be no negatives
if np.any(out < 0):
raise ValueError(
f"Negative indices {neg_ind} cannot be converted to positive "
f"indices for dimension {dim} with shape {self._dist_shape[dim]}."
)
# The max should be less than the dimension size
if np.any(out >= self._dist_shape[dim]):
raise ValueError(
f"Indices {neg_ind} exceed the size of dimension {dim} "
f"with shape {self._dist_shape[dim]}."
)
return out
# If a single integer, we convert it directly.
elif isinstance(neg_ind, int):
out = neg_ind + self._dist_shape[dim] if neg_ind < 0 else neg_ind
# Check that the index is within bounds
if out < 0:
raise ValueError(
f"Negative index {neg_ind} cannot be converted to positive "
f"index for dimension {dim} with shape {self._dist_shape[dim]}."
)
if out >= self._dist_shape[dim]:
raise ValueError(
f"Index {neg_ind} exceeds the size of dimension {dim} "
f"with shape {self._dist_shape[dim]}."
)
return out
# Error if the type is not supported
raise TypeError(
f"Unsupported index type {type(neg_ind)}. Expected int or numpy array."
)
def _process_indices(
self,
indices: tuple["custom_types.IndexType", ...],
) -> tuple[
tuple[int, ...],
tuple["custom_types.IndexType", ...],
dict[str, constants.Constant],
]:
"""Process and validate all indexing specifications.
:param indices: Raw indexing specifications
:returns: Tuple of (output_shape, processed_indices, constant_parents)
This method handles the complex logic of:
- Processing different index types (slices, integers, arrays, ellipsis)
- Calculating output shapes
- Converting to Stan-compatible 1-based indexing
- Creating constant parameters for array indices
- Validating consistency across multiple array indices
"""
def process_ellipsis() -> "custom_types.Integer":
"""Helper function to process Ellipses"""
# We can only have one ellipsis
if sum(1 for ind in indices if ind is Ellipsis) > 1:
raise ValueError("Only one ellipsis is allowed in indexing.")
# Add slices to the processed dimensions
n_real_dims = sum(
1 for ind in indices if ind is not Ellipsis and ind is not None
)
n_to_add = len(self._dist_shape) - n_real_dims
processed_inds.extend([slice(None) for _ in range(n_to_add)])
# The shape is extended by the number added
shape.extend(self._dist_shape[shape_ind : shape_ind + n_to_add])
# Return the number of added dimensions
return n_to_add
def process_slice() -> None:
"""Helper function to process slices."""
# Step cannot be set
if ind.step is not None and ind.step != 1:
raise ValueError(
f"Step size {ind.step} is not supported in IndexParameter transformation."
)
# Get the size of the output shape (stop - start after converting negatives
# to positives)
start = 0 if ind.start is None else self.neg_to_pos(ind.start, shape_ind)
stop = (
self._dist_shape[shape_ind]
if ind.stop is None
else self.neg_to_pos(ind.stop, shape_ind)
)
# Update outputs. Note that processed outputs are a new slice and that
# we do not add 1 to stop because Stan slices are inclusive while Python
# are exclusive
shape.append(stop - start)
processed_inds.append(slice(start + 1, stop))
def process_array() -> "custom_types.Integer":
"""Helper function to process numpy arrays and constants."""
# Must be a 1D array
if ind.ndim > 1:
raise IndexError(
"Cannot index with numpy array with more than 1 dimension"
)
elif ind.ndim == 0:
raise AssertionError("Should not get here")
# Ensure the array contains integers
if ind.dtype != np.int64:
raise TypeError(
f"Indexing with non-integer arrays is not supported. Got dtype "
f"{ind.dtype}."
)
# Must be the same as previous 1-d arrays
arrlen = len(ind)
if int_arr_len > 0 and int_arr_len != arrlen:
raise ValueError(
f"All 1-dimensional integer arrays must have the same length. "
f"Got lengths {int_arr_len} and {arrlen}."
)
# Build a constant for the index. This involves adjusting the indices
# to be Stan-compatible (1-indexed, no negative indices).
constant_arr = constants.Constant(
self.neg_to_pos(ind, shape_ind) + 1, togglable=False
)
# Record
shape.append(arrlen)
parents[f"idx_{len(parents)}"] = constant_arr
processed_inds.append(constant_arr)
return arrlen
# Set up for recording
shape = [] # This parameter's shape
processed_inds = [] # Indices processed for use in Stan
shape_ind = 0 # Current dimension in the indexed parameter
parents: dict[str, constants.Constant] = {} # Constants for arrays
int_arr_len = 0 # Length of integer arrays
# Process indices
for ind in indices:
# Process ellipses
if ind is Ellipsis:
shape_ind += process_ellipsis()
# Process slices
elif isinstance(ind, slice):
process_slice()
shape_ind += 1
# Numpy arrays are also processed by their own function
elif isinstance(ind, np.ndarray):
int_arr_len = max(int_arr_len, process_array())
shape_ind += 1
# Integers must be made positive
elif isinstance(ind, int):
processed_inds.append(self.neg_to_pos(ind, shape_ind) + 1)
shape_ind += 1
# `None` values add a new dimension to the output.
elif ind is None:
shape.append(1)
# Nothing else is legal
else:
raise TypeError(
"Indexing supported by slicing, numpy arrays, and integers only."
f"Got {type(ind)}"
)
# Remove None values from the shape
return tuple(shape), tuple(processed_inds), parents
# Note that parents are ignored here as their indices have been adjusted to
# reflect Stan's 1-indexing and no negative indices. We use the Python indices
# stored earlier as a result. The parents kwargs is included for compatibility
def run_np_torch_op( # pylint: disable=arguments-differ, unused-argument
self, dist, **parents
):
# If torch, numpy arrays must go to torch
module = utils.choose_module(dist)
if module is torch:
inds = tuple(
(
torch.from_numpy(ind).to(dist.device)
if isinstance(ind, np.ndarray)
else ind
)
for ind in self._python_indices
)
else:
inds = self._python_indices
# Index and check shape
indexed = dist[inds]
assert indexed.shape == self.shape
return indexed
def get_right_side(
self,
index_opts: tuple[str, ...] | None,
start_dims: dict[str, "custom_types.Integer"] | None = None,
end_dims: dict[str, "custom_types.Integer"] | None = None,
offset_adjustment: int = 0,
) -> str:
"""Generate Stan indexing code.
:returns: Stan indexing expression
:rtype: str
Gets the name of the variable that is being indexed, then passes it to
the `write_stan_operation` method to get the full Stan code for the transformation
"""
return self.write_stan_operation(dist=self.dist.model_varname)
def write_stan_operation( # pylint: disable=arguments-differ
self, dist: str
) -> str:
"""Generate complete Stan indexing expression.
:param dist: Variable name to index
:type dist: str
:returns: Stan indexing expression with bracket notation
:rtype: str
Handles complex indexing patterns including:
- Multiple array indices (creates separate bracket groups)
- Mixed slicing and array indexing
- Proper 1-based index conversion for Stan
- Colon notation for full dimension slicing
"""
# Compile all indices. Every time we encounter an array index, we start
# a new indexing operation. This allows us to mimic numpy behavior in Stan.
components = []
current_component = []
index_pos = 0
array_counter = 0
for ind in self._stan_indices:
# Handle slices
if isinstance(ind, slice):
start = "" if ind.start is None else str(ind.start)
end = "" if ind.stop is None else str(ind.stop)
index_pos += 1 # We keep a dimension with this operation
current_component.append(f"{start}:{end}")
# Handle integers
elif isinstance(ind, int):
current_component.append(str(ind))
# If an array, we need to use the constant that we defined
elif isinstance(ind, constants.Constant):
# Must be a 1D array
assert isinstance(ind.value, np.ndarray)
assert ind.value.ndim == 1
# If we have already encountered an array, start a new component,
# padding out the current component with colons.
if array_counter > 0:
components.append(current_component)
current_component = [":"] * (index_pos + 1)
# Record the array as a component
current_component.append(
self._parents[f"idx_{array_counter}"].get_indexed_varname(None)
)
# Update counters
index_pos += 1 # We keep a dimension with this operation
array_counter += 1 # Note finding another array
# Error with anything else
else:
raise ValueError(f"Unsupported index type: {type(ind)}")
# Record the last component
components.append(current_component)
# Join all components
return dist + "[" + "][".join(",".join(c) for c in components) + "]"
def get_transformation_assignment( # pylint: disable=unused-argument
self, index_opts: tuple[str, ...] | None
) -> str:
"""Generate transformation assignment without index options.
:param index_opts: Index options (ignored for direct assignment)
:returns: Stan assignment statement
:rtype: str
Indexing parameters are assigned directly without loop indexing
since they represent specific element selection rather than
computed transformations.
"""
# pylint: disable=no-value-for-parameter
return super().get_transformation_assignment(None) # Assigned directly
# The definition depth is always 0 for this transformation
def get_assign_depth(self) -> "custom_types.Integer": # pylint: disable=C0116
return 0
@property
def force_name(self) -> bool:
"""Force explicit naming for indexed parameters.
:returns: Always True
:rtype: bool
Indexed parameters must be explicitly named in Stan code to
enable proper variable reference and assignment.
"""
return True