Model Components API Reference

Core model components for SciStanPy probabilistic modeling framework.

This submodule contains the fundamental building blocks for constructing probabilistic models in SciStanPy. It provides a comprehensive set of components that enable users to define complex probabilistic models through composition of simple, well-defined elements.

Model Components Submodule Overview

The model components submodule is itself broken into additional submodules:

Most relevant to the typical end user are the parameters, constants, and transformed_parameters submodules, which provide concrete implementations of various types of model components. Note that the transformed_parameters submodule is not typically imported directly, but rather its functionality is accessed through mathematical operations on other component types (e.g., inbuilt Python operators like +, -, *, /, and functions in scistanpy.operations).

The remainder of this page highlights the key design principles, architectural features, and usage patterns of the model components framework.

Key Design Principles

Compositional Design:

All components follow a compositional design that enables complex model construction through combination of simple elements:

import scistanpy as ssp
import numpy as np

class MyModel(ssp.Model):
    def __init__(self, x_data, observations):

      # Record default data
      super().__init__(default_data = {"observed": observations})

      # Basic components
      self.intercept = ssp.parameters.Normal(mu=0.0, sigma=5.0)
      self.slope = ssp.parameters.Normal(mu=0.0, sigma=2.0)
      self.noise = ssp.parameters.LogNormal(mu=0.0, sigma=1.0)

      # Composition through mathematical operations
      linear_predictor = self.intercept + self.slope * x_data

      # Further composition
      self.observed = ssp.parameters.Normal(mu=linear_predictor, sigma=self.noise)

model_instance = MyModel(x_data=np.linspace(0, 10, 50), observations=np.random.randn(50))

In the above example, simple components (intercept, slope, noise) are combined through arithmetic operations to create a more complex component (linear_predictor), which is then used as the mean of the observed data distribution.

Automatic Dependency Tracking:

Components automatically track their relationships to enable proper Stan code generation and sampling:

# Dependencies are tracked automatically
print(f"Linear predictor depends on: {[p.model_varname for p in  model_instance.observed.parents]}")
print(f"Intercept is used by: {[c.model_varname for c in model_instance.intercept.children]}")

Shape Broadcasting:

Components automatically handle shape broadcasting following NumPy conventions:

# Automatic shape inference and broadcasting
scalar_param = ssp.parameters.Normal(mu=0.0, sigma=1.0)          # Shape: ()
vector_param = ssp.parameters.Normal(mu=0.0, sigma=1.0, shape=(5,)) # Shape: (5,)

# Broadcasting in operations
broadcasted = scalar_param + vector_param  # Result shape: (5,)

# Multi-dimensional broadcasting
matrix_param = ssp.parameters.Normal(mu=0.0, sigma=1.0, shape=(3, 5))
result = vector_param + matrix_param  # Result shape: (3, 5)

Additional Usage Patterns

Parent-Child Architecture:

# Building hierarchical relationships
global_mean = ssp.parameters.Normal(mu=0, sigma=5)
group_means = ssp.parameters.Normal(
    mu=global_mean,          # Parent relationship
    sigma=1.0,
    shape=(10,)              # 10 groups
)
observations = ssp.parameters.Normal(
    mu=group_means,          # Another parent relationship
    sigma=0.5,
    observable=True
)

# Explore relationships
print(f"Global mean children: {len(global_mean.children)}")
print(f"Group means parents: {[p.model_varname for p in group_means.parents]}")

Dependency Graph Navigation:

# Walk up the dependency tree
def show_dependencies(component, level=0):
    indent = "  " * level
    print(f"{indent}{component.model_varname} ({component.__class__.__name__})")
    for parent in component.parents:
        show_dependencies(parent, level + 1)

show_dependencies(observations)

Stan Code Generation Framework:

# Automatic Stan variable declarations
stan_dtype = y.get_stan_dtype()  # "real" for scalar normal
declaration = y.get_stan_parameter_declaration()

# Multi-dimensional declarations
matrix_param = ssp.parameters.Normal(mu=0, sigma=1, shape=(5, 3))
matrix_decl = matrix_param.get_stan_dtype()  # "array[5] vector[3]"

Sampling and Drawing Interface:

# Hierarchical sampling
samples, all_draws = y.draw(n=1000)

# Access all component draws
for component, draws in all_draws.items():
    print(f"{component.model_varname}: shape {draws.shape}")

Multi-dimensional Indexing:

# Advanced indexing support
matrix_param = ssp.parameters.Normal(mu=0, sigma=1, shape=(10, 5))

# Index into subcomponents
row_slice = matrix_param[2, :]    # Third row
column_slice = matrix_param[:, 1] # Second column
element = matrix_param[3, 4]      # Single element

Model Structure Analysis:

def analyze_model_structure(component):
    """Analyze the structure of a model component tree."""

    # Find all components in the tree
    components = set([component])
    for _, current, relative in component.walk_tree(walk_down=False):
        components.add(current)
        components.add(relative)

    # Categorize components
    parameters = [c for c in components if isinstance(c, ssp.parameters.Parameter)]
    constants = [c for c in components if isinstance(c, ssp.constants.Constant)]
    transforms = [c for c in components if hasattr(c, '_transformation')]

    print(f"Model structure analysis:")
    print(f"  Total components: {len(components)}")
    print(f"  Parameters: {len(parameters)}")
    print(f"  Constants: {len(constants)}")
    print(f"  Transformations: {len(transforms)}")

    return {
        'components': components,
        'parameters': parameters,
        'constants': constants,
        'transformations': transforms
    }

Performance Considerations

Efficient Construction:

# Efficient: Single multi-dimensional parameter
efficient = ssp.parameters.Normal(mu=0.0, sigma=1.0, shape=(100, 50))

# Less efficient: Many individual parameters
# inefficient = [[ssp.parameters.Normal(mu=0.0, sigma=1.0)
#                 for j in range(50)] for i in range(100)]

Other Notable Features

Automatic Validation:

Components perform comprehensive validation during construction:

# Automatic parameter validation
try:
    invalid_param = ssp.parameters.Beta(
        alpha=-1,  # Must be positive
        beta=2
    )
except ValueError as e:
    print(f"Parameter validation failed: {e}")

# Shape compatibility validation
try:
    incompatible = ssp.parameters.Normal(
        mu=np.zeros((3, 4)),
        sigma=np.ones((5, 2)),  # Incompatible shape
    )
except ValueError as e:
    print(f"Shape compatibility failed: {e}")

Constraint Enforcement:

Components automatically enforce distributional constraints:

# Constraint checking during sampling
positive_param = ssp.parameters.Gamma(alpha=2, beta=1)
samples, _ = positive_param.draw(n=100)
assert np.all(samples >= 0)  # Automatic constraint enforcement

# Simplex constraint enforcement
simplex_param = ssp.parameters.Dirichlet(alpha=[1, 1, 1])
simplex_samples, _ = simplex_param.draw(n=100)
assert np.allclose(simplex_samples.sum(axis=-1), 1)

Automatic Shape Inference:

# Broadcasting follows NumPy rules
a = ssp.parameters.Normal(mu=0, sigma=1, shape=(5, 1))
b = ssp.parameters.Normal(mu=0, sigma=1, shape=(3,))

# Combination automatically broadcasts to (5, 3)
combined = a + b
print(f"Broadcasted shape: {combined.shape}")

Shape Validation:

try:
    # Incompatible shapes raise clear errors
    incompatible = ssp.parameters.Normal(
        mu=np.zeros((3, 4)),
        sigma=np.ones((2, 5)),  # Incompatible shape
    )
except ValueError as e:
    print(f"Shape error: {e}")

Variable Declaration System:

# Automatic Stan type inference
real_param = ssp.parameters.Normal(mu=0, sigma=1)
int_param = ssp.parameters.Poisson(lambda_=5)
simplex_param = ssp.parameters.Dirichlet(alpha=[1, 1, 1])

print(real_param.get_stan_dtype())     # "real"
print(int_param.get_stan_dtype())      # "int<lower=0>"
print(simplex_param.get_stan_dtype())  # "simplex[3]"

Bound Constraint Handling:

# Automatic bound detection
positive_param = ssp.parameters.Gamma(alpha=2, beta=1)
bounded_param = ssp.parameters.Beta(alpha=2, beta=3)

print(positive_param.get_stan_dtype())  # "real<lower=0.0>"
print(bounded_param.get_stan_dtype())   # "real<lower=0.0, upper=1.0>"

Index Management for Multi-dimensional Arrays:

# Automatic indexing for Stan loops
param_3d = ssp.parameters.Normal(mu=0, sigma=1, shape=(4, 5, 3))

# Get indexed variable name for Stan code
index_opts = ('i', 'j', 'k')
indexed_name = param_3d.get_indexed_varname(index_opts)
# Result: "param_3d[i,j]" (last dimension vectorized)

Shape Compatibility Checking:

# Shape compatibility validation
try:
    incompatible = ssp.parameters.Normal(
        mu=np.zeros((3, 4)),
        sigma=np.ones((5, 2)),  # Incompatible
        shape=(2, 2)            # Also incompatible
    )
except ValueError as e:
    print(f"Shape compatibility: {e}")

Bound Violation Detection:

# Runtime bound checking during sampling
param = ssp.parameters.Beta(alpha=1, beta=1)
try:
    # This would violate Beta bounds during internal validation
    samples, _ = param.draw(n=100)
    # Automatic validation ensures samples ∈ (0, 1)
except Exception as e:
    print(f"Bound violation: {e}")