Custom PyTorch Distributions API Reference

Custom PyTorch distribution implementations for SciStanPy models.

This module provides specialized PyTorch distribution classes that extend or modify the standard PyTorch distributions to meet specific requirements of SciStanPy modeling. These distributions handle edge cases, provide numerical stability improvements, and enable functionality not available in the standard PyTorch distribution library.

Key Features:
  • Extended Multinomial: Support for inhomogeneous total counts

  • Numerical Stability: Improved log-space probability computations

  • Custom Distributions: Implementations of distributions not in PyTorch

  • SciStanPy Integration: Designed for compatibility with SciStanPy parameter types

The distributions in this submodule can be broadly broken down into the following categories:

Multinomial Extensions: Enhanced multinomial distributions
Numerically Stable Distributions: Improved standard distributions
  • Normal: Enhanced with stable log-CDF and log-survival functions

  • LogNormal: Enhanced with stable log-space probability functions

Custom Distribution Implementations: New distribution types

The distributions in this module are designed to work within PyTorch’s distribution framework while providing the specific functionality required for probabilistic modeling in SciStanPy.

Base Classes

All custom distributions inherit from scistanpy.model.components.custom_distributions.custom_torch_dists.CustomDistribution, which adds no additional functionality beyond PyTorch’s torch.distributions.Distribution but serves as a common ancestor for type checking and future extensions.

class scistanpy.model.components.custom_distributions.custom_torch_dists.CustomDistribution[source]

Bases: object

Base marker class for custom SciStanPy distributions.

This class serves as a marker interface for custom distribution implementations in SciStanPy. It doesn’t provide any functionality but is useful for type hinting and identifying custom distributions in the codebase.

All custom distribution classes should inherit from this class to maintain consistency and enable type checking.

Multinomial Extensions

The multinomial distributions extend PyTorch’s built-in multinomial capabilities to support inhomogeneous total counts and various parameterizations.

class scistanpy.model.components.custom_distributions.custom_torch_dists.Multinomial(
total_count: 'custom_types.Integer' | torch.Tensor = 1,
probs: torch.Tensor | None = None,
logits: torch.Tensor | None = None,
validate_args: bool | None = None,
)[source]

Bases: CustomDistribution

Extended multinomial distribution supporting inhomogeneous total counts.

This class extends the functionality of PyTorch’s standard multinomial distribution to support different total counts across batch dimensions. The standard PyTorch implementation requires all trials to have the same total count, but this implementation allows each batch element to have its own total count.

Parameters:
  • total_count (Union[custom_types.Integer, torch.Tensor]) – Total number of trials for each batch element. Defaults to 1.

  • probs (Optional[torch.Tensor]) – Event probabilities (mutually exclusive with logits)

  • logits (Optional[torch.Tensor]) – Event log-odds (mutually exclusive with probs)

  • validate_args (Optional[bool]) – Whether to validate arguments. Defaults to None.

Raises:

ValueError – If neither or both probs and logits are provided

Key Features:
  • Supports different total counts per batch element

  • Maintains PyTorch distribution interface compatibility

  • Efficient batched computation through internal distribution creation

  • Proper shape handling for multi-dimensional batch operations

The implementation creates individual multinomial distributions for each batch element, allowing for flexible modeling scenarios where trial counts vary across observations.

Example:
>>> # Different total counts for each batch element
>>> total_counts = torch.tensor([[10], [20], [15]])
>>> probs = torch.tensor([[0.3, 0.4, 0.3],
...                       [0.2, 0.5, 0.3],
...                       [0.4, 0.3, 0.3]])
>>> dist = Multinomial(total_count=total_counts, probs=probs)
>>> samples = dist.sample()
log_prob(value: torch.Tensor) torch.Tensor[source]

Compute log-probabilities for observed multinomial outcomes.

Parameters:

value (torch.Tensor) – Observed counts for each category

Returns:

Log-probabilities for the observed outcomes

Return type:

torch.Tensor

Raises:

ValueError – If value shape doesn’t match expected dimensions

The method validates that the input tensor has the correct shape and computes log-probabilities by calling the appropriate distribution for each batch element.

sample(sample_shape: torch.Size = torch.Size) torch.Tensor[source]

Generate samples from the multinomial distribution.

Parameters:

sample_shape (torch.Size) – Shape of samples to generate. Defaults to empty.

Returns:

Sampled multinomial outcomes

Return type:

torch.Tensor

Generates samples by calling the sample method of each individual distribution and properly reshaping the results to maintain the expected batch and sample dimensions.

class scistanpy.model.components.custom_distributions.custom_torch_dists.MultinomialProb(
total_count: 'custom_types.Integer' | torch.Tensor = 1,
probs: torch.Tensor | None = None,
validate_args: bool | None = None,
)[source]

Bases: Multinomial, CustomDistribution

Multinomial distribution parameterized by probabilities.

This class provides a specialized interface for multinomial distributions where the parameters are specified as probabilities rather than logits. It’s a convenience wrapper around the base Multinomial class.

Parameters:
  • total_count (Union[custom_types.Integer, torch.Tensor]) – Total number of trials for each batch element. Defaults to 1.

  • probs (Optional[torch.Tensor]) – Event probabilities (must sum to 1)

  • validate_args (Optional[bool]) – Whether to validate arguments. Defaults to None.

This parameterization is natural when working with probability vectors that are already normalized, such as output from softmax functions or empirical frequency estimates.

Example:
>>> # Probability parameterization
>>> probs = torch.softmax(torch.randn(3, 4), dim=-1)
>>> total_counts = torch.tensor([[100], [200], [150]])
>>> dist = MultinomialProb(total_count=total_counts, probs=probs)
class scistanpy.model.components.custom_distributions.custom_torch_dists.MultinomialLogit(
total_count: 'custom_types.Integer' | torch.Tensor = 1,
logits: torch.Tensor | None = None,
validate_args: bool | None = None,
)[source]

Bases: Multinomial, CustomDistribution

Multinomial distribution parameterized by logits.

This class provides a specialized interface for multinomial distributions where the parameters are specified as logits (log-odds) rather than probabilities. It’s a convenience wrapper around the base Multinomial class.

Parameters:
  • total_count (Union[custom_types.Integer, torch.Tensor]) – Total number of trials for each batch element. Defaults to 1.

  • logits (Optional[torch.Tensor]) – Event logits (log-odds)

  • validate_args (Optional[bool]) – Whether to validate arguments. Defaults to None.

Example:
>>> # Logit parameterization
>>> logits = torch.randn(3, 4)  # No normalization needed
>>> total_counts = torch.tensor([[50], [75], [100]])
>>> dist = MultinomialLogit(total_count=total_counts, logits=logits)
class scistanpy.model.components.custom_distributions.custom_torch_dists.MultinomialLogTheta(
total_count: 'custom_types.Integer' | torch.Tensor = 1,
log_probs: torch.Tensor | None = None,
validate_args: bool | None = None,
)[source]

Bases: MultinomialLogit

Multinomial distribution with normalized log-probabilities.

This class extends MultinomialLogit with the additional constraint that the input log-probabilities must already be normalized (i.e., their exponentials sum to 1). This is useful when working with log-probability vectors that are guaranteed to be valid probability distributions.

Parameters:
  • total_count (Union[custom_types.Integer, torch.Tensor]) – Total number of trials for each batch element. Defaults to 1.

  • log_probs (Optional[torch.Tensor]) – Normalized log-probabilities (exp(log_probs) must sum to 1)

  • validate_args (Optional[bool]) – Whether to validate arguments. Defaults to None.

Raises:
  • AssertionError – If log_probs is None

  • AssertionError – If log_probs are not properly normalized

This parameterization is particularly useful when:
  • Working with log-space normalized probability vectors

  • Ensuring numerical precision in log-space computations

  • Interfacing with other log-space probability calculations

The normalization constraint is enforced at initialization to prevent invalid probability distributions.

Example:
>>> # Normalized log-probabilities
>>> logits = torch.randn(3, 4)
>>> log_probs = torch.log_softmax(logits, dim=-1)
>>> total_counts = torch.tensor([[100], [200], [150]])
>>> dist = MultinomialLogTheta(total_count=total_counts, log_probs=log_probs)

Numerically Stable Distributions

PyTorch does not have inbuilt support for numerically stable log-CDF and log-survival functions. SciStanPy provides enhanced versions of the torch.distributions.Normal and torch.distributions.LogNormal distributions that include such functions.

class scistanpy.model.components.custom_distributions.custom_torch_dists.Normal(*args: Any, **kwargs: Any)[source]

Bases: Normal

Enhanced normal distribution with numerically stable log-space functions.

This class extends PyTorch’s standard Normal distribution with improved implementations of log-CDF and log-survival functions that provide better numerical stability, particularly in the tails of the distribution.

The enhanced methods use PyTorch’s special functions that are specifically designed for numerical stability in extreme value computations.

Key Improvements:
  • Numerically stable log-CDF computation using log_ndtr

  • Stable log-survival function using symmetry properties

  • Maintains full compatibility with PyTorch’s Normal interface

  • Better precision for extreme tail probabilities

These improvements are particularly important for:
  • Extreme value analysis

  • Tail probability computations

  • Log-likelihood calculations with extreme parameter values

Example:
>>> # Enhanced normal distribution
>>> normal = Normal(loc=0.0, scale=1.0)
>>> # Stable computation of very small tail probabilities
>>> extreme_value = torch.tensor(10.0)
>>> log_tail_prob = normal.log_cdf(extreme_value)  # Numerically stable
log_cdf(value: torch.Tensor) torch.Tensor[source]

Compute logarithm of cumulative distribution function.

Parameters:

value (torch.Tensor) – Values at which to evaluate log-CDF

Returns:

Log-CDF values

Return type:

torch.Tensor

Uses PyTorch’s special.log_ndtr function for numerical stability, which is specifically designed to handle extreme values without overflow or underflow issues.

log_sf(value: torch.Tensor) torch.Tensor[source]

Compute logarithm of survival function (1 - CDF).

Parameters:

value (torch.Tensor) – Values at which to evaluate log-survival function

Returns:

Log-survival function values

Return type:

torch.Tensor

Leverages the symmetry of the normal distribution to compute the survival function as the CDF evaluated at the reflection about the mean. This approach maintains numerical stability while avoiding direct computation of 1 - CDF.

class scistanpy.model.components.custom_distributions.custom_torch_dists.LogNormal(*args: Any, **kwargs: Any)[source]

Bases: LogNormal

Enhanced log-normal distribution with numerically stable log-space functions.

This class extends PyTorch’s standard LogNormal distribution with improved implementations of log-CDF and log-survival functions for better numerical stability, particularly important given the log-normal’s heavy tail behavior.

Key Improvements:
  • Stable log-CDF computation using log_ndtr

  • Numerically stable log-survival function

  • Maintains compatibility with PyTorch’s LogNormal interface

  • Better handling of extreme values in both tails

The log-normal distribution is particularly sensitive to numerical issues because of its relationship to the normal distribution through logarithmic transformation and its heavy-tailed nature.

Example:
>>> # Enhanced log-normal distribution
>>> lognormal = LogNormal(loc=0.0, scale=1.0)
>>> # Stable computation for extreme values
>>> large_value = torch.tensor(1000.0)
>>> log_tail_prob = lognormal.log_sf(large_value)  # Numerically stable
log_cdf(value: torch.Tensor) torch.Tensor[source]

Compute logarithm of cumulative distribution function.

Parameters:

value (torch.Tensor) – Values at which to evaluate log-CDF

Returns:

Log-CDF values

Return type:

torch.Tensor

Transforms the problem to the underlying normal distribution for stable computation using log_ndtr.

log_sf(value: torch.Tensor) torch.Tensor[source]

Compute logarithm of survival function.

Parameters:

value (torch.Tensor) – Values at which to evaluate log-survival function

Returns:

Log-survival function values

Return type:

torch.Tensor

Uses the relationship between log-normal and normal distributions to compute stable log-survival probabilities.

Custom Distribution Implementations

SciStanPy includes several custom distributions not available in PyTorch, implemented by extending or transforming existing PyTorch distributions.

class scistanpy.model.components.custom_distributions.custom_torch_dists.Lomax(*args: Any, **kwargs: Any)[source]

Bases: TransformedDistribution, CustomDistribution

Lomax distribution implementation (shifted Pareto distribution).

The Lomax distribution is a shifted version of the Pareto distribution, also known as the Pareto Type II distribution. It’s implemented as a transformed Pareto distribution with an affine transformation.

Parameters:
  • lambda (torch.Tensor) – Scale parameter (must be positive)

  • alpha (torch.Tensor) – Shape parameter (must be positive)

  • args – Additional arguments for the base distribution

  • kwargs – Additional keyword arguments for the base distribution

Mathematical Definition:
\[\begin{split}\begin{align*} \text{If } X &\sim \text{Pareto}(\lambda, \alpha), \text{then } \\ \\ Y &\sim \text{Lomax}(\lambda, \alpha), \text{where } \\ \\ Y &= X - \lambda \end{align*}\end{split}\]

The distribution is implemented using PyTorch’s TransformedDistribution framework with a Pareto base distribution and an affine transformation.

Example:
>>> # Lomax distribution for modeling heavy-tailed phenomena
>>> lambda_param = torch.tensor(1.0)
>>> alpha_param = torch.tensor(2.0)
>>> lomax = Lomax(lambda_=lambda_param, alpha=alpha_param)
>>> samples = lomax.sample((1000,))
class scistanpy.model.components.custom_distributions.custom_torch_dists.ExpLomax(*args: Any, **kwargs: Any)[source]

Bases: TransformedDistribution, CustomDistribution

Exponential-Lomax distribution implementation.

This distribution is created by taking the logarithm of a Lomax-distributed random variable. It’s useful for modeling log-scale phenomena that exhibit heavy-tailed behavior.

Parameters:
  • lambda (torch.Tensor) – Scale parameter for the underlying Lomax distribution

  • alpha (torch.Tensor) – Shape parameter for the underlying Lomax distribution

  • args – Additional arguments for the base distribution

  • kwargs – Additional keyword arguments for the base distribution

Mathematical Definition:
\[\begin{split}\begin{align*} \text{If } X &\sim \text{Lomax}(\lambda, \alpha), \text{then } \\ \\ Y &= \log(X) \sim \text{ExpLomax}(\lambda, \alpha) \end{align*}\end{split}\]
class scistanpy.model.components.custom_distributions.custom_torch_dists.ExpExponential(*args: Any, **kwargs: Any)[source]

Bases: TransformedDistribution, CustomDistribution

Exponential-Exponential distribution implementation.

This distribution is created by taking the logarithm of an exponentially distributed random variable. It’s also known as the Gumbel distribution and is useful for extreme value modeling.

Parameters:
  • rate (torch.Tensor) – Rate parameter for the underlying exponential distribution

  • args – Additional arguments for the base distribution

  • kwargs – Additional keyword arguments for the base distribution

Mathematical Definition:
\[\begin{split}\begin{align*} \text{If } X &\sim \text{Exponential}(\text{rate}), \text{then } \\ \\ Y &= \log(X) \sim \text{ExpExponential}(\text{rate}) \end{align*}\end{split}\]
class scistanpy.model.components.custom_distributions.custom_torch_dists.ExpDirichlet(*args: Any, **kwargs: Any)[source]

Bases: TransformedDistribution, CustomDistribution

Exponential-Dirichlet distribution implementation.

This distribution is created by taking the element-wise logarithm of a Dirichlet-distributed random vector. It’s useful for modeling log-scale compositional data and log-probability vectors.

Parameters:
  • concentration (torch.Tensor) – Concentration parameters for the underlying Dirichlet

  • args – Additional arguments for the base distribution

  • kwargs – Additional keyword arguments for the base distribution

Mathematical Definition:
\[\begin{split}\begin{align*} \text{If } X &\sim \text{Dirichlet}(\alpha), \text{then } \\ \\ Y &= \log(X) \sim \text{ExpDirichlet}(\alpha) \end{align*}\end{split}\]

This distribution is particularly valuable when working with probability vectors in log-space, where it maintains the simplex constraint through the exponential transformation.

log_prob(value: torch.Tensor) torch.Tensor[source]

Compute log-probabilities for Exponential-Dirichlet outcomes. The PyTorch implementation applies a Jacobian correction element-wise, neglecting the simplex constraint. This method adjusts the elementwise log probability to correct for this.

See discussion on the Stan forums.