# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Abstract base classes for SciStanPy model components.
This module defines the foundational abstract class that forms the core
architecture of SciStanPy model components, including parameters, constants, and
transformations. Users typically do not interact with this module directly; instead,
they use the concrete implementations provided in the
:py:mod:`scistanpy.model.components.constants` and
:py:mod:`scistanpy.model.components.parameters` submodules.
The module establishes the common functionality that all model components must
implement.
Core Abstractions:
- **Component Hierarchy**: Parent-child relationships between model elements
- **Stan Code Generation**: Automatic translation to Stan programming language
- **Shape Broadcasting**: Automatic handling of multi-dimensional parameters
- **Dependency Management**: Tracking and validation of component relationships
Key Responsibilities:
- Define abstract interfaces for model component behavior
- Implement common functionality for shape handling and validation
- Provide Stan code generation template
- Manage component relationships and dependency graphs
- Handle parameter bounds and constraints
- Support sampling and drawing from component distributions
Stan Integration: The abstract base provides core Stan code generation capabilities including:
- Variable declarations with appropriate types and constraints
- Index management for multi-dimensional arrays
- Target increment and transformation assignment generation
- Support function inclusion for custom distributions
Component Relationships: The hierarchy system enables complex model construction through:
- Parent-child linkage for dependency tracking
- Parameter name resolution and validation
- Automatic shape broadcasting across related components
- Tree traversal for model analysis and code generation
This foundational layer enables the construction of sophisticated probabilistic
models while maintaining type safety and automatic Stan code generation.
"""
# pylint: disable=too-many-lines
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Literal, Optional, overload, TYPE_CHECKING, Union
import numpy as np
import numpy.typing as npt
import torch
from scistanpy.exceptions import NumpySampleError
from scistanpy import utils
# Lazy imports for performance and to avoid circular imports
constants_module = utils.lazy_import("scistanpy.model.components.constants")
parameters = utils.lazy_import("scistanpy.model.components.parameters")
transformed_data = utils.lazy_import(
"scistanpy.model.components.transformations.transformed_data"
)
transformed_parameters = utils.lazy_import(
"scistanpy.model.components.transformations.transformed_parameters"
)
if TYPE_CHECKING:
from scistanpy import custom_types
[docs]
class AbstractModelComponent(ABC):
"""Abstract base class for all SciStanPy model components.
This class defines the fundamental interface and common functionality for
all elements in a SciStanPy probabilistic model. It provides the foundation
for :py:mod:`~scistanpy.model.components.parameters`,
:py:mod:`~scistanpy.model.components.constants`,
:py:mod:`~scistanpy.model.components.transformations.transformed_parameters`,
and other model components.
:param shape: Shape of the component array. Defaults to scalar ().
:type shape: Union[tuple[custom_types.Integer, ...], custom_types.Integer]
:param model_params: Named parameters that this component depends on
:type model_params: custom_types.CombinableParameterType
:cvar POSITIVE_PARAMS: Set of parameter names that must be positive
:cvar NEGATIVE_PARAMS: Set of parameter names that must be negative
:cvar SIMPLEX_PARAMS: Set of parameter names that must be simplexes
:cvar LOG_SIMPLEX_PARAMS: Set of parameter names that must be log-simplexes
:cvar BASE_STAN_DTYPE: Base Stan data type for this component
:cvar LOWER_BOUND: Lower bound constraint for component values
:cvar UPPER_BOUND: Upper bound constraint for component values
:cvar IS_SIMPLEX: Whether this component represents a simplex
:cvar IS_LOG_SIMPLEX: Whether this component represents a log-simplex
:cvar FORCE_PARENT_NAME: Whether to force naming of parent variables in Stan code
:cvar FORCE_LOOP_RESET: Whether to force loop reset in Stan code
The class provides core functionality for:
- Component relationship management (parents and children)
- Shape validation and broadcasting
- Stan code generation for variable declarations and operations
- Sampling and drawing from component distributions
- Tree traversal for model analysis
All model components must implement the abstract methods for drawing samples
and generating Stan code appropriate to their type.
"""
# Define allowed ranges for the parameters used to define this one
POSITIVE_PARAMS: set[str] = set()
"""Class variable giving the set of parent parameter names that must be positive."""
NEGATIVE_PARAMS: set[str] = set()
"""Class variable giving the set of parent parameter names that must be negative."""
SIMPLEX_PARAMS: set[str] = set()
"""
Class variable giving the set of parent parameter names that must be simplexes.
The sum over the last dimension for any parent parameter named here must equal 1.
"""
LOG_SIMPLEX_PARAMS: set[str] = set()
"""
Class variable giving the set of parent parameter names that must be log-simplexes.
The sum of exponentials over the last dimension for any parent parameter named
here must equal 1
"""
# Define the stan data type
BASE_STAN_DTYPE: Literal["real", "int", "simplex"] = "real"
"""Class variable giving the base Stan data type for this component."""
# Define the bounds for this parameter
LOWER_BOUND: Optional["custom_types.Float" | "custom_types.Integer"] = None
"""
Class variable giving the lower bound constraint for component values. None
if unbounded.
"""
UPPER_BOUND: Optional["custom_types.Float" | "custom_types.Integer"] = None
"""
Class variable giving the upper bound constraint for component values. None
if unbounded.
"""
IS_SIMPLEX: bool = False
"""
Class variable giving whether this component represents a simplex (elements
over last dim sum to 1).
"""
IS_LOG_SIMPLEX: bool = False
"""
Class variable giving whether this component represents a log-simplex (exponentials
over last dim sum to 1).
"""
# Do we force parents of this parameter to be named in Stan code?
FORCE_PARENT_NAME: bool = False
"""
Class variable noting whether to force naming of parent variables in Stan code.
If ``True`` and not provided by the user, parent parameters will be assigned
names automatically in the Stan code.
"""
# Do we want the loop to be reset in Stan code?
FORCE_LOOP_RESET: bool = False
"""
Class variable noting whether to force loop reset in Stan code. This is useful
where this parameter's shape makes it appear nestable with another inside the
same loop, but it actually is not.
"""
def __init__( # pylint: disable=unused-argument
self,
*,
shape: tuple["custom_types.Integer", ...] | "custom_types.Integer" = (),
**model_params: "custom_types.CombinableParameterType",
):
"""Initialize a model component with specified shape and parameters.
:param shape: Shape of the component.
:type shape: Union[tuple[custom_types.Integer, ...], custom_types.Integer]
:param model_params: Named parameters this component depends on
:type model_params: custom_types.CombinableParameterType
The initialization process:
1. Normalizes shape specification to tuple format (i.e., integer to 1-element tuple)
2. Validates parameter constraints and bounds
3. Converts non-component parameters to constants
4. Establishes parent-child relationships
5. Validates and sets component shape through broadcasting
"""
# Convert shape to the appropriate type
try:
len(shape)
except TypeError:
shape = (shape,)
# Define placeholder variables
self._model_varname: str = "" # SciStanPy Model variable name
self._parents: dict[str, AbstractModelComponent]
self._component_to_paramname: dict[AbstractModelComponent, str]
self._shape: tuple["custom_types.Integer", ...] = shape
self._children: list[AbstractModelComponent] = [] # Children of the component
# Validate incoming parameters
self._validate_parameters(model_params)
# Set parents
self._set_parents(model_params)
# Link parent and child parameters
for val in self._parents.values():
val._record_child(self)
# Set the shape
self._set_shape()
def _validate_parameters(
self,
model_params: dict[str, "custom_types.CombinableParameterType"],
) -> None:
"""Validate parameter constraints and bounds.
:param model_params: Dictionary of parameter names to values
:type model_params: dict[str, custom_types.CombinableParameterType]
:raises ValueError: If required bounded parameters are missing
:raises ValueError: If bounds are invalid (lower >= upper)
:raises ValueError: If simplex parameters have incompatible bounds
:raises ValueError: If log-simplex parameters have incompatible bounds
This method ensures that:
- All bounded parameters are present in the parameter dictionary
- Lower bounds are less than upper bounds
- Simplex and log-simplex parameters have appropriate bound constraints
"""
# All bounded parameters must be named in the parameter dictionary
if missing_names := (
self.POSITIVE_PARAMS
| self.NEGATIVE_PARAMS
| self.SIMPLEX_PARAMS
| self.LOG_SIMPLEX_PARAMS
) - set(model_params.keys()):
raise ValueError(
f"{', '.join(missing_names)} are bounded parameters but are missing "
"from those defined"
)
# Lower bounds must be below upper bounds
if (
self.LOWER_BOUND is not None
and self.UPPER_BOUND is not None
and self.LOWER_BOUND >= self.UPPER_BOUND
):
raise ValueError("Lower bound must be less than upper bound")
# If THIS parameter is a simplex, then the upper and lower bounds cannot
# be set
if self.IS_SIMPLEX:
if self.LOWER_BOUND is not None and self.LOWER_BOUND != 0.0:
raise ValueError("Simplex parameters cannot have lower bounds")
if self.UPPER_BOUND is not None and self.UPPER_BOUND != 1.0:
raise ValueError("Simplex parameters cannot have upper bounds")
# If THIS parameter is a log simplex, then the upper and lower bounds cannot
# be set
if self.IS_LOG_SIMPLEX:
if self.LOWER_BOUND is not None and self.LOWER_BOUND != -np.inf:
raise ValueError("Log simplex parameters cannot have lower bounds")
if self.UPPER_BOUND is not None and self.UPPER_BOUND != 0.0:
raise ValueError("Log simplex parameters cannot have upper bounds")
def _set_parents(
self,
model_params: dict[str, "custom_types.CombinableParameterType"],
) -> None:
"""Establish parent component relationships with automatic constant creation.
:param model_params: Dictionary of parameter names to values/components
:type model_params: dict[str, custom_types.CombinableParameterType]
This method processes the input parameters and:
1. Preserves existing AbstractModelComponent instances as parents
2. Converts non-component values to Constant instances with appropriate bounds
3. Creates bidirectional mapping between components and parameter names
4. Applies parameter type constraints (positive, negative, simplex, etc.)
"""
# Convert any non-model components to model components, making sure to
# propagate any restrictions on
self._parents = {}
for name, val in model_params.items():
# Just the value if an AbstractModelComponent
if isinstance(val, AbstractModelComponent):
self._parents[name] = val
continue
# Otherwise, convert to a constant model component with the appropriate
# bounds
if name in self.POSITIVE_PARAMS:
lower_bound = 0
upper_bound = None
elif name in self.NEGATIVE_PARAMS or name in self.LOG_SIMPLEX_PARAMS:
lower_bound = None
upper_bound = 0
elif name in self.SIMPLEX_PARAMS:
lower_bound = 0
upper_bound = 1
else:
lower_bound = None
upper_bound = None
self._parents[name] = constants_module.Constant(
value=val,
lower_bound=lower_bound,
upper_bound=upper_bound,
)
# Map components to param names
self._component_to_paramname = {v: k for k, v in self._parents.items()}
assert len(self._component_to_paramname) == len(self._parents)
def _record_child(self, child: "AbstractModelComponent") -> None:
"""Record a child component in the dependency graph.
:param child: Child component that depends on this component
:type child: AbstractModelComponent
:raises AssertionError: If child is already recorded
This method maintains the bidirectional parent-child relationships
that form the model dependency graph. Each child is recorded only
once to prevent duplicate dependencies.
"""
# If the child is already in the list of children, then we don't need to
# add it again
assert child not in self._children, "Child already recorded"
# Record the child
self._children.append(child)
def _set_shape(self) -> None:
"""Validate and set component shape through broadcasting with parents.
:raises ValueError: If shape is not broadcastable with parent shapes
:raises ValueError: If provided shape conflicts with broadcasted shape
This method:
1. Collects shapes from all parent components
2. Attempts to broadcast the component shape with parent shapes
3. Validates that the final shape is consistent
4. Sets the component's shape to the broadcasted result
Shape broadcasting follows NumPy broadcasting rules, ensuring that
multi-dimensional model components can interact properly.
"""
# The shape must be broadcastable to the shapes of the parameters.
try:
parent_shapes = [param.shape for param in self.parents]
broadcasted_shape = np.broadcast_shapes(self._shape, *parent_shapes)
except ValueError as error:
raise ValueError(
f"Shape is not broadcastable to parent shapes while initializing instance "
f"of {self.__class__.__name__}: {','.join(map(str, parent_shapes))} "
f"not broadcastable to {self._shape}"
) from error
# The broadcasted shape must be the same as the shape of the parameter if
# it is not 0-dimensional.
if broadcasted_shape != self._shape and self._shape != ():
raise ValueError(
"Provided shape does not match broadcasted shapes of parents while "
f"initializing instance of {self.__class__.__name__}. {self._shape} "
f"!= {broadcasted_shape}"
)
# Set the shape
self._shape = broadcasted_shape
[docs]
def get_child_paramnames(self) -> dict["AbstractModelComponent", str]:
"""Get parameter names that this component defines in its children.
:returns: Mapping from child components to parameter names they use for this component
:rtype: dict[AbstractModelComponent, str]
:raises AssertionError: If a child references this component with multiple parameter names
This method analyzes the dependency graph to determine how child
components reference this component. Each child should reference
this component through exactly one parameter name.
"""
# Set up the dictionary to return
child_paramnames = {}
# Process all children
for child in self._children:
# Make sure the child is not already recorded
assert child not in child_paramnames
# Get the name of the parameter in the child that the bound parameter
# defines. There should only be one parameter in the child that the
# bound parameter defines.
names = [
paramname
for paramname, param in child._parents.items() # pylint: disable=protected-access
if param is self
]
assert len(names) == 1
# Record the name
child_paramnames[child] = names[0]
return child_paramnames
[docs]
def get_indexed_varname(
self,
index_opts: tuple[str, ...] | None,
offset: "custom_types.Integer" = 0,
start_dim: "custom_types.Integer" = 0,
end_dim: Optional["custom_types.Integer"] = -1,
_name_override: str = "",
) -> str:
"""Generate Stan variable name with appropriate indexing.
:param index_opts: Index variable names to choose from.
:type index_opts: Optional[tuple[str, ...]]
:param offset: Number of leading indices to skip. Defaults to 0.
:type offset: custom_types.Integer
:param start_dim: First dimension to include in indexing. Defaults to 0.
:type start_dim: custom_types.Integer
:param end_dim: Last dimension to include in indexing. Defaults to -1.
:type end_dim: Optional[custom_types.Integer]
:param _name_override: Override for base variable name. Defaults to "".
Internal use only.
:type _name_override: str
:returns: Stan variable name with proper indexing
:rtype: str
This method generates proper Stan variable names for multi-dimensional
components, handling:
- Singleton dimension skipping
- Index offset management for broadcasting
- Dimension range selection
- Automatic vectorization of the last dimension
The offset parameter accounts for implicit singleton dimensions prepended
during broadcasting between parent and child components.
"""
# Offset must be >= 0
assert offset >= 0, self.model_varname
# If end dim is not provided, set it to the number of dimensions
if end_dim is None:
end_dim = self.ndim
# If the name override is provided, then we use that name
base_name = _name_override or self.stan_model_varname
# If there are no indices, then we just return the variable name
if self.ndim == 0 or index_opts is None:
return base_name
# First and last dim must be positive integers
start_dim = start_dim if start_dim >= 0 else self.ndim + start_dim
end_dim = end_dim if end_dim >= 0 else self.ndim + end_dim
assert end_dim >= start_dim
# Skip singleton dimensions. All others get the index options.
indices = [
index_opts[index_ind]
for index_ind, dimsize in enumerate(self.shape[start_dim:end_dim], offset)
if dimsize != 1
]
# If there are no indices, then we just return the variable name
if len(indices) == 0:
return base_name
# Build the indexed variable name
indexed_varname = f"{base_name}[{','.join(indices)}]"
return indexed_varname
@abstractmethod
def _draw(
self,
level_draws: dict[str, Union[npt.NDArray, "custom_types.Float"]],
seed: Optional["custom_types.Integer"],
) -> Union[npt.NDArray, "custom_types.Float", "custom_types.Integer"]:
"""Draw a single sample from this component's distribution.
:param level_draws: Samples from parent components for this draw
:type level_draws: dict[str, Union[npt.NDArray, custom_types.Float]]
:param seed: Random seed for reproducible sampling
:type seed: Optional[custom_types.Integer]
:returns: Sample(s) from this component's distribution
:rtype: Union[npt.NDArray, custom_types.Float, custom_types.Integer]
This abstract method must be implemented by all concrete component
classes to define how samples are drawn from their specific
distribution or deterministic function.
"""
[docs]
def draw(
self,
n: "custom_types.Integer",
*,
_drawn: Optional[dict["AbstractModelComponent", npt.NDArray]] = None,
seed: Optional["custom_types.Integer"] = None,
) -> tuple[npt.NDArray, dict["AbstractModelComponent", npt.NDArray]]:
"""Recursively draw samples from this component and its dependency tree.
:param n: Number of samples to draw
:type n: custom_types.Integer
:param _drawn: Cache of previously drawn samples. Auto-created if None. Defaults to None.
Internal use only. Used to cache draws throughout recursion.
:type _drawn: Optional[dict[AbstractModelComponent, npt.NDArray]]
:param seed: Random seed for reproducible sampling. Defaults to None.
:type seed: Optional[custom_types.Integer]
:returns: Tuple of (samples_from_this_component, all_drawn_samples)
:rtype: tuple[npt.NDArray, dict[AbstractModelComponent, npt.NDArray]]
:raises NumpySampleError: If sampling fails due to parameter issues
:raises AssertionError: If drawn values violate component bounds or constraints
This method implements the complete sampling workflow:
1. Recursively draws from parent components if not already drawn
2. Collects parent samples for the current level
3. Draws n samples from this component using parent values
4. Validates drawn samples against bounds and constraints
5. Returns samples and updates the global draw cache
The method enforces constraint validation including:
- Lower and upper bound checking
- Simplex sum-to-one validation
- Parameter type constraint validation
"""
# Build the _drawn dictionary if it is not already built
if _drawn is None:
_drawn = {}
# Loop over the parents and draw from them if we haven't already
level_draws: dict[str, npt.NDArray] = {}
for paramname, parent in self._parents.items():
# If the parent has been drawn, use the already drawn value. Otherwise,
# draw from the parent.
if parent in _drawn:
parent_draw = _drawn[parent]
else:
parent_draw, _ = parent.draw(n, _drawn=_drawn, seed=seed)
_drawn[parent] = parent_draw
# Record the parent draw for this level
level_draws[paramname] = parent_draw
# Now draw from the current parameter
try:
draws = np.stack(
[
self._draw(
{k: v[i] for k, v in level_draws.items()},
seed=(None if seed is None else seed + i),
)
for i in range(n)
]
)
except ValueError as error:
raise NumpySampleError(
f"Error encountered when trying to sample from {self.model_varname}: {error}"
) from error
assert self.LOWER_BOUND is None or np.all(draws >= self.LOWER_BOUND), (
f"Draw from `{self}` must be greater than or equal to {self.LOWER_BOUND} "
f"but got {draws.min()}"
)
assert self.UPPER_BOUND is None or np.all(draws <= self.UPPER_BOUND)
assert not self.IS_SIMPLEX or np.allclose(np.sum(draws, axis=-1), 1)
# Update the _drawn dictionary
assert self not in _drawn
_drawn[self] = draws
# Test the ranges of the draws
for paramname, param in self._parents.items():
if paramname in self.POSITIVE_PARAMS:
assert np.all(_drawn[param] >= 0)
elif paramname in self.NEGATIVE_PARAMS:
assert np.all(_drawn[param] <= 0)
elif paramname in self.SIMPLEX_PARAMS:
assert np.all((_drawn[param] >= 0) & (_drawn[param] <= 1))
assert np.allclose(np.sum(_drawn[param], axis=-1), 1)
elif paramname in self.LOG_SIMPLEX_PARAMS:
assert np.all(_drawn[param] < 0)
assert np.allclose(np.sum(np.exp(_drawn[param]), axis=-1), 1)
return draws, _drawn
[docs]
def walk_tree(
self, walk_down: bool = True, _recursion_depth: "custom_types.Integer" = 1
) -> list[tuple[int, "AbstractModelComponent", "AbstractModelComponent"]]:
"""Traverse the model component dependency tree.
:param walk_down: Whether to walk toward children (True) or parents (False).
Defaults to True.
:type walk_down: bool
:param _recursion_depth: Current recursion depth (internal parameter).
Defaults to 1.
:type _recursion_depth: custom_types.Integer
:returns: List of (depth, current_component, relative_component) tuples
:rtype: list[tuple[int, AbstractModelComponent, AbstractModelComponent]]
This method enables systematic traversal of the model dependency graph
in either direction. Each tuple contains:
- Recursion depth relative to the starting component
- The current component in the traversal
- The relative component (child if walking down, parent if walking up)
Tree traversal is useful for:
- Model structure analysis and visualization
- Dependency validation and cycle detection
- Code generation ordering
- Model component discovery
"""
# Get the variables to loop over
relatives = self.children if walk_down else self.parents
# Recurse
to_return = []
for relative in relatives:
# Add the current parameter and the relative parameter to the list
to_return.append((_recursion_depth, self, relative))
# Recurse on the relative parameter
to_return.extend(
relative.walk_tree(
walk_down=walk_down, _recursion_depth=_recursion_depth + 1
)
)
return to_return
[docs]
def get_shared_leading(
self, other: "AbstractModelComponent"
) -> "custom_types.Integer":
"""Determine the number of compatible leading dimensions with another component.
:param other: Component to compare shapes with
:type other: AbstractModelComponent
:returns: Number of shared leading dimensions
:rtype: custom_types.Integer
This method analyzes shape compatibility between components by counting
the number of leading dimensions that are compatible according to
broadcasting rules. Dimensions are compatible if they are equal or
if at least one of them is 1 (singleton).
Shape compatibility is important for:
- Determining indexing strategies
- Validating broadcasting operations
- Optimizing Stan code generation
"""
# Define the compatibility level
compat_level = 0
# Get the extent to which the shapes this and the other components are compatible.
for i, (prev_dimsize, current_dimsize) in enumerate(
zip(self.shape, other.shape)
):
# If the dimensions are equal or at least one is 1, they are compatible
# at this level of indentation. Otherwise, we break the loop, as we
# have reached a position where the shapes are not compatible.
if (
prev_dimsize == current_dimsize
or prev_dimsize == 1
or current_dimsize == 1
):
compat_level = i
else:
break
return compat_level
[docs]
def get_supporting_functions(self) -> list[str]:
"""Get Stan function definitions required by this component.
:returns: List of Stan function definition strings
:rtype: list[str]
This method returns Stan function definitions or include statements
that must be added to the Stan program to support this component.
The default implementation returns an empty list.
Custom components may override this method to include:
- Custom distribution definitions
- Helper function implementations
- Include statements for external function libraries
"""
# The default is no supporting functions
return []
[docs]
def get_target_incrementation(
self, index_opts: tuple[str, ...] # pylint: disable=unused-argument
) -> str:
"""Generate Stan code for log-probability target increments. This is the
statement that appears in the ``model`` block of the Stan program.
:param index_opts: Index variable names for multi-dimensional access
:type index_opts: tuple[str, ...]
:returns: Stan code for target increment (empty by default)
:rtype: str
This method generates Stan code for the model block, where
log-probability contributions are added to the target density.
The default implementation returns an empty string.
Probabilistic components should override this method to provide
appropriate target increment statements for their distributions.
"""
return ""
[docs]
def get_index_offset(
self, query: Union[str, "AbstractModelComponent"], offset_adjustment: int = 0
) -> int:
"""Calculate index offset for multi-dimensional variable access.
:param query: Component or parameter name to calculate offset for
:type query: Union[str, AbstractModelComponent]
:param offset_adjustment: Additional offset to apply. Defaults to 0.
:type offset_adjustment: int
:returns: Number of leading indices to skip
:rtype: int
:raises KeyError: If query string doesn't match any parent parameter
This method calculates the appropriate index offset when accessing
parent components that have different numbers of dimensions. The
offset accounts for implicit singleton dimensions that are added
during broadcasting operations (e.g., the offset between ``x`` with shape
(10,) and ``y`` with shape (2, 10) would be "1").
Index offsets are essential for:
- Proper multi-dimensional array indexing in Stan code
- Handling broadcasting between components of different shapes
- Maintaining correct dimension alignment in generated code
"""
# Get the query if we need to
if isinstance(query, str):
query = self._parents[query]
# Calculate offset
return self.ndim - query.ndim + offset_adjustment
[docs]
@abstractmethod
def get_right_side(
self,
index_opts: tuple[str, ...] | None,
start_dims: dict[str, "custom_types.Integer"] | None = None,
end_dims: dict[str, "custom_types.Integer"] | None = None,
offset_adjustment: int = 0,
) -> dict[str, str]:
"""Generate Stan code for the right-hand side of statements.
:param index_opts: Index variable names for multi-dimensional access
:type index_opts: Optional[tuple[str, ...]]
:param start_dims: First indexable dimension for each parent parameter.
Defaults to None.
:type start_dims: Optional[dict[str, custom_types.Integer]]
:param end_dims: Last indexable dimension for each parent parameter. Defaults
to None.
:type end_dims: Optional[dict[str, custom_types.Integer]]
:param offset_adjustment: Index offset adjustment. Defaults to 0.
:type offset_adjustment: int
:returns: Dictionary mapping parameter names to Stan code strings
:rtype: dict[str, str]
This abstract method must be implemented by all model components
to generate appropriate Stan code for probability statements, transformations,
and assignments. The method processes parent components and returns
properly formatted Stan expressions.
The base implementation in this abstract class provides common
functionality for:
- Processing parent component relationships
- Handling index offsets and dimension slicing
- Determining when to use variable names vs. inline expressions
- Managing transformed parameter code generation
Subclasses extend this foundation to generate component-specific
Stan code patterns.
"""
# Get default values for start and end dims
start_dims = start_dims or {}
end_dims = end_dims or {}
# Get variables that make up the right side of the statement. These will
# be the parent parameters of the current parameter.
model_components: dict[str, str] = {}
for name, param in self._parents.items():
# What's the offset between the current parameter in the loop and THIS
# parameter
current_offset = self.get_index_offset(
param, offset_adjustment=offset_adjustment
)
# If the parameter is a constant or another parameter OR it is a named
# transformed parameter OR the assignment depth changes, we get its
# indexed variable name offset by the appropriate amount to account
# for implicit singleton dimensions
if (
isinstance(
param,
(constants_module.Constant, parameters.Parameter),
)
or param.is_named
or param.force_name
):
model_components[name] = param.get_indexed_varname(
index_opts,
offset=current_offset,
start_dim=start_dims.get(name, 0),
end_dim=end_dims.get(name, -1),
)
# Otherwise, we need to get the thread of operations that make up the
# transformation for the parameter. This is equivalent to calling the
# get_right_side method of the parameter.
elif isinstance(param, transformed_parameters.TransformedParameter):
# We need to propogate the offset of the current parameter in the
# loop to ITS parents
model_components[name] = param.get_right_side(
index_opts, offset_adjustment=current_offset
)
# Otherwise, raise an error
else:
raise TypeError(f"Unknown model component type {type(param)}")
return model_components
[docs]
def declare_stan_variable(self, varname: str, force_basetype: bool = False) -> str:
"""Generate Stan variable declaration with appropriate type and bounds.
:param varname: Variable name to declare
:type varname: str
:param force_basetype: Whether to force array[...] basetype format. Defaults to False.
:type force_basetype: bool
:returns: Complete Stan variable declaration
:rtype: str
This method combines the Stan data type (from
:py:meth:`~scistanpy.model.components.abstract_model_component.AbstractModelComponent.get_stan_dtype`)
with the variable name to create a complete variable declaration suitable
for use in Stan data, parameters, or other blocks.
"""
return f"{self.get_stan_dtype(force_basetype)} {varname}"
[docs]
def get_assign_depth(self) -> int:
"""Calculate the assignment depth for Stan loop structure.
:returns: Loop nesting level for this component's assignment
:rtype: int
This method determines the appropriate loop nesting level for
defining this component in Stan code. The depth is calculated as:
- Number of dimensions minus one (last dimension is vectorized)
- Minus trailing singleton dimensions (except the last)
- Clipped to a minimum of zero
Assignment depth affects:
- Loop structure in generated Stan code
- Index variable management
- Vectorization opportunities
"""
# Default is number of dimensions minus one. We subtract one because the
# last dimension is always vectorized.
level = self.ndim - 1
# Subtract off trailing singleton dimensions ignoring the last dimension
# (again, always vectorized)
for dimsize in reversed(self.shape[:-1]):
if dimsize > 1:
break
level -= 1
# Clip at zero
return max(level, 0)
[docs]
def get_stan_dtype(self, force_basetype: bool = False) -> str:
"""Generate Stan data type declaration for this component.
:param force_basetype: Whether to force array[...] format instead of vectors.
Defaults to False.
:type force_basetype: bool
:returns: Stan data type string with bounds
:rtype: str
:raises AssertionError: If unknown data type is encountered
This method generates appropriate Stan data type declarations based on:
- Base data type (real, int, simplex)
- Component dimensionality
- Bound constraints
- Whether vector/array format is preferred
"""
# Get the base datatype
dtype = self.BASE_STAN_DTYPE
# Convert shape to strings
string_shape = [str(dim) for dim in self.shape if dim > 1]
ndim = len(string_shape)
# Base data type for 0-dimensional parameters. If the parameter is 0-dimensional,
# then we can only have real or int as the data type.
if ndim == 0:
assert dtype in {"real", "int"}
return f"{dtype}{self.stan_bounds}"
# Handle different data types for different dimensions
if dtype == "int" or force_basetype:
return (
f"array[{','.join(string_shape)}] "
+ ("int" if dtype == "int" else "real")
+ self.stan_bounds
)
if dtype == "real": # Becomes vector or array of vectors
dtype = f"vector{self.stan_bounds}[{string_shape[-1]}]"
elif dtype == "simplex": # Becomes array of simplexes
dtype = f"simplex[{string_shape[-1]}]"
else:
raise AssertionError(f"Unknown data type {dtype}")
# Convert to an array of vectors or simplexes if necessary
if ndim > 1:
return f"array[{','.join(string_shape[:-1])}] {dtype}"
return dtype
[docs]
def get_stan_parameter_declaration(self, force_basetype: bool = False) -> str:
"""Generate Stan parameter declaration for this component.
:param force_basetype: Whether to force array[...] format. Defaults to False.
:type force_basetype: bool
:returns: Complete Stan parameter declaration
:rtype: str
This convenience method generates a parameter declaration using the
component's Stan model variable name and appropriate data type.
"""
return self.declare_stan_variable(
self.stan_model_varname, force_basetype=force_basetype
)
@abstractmethod
def __str__(self) -> str:
"""Return human-readable string representation of the component.
:returns: String representation showing component structure
:rtype: str
This abstract method must be implemented by all concrete components
to provide meaningful string representations for debugging and
model inspection.
"""
def __repr__(self) -> str:
"""Return detailed string representation (identical to __str__).
:returns: String representation of the component
:rtype: str
"""
return str(self)
def __contains__(self, key: str) -> bool:
"""Check if the component has a parent with the given parameter name.
:param key: Parameter name to check
:type key: str
:returns: True if parameter name exists in parents
:rtype: bool
"""
return key in self._parents
@overload
def __getitem__(self, key: str) -> "AbstractModelParameter": ...
@overload
def __getitem__(self, key: "custom_types.IndexType") -> "IndexParameter": ...
def __getitem__(self, key):
"""Access parent parameters by name or create indexed subcomponents.
:param key: Parameter name (string) or array indexing specification
:type key: Union[str, custom_types.IndexType]
:returns: Parent component or indexed subcomponent
:rtype: Union[AbstractModelComponent, IndexParameter]
This method provides two access patterns:
1. String keys return parent components by parameter name
2. Index specifications create IndexParameter subcomponents for array slicing
"""
# If a string, check the parents
if isinstance(key, str):
return self._parents[key]
# Anything else, we must be slicing or selecting the underlying distribution
if not isinstance(key, tuple):
key = (key,)
return transformed_parameters.IndexParameter(self, *key)
def __getattr__(self, key: str) -> "AbstractModelComponent":
"""Access parent parameters as attributes.
:param key: Parameter name to access
:type key: str
:returns: Parent component with the given parameter name
:rtype: AbstractModelComponent
:raises AttributeError: If parameter name not found in parents
This method enables convenient attribute-style access to parent
components using dot notation instead of bracket notation.
"""
# Make sure we don't have a circular reference between `__getattr__` and
# `__getitem__`
if key == "_parents":
raise AttributeError("No attribute '_parents' in this object")
# If the key is not in the parents, then we raise an error
try:
return self[key]
except KeyError as error:
try:
string_repr = repr(self)
except Exception as error2: # pylint: disable=broad-except
string_repr = super().__repr__()
raise AttributeError(
f"Attribute '{key}' not found in {string_repr}. "
"Encountered an error while trying to get the string representation."
) from error2
# Otherwise, just raise the original error
raise AttributeError(
f"Attribute '{key}' not found in {string_repr}"
) from error
@property
def shape(self) -> tuple[int, ...]:
"""Get the shape of this component.
:returns: Shape tuple for this component
:rtype: tuple[int, ...]
"""
return self._shape
@property
def ndim(self) -> "custom_types.Integer":
"""Get the number of dimensions of this component.
:returns: Number of dimensions
:rtype: custom_types.Integer
"""
return len(self.shape)
@property
def is_named(self) -> bool:
"""Check whether this component has an assigned variable name.
:returns: True if component has been assigned a model variable name
:rtype: bool
Examples:
.. code-block:: python
# Example model with named and unnamed parameters
class MyModel(Model):
def __init__(self):
super().__init__()
self.param1 = Parameter(...) # is_named is True
param2 = Parameter(...) # is_named is False
# Outside of a model, is_named is always False
param = Parameter(...)
assert not param.is_named
"""
return self._model_varname != ""
@property
def stan_bounds(self) -> str:
"""Generate Stan bounds specification string.
:returns: Stan bounds specification (empty if no bounds)
:rtype: str
This property formats
:py:attr:`~scistanpy.model.components.abstract_model_component.AbstractModelComponent.LOWER_BOUND`
and :py:attr:`~scistanpy.model.components.abstract_model_component.AbstractModelComponent.UPPER_BOUND`
into Stan's bound format. Returns an empty string if no bounds are specified.
"""
# Format the lower and upper bounds
lower = "" if self.LOWER_BOUND is None else f"lower={self.LOWER_BOUND}"
upper = "" if self.UPPER_BOUND is None else f"upper={self.UPPER_BOUND}"
# Combine the bounds
if lower and upper:
bounds = f"{lower}, {upper}"
elif lower:
bounds = lower
elif upper:
bounds = upper
else:
return ""
# Return the bounds
return f"<{bounds}>"
@property
def assign_depth(self) -> "custom_types.Integer":
"""Get the assignment depth for Stan loop nesting.
:returns: Loop nesting level for this component
:rtype: custom_types.Integer
This is the property form of
:py:meth:`~scistanpy.model.components.abstract_model_component.AbstractModelComponent.get_assign_depth`
and provides convenient access to the assignment depth. It determines how
deeply nested this component should be in Stan's loop structure.
"""
return self.get_assign_depth()
@property
def model_varname(self) -> str:
"""Get or generate the SciStanPy variable name for this component. Only
valid within a SciStanPy Model.
:returns: Variable name for this component
:rtype: str
If a variable name has been explicitly assigned, returns that name.
Otherwise, automatically generates a name based on child component
relationships using dot notation for hierarchical names.
Examples:
.. code-block:: python
class MyModel(Model):
def __init__(self):
super().__init__()
# Parameter without explicit name has auto-generated name
# Name is "param2.mu" based on child relationship
param1 = Parameter(...) # model_varname is "param2.mu"
# Explicitly named parameter
param2 = Parameter(mu = self.param1) # model_varname is "param2"
"""
# If the _model_varname variable is set, then we return it
if self.is_named:
return self._model_varname
# Otherwise, we automatically create the name. This is the name of the
# child components and the name of this component as defined in that child
# component separated by underscores.
return ".".join(
[
f"{child.model_varname}.{name}"
for child, name in self.get_child_paramnames().items()
if not isinstance(child, transformed_data.TransformedData)
]
)
@model_varname.setter
def model_varname(self, name: str) -> None:
"""Set the SciStanPy variable name for this component.
:param name: Variable name to assign
:type name: str
:raises ValueError: If attempting to rename an already-named component
Variable names can only be set once to prevent accidental overwrites
that could break model consistency. These are set automatically when parameters
are defined inside the `__init__` method of a SciStanPy Model subclass.
"""
# If the name is not set, then we set it
if self._model_varname == "":
self._model_varname = name
else:
raise ValueError(
"Cannot set model variable name more than once. Trying to rename "
f"{self._model_varname} to {name}"
)
@property
def stan_model_varname(self) -> str:
"""Get the Stan-compatible variable name for this component. This is identical
to :py:attr:`~scistanpy.model.components.abstract_model_component.AbstractModelComponent.model_varname`,
but with dots replaced by double underscores.
:returns: Stan variable name with dots replaced by double underscores
:rtype: str
"""
return self.model_varname.replace(".", "__")
@property
def constants(self):
"""Get all constant-valued parent components.
:returns: Dictionary mapping parameter names to constant components
:rtype: dict[str, Constant]
This property filters parent components to return only those that
are :py:class:`~scistanpy.model.components.constants.Constant` instances,
useful for identifying fixed values in the model hierarchy.
"""
return {
name: component
for name, component in self._parents.items()
if isinstance(component, constants_module.Constant)
}
@property
def parents(self) -> list["AbstractModelComponent"]:
"""Get all parent components that this component depends on.
:returns: List of parent components
:rtype: list[AbstractModelComponent]
"""
return list(self._parents.values())
@property
def children(self) -> list["AbstractModelComponent"]:
"""Get all child components that depend on this component.
:returns: Copy of the children list
:rtype: list[AbstractModelComponent]
Returns a shallow copy to prevent external modification of the internal
``_children`` list while allowing iteration and inspection.
"""
return self._children.copy()
@property
@abstractmethod
def torch_parametrization(self) -> torch.Tensor:
"""Get PyTorch tensor representation with appropriate transformations.
:returns: PyTorch tensor for this component
:rtype: torch.Tensor
This abstract property must be implemented by all concrete components
to provide PyTorch tensor representations suitable for gradient-based
computation and optimization.
"""
@property
def observable(self) -> bool:
"""Check whether this component represents observed data.
:returns: False by default (most components are not observable)
:rtype: bool
Observable components represent known data values rather than
parameters to be inferred. The default implementation returns
``False``; subclasses may override for specific behavior.
"""
return False
@property
def force_name(self) -> bool:
"""Check whether this component should be explicitly named in Stan code.
:returns: ``True`` if any child forces parent naming
:rtype: bool
This property returns ``True`` if any child component has
:py:attr:`~scistanpy.model.components.abstract_model_component.AbstractModelComponent.FORCE_PARENT_NAME`
set to ``True``, indicating that this component should be given an explicit
variable name in the generated Stan code rather than being inlined.
"""
return any(child.FORCE_PARENT_NAME for child in self._children)