Source code for scistanpy.model.components.custom_distributions.custom_torch_dists

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Custom PyTorch distribution implementations for SciStanPy models.

This module provides specialized PyTorch distribution classes that extend or
modify the standard PyTorch distributions to meet specific requirements of
SciStanPy modeling. These distributions handle edge cases, provide numerical
stability improvements, and enable functionality not available in the standard
PyTorch distribution library.

Key Features:
    - **Extended Multinomial**: Support for inhomogeneous total counts
    - **Numerical Stability**: Improved log-space probability computations
    - **Custom Distributions**: Implementations of distributions not in PyTorch
    - **SciStanPy Integration**: Designed for compatibility with SciStanPy parameter types
"""

from __future__ import annotations

from typing import Optional, ParamSpec, TYPE_CHECKING

import torch
import torch.distributions as dist

if TYPE_CHECKING:
    from scistanpy import custom_types

# Define a type variable for the parameters of a distribution
P = ParamSpec("P")


[docs] class CustomDistribution: """Base marker class for custom SciStanPy distributions. This class serves as a marker interface for custom distribution implementations in SciStanPy. It doesn't provide any functionality but is useful for type hinting and identifying custom distributions in the codebase. All custom distribution classes should inherit from this class to maintain consistency and enable type checking. """
[docs] class Multinomial(CustomDistribution): """Extended multinomial distribution supporting inhomogeneous total counts. This class extends the functionality of PyTorch's standard multinomial distribution to support different total counts across batch dimensions. The standard PyTorch implementation requires all trials to have the same total count, but this implementation allows each batch element to have its own total count. :param total_count: Total number of trials for each batch element. Defaults to 1. :type total_count: Union[custom_types.Integer, torch.Tensor] :param probs: Event probabilities (mutually exclusive with logits) :type probs: Optional[torch.Tensor] :param logits: Event log-odds (mutually exclusive with probs) :type logits: Optional[torch.Tensor] :param validate_args: Whether to validate arguments. Defaults to None. :type validate_args: Optional[bool] :raises ValueError: If neither or both probs and logits are provided Key Features: - Supports different total counts per batch element - Maintains PyTorch distribution interface compatibility - Efficient batched computation through internal distribution creation - Proper shape handling for multi-dimensional batch operations The implementation creates individual multinomial distributions for each batch element, allowing for flexible modeling scenarios where trial counts vary across observations. Example: >>> # Different total counts for each batch element >>> total_counts = torch.tensor([[10], [20], [15]]) >>> probs = torch.tensor([[0.3, 0.4, 0.3], ... [0.2, 0.5, 0.3], ... [0.4, 0.3, 0.3]]) >>> dist = Multinomial(total_count=total_counts, probs=probs) >>> samples = dist.sample() """ def __init__( self, total_count: "custom_types.Integer" | torch.Tensor = 1, probs: Optional[torch.Tensor] = None, logits: Optional[torch.Tensor] = None, validate_args: Optional[bool] = None, ) -> None: """Initialize multinomial distribution with inhomogeneous total counts. The initialization process validates parameters, determines batch shapes, and creates individual multinomial distributions for each batch element to enable different total counts across batches. """ # Probs or logits must be provided. Not both. if probs is None and logits is None: raise ValueError("Either `probs` or `logits` must be provided.") if probs is not None and logits is not None: raise ValueError("Only one of `probs` or `logits` can be provided.") # Are we working with probs or logits? key, values = ("probs", probs) if logits is None else ("logits", logits) # Get the shape of all but the last dimension, which is the number of categories self._batch_shape = values.shape[:-1] self._n_categories = values.shape[-1] # Broadcast total_count to the same shape as the values total_count, values = torch.broadcast_tensors( torch.as_tensor(total_count), values ) # The last dimension should have identical entries in each row for the total # count. assert torch.all(total_count[..., 0:1] == total_count) # Now we build the distributions. Start by flattening all but the last dimension. total_count = total_count[..., 0].flatten() values = values.reshape(-1, values.size(-1)) assert total_count.ndim == 1 and values.ndim == 2 assert ( len(total_count) == len(values) == torch.prod(torch.tensor(self._batch_shape)) # Batch size ) # Build a multinomial distribution for each batch element self.distributions = [ dist.Multinomial( **{ "total_count": N.item(), key: values[i], "validate_args": validate_args, } ) for i, N in enumerate(total_count) ]
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor: """Compute log-probabilities for observed multinomial outcomes. :param value: Observed counts for each category :type value: torch.Tensor :returns: Log-probabilities for the observed outcomes :rtype: torch.Tensor :raises ValueError: If value shape doesn't match expected dimensions The method validates that the input tensor has the correct shape and computes log-probabilities by calling the appropriate distribution for each batch element. """ # Ensure that the value of our multinomial is the same shape as the batch # dimension of the input values. if value.shape[:-1] != self._batch_shape: raise ValueError( f"Value shape {value.shape[:-1]} does not match batch shape {self._batch_shape}" ) if value.shape[-1] != self._n_categories: raise ValueError( f"Value shape {value.shape[-1]} does not match number of categories " f"{self._n_categories}" ) # Flatten the value tensor at all but the last dimension. value = value.reshape(-1, value.size(-1)) assert len(value) == len(self.distributions) return torch.stack( [d.log_prob(v) for d, v in zip(self.distributions, value)] ).reshape(self._batch_shape)
[docs] def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: """Generate samples from the multinomial distribution. :param sample_shape: Shape of samples to generate. Defaults to empty. :type sample_shape: torch.Size :returns: Sampled multinomial outcomes :rtype: torch.Tensor Generates samples by calling the sample method of each individual distribution and properly reshaping the results to maintain the expected batch and sample dimensions. """ # Make the samples. Each sample comes from the batch. We reshape each sample # to match the original shape, stack the samples, then reshape to get the # appropriate sample dimension return torch.stack( [d.sample(sample_shape=sample_shape) for d in self.distributions], dim=-2 ).reshape((*sample_shape, *self._batch_shape, self._n_categories))
[docs] class MultinomialProb(Multinomial, CustomDistribution): """Multinomial distribution parameterized by probabilities. This class provides a specialized interface for multinomial distributions where the parameters are specified as probabilities rather than logits. It's a convenience wrapper around the base :py:class:`~scistanpy.model.components.custom_distributions.custom_torch_dists.Multinomial` class. :param total_count: Total number of trials for each batch element. Defaults to 1. :type total_count: Union[custom_types.Integer, torch.Tensor] :param probs: Event probabilities (must sum to 1) :type probs: Optional[torch.Tensor] :param validate_args: Whether to validate arguments. Defaults to None. :type validate_args: Optional[bool] This parameterization is natural when working with probability vectors that are already normalized, such as output from softmax functions or empirical frequency estimates. Example: >>> # Probability parameterization >>> probs = torch.softmax(torch.randn(3, 4), dim=-1) >>> total_counts = torch.tensor([[100], [200], [150]]) >>> dist = MultinomialProb(total_count=total_counts, probs=probs) """ def __init__( self, total_count: "custom_types.Integer" | torch.Tensor = 1, probs: Optional[torch.Tensor] = None, validate_args: Optional[bool] = None, ) -> None: """Initialize probability-parameterized multinomial distribution. :param total_count: Total trials per batch element :param probs: Probability parameters (must be provided) :param validate_args: Validation flag """ # Call the parent class with probs super().__init__( total_count=total_count, probs=probs, validate_args=validate_args )
[docs] class MultinomialLogit(Multinomial, CustomDistribution): """Multinomial distribution parameterized by logits. This class provides a specialized interface for multinomial distributions where the parameters are specified as logits (log-odds) rather than probabilities. It's a convenience wrapper around the base :py:class:`~scistanpy.model.components.custom_distributions.custom_torch_dists.Multinomial` class. :param total_count: Total number of trials for each batch element. Defaults to 1. :type total_count: Union[custom_types.Integer, torch.Tensor] :param logits: Event logits (log-odds) :type logits: Optional[torch.Tensor] :param validate_args: Whether to validate arguments. Defaults to None. :type validate_args: Optional[bool] Example: >>> # Logit parameterization >>> logits = torch.randn(3, 4) # No normalization needed >>> total_counts = torch.tensor([[50], [75], [100]]) >>> dist = MultinomialLogit(total_count=total_counts, logits=logits) """ def __init__( self, total_count: "custom_types.Integer" | torch.Tensor = 1, logits: Optional[torch.Tensor] = None, validate_args: Optional[bool] = None, ) -> None: """Initialize logit-parameterized multinomial distribution. :param total_count: Total trials per batch element :param logits: Logit parameters (must be provided) :param validate_args: Validation flag """ # Call the parent class with logits super().__init__( total_count=total_count, logits=logits, validate_args=validate_args )
[docs] class MultinomialLogTheta(MultinomialLogit): """Multinomial distribution with normalized log-probabilities. This class extends :py:class:`~scistanpy.model.components.custom_distributions.custom_torch_dists.MultinomialLogit` with the additional constraint that the input log-probabilities must already be normalized (i.e., their exponentials sum to 1). This is useful when working with log-probability vectors that are guaranteed to be valid probability distributions. :param total_count: Total number of trials for each batch element. Defaults to 1. :type total_count: Union[custom_types.Integer, torch.Tensor] :param log_probs: Normalized log-probabilities (exp(log_probs) must sum to 1) :type log_probs: Optional[torch.Tensor] :param validate_args: Whether to validate arguments. Defaults to None. :type validate_args: Optional[bool] :raises AssertionError: If log_probs is None :raises AssertionError: If log_probs are not properly normalized This parameterization is particularly useful when: - Working with log-space normalized probability vectors - Ensuring numerical precision in log-space computations - Interfacing with other log-space probability calculations The normalization constraint is enforced at initialization to prevent invalid probability distributions. Example: >>> # Normalized log-probabilities >>> logits = torch.randn(3, 4) >>> log_probs = torch.log_softmax(logits, dim=-1) >>> total_counts = torch.tensor([[100], [200], [150]]) >>> dist = MultinomialLogTheta(total_count=total_counts, log_probs=log_probs) """ def __init__( self, total_count: "custom_types.Integer" | torch.Tensor = 1, log_probs: Optional[torch.Tensor] = None, validate_args: Optional[bool] = None, ) -> None: """Initialize normalized log-probability multinomial distribution. :param total_count: Total trials per batch element :param log_probs: Normalized log-probability parameters :param validate_args: Validation flag Validates that log_probs are properly normalized before initialization. """ # Make sure the log_probs are normalized assert log_probs is not None, "log_probs must be provided" assert torch.allclose( log_probs.exp().sum(dim=-1), torch.ones_like(log_probs[..., 0]) ), "log_probs must be normalized to sum to 1" # Otherwise, we can just call the parent class super().__init__( total_count=total_count, logits=log_probs, validate_args=validate_args, )
[docs] class Normal(dist.Normal): """Enhanced normal distribution with numerically stable log-space functions. This class extends PyTorch's standard Normal distribution with improved implementations of log-CDF and log-survival functions that provide better numerical stability, particularly in the tails of the distribution. The enhanced methods use PyTorch's special functions that are specifically designed for numerical stability in extreme value computations. Key Improvements: - Numerically stable log-CDF computation using ``log_ndtr`` - Stable log-survival function using symmetry properties - Maintains full compatibility with PyTorch's Normal interface - Better precision for extreme tail probabilities These improvements are particularly important for: - Extreme value analysis - Tail probability computations - Log-likelihood calculations with extreme parameter values Example: >>> # Enhanced normal distribution >>> normal = Normal(loc=0.0, scale=1.0) >>> # Stable computation of very small tail probabilities >>> extreme_value = torch.tensor(10.0) >>> log_tail_prob = normal.log_cdf(extreme_value) # Numerically stable """ # pylint: disable=not-callable, abstract-method
[docs] def log_cdf(self, value: torch.Tensor) -> torch.Tensor: """Compute logarithm of cumulative distribution function. :param value: Values at which to evaluate log-CDF :type value: torch.Tensor :returns: Log-CDF values :rtype: torch.Tensor Uses PyTorch's ``special.log_ndtr`` function for numerical stability, which is specifically designed to handle extreme values without overflow or underflow issues. """ return torch.special.log_ndtr((value - self.loc) / self.scale)
[docs] def log_sf(self, value: torch.Tensor) -> torch.Tensor: """Compute logarithm of survival function (1 - CDF). :param value: Values at which to evaluate log-survival function :type value: torch.Tensor :returns: Log-survival function values :rtype: torch.Tensor Leverages the symmetry of the normal distribution to compute the survival function as the CDF evaluated at the reflection about the mean. This approach maintains numerical stability while avoiding direct computation of 1 - CDF. """ # We take advantage of the symmetry of the normal distribution. The CDF # evaluated at the reflection about the mean will be the survival function. return torch.special.log_ndtr((self.loc - value) / self.scale)
[docs] class LogNormal(dist.LogNormal): """Enhanced log-normal distribution with numerically stable log-space functions. This class extends PyTorch's standard LogNormal distribution with improved implementations of log-CDF and log-survival functions for better numerical stability, particularly important given the log-normal's heavy tail behavior. Key Improvements: - Stable log-CDF computation using ``log_ndtr`` - Numerically stable log-survival function - Maintains compatibility with PyTorch's LogNormal interface - Better handling of extreme values in both tails The log-normal distribution is particularly sensitive to numerical issues because of its relationship to the normal distribution through logarithmic transformation and its heavy-tailed nature. Example: >>> # Enhanced log-normal distribution >>> lognormal = LogNormal(loc=0.0, scale=1.0) >>> # Stable computation for extreme values >>> large_value = torch.tensor(1000.0) >>> log_tail_prob = lognormal.log_sf(large_value) # Numerically stable """ # pylint: disable=not-callable, abstract-method
[docs] def log_cdf(self, value: torch.Tensor) -> torch.Tensor: """Compute logarithm of cumulative distribution function. :param value: Values at which to evaluate log-CDF :type value: torch.Tensor :returns: Log-CDF values :rtype: torch.Tensor Transforms the problem to the underlying normal distribution for stable computation using log_ndtr. """ return torch.special.log_ndtr((torch.log(value) - self.loc) / self.scale)
[docs] def log_sf(self, value: torch.Tensor) -> torch.Tensor: """Compute logarithm of survival function. :param value: Values at which to evaluate log-survival function :type value: torch.Tensor :returns: Log-survival function values :rtype: torch.Tensor Uses the relationship between log-normal and normal distributions to compute stable log-survival probabilities. """ return torch.special.log_ndtr((self.loc - torch.log(value)) / self.scale)
# pylint: disable=abstract-method
[docs] class Lomax(dist.transformed_distribution.TransformedDistribution, CustomDistribution): r"""Lomax distribution implementation (shifted Pareto distribution). The Lomax distribution is a shifted version of the Pareto distribution, also known as the Pareto Type II distribution. It's implemented as a transformed Pareto distribution with an affine transformation. :param lambda_: Scale parameter (must be positive) :type lambda_: torch.Tensor :param alpha: Shape parameter (must be positive) :type alpha: torch.Tensor :param args: Additional arguments for the base distribution :param kwargs: Additional keyword arguments for the base distribution Mathematical Definition: .. math:: \begin{align*} \text{If } X &\sim \text{Pareto}(\lambda, \alpha), \text{then } \\ \\ Y &\sim \text{Lomax}(\lambda, \alpha), \text{where } \\ \\ Y &= X - \lambda \end{align*} The distribution is implemented using PyTorch's TransformedDistribution framework with a Pareto base distribution and an affine transformation. Example: >>> # Lomax distribution for modeling heavy-tailed phenomena >>> lambda_param = torch.tensor(1.0) >>> alpha_param = torch.tensor(2.0) >>> lomax = Lomax(lambda_=lambda_param, alpha=alpha_param) >>> samples = lomax.sample((1000,)) """ def __init__(self, lambda_: torch.Tensor, alpha: torch.Tensor, *args, **kwargs): """Initialize Lomax distribution as transformed Pareto. :param lambda_: Scale parameter :param alpha: Shape parameter :param args: Additional base distribution arguments :param kwargs: Additional base distribution keyword arguments """ # Build the base distribution and the transforms (just a shift in the output) base_dist = dist.Pareto(scale=lambda_, alpha=alpha) transforms = [dist.transforms.AffineTransform(loc=-lambda_, scale=1)] super().__init__(base_dist, transforms, *args, **kwargs)
[docs] class ExpLomax( dist.transformed_distribution.TransformedDistribution, CustomDistribution ): r"""Exponential-Lomax distribution implementation. This distribution is created by taking the logarithm of a Lomax-distributed random variable. It's useful for modeling log-scale phenomena that exhibit heavy-tailed behavior. :param lambda_: Scale parameter for the underlying Lomax distribution :type lambda_: torch.Tensor :param alpha: Shape parameter for the underlying Lomax distribution :type alpha: torch.Tensor :param args: Additional arguments for the base distribution :param kwargs: Additional keyword arguments for the base distribution Mathematical Definition: .. math:: \begin{align*} \text{If } X &\sim \text{Lomax}(\lambda, \alpha), \text{then } \\ \\ Y &= \log(X) \sim \text{ExpLomax}(\lambda, \alpha) \end{align*} """ def __init__(self, lambda_: torch.Tensor, alpha: torch.Tensor, *args, **kwargs): """Initialize Exponential-Lomax distribution. :param lambda_: Scale parameter :param alpha: Shape parameter :param args: Additional arguments :param kwargs: Additional keyword arguments """ # Build the base distribution (Lomax) and transforms (log) base_dist = Lomax(lambda_=lambda_, alpha=alpha) transforms = [dist.transforms.ExpTransform().inv] super().__init__(base_dist, transforms, *args, **kwargs)
[docs] class ExpExponential( dist.transformed_distribution.TransformedDistribution, CustomDistribution ): r"""Exponential-Exponential distribution implementation. This distribution is created by taking the logarithm of an exponentially distributed random variable. It's also known as the Gumbel distribution and is useful for extreme value modeling. :param rate: Rate parameter for the underlying exponential distribution :type rate: torch.Tensor :param args: Additional arguments for the base distribution :param kwargs: Additional keyword arguments for the base distribution Mathematical Definition: .. math:: \begin{align*} \text{If } X &\sim \text{Exponential}(\text{rate}), \text{then } \\ \\ Y &= \log(X) \sim \text{ExpExponential}(\text{rate}) \end{align*} """ def __init__(self, rate: torch.Tensor, *args, **kwargs): """Initialize Exponential-Exponential distribution. :param rate: Rate parameter for base exponential distribution :param args: Additional arguments :param kwargs: Additional keyword arguments """ # Build the base distribution (Exponential) and transforms (log) base_dist = dist.Exponential(rate=rate) transforms = [dist.transforms.ExpTransform().inv] super().__init__(base_dist, transforms, *args, **kwargs)
[docs] class ExpDirichlet( dist.transformed_distribution.TransformedDistribution, CustomDistribution ): r"""Exponential-Dirichlet distribution implementation. This distribution is created by taking the element-wise logarithm of a Dirichlet-distributed random vector. It's useful for modeling log-scale compositional data and log-probability vectors. :param concentration: Concentration parameters for the underlying Dirichlet :type concentration: torch.Tensor :param args: Additional arguments for the base distribution :param kwargs: Additional keyword arguments for the base distribution Mathematical Definition: .. math:: \begin{align*} \text{If } X &\sim \text{Dirichlet}(\alpha), \text{then } \\ \\ Y &= \log(X) \sim \text{ExpDirichlet}(\alpha) \end{align*} This distribution is particularly valuable when working with probability vectors in log-space, where it maintains the simplex constraint through the exponential transformation. """ def __init__(self, concentration: torch.Tensor, *args, **kwargs): """Initialize Exponential-Dirichlet distribution. :param concentration: Concentration parameter vector :param args: Additional arguments :param kwargs: Additional keyword arguments """ # Build the base distribution (Dirichlet) and transforms (log) base_dist = dist.Dirichlet(concentration=concentration) transforms = [dist.transforms.ExpTransform().inv] super().__init__(base_dist, transforms, *args, **kwargs)
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor: """ Compute log-probabilities for Exponential-Dirichlet outcomes. The PyTorch implementation applies a Jacobian correction element-wise, neglecting the simplex constraint. This method adjusts the elementwise log probability to correct for this. See `discussion <https://discourse.mc-stan.org/t/log-simplex-constraints/39782/5>`_ on the Stan forums. """ # Get the base log probability from the parent class base_log_prob = super().log_prob(value) # Make adjustments return ( base_log_prob + 0.5 * torch.log( torch.tensor(value.size(-1), dtype=value.dtype, device=value.device) ) - value[..., torch.tensor([-1], device=value.device)] )
# pylint: enable=abstract-method