# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Custom SciPy distribution implementations for SciStanPy models.
This module provides extended and custom SciPy distribution classes. Among other
things, these implementations provide:
- **Enhanced Batch Support**: Extended multivariate distributions with variable
batch dimensions
- **Custom Transformations**: Log-transformed distributions with proper Jacobian
corrections
- **Alternative Parameterizations**: Logit and log-probability parameterizations
for multinomial distributions
- **Numerical Stability**: Improved implementations for edge cases and extreme
values
"""
from __future__ import annotations
import functools
import inspect
from abc import ABC, abstractmethod
from typing import Callable, TYPE_CHECKING
import numpy as np
import numpy.typing as npt
from scipy import special
from scipy import stats
if TYPE_CHECKING:
from scistanpy import custom_types
# pylint: disable=line-too-long
def _combine_args_kwargs(function: Callable, args: tuple, kwargs: dict) -> dict:
"""Combine positional and keyword arguments into a single dictionary.
:param function: Function whose signature determines parameter names
:type function: Callable
:param args: Positional arguments to the function
:type args: tuple
:param kwargs: Keyword arguments to the function
:type kwargs: dict
:returns: Combined arguments as a dictionary
:rtype: dict
:raises ValueError: If total arguments don't match function signature
This utility function inspects the function signature and maps positional
arguments to their corresponding parameter names, then merges them with
the provided keyword arguments.
"""
# We will need the function signature to determine the arg and kwarg names
signature = inspect.signature(function)
paramnames = list(signature.parameters.keys())
# Make sure that the number of args and kwargs matches the number of parameters
if len(args) + len(kwargs) != len(paramnames):
raise ValueError(
f"Expected {len(paramnames)} arguments, but got {len(args) + len(kwargs)}."
)
# Combine args and kwargs
combined_kwargs = dict(zip(paramnames, args))
combined_kwargs.update(kwargs)
return combined_kwargs
[docs]
class CustomDirichlet(
stats._multivariate.dirichlet_gen # pylint: disable=protected-access
):
"""Enhanced Dirichlet distribution supporting variable batch dimensions.
This class extends SciPy's standard Dirichlet distribution to support
arbitrary batch dimensions while maintaining compatibility with the
SciPy distribution interface. The standard SciPy implementation has
limitations with batch operations that this class addresses.
Key Enhancements:
- Support for arbitrary batch dimensions in alpha parameters
- Proper broadcasting behavior across batch dimensions
- Consistent output shapes for all distribution methods
- Efficient vectorized operations over batch elements
The implementation uses a decorator pattern to extend existing SciPy
methods with batch dimension handling while preserving the original
mathematical properties of the Dirichlet distribution.
"""
@staticmethod
def _expand_batch(function: Callable, expect_x: bool = False) -> Callable:
"""Decorator for adding batch dimension support to Dirichlet methods.
:param function: SciPy Dirichlet method to wrap
:type function: Callable
:param expect_x: Whether the function expects an 'x' parameter. Defaults to False.
:type expect_x: bool
:returns: Wrapped function with batch dimension support
:rtype: Callable
:raises ValueError: If expected parameters are missing or unexpected
This decorator automatically handles:
- Parameter validation and broadcasting
- Reshaping for batch operations
- Vectorized computation across batch elements
- Proper output shape reconstruction
The decorator distinguishes between functions that operate on data (expect_x=True)
and those that only use distribution parameters (expect_x=False).
"""
@functools.wraps(function)
def inner(*args, **kwargs):
# Combine args and kwargs
combined_kwargs = _combine_args_kwargs(function, args, kwargs)
# Check for 'x'
if expect_x and "x" not in combined_kwargs:
raise ValueError("Expected 'x' parameter in the function signature.")
elif not expect_x and "x" in combined_kwargs:
raise ValueError("Unexpected 'x' parameter in the function signature.")
# Get the alpha parameter and the x, if present
param_kwargs = {"alpha": np.asarray(combined_kwargs.pop("alpha"))}
if expect_x:
param_kwargs["x"] = np.asarray(combined_kwargs.pop("x"))
# Broadcast the arrays in the param kwargs and reshape to be 2D
broadcasted_shape = np.broadcast_shapes(
*[v.shape for v in param_kwargs.values()]
)
param_kwargs = {
k: np.broadcast_to(v, broadcasted_shape).reshape(-1, v.shape[-1])
for k, v in param_kwargs.items()
}
# Run the function and combine the results
res = np.concatenate(
[
function(
**{k: v[i] for k, v in param_kwargs.items()}, **combined_kwargs
)
for i in range(len(param_kwargs["alpha"]))
]
)
# Some sanity checks on the result
assert (
res.ndim == 1
), f"Expected result to be 1D, but got {res.ndim}D with shape {res.shape}."
assert len(res) == len(
param_kwargs["alpha"]
), f"Expected result length {len(param_kwargs['alpha'])}, but got {len(res)}."
# Always applied to a reduction, so we reshape the result to one less
# dimension than the broadcasted shape
return res.reshape(*broadcasted_shape[:-1])
return inner
# Enhanced method implementations with batch support
# pylint: disable=protected-access
logpdf = _expand_batch(stats._multivariate.dirichlet_gen.logpdf, expect_x=True)
pdf = _expand_batch(stats._multivariate.dirichlet_gen.pdf, expect_x=True)
mean = _expand_batch(stats._multivariate.dirichlet_gen.mean)
var = _expand_batch(stats._multivariate.dirichlet_gen.var)
cov = _expand_batch(stats._multivariate.dirichlet_gen.cov)
entropy = _expand_batch(stats._multivariate.dirichlet_gen.entropy)
# pylint: enable=protected-access
[docs]
def rvs(
self,
alpha: npt.NDArray[np.floating],
size: tuple["custom_types.Integer", ...] | "custom_types.Integer" | None = 1,
random_state: "custom_types.Integer" | np.random.Generator | None = None,
) -> npt.NDArray[np.floating]:
"""Generate random samples from the Dirichlet distribution.
:param alpha: Concentration parameters with shape (..., k)
:type alpha: npt.NDArray[np.floating]
:param size: Output shape. Defaults to 1.
:type size: Union[tuple[custom_types.Integer, ...], custom_types.Integer, None]
:param random_state: Random state for reproducible sampling. Defaults to None.
:type random_state: Union[custom_types.Integer, np.random.Generator, None]
:returns: Random samples from Dirichlet distribution
:rtype: npt.NDArray[np.floating]
:raises ValueError: If alpha cannot be broadcast to the specified size
This method supports arbitrary batch dimensions in the alpha parameter
and properly broadcasts to the requested output size while maintaining
the simplex constraint for each sample.
"""
# Set the size
if size is None:
size = alpha.shape
elif isinstance(size, int):
size = (size, *alpha.shape)
else:
size = tuple(size)
# Broadcast alpha to the given size.
try:
alpha = np.broadcast_to(alpha, size)
except ValueError as err:
raise ValueError(
f"Cannot broadcast alpha ({alpha.shape}) to size ({size})"
) from err
# Now that alphas have been broadcasted to the correct size, we can proceed
# by sampling just once from the Dirichlet distribution for each alpha.
# Reshaping at the end will reconstruct the original dimensions.
return np.stack(
[
super().rvs(arr, random_state=random_state)
for arr in alpha.reshape(-1, alpha.shape[-1])
]
).reshape(size)
[docs]
class CustomMultinomial(stats._multivariate.multinomial_gen): # pylint: disable=W0212
"""Enhanced multinomial distribution supporting variable batch dimensions.
This class extends SciPy's standard multinomial distribution to support
arbitrary batch dimensions in both the trial count (n) and probability
parameters (p), enabling flexible batch operations for discrete multivariate
modeling scenarios.
Key Enhancements:
- Variable batch dimensions for n and p parameters
- Proper broadcasting behavior between n and p
- Support for different trial counts across batch elements
- Consistent output shapes for sampling operations
Example:
>>> # Batch multinomial with different trial counts
>>> n = np.array([[10], [20], [15]])
>>> p = np.array([[0.3, 0.4, 0.3],
... [0.2, 0.5, 0.3],
... [0.4, 0.3, 0.3]])
>>> multinomial = CustomMultinomial()
>>> samples = multinomial.rvs(n=n, p=p, size=100) # shape = (100, 3, 3)
"""
[docs]
def rvs(
self,
n: "custom_types.Integer" | npt.NDArray[np.integer],
p: npt.NDArray[np.floating],
size: tuple["custom_types.Integer", ...] | "custom_types.Integer" | None = 1,
random_state: "custom_types.Integer" | np.random.Generator | None = None,
) -> npt.NDArray[np.integer]:
"""Generate random samples from the multinomial distribution.
:param n: Number of trials (can be scalar or array)
:type n: Union[custom_types.Integer, npt.NDArray[np.integer]]
:param p: Event probabilities with shape (..., k)
:type p: npt.NDArray[np.floating]
:param size: Output shape. Defaults to 1.
:type size: Union[tuple[custom_types.Integer, ...], custom_types.Integer, None]
:param random_state: Random state for reproducible sampling. Defaults to None.
:type random_state: Union[custom_types.Integer, np.random.Generator, None]
:returns: Random samples from multinomial distribution
:rtype: npt.NDArray[np.integer]
:raises ValueError: If n and p cannot be broadcast to compatible shapes
This method supports different trial counts for each batch element
and handles broadcasting between scalar/array n and multi-dimensional p.
"""
def try_broadcast(x, target_size):
"""Attempts to broadcast and raises an error if not possible"""
try:
return np.broadcast_to(x, target_size)
except ValueError as err:
raise ValueError(
f"Cannot broadcast shape {x.shape} to {target_size}"
) from err
# Set the size of p
if size is None:
p_size = p.shape
elif isinstance(size, int):
p_size = (size, *p.shape)
else:
p_size = tuple(size)
# Set the size of n
n_size = list(p_size)
n_size[-1] = 1
# n and p must be broadcastable to their respective sizes
n = try_broadcast(n, n_size)
p = try_broadcast(p, p_size)
# Reshape to 2D
n = n.reshape(-1, 1)
p = p.reshape(-1, p_size[-1])
assert len(n) == len(p)
# Take the random samples. We take 1 for each n-p pair. Reshape to the
# target size
return np.stack(
[
super().rvs(n=n_el, p=p_el, random_state=random_state)
for n_el, p_el in zip(n, p)
]
).reshape(size)
[docs]
class MultinomialLogit(CustomMultinomial):
"""Multinomial distribution with logit parameterization. This is identical to
:py:class:`~scistanpy.model.components.custom_distributions.custom_scipy_dists.CustomMultinomial`
except that the probabilities need not be normalized and are specified as logits.
"""
@staticmethod
def softmax_p(function: Callable) -> Callable:
"""Decorator that transforms logits to probabilities using softmax.
:param function: Function to wrap with logit transformation
:type function: Callable
:returns: Wrapped function that accepts logits instead of probabilities
:rtype: Callable
This decorator automatically applies the softmax transformation to
convert logits to valid probabilities before calling the underlying
multinomial distribution methods.
"""
@functools.wraps(function)
def inner(self, **kwargs):
# Apply the softmax transformation to the logits
kwargs["p"] = special.softmax(kwargs.pop("logits"), axis=-1)
return function(self, **kwargs)
return inner
# Wrapped methods with logit transformation
pmf = softmax_p(CustomMultinomial.pmf)
logpmf = softmax_p(CustomMultinomial.logpmf)
rvs = softmax_p(CustomMultinomial.rvs)
entropy = softmax_p(CustomMultinomial.entropy)
cov = softmax_p(CustomMultinomial.cov)
[docs]
class MultinomialLogTheta(CustomMultinomial):
"""Multinomial distribution with normalized log-probability parameterization.
This is identical to
:py:class:`~scistanpy.model.components.custom_distributions.custom_scipy_dists.CustomMultinomial`
except that the probabilities are specified as log-probabilities that must
already be normalized (i.e., their exponentials sum to 1).
"""
@staticmethod
def exp_p(function: Callable) -> Callable:
"""Decorator that transforms log-probabilities to probabilities.
:param function: Function to wrap with log-probability transformation
:type function: Callable
:returns: Wrapped function that accepts log_p instead of probabilities
:rtype: Callable
:raises ValueError: If log-probabilities are not properly normalized
This decorator validates that the exponentials of log-probabilities
sum to 1 (within tolerance) and applies the exponential transformation
to convert to valid probabilities.
"""
@functools.wraps(function)
def inner(self, **kwargs):
# Exponentiate the log probabilities
p = np.exp(kwargs.pop("log_p"))
# The rows of `p` must sum to 1 within a threshold
p_sum = p.sum(axis=-1, keepdims=True)
if not np.allclose(p_sum, 1, atol=1e-6):
raise ValueError(f"Rows of `p` must sum to 1, but got {p_sum.max()}")
# Ensure total normalization
kwargs["p"] = p / p_sum
return function(self, **kwargs)
return inner
# Wrapped methods with log-probability transformation
pmf = exp_p(CustomMultinomial.pmf)
logpmf = exp_p(CustomMultinomial.logpmf)
rvs = exp_p(CustomMultinomial.rvs)
entropy = exp_p(CustomMultinomial.entropy)
cov = exp_p(CustomMultinomial.cov)
[docs]
class ExpDirichlet(CustomDirichlet):
"""Log-transformed Dirichlet distribution (Exponential-Dirichlet).
This class implements a distribution where the logarithm of a Dirichlet-distributed
random vector follows this distribution. It's useful for modeling log-scale
compositional data and log-probability vectors with proper Jacobian corrections.
"""
[docs]
def logpdf(self, x, alpha):
"""Compute log probability density with Jacobian correction.
:param x: Log-probability values
:param alpha: Concentration parameters
:returns: Log probability density values
The implementation includes the proper Jacobian correction for the
log-transformation, computed analytically for efficiency and numerical stability.
"""
# pylint: disable=no-member
return (
np.sum(x * alpha, axis=-1)
- x[..., -1]
+ special.gammaln(np.sum(alpha, axis=-1))
- np.sum(special.gammaln(alpha), axis=-1)
)
[docs]
def pdf(self, x, alpha):
"""Compute probability density function.
:param x: Log-probability values
:param alpha: Concentration parameters
:returns: Probability density values
Computed as the exponential of the log probability density for
numerical stability and consistency.
"""
return np.exp(self.logpdf(x, alpha))
[docs]
def rvs(
self,
alpha: npt.NDArray[np.floating],
size: tuple["custom_types.Integer", ...] | "custom_types.Integer" | None = 1,
random_state: "custom_types.Integer" | np.random.Generator | None = None,
) -> npt.NDArray[np.floating]:
"""Generate random samples from the log-transformed Dirichlet distribution.
:param alpha: Concentration parameters
:type alpha: npt.NDArray[np.floating]
:param size: Output shape. Defaults to 1.
:type size: Union[tuple[custom_types.Integer, ...], custom_types.Integer, None]
:param random_state: Random state. Defaults to None.
:type random_state: Union[custom_types.Integer, np.random.Generator, None]
:returns: Log-probability samples
:rtype: npt.NDArray[np.floating]
Samples are generated by first sampling from the standard Dirichlet
distribution and then applying the logarithmic transformation.
"""
# Sample from the Dirichlet distribution and then take the logarithm
return np.log(super().rvs(alpha, size=size, random_state=random_state))
[docs]
def mean(self, alpha):
"""This is not implemented.
:raises NotImplementedError:
"""
raise NotImplementedError("Not defined for this custom distribution")
[docs]
def var(self, alpha):
"""This is not implemented.
:raises NotImplementedError:
"""
raise NotImplementedError("Not defined for this custom distribution")
[docs]
def cov(self, alpha):
"""This is not implemented.
:raises NotImplementedError:
"""
raise NotImplementedError("Not defined for this custom distribution")
[docs]
def entropy(self, alpha):
"""This is not implemented.
:raises NotImplementedError:
"""
raise NotImplementedError("Not defined for this custom distribution")
# Pre-configured distribution instances for convenient use
dirichlet = CustomDirichlet()
"""
Instance of :py:class:`~scistanpy.model.components.custom_distributions.custom_scipy_dists.CustomDirichlet`.
See that class for details.
"""
expexponential = LogUnivariateScipyTransform(stats.expon)
"""
``scipy.stats.expon`` transformed to the log scale using
:py:class:`~scistanpy.model.components.custom_distributions.custom_scipy_dists.LogUnivariateScipyTransform`.
"""
explomax = LogUnivariateScipyTransform(stats.lomax)
"""
``scipy.stats.lomax`` transformed to the log scale using
:py:class:`~scistanpy.model.components.custom_distributions.custom_scipy_dists.LogUnivariateScipyTransform`.
"""
expdirichlet = ExpDirichlet()
"""
Instance of :py:class:`~scistanpy.model.components.custom_distributions.custom_scipy_dists.ExpDirichlet`.
See that class for details.
"""
multinomial = CustomMultinomial()
"""
Instance of :py:class:`~scistanpy.model.components.custom_distributions.custom_scipy_dists.CustomMultinomial`.
See that class for details.
"""
multinomial_logit = MultinomialLogit()
"""
Instance of :py:class:`~scistanpy.model.components.custom_distributions.custom_scipy_dists.MultinomialLogit`.
See that class for details.
"""
multinomial_log_theta = MultinomialLogTheta()
"""
Instance of :py:class:`~scistanpy.model.components.custom_distributions.custom_scipy_dists.MultinomialLogTheta`.
See that class for details.
"""