# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Core Model class for SciStanPy Bayesian modeling framework.
This module contains the fundamental :py:class:`~scistanpy.model.model.Model` class
that serves as the primary interface for building, compiling, and executing Bayesian
models in SciStanPy. The :py:class:`~scistanpy.model.model.Model` class orchestrates
the composition of model components (:py:mod:`Parameters <scistanpy.model.components.parameters>`,
:py:class:`Constants <scistanpy.model.components.constants.Constant>`, and
:py:mod:`Transformed Parameters <scistanpy.model.components.transformations.transformed_parameters>`)
that define the probabilistic structure of the model.
The :py:class:`~scistanpy.model.model.Model` class uses a metaclass pattern to automatically
register model components defined as instance attributes, enabling intuitive model
construction through simple attribute assignment. It supports multiple backends
including Stan for Hamiltonai Monte Carlo sampling and PyTorch for maximum likelihood
estimation.
"""
# pylint: disable=too-many-lines, line-too-long
from __future__ import annotations
import os.path
import pickle
import weakref
from tempfile import TemporaryDirectory
from typing import Any, Iterable, Literal, Optional, overload, TYPE_CHECKING, Union
import numpy as np
import numpy.typing as npt
import panel as pn
import torch
import xarray as xr
from scistanpy import utils
from scistanpy.defaults import (
DEFAULT_CPP_OPTIONS,
DEFAULT_DIM_NAMES,
DEFAULT_EARLY_STOP,
DEFAULT_FORCE_COMPILE,
DEFAULT_LR,
DEFAULT_MODEL_NAME,
DEFAULT_N_EPOCHS,
DEFAULT_STANC_OPTIONS,
DEFAULT_USER_HEADER,
)
from scistanpy.model.components import abstract_model_component
from scistanpy.model.components import (
constants as constants_module,
parameters as parameters_module,
)
from scistanpy.model.components.transformations import (
transformed_data,
transformed_parameters as transformed_parameters_module,
)
if TYPE_CHECKING:
from scistanpy import custom_types
from scistanpy.model.results import hmc as hmc_results
mle_module = utils.lazy_import("scistanpy.model.results.mle")
nn_module = utils.lazy_import("scistanpy.model.nn_module")
prior_predictive_module = utils.lazy_import("scistanpy.plotting.prior_predictive")
stan_model = utils.lazy_import("scistanpy.model.stan.stan_model")
def model_comps_to_dict(
model_comps: Iterable[abstract_model_component.AbstractModelComponent],
) -> dict[str, abstract_model_component.AbstractModelComponent]:
"""Convert an iterable of model components to a dictionary keyed by variable names.
This utility function creates a dictionary mapping from model variable names
to their corresponding component objects, facilitating easy lookup and
access to model components by name.
:param model_comps: Iterable of model components to convert
:type model_comps: Iterable[abstract_model_component.AbstractModelComponent]
:returns: Dictionary mapping variable names to components
:rtype: dict[str, abstract_model_component.AbstractModelComponent]
Example:
>>> components = [param1, param2, observable1]
>>> comp_dict = model_comps_to_dict(components)
>>> # Access by name: comp_dict['param1']
"""
return {comp.model_varname: comp for comp in model_comps}
[docs]
def run_delayed_mcmc(filepath: str) -> "hmc_results.SampleResults":
"""Execute a delayed MCMC run from a pickled configuration file.
This function loads and executes MCMC sampling that was previously
configured with `Model.mcmc(delay_run=True)`. It's useful for
running computationally intensive sampling jobs in batch systems
or separate processes.
:param filepath: Path to the pickled MCMC configuration file
:type filepath: str
:returns: MCMC sampling results with posterior draws and diagnostics
:rtype: hmc_results.SampleResults
The function automatically enables console output to provide progress
feedback during the potentially long-running sampling process.
Example:
>>> # First, create delayed run
>>> model.mcmc(output_dir = ".", delay_run=True, chains=4, iter_sampling=2000)
>>> # Later, execute the run
>>> results = run_delayed_mcmc(f"{model.stan_executable_path}-delay.pkl")
"""
# Load the pickled object
with open(filepath, "rb") as f:
obj = pickle.load(f)
# We will be printing to the console
obj["sample_kwargs"]["show_console"] = True
# Run sampling and return the results
return obj["stan_model"].sample(
inits=obj["inits"],
data=obj["data"],
**obj["sample_kwargs"],
)
[docs]
class Model:
"""Primary interface for Bayesian model construction and analysis in SciStanPy.
The Model class provides a declarative interface for building Bayesian models
by composing :py:mod:`Parameters <scistanpy.model.components.parameters>`,
:py:class:`Constants <scistanpy.model.components.constants.Constant>`, and
:py:mod:`Transformed Parameters <scistanpy.model.components.transformations.transformed_parameters>`.
It automatically handles component registration and validation, and provides
methods for sampling, compilation, and analysis across multiple backends.
:param default_data: Default observed data for model observables. When provided,
any instance method requiring data will use these if not otherwise provided.
Defaults to None.
:type default_data: Optional[dict[str, npt.NDArray]]
The class uses a metaclass pattern that automatically registers model
components defined as instance attributes. Components are validated
for naming conventions and automatically assigned model variable names.
Models are built by subclassing Model and defining components as
instance attributes in the __init__ method. The metaclass automatically
discovers and registers these components.
Example:
>>> class MyModel(Model):
... def __init__(self):
... super().__init__()
... self.mu = ssp.parameters.Normal(0.0, 1.0)
... self.sigma = ssp.parameters.HalfNormal(1.0)
... self.y = ssp.parameters.Normal(self.mu, self.sigma, observable=True)
>>>
>>> model = MyModel()
>>> prior_samples = model.draw(n=1000)
>>> model.prior_predictive() # Interactive dashboard for prior exploration
>>> mle_result = model.mle() # Maximum likelihood estimation
>>> mcmc_result = model.mcmc() # Hamiltonian Monte Carlo sampling in Stan
.. note::
All :py:mod:`Parameters <scistanpy.model.components.parameters>` have a
``shape`` attribute that defines their dimensionality. For example, I can
define a 2D array of parameters like this:
>>> import scistanpy as ssp
>>> self.beta = ssp.parameters.Normal(mu = 0.0, sigma = 1.0, shape=(3, 4))
This will create 12 independent Normal parameters arranged in a 3x4 array.
All SciStanPy parameters and operations support broadcasting using the same
rules as NumPy, so you can easily define complex hierarchical models with
minimal code. For example, I can define a child parameter that depends on
``beta`` like this:
>>> self.alpha = ssp.parameters.Normal(mu = self.beta, sigma = 1.0, shape = (10, 3, 4))
This will create a 10x3x4 array of Normal parameters where each slice along
the first dimension represents a set of parameters drawn from a Normal distribution
whose mean is given by a single element of ``beta``.
.. hint::
Follow these best practices when building models:
1. Use descriptive component names that reflect their scientific meaning
2. Set ``default_data`` for models with fixed datasets to streamline workflows
3. Start with prior predictive checks before fitting to real data
4. Use simulation methods to validate model implementation
5. Choose appropriate backends for different tasks (PyTorch for MLE, Stan for MCMC)
6. Validate models incrementally by building from simple to complex
"""
def __init__(
self,
*args, # pylint: disable=unused-argument
default_data: dict[str, npt.NDArray] | None = None,
**kwargs, # pylint: disable=unused-argument
):
"""This should be overridden by the subclass."""
# Set the default values for the model
self._default_data: dict[str, npt.NDArray] | None = default_data
self._named_model_components: tuple[
abstract_model_component.AbstractModelComponent, ...
] = getattr(self, "_named_model_components", ())
self._model_varname_to_object: dict[
str, abstract_model_component.AbstractModelComponent
] = getattr(self, "_model_varname_to_object", {})
self._init_complete: bool = getattr(self, "_init_complete", False)
def __init_subclass__(cls, **kwargs):
"""Configure automatic component registration for Model subclasses.
This method is called when a class inherits from Model and sets up
the metaclass behavior that automatically discovers and registers
model components defined as instance attributes.
:param kwargs: Keyword arguments passed to the subclass
:raises ValueError: If forbidden attribute names are used or naming
conventions are violated
The method wraps the subclass __init__ to add component discovery
and registration logic while preserving the original initialization
behavior.
"""
# The old __init__ method of the class is renamed to '_wrapped_init'
if "_wrapped_init" in cls.__dict__:
raise ValueError(
"The attribute `_wrapped_init` cannot be defined in `Model` subclasses"
)
# Redefine the __init__ method of the class
def __init__(
self: "Model",
*init_args,
**init_kwargs,
):
# Initialization is incomplete at this stage
self._init_complete = False
# Run the init method that was defined in the class.
cls._wrapped_init(self, *init_args, **init_kwargs)
# That's it if we are not the last subclass to be initialized
if cls is not self.__class__:
return
# If we already have model components, defined in the class, update
# them with the new model components. This situation occurs when a child
# class is defined that inherits from a parent class that is also a
# model.
named_model_components = {}
# Now we need to find all the model components that are defined in the
# class.
for attr in vars(self).keys():
if not isinstance(
retrieved := getattr(self, attr),
abstract_model_component.AbstractModelComponent,
):
continue
# Double-underscore attributes are forbidden, as this will clash
# with how we handle unnamed parameters in Stan code.
if "__" in attr:
raise ValueError(
"Model component names cannot include double underscores: "
f"{attr} is invalid."
)
# Check if the variable name starts with an underscore. This is
# forbidden in Stan code.
if attr.startswith("_"):
raise ValueError(
"Model variable names cannot start with an underscore: "
f"{attr} is invalid."
)
# Set the model variable name and record the model component
retrieved.model_varname = attr
named_model_components[attr] = retrieved
# Set the named parameters attribute
self._named_model_components = tuple(named_model_components.values())
# Build the mapping between model variable names and parameter objects
self._model_varname_to_object = self._build_model_varname_to_object()
# Initialization is complete
self._init_complete = True
# Set default data as itself. This will trigger the setter method
# and will check that the data is valid.
if self.has_default_data:
self.default_data = self.default_data
# Add the new __init__ method
cls._wrapped_init = cls.__init__
cls.__init__ = __init__
def _build_model_varname_to_object(
self,
) -> dict[str, abstract_model_component.AbstractModelComponent]:
"""Build comprehensive mapping from variable names to model components.
This method constructs a complete dictionary mapping model variable
names to their corresponding component objects. It walks the component
dependency tree starting from observables to ensure all referenced
components are included.
:returns: Dictionary mapping variable names to components
:rtype: dict[str, abstract_model_component.AbstractModelComponent]
The mapping includes:
- All observable parameters and their dependencies
- Transformed data components
- Constants and hyperparameters
- Transformed parameters used in the model
The method ensures no duplicate variable names exist and validates
the integrity of the component dependency graph.
"""
def build_initial_mapping() -> (
dict[str, abstract_model_component.AbstractModelComponent]
):
"""Builds an initial mapping of model varnames to objects."""
# Start from each observable and walk up the tree to the root
model_varname_to_object: dict[
str, abstract_model_component.AbstractModelComponent
] = {}
for observable in self.observables:
# Add the observable to the mapping
assert observable.model_varname not in model_varname_to_object
model_varname_to_object[observable.model_varname] = observable
# Add all parents to the mapping and make sure
# `parameters_module.Parameter` instances are explicitly defined.
for *_, parent in observable.walk_tree(walk_down=False):
# If the parent is already in the mapping, make sure it is the
# same
if parent.model_varname in model_varname_to_object:
assert model_varname_to_object[parent.model_varname] == parent
else:
model_varname_to_object[parent.model_varname] = parent
return model_varname_to_object
def record_transformed_data() -> None:
"""Updates the mapping with all transformed data components."""
# Add all TransformedData instances to the mapping
for component in list(model_varname_to_object.values()):
for child in component._children: # pylint: disable=protected-access
if isinstance(
child,
transformed_data.TransformedData,
):
assert child.model_varname not in model_varname_to_object
model_varname_to_object[child.model_varname] = child
# Run the steps
model_varname_to_object = build_initial_mapping()
record_transformed_data()
# There can be no duplicate values in the mapping
assert len(model_varname_to_object) == len(
set(model_varname_to_object.values())
)
return model_varname_to_object
[docs]
def get_dimname_map(
self,
) -> dict[tuple["custom_types.Integer", "custom_types.Integer"], str]:
"""Generate mapping from dimension specifications to dimension names.
This method creates a dictionary that maps dimension level and size
tuples to appropriate dimension names for xarray dataset construction.
It ensures dimension names don't conflict with model variable names.
:returns: Dictionary mapping (level, size) tuples to dimension names
:rtype: dict[tuple[custom_types.Integer, custom_types.Integer], str]
The mapping is used to create consistent dimension naming across
all xarray datasets generated from model samples, ensuring proper
coordinate alignment and data structure.
Only dimensions with size > 1 are assigned names, as singleton
dimensions are typically squeezed during xarray construction.
"""
# Set up variables
dims: dict[tuple["custom_types.Integer", "custom_types.Integer"], str] = {}
# The list of dimension options cannot overlap with variable names
allowed_dim_names = [
name
for name in DEFAULT_DIM_NAMES
if name not in self.named_model_components_dict
]
# Check sizes of all observables and record the dimension names. Dimensions
# of size '1' are not named as we do not need to distinguish them in an xarray
for observable in self.named_model_components:
for dimkey in enumerate(observable.shape[::-1]):
if dimkey not in dims and dimkey[1] > 1:
dims[dimkey] = allowed_dim_names[len(dims)]
return dims
def _compress_for_xarray(
self,
*arrays: npt.NDArray,
include_sample_dim: bool = False,
) -> list[tuple[tuple[str, ...], npt.NDArray]]:
"""Process arrays for xarray dataset construction with proper dimension naming.
This method transforms numpy arrays into the format required for xarray
dataset construction, including appropriate dimension naming and
singleton dimension handling.
:param arrays: Arrays to process for xarray construction
:type arrays: npt.NDArray
:param include_sample_dim: Whether to include sample dimension in naming.
Defaults to False.
:type include_sample_dim: bool
:returns: List of (dimension_names, processed_array) tuples
:rtype: list[tuple[tuple[str, ...], npt.NDArray]]
:raises ValueError: If array dimensions don't match expected model structure
The method:
- Identifies and removes singleton dimensions
- Assigns appropriate dimension names based on model structure
- Handles sample dimensions for drawn data (e.g., from prior predictive checks)
- Ensures dimensional consistency across all processed arrays
"""
# Get a mapping from dimension keys to dimension names
dims = self.get_dimname_map()
# Set our start and end indices for the shape
start_ind = int(include_sample_dim)
# Process each input array
processed: list[tuple[tuple[str, ...], npt.NDArray]] = []
for array_ind, array in enumerate(arrays):
# Identify singleton dimensions and named dimensions
singleton_axes, dimnames = [], []
effective_shape = array.shape[start_ind:]
for dimind, dimsize in enumerate(effective_shape[::-1]):
forward_ind = len(effective_shape) - 1 - dimind + start_ind
if dimsize == 1:
singleton_axes.append(forward_ind)
else:
try:
dimnames.append(dims[(dimind, dimsize)])
except KeyError as error:
raise ValueError(
f"There is no dimension index of {forward_ind - start_ind} "
f"with size {dimsize} in this model. Error triggered "
f"by array {array_ind}. Options are: {dims}"
) from error
# Append "n" to the dimension names if we are including the sample dimension
if include_sample_dim:
dimnames.append("n")
# Squeeze the array
processed.append(
(tuple(dimnames[::-1]), np.squeeze(array, axis=tuple(singleton_axes)))
)
return processed
def _dict_to_xarray(
self,
draws: dict[abstract_model_component.AbstractModelComponent, npt.NDArray],
) -> xr.Dataset:
"""Convert model component draws dictionary to structured xarray Dataset.
This method transforms a dictionary of model component draws into a
properly structured xarray Dataset with appropriate coordinates
and dimension names for analysis and visualization.
:param draws: Dictionary mapping components to their sampled values
:type draws: dict[abstract_model_component.AbstractModelComponent, npt.NDArray]
:returns: Structured dataset with draws and coordinates
:rtype: xr.Dataset
The resulting dataset includes:
- Data variables for all non-constant components
- Coordinates for multi-dimensional constants
- Proper dimension naming and alignment
"""
# Split into components and draws and components and values
model_comps, unpacked_draws = zip(
*[
[comp, draw]
for comp, draw in draws.items()
if not isinstance(comp, constants_module.Constant)
]
)
coordinates = list(
zip(
*[
[parent, parent.value]
for component in self.all_model_components
for parent in component.parents
if isinstance(parent, constants_module.Constant)
and np.prod(parent.shape) > 1
]
)
)
if len(coordinates) == 0:
parents, values = [], []
else:
parents, values = coordinates
# Process the draws and values for xarray. Note that because constants
# have no sample prefix, we do not add the sampling dimension to them
# when calling `compress_for_xarray` (i.e., we do not use `_n`).
compressed_draws = self._compress_for_xarray(
*unpacked_draws, include_sample_dim=True
)
compressed_values = self._compress_for_xarray(*values)
# Build kwargs
return xr.Dataset(
data_vars={
component.model_varname: compressed_draw
for component, compressed_draw in zip(model_comps, compressed_draws)
},
coords={
parent.model_varname: compressed_value
for parent, compressed_value in zip(parents, compressed_values)
},
)
@overload
def draw(
self,
n: "custom_types.Integer",
*,
named_only: Literal[True],
as_xarray: Literal[False],
seed: Optional["custom_types.Integer"],
) -> dict[str, npt.NDArray]: ...
@overload
def draw(
self,
n: "custom_types.Integer",
*,
named_only: Literal[False],
as_xarray: Literal[False],
seed: Optional["custom_types.Integer"],
) -> dict[abstract_model_component.AbstractModelComponent, npt.NDArray]: ...
@overload
def draw(
self,
n: "custom_types.Integer",
*,
named_only: Literal[True],
as_xarray: Literal[True],
seed: Optional["custom_types.Integer"],
) -> xr.Dataset: ...
@overload
def draw(
self,
n: "custom_types.Integer",
*,
named_only: Literal[False],
as_xarray: Literal[True],
seed: Optional["custom_types.Integer"],
) -> xr.Dataset: ...
[docs]
def draw(self, n, *, named_only=True, as_xarray=False, seed=None):
"""Draw samples from the model's prior distribution.
This method generates samples from all elements of the
model by traversing the dependency graph and sampling from each
component.
:param n: Number of samples to draw from each component.
:type n: custom_types.Integer
:param named_only: Whether to return only named components. Defaults to True.
:type named_only: bool
:param as_xarray: Whether to return results as xarray Dataset. Defaults to False.
:type as_xarray: bool
:param seed: Random seed for reproducible sampling. Defaults to None.
:type seed: Optional[custom_types.Integer]
:returns: Sampled values in requested format
:rtype: Union[dict[str, npt.NDArray], dict[AbstractModelComponent, npt.NDArray], xr.Dataset]
Example:
>>> # Draw 1000 samples as dictionary
>>> samples = model.draw(1000)
>>> # Draw 1000 samples as an xarray Dataset
>>> dataset = model.draw(1000, as_xarray=True)
"""
# Draw from all observables
draws: dict[abstract_model_component.AbstractModelComponent, npt.NDArray] = {}
for observable in self.observables:
_, draws = observable.draw(n, _drawn=draws, seed=seed)
# Filter down to just named parameters if requested
if named_only:
draws = {k: v for k, v in draws.items() if k.is_named}
# Convert to an xarray dataset if requested
if as_xarray:
return self._dict_to_xarray(draws)
# If we are returning only named parameters, then we need to update the
# dictionary keys to be the model variable names.
if named_only:
return {k.model_varname: v for k, v in draws.items()}
return draws
[docs]
def to_pytorch(
self, seed: Optional["custom_types.Integer"] = None
) -> "nn_module.PyTorchModel":
"""Compile the model to a trainable PyTorch module.
This method converts the SciStanPy model into a PyTorch module that
can be optimized using standard PyTorch training procedures for
maximum likelihood estimation or variational inference. The inputs to this
module (i.e., keyword arguments to its ``forward`` method) are all observed
data; the output is the likelihood of that data given the current model
parameters.
:param seed: Random seed for reproducible compilation. Defaults to None.
:type seed: Optional[custom_types.Integer]
:returns: Compiled PyTorch model ready for training
:rtype: nn_module.PyTorchModel
The compiled model preserves the probabilistic structure while
enabling gradient-based optimization of model parameters. It's
particularly useful for maximum likelihood estimation and can
leverage GPU acceleration for large models.
"""
return nn_module.PyTorchModel(self, seed=seed)
[docs]
def to_stan(self, **kwargs) -> "stan_model.StanModel":
"""Compile the model to Stan code for MCMC sampling.
This method automatically generates Stan probabilistic programming
language code from the SciStanPy model specification and compiles
it for Hamilitonian Monte-Carlo sampling.
:param kwargs: Additional compilation options passed to StanModel
:returns: Compiled Stan model ready for MCMC sampling
:rtype: stan_model.StanModel
"""
return stan_model.StanModel(self, **kwargs)
[docs]
def mle(
self,
epochs: "custom_types.Integer" = DEFAULT_N_EPOCHS,
early_stop: "custom_types.Integer" = DEFAULT_EARLY_STOP,
lr: "custom_types.Float" = DEFAULT_LR,
data: Optional[dict[str, Union[torch.Tensor, npt.NDArray]]] = None,
device: "custom_types.Integer" | str = "cpu",
seed: Optional["custom_types.Integer"] = None,
mixed_precision: bool = False,
) -> "mle_module.MLE":
"""Compute maximum likelihood estimates of model parameters.
This method fits a PyTorch model to observed data by minimizing
the negative log-likelihood, providing point estimates of all
model parameters along with optimization diagnostics.
:param epochs: Maximum number of training epochs. Note that one step is one
epoch as the model must be evaluated over all observable data to calculate
loss. Defaults to 100000.
:type epochs: custom_types.Integer
:param early_stop: Epochs without improvement before stopping. Defaults to 10.
:type early_stop: custom_types.Integer
:param lr: Learning rate for optimization. Defaults to 0.001.
:type lr: custom_types.Float
:param data: Observed data for observables. Uses default_data provided at
initialization if not provided.
:type data: Optional[dict[str, Union[torch.Tensor, npt.NDArray]]]
:param device: Computation device ('cpu', 'cuda', or device index). Defaults to 'cpu'.
:type device: Union[custom_types.Integer, str]
:param seed: Random seed for reproducible optimization. Defaults to None.
:type seed: Optional[custom_types.Integer]
:param mixed_precision: Whether to use mixed precision training. Defaults to False.
:type mixed_precision: bool
:returns: MLE results with parameter estimates and diagnostics
:rtype: mle_module.MLE
The optimization process:
- Converts model to PyTorch and moves to specified device
- Trains for `epochs` number of epochs or until there has been no improvement
for `early_stop` number epochs.
- Tracks loss trajectory for convergence assessment
- Returns parameter estimates and fitted distributions
Example:
>>> # Basic MLE with default settings
>>> mle_result = model.mle(data=observed_data)
>>> # GPU-accelerated with custom settings
>>> mle_result = model.mle(data=obs, device='cuda', epochs=50000, lr=0.01)
.. note::
The ``mle`` method is much cheaper to run than ``mcmc`` and can be a
useful first step for model validation and debugging. Also note that,
via the :py:class:`scistanpy.model.results.mle.MLE` object returned,
you can bootstrap observed data to obtain uncertainty estimates about
the observed data.
"""
# Set the default value for observed data
data = self.default_data if data is None else data
# Observed data to tensors and the appropriate device
data = {
k: (
v.to(device=device)
if isinstance(v, torch.Tensor)
else torch.tensor(v).to(device=device)
)
for k, v in data.items()
}
# Check observed data
nn_module.check_observable_data(self, data)
# Fit the model
pytorch_model = self.to_pytorch(seed=seed).to(device=device)
loss_trajectory = pytorch_model.fit(
epochs=epochs,
early_stop=early_stop,
lr=lr,
data=data,
mixed_precision=mixed_precision,
)
# Get the MLE estimate for all model parameters
max_likelihood = {
k: v.detach().cpu().numpy()
for k, v in pytorch_model.export_params().items()
}
# Get the distributions of the parameters
distributions = pytorch_model.export_distributions()
# Return the MLE estimate, the distributions, and the loss trajectory
return mle_module.MLE(
model=self,
mle_estimate=max_likelihood,
distributions=distributions,
losses=loss_trajectory.detach().cpu().numpy(),
data={k: v.detach().cpu().numpy() for k, v in data.items()},
)
def _get_simulation_data(
self, seed: Optional["custom_types.Integer"]
) -> dict[str, npt.NDArray]:
"""Generate simulated observable data from model prior.
This internal method draws a single realization from each observable
parameter in the model using the current prior specification. It is
used by simulation methods to generate synthetic datasets.
:param seed: Random seed for reproducible simulation
:type seed: Optional[custom_types.Integer]
:returns: Dictionary mapping observable names to simulated values
:rtype: dict[str, npt.NDArray]
"""
data = self.draw(1, named_only=True, as_xarray=False, seed=seed)
return {
observable.model_varname: data[observable.model_varname][0]
for observable in self.observables
}
[docs]
def simulate_mle(self, **kwargs) -> tuple[dict[str, npt.NDArray], "mle_module.MLE"]:
"""Simulate data from model prior and fit via maximum likelihood.
This method performs a complete simulation study by first generating
synthetic data from the model's prior distribution, then fitting
the model to this simulated data using maximum likelihood estimation.
:param kwargs: Keyword arguments passed to mle() method (except 'data')
:returns: Tuple of (simulated_data, mle_results)
:rtype: tuple[dict[str, npt.NDArray], mle_module.MLE]
This is particularly useful for:
- Model validation and debugging
- Assessing parameter identifiability (e.g. by running multiple simulations)
- Verifying implementation correctness
The simulated data is automatically passed to the MLE fitting
procedure, overriding any data specification in kwargs.
Example:
>>> # Simulate and fit with custom settings
>>> sim_data, mle_fit = model.simulate_mle(epochs=10000, lr=0.01)
"""
# TODO: This and the other simulate method should include non-observables
# in the returned MLE as well.
# Get the data
kwargs["data"] = self._get_simulation_data(seed=kwargs.get("seed"))
# Fit the model
return kwargs["data"], self.mle(**kwargs)
@overload
def mcmc(
self,
*,
output_dir: Optional[str],
force_compile: bool,
stanc_options: Optional[dict[str, Any]],
cpp_options: Optional[dict[str, Any]],
user_header: Optional[str],
model_name: Optional[str],
inits: Optional[str],
data: Optional[dict[str, npt.NDArray]],
delay_run: Literal[False],
**sample_kwargs,
) -> "hmc_results.SampleResults": ...
@overload
def mcmc(
self,
*,
output_dir: Optional[str],
force_compile: bool,
stanc_options: Optional[dict[str, Any]],
cpp_options: Optional[dict[str, Any]],
model_name: Optional[str],
user_header: Optional[str],
inits: Optional[str],
data: Optional[dict[str, npt.NDArray]],
delay_run: Literal[True] | str,
**sample_kwargs,
) -> None: ...
[docs]
def mcmc(
self,
*,
output_dir=None,
force_compile=DEFAULT_FORCE_COMPILE,
stanc_options=None,
cpp_options=None,
user_header=DEFAULT_USER_HEADER,
model_name=DEFAULT_MODEL_NAME,
inits=None,
data=None,
delay_run=False,
**sample_kwargs,
):
"""Perform Hamiltonia Monte Carlo sampling using Stan backend.
This method compiles the model to Stan and executes Hamiltonian
Monte Carlo sampling to generate posterior samples. It supports
both immediate execution and delayed runs for batch processing.
:param output_dir: Directory for compilation and output files. Defaults to None,
in which case all raw outputs will be saved to a temporary directory and
be accessible only for the lifetime of this object.
:type output_dir: Optional[str]
:param force_compile: Whether to force recompilation of Stan model. Defaults
to False.
:type force_compile: bool
:param stanc_options: Options for Stan compiler. Defaults to None (uses
DEFAULT_STANC_OPTIONS).
:type stanc_options: Optional[dict[str, Any]]
:param cpp_options: Options for C++ compilation. Defaults to None (uses
DEFAULT_CPP_OPTIONS).
:type cpp_options: Optional[dict[str, Any]]
:param user_header: Custom C++ header code. Defaults to None.
:type user_header: Optional[str]
:param model_name: Name for compiled model. Defaults to 'model'.
:type model_name: Optional[str]
:param inits: Initialization strategy. See `stan_model.StanModel`
for options. Defaults to None.
:type inits: Optional[str]
:param data: Observed data for observables. Uses default_data defined during
initialization if not provided.
:type data: Optional[dict[str, npt.NDArray]]
:param delay_run: Whether to delay execution. If `True`, a pickle file that
can be used for delayed execution will be saved to `output_dir`. A string
can also be provided to save the pickle file to an alternate location.
Defaults to False (meaning immediate execution).
:type delay_run: Union[bool, str]
:param sample_kwargs: Additional arguments passed to Stan sampling. See
the `cmdstanpy.CmdStanModel.sample` for options.
:returns: MCMC results if delay_run=False, None if delayed
:rtype: Union[hmc_results.SampleResults, None]
:raises ValueError: If delay_run is True but output_dir is None
.. note::
When delay_run=True, the method saves sampling configuration to a
pickle file instead of executing immediately. This enables batch
processing and distributed computing workflows.
Example:
>>> # Immediate MCMC sampling
>>> results = model.mcmc(chains=4, iter_sampling=2000)
>>> # Delayed execution for batch processing
>>> model.mcmc(delay_run='batch_job.pkl', chains=8, iter_sampling=5000)
"""
# Get the default observed data and cpp options
data = self.default_data if data is None else data
stanc_options = stanc_options or DEFAULT_STANC_OPTIONS
cpp_options = cpp_options or DEFAULT_CPP_OPTIONS
# An output directory must be provided if we are delaying the run
if delay_run and output_dir is None:
raise ValueError(
"An output directory must be provided if `delay_run` is True."
)
# Build the output directory if not provided
if output_dir is None:
tempdir = TemporaryDirectory()
weakref.finalize(self, tempdir.cleanup)
output_dir = tempdir.name
# Full path of output directory
output_dir = os.path.abspath(output_dir)
sample_kwargs["output_dir"] = output_dir
# Build the Stan model
model = self.to_stan(
output_dir=output_dir,
force_compile=force_compile,
stanc_options=stanc_options,
cpp_options=cpp_options,
user_header=user_header,
model_name=model_name,
)
# If delaying, then we save the data needed for sampling and return
if delay_run:
with open(
(
delay_run
if isinstance(delay_run, str)
else f"{model.stan_executable_path}-delay.pkl"
),
"wb",
) as f:
pickle.dump(
{
"stan_model": model,
"sample_kwargs": sample_kwargs,
"data": data,
"inits": inits,
},
f,
)
return
# Sample from the model
return model.sample(inits=inits, data=data, **sample_kwargs)
@overload
def simulate_mcmc(
self, delay_run: Literal[False], **kwargs
) -> tuple[dict[str, npt.NDArray], "hmc_results.SampleResults"]: ...
@overload
def simulate_mcmc(self, delay_run: Literal[True], **kwargs) -> None: ...
[docs]
def simulate_mcmc(self, delay_run=False, **kwargs):
"""Simulate data from model prior and perform Hamiltonian Monte Carlo sampling.
This method generates synthetic data from the model's prior
distribution and then performs full Bayesian inference via MCMC.
It's extremely helpful for model validation and posterior recovery testing.
:param delay_run: Whether to delay MCMC execution. Defaults to False.
:type delay_run: bool
:param kwargs: Additional keyword arguments passed to
:py:meth:`~scistanpy.model.model.Model.mcmc` method.
:returns: Tuple of (simulated_data, mcmc_results) if delay_run=False,
None if delay_run=True
:rtype: Union[tuple[dict[str, npt.NDArray], hmc_results.SampleResults], None]
The method automatically updates the model name to indicate
simulation when using the default name, helping distinguish
simulated from real data analyses.
This is crucial for:
- Validating MCMC implementation correctness
- Testing posterior recovery in known-truth scenarios
- Assessing sampler efficiency and convergence
- Debugging model specification issues
Example:
>>> # Simulate and sample with immediate execution
>>> sim_data, mcmc_results = model.simulate_mcmc(chains=4, iter_sampling=1000)
"""
# Update the model name
if kwargs.get("model_name") == DEFAULT_MODEL_NAME:
kwargs["model_name"] = f"{DEFAULT_MODEL_NAME}-simulated"
# Get the data
kwargs["data"] = self._get_simulation_data(seed=kwargs.get("seed"))
# Run MCMC
return kwargs["data"], self.mcmc(delay_run=delay_run, **kwargs)
[docs]
def prior_predictive(self, *, copy_model: bool = False) -> pn.Row:
"""Create interactive prior predictive check visualization.
This method generates an interactive dashboard for exploring the
model's prior predictive distribution. Users can adjust model
hyperparameters via sliders and immediately see how changes
affect prior predictions.
:param copy_model: Whether to copy model to avoid modifying original. Defaults
to False, meaning the calling model is updated in place by changing
slider values and clicking "update model".
:type copy_model: bool
:returns: Panel dashboard with interactive prior predictive visualization
:rtype: pn.Row
The dashboard includes:
- Sliders for all adjustable model hyperparameters
- Multiple visualization modes (ECDF, KDE, violin, relationship plots)
- Real-time updates as parameters are modified
- Options for different grouping and display configurations
This is useful for:
- Prior specification and calibration
- Understanding model behavior before data fitting
- Identifying unrealistic prior assumptions
Example:
>>> # Create interactive dashboard
>>> dashboard = model.prior_predictive()
>>> dashboard.servable() # For web deployment
>>> # Or display in Jupyter notebook
>>> dashboard
"""
# Create the prior predictive object
pp = prior_predictive_module.PriorPredictiveCheck(self, copy_model=copy_model)
# Return the plot
return pp.display()
def __str__(self) -> str:
"""Return comprehensive string representation of the model.
:returns: Formatted string showing all model components organized by type
:rtype: str
The representation includes organized sections for:
- Constants and hyperparameters
- Transformed parameters
- Regular parameters
- Observable parameters
Each section lists components with their specifications and
current values, providing a complete overview of model structure.
"""
# Get all model components
model_comps = {
"Constants": [
el
for el in self.all_model_components
if isinstance(el, constants_module.Constant)
],
"Transformed Parameters": self.transformed_parameters,
"Parameters": self.parameters,
"Observables": self.observables,
}
# Combine representations from all model components
return "\n\n".join(
key + "\n" + "=" * len(key) + "\n" + "\n".join(str(el) for el in complist)
for key, complist in model_comps.items()
if len(complist) > 0
)
def __contains__(self, paramname: str) -> bool:
"""Check if model contains a component with the given name.
:param paramname: Name of the model component to check
:type paramname: str
:returns: True if component exists, False otherwise
:rtype: bool
Example:
>>> 'mu' in model # Check if parameter 'mu' exists
True
"""
return paramname in self._model_varname_to_object
def __getitem__(
self, paramname: str
) -> abstract_model_component.AbstractModelComponent:
"""Retrieve model component by name.
:param paramname: Name of the model component to retrieve
:type paramname: str
:returns: The requested model component
:rtype: abstract_model_component.AbstractModelComponent
:raises KeyError: If component name doesn't exist
Example:
>>> mu_param = model['mu'] # Get parameter named 'mu'
>>> print(mu_param.distribution)
"""
return self._model_varname_to_object[paramname]
def __setattr__(self, name: str, value: Any) -> None:
"""Set model attribute with protection for model components.
:param name: Attribute name to set
:type name: str
:param value: Value to assign to the attribute
:type value: Any
:raises AttributeError: If attempting to modify existing model component
or add a new model component after initialization.
This method prevents modification of model components after
initialization to maintain model integrity and prevent
accidental corruption of the dependency graph.
"""
# We cannot set attributes that are model components
if (
hasattr(self, "_model_varname_to_object")
and name in self._model_varname_to_object
):
raise AttributeError(
"Model components can only be set during initialization."
)
# Otherwise, set the attribute
super().__setattr__(name, value)
@property
def default_data(self) -> dict[str, npt.NDArray] | None:
"""Get the default observed data for model observables.
:returns: Dictionary mapping observable names to their default data
:rtype: dict[str, npt.NDArray] | None
:raises ValueError: If default data has not been set
Default data is used automatically by methods like mle() and mcmc()
when no explicit data is provided, streamlining common workflows.
"""
if getattr(self, "_default_data", None) is None:
raise ValueError(
"Default data is not set. Please set the default data using "
"`model.default_data = data`."
)
return self._default_data
@default_data.setter
def default_data(self, data: dict[str, npt.NDArray] | None) -> None:
"""Set default observed data for model observables.
:param data: Dictionary mapping observable names to their data, or None to clear
:type data: dict[str, npt.NDArray] | None
:raises ValueError: If data is missing required observable keys or contains extra keys
The data dictionary must contain entries for all observable parameters
in the model. Setting to None clears the default data.
"""
# Reset the default data if `None` is passed
if data is None:
self._default_data = None
return
# If initialization is complete, the data must be a dictionary and we must
# have all the appropriate keys. We skip this check if the model is not
# initialized yet, as we do not know the model variable names yet.
if self._init_complete:
expected_keys = {comp.model_varname for comp in self.observables}
if missing_keys := expected_keys - data.keys():
raise ValueError(
f"The following keys are missing from the default data: {missing_keys}"
)
if extra_keys := data.keys() - expected_keys:
raise ValueError(
f"The following keys are not expected in the data: {extra_keys}"
)
# Set the default data
self._default_data = data
@property
def has_default_data(self) -> bool:
"""Check whether the model has default data configured.
:returns: True if default data is set, False otherwise
:rtype: bool
This property is useful for conditional logic that depends on
whether default data is available for automatic use in methods.
"""
return getattr(self, "_default_data", None) is not None
@property
def named_model_components(
self,
) -> tuple[abstract_model_component.AbstractModelComponent, ...]:
"""Get all named model components.
:returns: Tuple of named components
:rtype: tuple[abstract_model_component.AbstractModelComponent, ...]
Named components are those explicitly assigned as instance attributes
during model construction, as opposed to intermediate components
created automatically during dependency resolution.
"""
return self._named_model_components
@property
def named_model_components_dict(
self,
) -> dict[str, abstract_model_component.AbstractModelComponent]:
"""Get named model components as a dictionary.
:returns: Dictionary mapping variable names to named components
:rtype: dict[str, abstract_model_component.AbstractModelComponent]
This provides convenient access to named components by their
string names for programmatic model inspection and manipulation.
"""
return model_comps_to_dict(self.named_model_components)
@property
def all_model_components(
self,
) -> tuple[abstract_model_component.AbstractModelComponent, ...]:
"""Get all model components including unnamed intermediate components.
:returns: Tuple of all components sorted by variable name
:rtype: tuple[abstract_model_component.AbstractModelComponent, ...]
This includes both explicitly named components and any intermediate
components created during dependency resolution, providing complete
visibility into the model's computational graph.
"""
return tuple(
sorted(
self._model_varname_to_object.values(), key=lambda x: x.model_varname
)
)
@property
def all_model_components_dict(
self,
) -> dict[str, abstract_model_component.AbstractModelComponent]:
"""Get all model components as a dictionary.
:returns: Dictionary mapping variable names to all components
:rtype: dict[str, abstract_model_component.AbstractModelComponent]
This comprehensive mapping includes both named and intermediate
components, enabling full programmatic access to the model structure.
"""
return self._model_varname_to_object
@property
def parameters(self) -> tuple[parameters_module.Parameter, ...]:
"""Get all non-observable parameters in the model.
:returns: Tuple of parameter components that are not observables
:rtype: tuple[parameters_module.Parameter, ...]
These are the latent variables and hyperparameters that will be
inferred during MCMC sampling or optimized during MLE fitting.
"""
return tuple(
filter(
lambda x: isinstance(x, parameters_module.Parameter)
and not x.observable,
self.all_model_components,
)
)
@property
def parameter_dict(self) -> dict[str, parameters_module.Parameter]:
"""Get non-observable parameters as a dictionary.
:returns: Dictionary mapping names to non-observable parameters
:rtype: dict[str, parameters_module.Parameter]
Provides convenient named access to the model's latent parameters
for inspection and programmatic manipulation.
"""
return model_comps_to_dict(self.parameters)
@property
def hyperparameters(self) -> tuple[parameters_module.Parameter, ...]:
"""Get hyperparameters (parameters with only constant parents).
:returns: Tuple of parameters that depend only on constants
:rtype: tuple[parameters_module.Parameter, ...]
Hyperparameters are the highest-level parameters in the model
hierarchy, typically representing prior distribution parameters
that are not derived from other random variables.
"""
return tuple(filter(lambda x: x.is_hyperparameter, self.parameters))
@property
def hyperparameter_dict(self) -> dict[str, parameters_module.Parameter]:
"""Get hyperparameters as a dictionary.
:returns: Dictionary mapping names to hyperparameters
:rtype: dict[str, parameters_module.Parameter]
Provides convenient access to the model's hyperparameters by name
for prior specification and sensitivity analysis.
"""
return model_comps_to_dict(self.hyperparameters)
@property
def transformed_parameters(
self,
) -> tuple[transformed_parameters_module.TransformedParameter, ...]:
"""Get all named transformed parameters in the model.
:returns: Tuple of transformed parameter components
:rtype: tuple[transformed_parameters_module.TransformedParameter, ...]
Transformed parameters are deterministic functions of other model
components, representing computed quantities like sums, products,
or other mathematical transformations.
"""
return tuple(
filter(
lambda x: isinstance(
x, transformed_parameters_module.TransformedParameter
),
self.named_model_components,
)
)
@property
def transformed_parameter_dict(
self,
) -> dict[str, transformed_parameters_module.TransformedParameter]:
"""Get named transformed parameters as a dictionary.
:returns: Dictionary mapping names to transformed parameters
:rtype: dict[str, transformed_parameters_module.TransformedParameter]
Enables convenient access to transformed parameters for model
inspection and derived quantity analysis.
"""
return model_comps_to_dict(self.transformed_parameters)
@property
def constants(self) -> tuple[constants_module.Constant, ...]:
"""Get all named constants in the model.
:returns: Tuple of constant components
:rtype: tuple[constants_module.Constant, ...]
Constants represent fixed values and hyperparameter specifications
that do not change during inference or optimization procedures.
"""
return tuple(
filter(
lambda x: isinstance(x, constants_module.Constant),
self.named_model_components,
)
)
@property
def constant_dict(self) -> dict[str, constants_module.Constant]:
"""Get named constants as a dictionary.
:returns: Dictionary mapping names to constant components
:rtype: dict[str, constants_module.Constant]
Provides convenient access to model constants for hyperparameter
inspection and sensitivity analysis workflows.
"""
return model_comps_to_dict(self.constants)
@property
def observables(self) -> tuple[parameters_module.Parameter, ...]:
"""Get all observable parameters in the model (observables are always named).
:returns: Tuple of parameters marked as observable
:rtype: tuple[parameters_module.Parameter, ...]
Observable parameters represent the data-generating components
of the model - the variables for which observed data will be
provided during inference procedures.
"""
return tuple(
filter(
lambda x: isinstance(x, parameters_module.Parameter) and x.observable,
self.named_model_components,
)
)
@property
def observable_dict(self) -> dict[str, parameters_module.Parameter]:
"""Get observable parameters as a dictionary.
:returns: Dictionary mapping names to observable parameters
:rtype: dict[str, parameters_module.Parameter]
Enables convenient access to observable parameters for data
specification and model validation workflows.
"""
return model_comps_to_dict(self.observables)