Custom PyTorch Distributions API Reference¶
Custom PyTorch distribution implementations for SciStanPy models.
This module provides specialized PyTorch distribution classes that extend or modify the standard PyTorch distributions to meet specific requirements of SciStanPy modeling. These distributions handle edge cases, provide numerical stability improvements, and enable functionality not available in the standard PyTorch distribution library.
- Key Features:
Extended Multinomial: Support for inhomogeneous total counts
Numerical Stability: Improved log-space probability computations
Custom Distributions: Implementations of distributions not in PyTorch
SciStanPy Integration: Designed for compatibility with SciStanPy parameter types
The distributions in this submodule can be broadly broken down into the following categories:
- Multinomial Extensions: Enhanced multinomial distributions
Multinomial
: Base class with inhomogeneous total count supportMultinomialProb
: Probability-parameterized multinomialMultinomialLogit
: Logit-parameterized multinomialMultinomialLogTheta
: Normalized log-probability multinomial
- Numerically Stable Distributions: Improved standard distributions
- Custom Distribution Implementations: New distribution types
Lomax
: Shifted Pareto distributionExpLomax
: Exponential-Lomax distributionExpExponential
: Exponential-Exponential distributionExpDirichlet
: Exponential-Dirichlet distribution
The distributions in this module are designed to work within PyTorch’s distribution framework while providing the specific functionality required for probabilistic modeling in SciStanPy.
Base Classes¶
All custom distributions inherit from scistanpy.model.components.custom_distributions.custom_torch_dists.CustomDistribution
, which adds no additional functionality beyond PyTorch’s torch.distributions.Distribution
but serves as a common ancestor for type checking and future extensions.
- class scistanpy.model.components.custom_distributions.custom_torch_dists.CustomDistribution[source]¶
Bases:
object
Base marker class for custom SciStanPy distributions.
This class serves as a marker interface for custom distribution implementations in SciStanPy. It doesn’t provide any functionality but is useful for type hinting and identifying custom distributions in the codebase.
All custom distribution classes should inherit from this class to maintain consistency and enable type checking.
Multinomial Extensions¶
The multinomial distributions extend PyTorch’s built-in multinomial capabilities to support inhomogeneous total counts and various parameterizations.
- class scistanpy.model.components.custom_distributions.custom_torch_dists.Multinomial(
- total_count: 'custom_types.Integer' | torch.Tensor = 1,
- probs: torch.Tensor | None = None,
- logits: torch.Tensor | None = None,
- validate_args: bool | None = None,
Bases:
CustomDistribution
Extended multinomial distribution supporting inhomogeneous total counts.
This class extends the functionality of PyTorch’s standard multinomial distribution to support different total counts across batch dimensions. The standard PyTorch implementation requires all trials to have the same total count, but this implementation allows each batch element to have its own total count.
- Parameters:
total_count (Union[custom_types.Integer, torch.Tensor]) – Total number of trials for each batch element. Defaults to 1.
probs (Optional[torch.Tensor]) – Event probabilities (mutually exclusive with logits)
logits (Optional[torch.Tensor]) – Event log-odds (mutually exclusive with probs)
validate_args (Optional[bool]) – Whether to validate arguments. Defaults to None.
- Raises:
ValueError – If neither or both probs and logits are provided
- Key Features:
Supports different total counts per batch element
Maintains PyTorch distribution interface compatibility
Efficient batched computation through internal distribution creation
Proper shape handling for multi-dimensional batch operations
The implementation creates individual multinomial distributions for each batch element, allowing for flexible modeling scenarios where trial counts vary across observations.
- Example:
>>> # Different total counts for each batch element >>> total_counts = torch.tensor([[10], [20], [15]]) >>> probs = torch.tensor([[0.3, 0.4, 0.3], ... [0.2, 0.5, 0.3], ... [0.4, 0.3, 0.3]]) >>> dist = Multinomial(total_count=total_counts, probs=probs) >>> samples = dist.sample()
- log_prob(value: torch.Tensor) torch.Tensor [source]¶
Compute log-probabilities for observed multinomial outcomes.
- Parameters:
value (torch.Tensor) – Observed counts for each category
- Returns:
Log-probabilities for the observed outcomes
- Return type:
torch.Tensor
- Raises:
ValueError – If value shape doesn’t match expected dimensions
The method validates that the input tensor has the correct shape and computes log-probabilities by calling the appropriate distribution for each batch element.
- sample(sample_shape: torch.Size = torch.Size) torch.Tensor [source]¶
Generate samples from the multinomial distribution.
- Parameters:
sample_shape (torch.Size) – Shape of samples to generate. Defaults to empty.
- Returns:
Sampled multinomial outcomes
- Return type:
torch.Tensor
Generates samples by calling the sample method of each individual distribution and properly reshaping the results to maintain the expected batch and sample dimensions.
- class scistanpy.model.components.custom_distributions.custom_torch_dists.MultinomialProb(
- total_count: 'custom_types.Integer' | torch.Tensor = 1,
- probs: torch.Tensor | None = None,
- validate_args: bool | None = None,
Bases:
Multinomial
,CustomDistribution
Multinomial distribution parameterized by probabilities.
This class provides a specialized interface for multinomial distributions where the parameters are specified as probabilities rather than logits. It’s a convenience wrapper around the base
Multinomial
class.- Parameters:
total_count (Union[custom_types.Integer, torch.Tensor]) – Total number of trials for each batch element. Defaults to 1.
probs (Optional[torch.Tensor]) – Event probabilities (must sum to 1)
validate_args (Optional[bool]) – Whether to validate arguments. Defaults to None.
This parameterization is natural when working with probability vectors that are already normalized, such as output from softmax functions or empirical frequency estimates.
- Example:
>>> # Probability parameterization >>> probs = torch.softmax(torch.randn(3, 4), dim=-1) >>> total_counts = torch.tensor([[100], [200], [150]]) >>> dist = MultinomialProb(total_count=total_counts, probs=probs)
- class scistanpy.model.components.custom_distributions.custom_torch_dists.MultinomialLogit(
- total_count: 'custom_types.Integer' | torch.Tensor = 1,
- logits: torch.Tensor | None = None,
- validate_args: bool | None = None,
Bases:
Multinomial
,CustomDistribution
Multinomial distribution parameterized by logits.
This class provides a specialized interface for multinomial distributions where the parameters are specified as logits (log-odds) rather than probabilities. It’s a convenience wrapper around the base
Multinomial
class.- Parameters:
total_count (Union[custom_types.Integer, torch.Tensor]) – Total number of trials for each batch element. Defaults to 1.
logits (Optional[torch.Tensor]) – Event logits (log-odds)
validate_args (Optional[bool]) – Whether to validate arguments. Defaults to None.
- Example:
>>> # Logit parameterization >>> logits = torch.randn(3, 4) # No normalization needed >>> total_counts = torch.tensor([[50], [75], [100]]) >>> dist = MultinomialLogit(total_count=total_counts, logits=logits)
- class scistanpy.model.components.custom_distributions.custom_torch_dists.MultinomialLogTheta(
- total_count: 'custom_types.Integer' | torch.Tensor = 1,
- log_probs: torch.Tensor | None = None,
- validate_args: bool | None = None,
Bases:
MultinomialLogit
Multinomial distribution with normalized log-probabilities.
This class extends
MultinomialLogit
with the additional constraint that the input log-probabilities must already be normalized (i.e., their exponentials sum to 1). This is useful when working with log-probability vectors that are guaranteed to be valid probability distributions.- Parameters:
total_count (Union[custom_types.Integer, torch.Tensor]) – Total number of trials for each batch element. Defaults to 1.
log_probs (Optional[torch.Tensor]) – Normalized log-probabilities (exp(log_probs) must sum to 1)
validate_args (Optional[bool]) – Whether to validate arguments. Defaults to None.
- Raises:
AssertionError – If log_probs is None
AssertionError – If log_probs are not properly normalized
- This parameterization is particularly useful when:
Working with log-space normalized probability vectors
Ensuring numerical precision in log-space computations
Interfacing with other log-space probability calculations
The normalization constraint is enforced at initialization to prevent invalid probability distributions.
- Example:
>>> # Normalized log-probabilities >>> logits = torch.randn(3, 4) >>> log_probs = torch.log_softmax(logits, dim=-1) >>> total_counts = torch.tensor([[100], [200], [150]]) >>> dist = MultinomialLogTheta(total_count=total_counts, log_probs=log_probs)
Numerically Stable Distributions¶
PyTorch does not have inbuilt support for numerically stable log-CDF and log-survival functions. SciStanPy provides enhanced versions of the torch.distributions.Normal
and torch.distributions.LogNormal
distributions that include such functions.
- class scistanpy.model.components.custom_distributions.custom_torch_dists.Normal(*args: Any, **kwargs: Any)[source]¶
Bases:
Normal
Enhanced normal distribution with numerically stable log-space functions.
This class extends PyTorch’s standard Normal distribution with improved implementations of log-CDF and log-survival functions that provide better numerical stability, particularly in the tails of the distribution.
The enhanced methods use PyTorch’s special functions that are specifically designed for numerical stability in extreme value computations.
- Key Improvements:
Numerically stable log-CDF computation using
log_ndtr
Stable log-survival function using symmetry properties
Maintains full compatibility with PyTorch’s Normal interface
Better precision for extreme tail probabilities
- These improvements are particularly important for:
Extreme value analysis
Tail probability computations
Log-likelihood calculations with extreme parameter values
- Example:
>>> # Enhanced normal distribution >>> normal = Normal(loc=0.0, scale=1.0) >>> # Stable computation of very small tail probabilities >>> extreme_value = torch.tensor(10.0) >>> log_tail_prob = normal.log_cdf(extreme_value) # Numerically stable
- log_cdf(value: torch.Tensor) torch.Tensor [source]¶
Compute logarithm of cumulative distribution function.
- Parameters:
value (torch.Tensor) – Values at which to evaluate log-CDF
- Returns:
Log-CDF values
- Return type:
torch.Tensor
Uses PyTorch’s
special.log_ndtr
function for numerical stability, which is specifically designed to handle extreme values without overflow or underflow issues.
- log_sf(value: torch.Tensor) torch.Tensor [source]¶
Compute logarithm of survival function (1 - CDF).
- Parameters:
value (torch.Tensor) – Values at which to evaluate log-survival function
- Returns:
Log-survival function values
- Return type:
torch.Tensor
Leverages the symmetry of the normal distribution to compute the survival function as the CDF evaluated at the reflection about the mean. This approach maintains numerical stability while avoiding direct computation of 1 - CDF.
- class scistanpy.model.components.custom_distributions.custom_torch_dists.LogNormal(*args: Any, **kwargs: Any)[source]¶
Bases:
LogNormal
Enhanced log-normal distribution with numerically stable log-space functions.
This class extends PyTorch’s standard LogNormal distribution with improved implementations of log-CDF and log-survival functions for better numerical stability, particularly important given the log-normal’s heavy tail behavior.
- Key Improvements:
Stable log-CDF computation using
log_ndtr
Numerically stable log-survival function
Maintains compatibility with PyTorch’s LogNormal interface
Better handling of extreme values in both tails
The log-normal distribution is particularly sensitive to numerical issues because of its relationship to the normal distribution through logarithmic transformation and its heavy-tailed nature.
- Example:
>>> # Enhanced log-normal distribution >>> lognormal = LogNormal(loc=0.0, scale=1.0) >>> # Stable computation for extreme values >>> large_value = torch.tensor(1000.0) >>> log_tail_prob = lognormal.log_sf(large_value) # Numerically stable
- log_cdf(value: torch.Tensor) torch.Tensor [source]¶
Compute logarithm of cumulative distribution function.
- Parameters:
value (torch.Tensor) – Values at which to evaluate log-CDF
- Returns:
Log-CDF values
- Return type:
torch.Tensor
Transforms the problem to the underlying normal distribution for stable computation using log_ndtr.
- log_sf(value: torch.Tensor) torch.Tensor [source]¶
Compute logarithm of survival function.
- Parameters:
value (torch.Tensor) – Values at which to evaluate log-survival function
- Returns:
Log-survival function values
- Return type:
torch.Tensor
Uses the relationship between log-normal and normal distributions to compute stable log-survival probabilities.
Custom Distribution Implementations¶
SciStanPy includes several custom distributions not available in PyTorch, implemented by extending or transforming existing PyTorch distributions.
- class scistanpy.model.components.custom_distributions.custom_torch_dists.Lomax(*args: Any, **kwargs: Any)[source]¶
Bases:
TransformedDistribution
,CustomDistribution
Lomax distribution implementation (shifted Pareto distribution).
The Lomax distribution is a shifted version of the Pareto distribution, also known as the Pareto Type II distribution. It’s implemented as a transformed Pareto distribution with an affine transformation.
- Parameters:
lambda (torch.Tensor) – Scale parameter (must be positive)
alpha (torch.Tensor) – Shape parameter (must be positive)
args – Additional arguments for the base distribution
kwargs – Additional keyword arguments for the base distribution
- Mathematical Definition:
- \[\begin{split}\begin{align*} \text{If } X &\sim \text{Pareto}(\lambda, \alpha), \text{then } \\ \\ Y &\sim \text{Lomax}(\lambda, \alpha), \text{where } \\ \\ Y &= X - \lambda \end{align*}\end{split}\]
The distribution is implemented using PyTorch’s TransformedDistribution framework with a Pareto base distribution and an affine transformation.
- Example:
>>> # Lomax distribution for modeling heavy-tailed phenomena >>> lambda_param = torch.tensor(1.0) >>> alpha_param = torch.tensor(2.0) >>> lomax = Lomax(lambda_=lambda_param, alpha=alpha_param) >>> samples = lomax.sample((1000,))
- class scistanpy.model.components.custom_distributions.custom_torch_dists.ExpLomax(*args: Any, **kwargs: Any)[source]¶
Bases:
TransformedDistribution
,CustomDistribution
Exponential-Lomax distribution implementation.
This distribution is created by taking the logarithm of a Lomax-distributed random variable. It’s useful for modeling log-scale phenomena that exhibit heavy-tailed behavior.
- Parameters:
lambda (torch.Tensor) – Scale parameter for the underlying Lomax distribution
alpha (torch.Tensor) – Shape parameter for the underlying Lomax distribution
args – Additional arguments for the base distribution
kwargs – Additional keyword arguments for the base distribution
- Mathematical Definition:
- \[\begin{split}\begin{align*} \text{If } X &\sim \text{Lomax}(\lambda, \alpha), \text{then } \\ \\ Y &= \log(X) \sim \text{ExpLomax}(\lambda, \alpha) \end{align*}\end{split}\]
- class scistanpy.model.components.custom_distributions.custom_torch_dists.ExpExponential(*args: Any, **kwargs: Any)[source]¶
Bases:
TransformedDistribution
,CustomDistribution
Exponential-Exponential distribution implementation.
This distribution is created by taking the logarithm of an exponentially distributed random variable. It’s also known as the Gumbel distribution and is useful for extreme value modeling.
- Parameters:
rate (torch.Tensor) – Rate parameter for the underlying exponential distribution
args – Additional arguments for the base distribution
kwargs – Additional keyword arguments for the base distribution
- Mathematical Definition:
- \[\begin{split}\begin{align*} \text{If } X &\sim \text{Exponential}(\text{rate}), \text{then } \\ \\ Y &= \log(X) \sim \text{ExpExponential}(\text{rate}) \end{align*}\end{split}\]
- class scistanpy.model.components.custom_distributions.custom_torch_dists.ExpDirichlet(*args: Any, **kwargs: Any)[source]¶
Bases:
TransformedDistribution
,CustomDistribution
Exponential-Dirichlet distribution implementation.
This distribution is created by taking the element-wise logarithm of a Dirichlet-distributed random vector. It’s useful for modeling log-scale compositional data and log-probability vectors.
- Parameters:
concentration (torch.Tensor) – Concentration parameters for the underlying Dirichlet
args – Additional arguments for the base distribution
kwargs – Additional keyword arguments for the base distribution
- Mathematical Definition:
- \[\begin{split}\begin{align*} \text{If } X &\sim \text{Dirichlet}(\alpha), \text{then } \\ \\ Y &= \log(X) \sim \text{ExpDirichlet}(\alpha) \end{align*}\end{split}\]
This distribution is particularly valuable when working with probability vectors in log-space, where it maintains the simplex constraint through the exponential transformation.
- log_prob(value: torch.Tensor) torch.Tensor [source]¶
Compute log-probabilities for Exponential-Dirichlet outcomes. The PyTorch implementation applies a Jacobian correction element-wise, neglecting the simplex constraint. This method adjusts the elementwise log probability to correct for this.
See discussion on the Stan forums.