SciStanPy: Intuitive Bayesian Modeling for Scientists¶
Welcome to SciStanPy, a Python framework that makes Bayesian statistical modeling accessible to scientists and researchers. Whether you’re analyzing experimental data, building predictive models, or exploring uncertainty in your research, SciStanPy provides an intuitive interface that bridges the gap between scientific thinking and advanced statistical computation. Express your scientific models naturally in Python while SciStanPy automatically handles the computational complexity of model fitting and sampling via PyTorch and Stan.
Why SciStanPy?¶
Bayesian modeling often requires deep statistical programming knowledge. SciStanPy changes this by letting you express your scientific models naturally in Python, while automatically handling the complex computational details behind the scenes.
- 🔬 Scientific Focus
Designed specifically for scientific applications and research workflows
- 🐍 Intuitive Python Interface
Express hypotheses using familiar Python syntax and scientific thinking
- 🎯 Uncertainty Quantification
Natural handling of measurement uncertainty and parameter estimation
- 📊 Rich Diagnostics
Comprehensive model checking and convergence diagnostics
- ⚡ Multi-Backend Performance
Automatic compilation to Stan code and PyTorch modules
- 🛡️ Built-in Validation
Comprehensive error checking catches modeling mistakes early
Key Concepts¶
Parameters vs Observations: In SciStanPy, parameters represent unknown quantities you want to learn about (“infer”), while observations are your measured data.
Distributions: These encode your knowledge and uncertainty. Use them to express:
Prior knowledge about parameters
Expected relationships between variables
Measurement uncertainty and noise
In SciStanPy, parameters and observations are defined using distributions.
Model Building: Parameters can depend on one another and/or be transformed using standard mathematical operations. Combine and transform parameters to build a SciStanPy model.
Inference: Once your model is built, SciStanPy handles compilation to Stan code or PyTorch modules, allowing for inference of parameter values.
Note
SciStanPy is designed for scientists with basic Python knowledge who want to incorporate uncertainty quantification and Bayesian analysis into their research without needing deep statistical programming expertise. For an introduction to Bayesian modeling, the BE/Bi 103b class developed and taught by Justin Bois at Caltech is an excellent resource.
Quick Start¶
This section assumes that you have followed the installation instructions found on GitHub. If not, return to the GitHub repository for detailed setup instructions.
Comprehensive documentation for the SciStanPy API can be found here. For starting quickly, however, the most important modules and objects are model, parameters, and operations. Less frequently used, but also important, is the constants module. Follow the links for detailed API references.
The Model
object forms the backbone of all SciStanPy models, distributions defined in parameters define those models’ variables, operations define transformations of those variables, and constants provide fixed values used throughout the model. If you work with PyTorch, defining models in SciStanPy will feel quite familiar:
import scistanpy as ssp
import numpy as np
# Define/import experimental data
temperatures = np.array([20, 25, 30, 35, 40]) # °C
reaction_rates = np.array([0.1, 0.3, 0.8, 2.1, 5.2]) # units/sec
# Define the model. Here we're modeling the effect of temperature on reaction
# rates
class MyModel(ssp.Model):
def __init__(self, temperatures, reaction_rates):
# Record default data
super().__init__(default_data = {"reaction_rates": reaction_rates})
# Define priors
self.baseline_rate = ssp.parameters.LogNormal(mu=0.0, sigma=1.0) # Rates are always positive
self.temperature_effect = ssp.parameters.Normal(mu=0.0, sigma=0.5) # Effect of temperature
# Model the relationship (Arrhenius-like behavior with noise)
self.reaction_rates = ssp.parameters.Normal(
mu = self.baseline_rate * ssp.operations.exp(self.temperature_effect * temperatures),
sigma = 0.1
)
# Build the model
model = MyModel(temperatures, reaction_rates)
Once a model is defined, there are multiple ways to use it. First, you may wish to explore the effects of different hyperparameter values. This can be done through an interactive interface as below:
# If you are in a Jupyter notebook, you can create an interactive dashboard to
# explore prior predictive distributions. Just run:
model.prior_predictive()
Additional information on the dashboard can be found here.
You can also directly draw from the prior distribution with the below:
# Draw 100 samples from the prior distribution
prior_samples = model.draw(100)
The returned dictionary will contain draws from all model parameters, each conditioned on one another as appropriate. See draw()
for more details.
SciStanPy models are also fully compatible with PyTorch, and can be compiled to a PyTorch Module using the to_pytorch()
method as below:
# Convert to PyTorch module
torch_model = model.to_pytorch()
The returned PyTorch Module will have a forward method defined that takes observed data as input (using keyword arguments named according to the observables defined in the SciStanPy model – reaction_rates
in the example above) and returns the log-likelihood of the data given the model. Parameter values for the returned model will initially be random, but can be optimized using standard PyTorch techniques. Compiling to a PyTorch module like this also allows SciStanPy models to be dropped into other frameworks that rely on or extend PyTorch.
As a convenience method, SciStanPy models can be converted to PyTorch Modules and optimized in a single step using the mle()
method:
# Perform maximum likelihood estimation (MLE) using PyTorch backend
mle_result = model.mle(device='cuda')
Note that if default data was provided during model initialization (as above), then no data need be provided to this method call–the registered defaults will be used automatically. Also note that, because this is a PyTorch-based method, GPU accelerattion can be used, which is a particularly useful feature for larger models.
The MLE
object returned by the mle()
method contains the optimized parameter values (accessible as instance variables or via the model_to_varname
dictionary) and utility methods for evaluating the fit. Extensive details can be found in the MLE documentation
, but one to highlight is the get_inference_obj()
method, which bootstraps samples from the optimized model, providing a cheap (relative to full MCMC sampling) alternative for uncertainty quantification. It can also be particularly helpful during the early stages of model development for assessing model validity before committing to MCMC sampling, especially for large models. Indeed, evaluating model fit with the returned MLEInferenceRes
instance is as straightforward as below:
# Get inference object from MLE result
mle_inference = mle_result.get_inference_obj(n_samples=1000)
# Evaluate model fit with posterior predictive checks
mle_inference.run_ppc()
Additional methods exposed by the MLEInferenceRes
class can be found in its associated documentation. Note that the MLEInferenceRes
class also exposes the inference_obj
instance variable, which is an arviz.InferenceData object containing the samples drawn from the model. This can be used to interface directly with the ArviZ ecosystem for analysis of Bayesian models.
The most notable feature of SciStanPy–and, indeed, the inspiration for its name–is its ability to automatically write and execute Stan programs. As with SciStanPy’s PyTorch integration, Stan functionality can be accessed at both the two levels of granularity. At the lower level, to access an object representing a Stan model, run the following:
# Convert to a StanModel
stan_model = model.to_stan()
The returned StanModel
object is an extension of cmdstanpy’s CmdStanModel class. When instantiated, it will automatically write and compile a Stan program to sample from the model defined in SciStanPy. The resulting StanModel instance will use this Stan program for subsequent operations. Currently, the sample()
method has full support, while other CmdStanModel methods have experimental support. See the StanModel
documentation for more details.
To write, compile, and sample from a SciStanPy model directly, run the below:
# Run Hamiltonian Monte Carlo sampling
mcmc_result = model.mcmc()
As with the PyTorch model, if the default_data
argument was provided to the parent class on initialization, that data will be used for sampling. Otherwise, it should be provided as the data
kwarg to mcmc()
.
The SampleResults
instance that is returned by the mcmc()
method contains the samples drawn during sampling. The SampleResults
class is an extension of the MLEInferenceRes
class introduced above. It shares the same methods and properties, plus some others. Most notable among the additional methods are diagnose()
, which runs diagnostic checks on MCMC samples (see Stan’s documentation for a review of the different diagnostics), plot_sample_failure_quantile_traces()
, which provides an interactive dashboard for visualizing samples that failed diagnostic checks, and plot_variable_failure_quantile_traces()
, which provides an interactive dashboard for visualizing variables (i.e., parameters) that failed diagnostic checks.
For additional examples for specific use cases, check the Examples section of the documentation.
API Documentation¶
Essential Components¶
The below table links to the documentaiton for the most commonly used components of the SciStanPy API:
Component |
Description |
---|---|
Parent class for all SciStanPy models, handling data management, model building, and inference. |
|
Describes constant values used in SciStanPy models. |
|
Contains objects that describe latent and observed parameters in SciStanPy models as probability distributions. |
|
Contains operations for transforming parameters and building relationships between them. |
All Submodules¶
- SciStanPy API Reference
- Custom Types API Reference
- Defaults API Reference
- Exceptions API Reference
- Model SubPackage API Reference
- Submodules
- Model Components API Reference
- Model Components Submodule Overview
- Abstract Model Component API Reference
- AbstractModelComponent
AbstractModelComponent
AbstractModelComponent.BASE_STAN_DTYPE
AbstractModelComponent.FORCE_LOOP_RESET
AbstractModelComponent.FORCE_PARENT_NAME
AbstractModelComponent.IS_LOG_SIMPLEX
AbstractModelComponent.IS_SIMPLEX
AbstractModelComponent.LOG_SIMPLEX_PARAMS
AbstractModelComponent.LOWER_BOUND
AbstractModelComponent.NEGATIVE_PARAMS
AbstractModelComponent.POSITIVE_PARAMS
AbstractModelComponent.SIMPLEX_PARAMS
AbstractModelComponent.UPPER_BOUND
AbstractModelComponent.assign_depth
AbstractModelComponent.children
AbstractModelComponent.constants
AbstractModelComponent.declare_stan_variable()
AbstractModelComponent.draw()
AbstractModelComponent.force_name
AbstractModelComponent.get_assign_depth()
AbstractModelComponent.get_child_paramnames()
AbstractModelComponent.get_index_offset()
AbstractModelComponent.get_indexed_varname()
AbstractModelComponent.get_right_side()
AbstractModelComponent.get_shared_leading()
AbstractModelComponent.get_stan_dtype()
AbstractModelComponent.get_stan_parameter_declaration()
AbstractModelComponent.get_supporting_functions()
AbstractModelComponent.get_target_incrementation()
AbstractModelComponent.get_transformation_assignment()
AbstractModelComponent.is_named
AbstractModelComponent.model_varname
AbstractModelComponent.ndim
AbstractModelComponent.observable
AbstractModelComponent.parents
AbstractModelComponent.shape
AbstractModelComponent.stan_bounds
AbstractModelComponent.stan_model_varname
AbstractModelComponent.torch_parametrization
AbstractModelComponent.walk_tree()
- AbstractModelComponent
- Constants API Reference
- Custom Distributions API Reference
- Custom PyTorch Distributions API Reference
- Custom SciPy Distributions API Reference
- Custom-Built SciPy Distributions
- Transforms of Existing SciPy Distributions
TransformedScipyDist
TransformedScipyDist.cdf()
TransformedScipyDist.inverse_transform()
TransformedScipyDist.isf()
TransformedScipyDist.log_jacobian_correction()
TransformedScipyDist.logpdf()
TransformedScipyDist.logsf()
TransformedScipyDist.pdf()
TransformedScipyDist.ppf()
TransformedScipyDist.rvs()
TransformedScipyDist.sf()
TransformedScipyDist.transform()
LogUnivariateScipyTransform
- Distribution Instances
- Parameters API Reference
- Continuous Univariate
- Continuous Multivariate
- Discrete Univariate
- Discrete Multivariate
- Base Classes
ParameterMeta
Parameter
Parameter.CDF
Parameter.HAS_RAW_VARNAME
Parameter.LOG_CDF
Parameter.LOG_SF
Parameter.SCIPY_DIST
Parameter.SF
Parameter.STAN_DIST
Parameter.STAN_TO_SCIPY_NAMES
Parameter.STAN_TO_SCIPY_TRANSFORMS
Parameter.STAN_TO_TORCH_NAMES
Parameter.TORCH_DIST
Parameter.as_observable()
Parameter.ccdf()
Parameter.cdf()
Parameter.generated_varname
Parameter.get_generated_quantities()
Parameter.get_generated_quantity_declaration()
Parameter.get_raw_stan_parameter_declaration()
Parameter.get_right_side()
Parameter.get_rng()
Parameter.get_target_incrementation()
Parameter.get_torch_logprob()
Parameter.get_transformed_data_declaration()
Parameter.init_pytorch()
Parameter.is_hyperparameter
Parameter.log_ccdf()
Parameter.log_cdf()
Parameter.observable
Parameter.raw_varname
Parameter.torch_dist_instance
Parameter.torch_parametrization
Parameter.write_dist_args()
ContinuousDistribution
DiscreteDistribution
- Univariate Continuous Distributions
Normal
Normal.CDF
Normal.HAS_RAW_VARNAME
Normal.LOG_CDF
Normal.LOG_SF
Normal.POSITIVE_PARAMS
Normal.SCIPY_DIST
Normal.SF
Normal.STAN_DIST
Normal.STAN_TO_SCIPY_NAMES
Normal.STAN_TO_TORCH_NAMES
Normal.TORCH_DIST
Normal.get_right_side()
Normal.get_target_incrementation()
Normal.get_transformation_assignment()
Normal.is_noncentered
HalfNormal
UnitNormal
LogNormal
Beta
Gamma
InverseGamma
Exponential
ExpExponential
Lomax
ExpLomax
- Univariate Discrete Distributions
- Multivariate Continuous Distributions
- Multivariate Discrete Distributions
- Utilities
- Transformations API Reference
- CDFs API Reference
- Transformed Data API Reference
- Transformed Parameters API Reference
- Base Classes
TransformableParameter
TransformableParameter.__add__()
TransformableParameter.__mul__()
TransformableParameter.__neg__()
TransformableParameter.__pow__()
TransformableParameter.__radd__()
TransformableParameter.__rmul__()
TransformableParameter.__rpow__()
TransformableParameter.__rsub__()
TransformableParameter.__rtruediv__()
TransformableParameter.__sub__()
TransformableParameter.__truediv__()
Transformation
TransformedParameter
UnaryTransformedParameter
BinaryTransformedParameter
- Basic Arithmetic Operations
- Standard Mathematical Functions
- Normalization Transformations
- Reduction Operations
- Growth Model Transformations
- Special Functions
- Indexing and Array Operations
- Base Classes
- Abstract Model Component API Reference
- Key Design Principles
- Additional Usage Patterns
- Performance Considerations
- Other Notable Features
- Model Components Submodule Overview
- Model API Reference
- Core Model Class
Model
Model.all_model_components
Model.all_model_components_dict
Model.constant_dict
Model.constants
Model.default_data
Model.draw()
Model.get_dimname_map()
Model.has_default_data
Model.hyperparameter_dict
Model.hyperparameters
Model.mcmc()
Model.mle()
Model.named_model_components
Model.named_model_components_dict
Model.observable_dict
Model.observables
Model.parameter_dict
Model.parameters
Model.prior_predictive()
Model.simulate_mcmc()
Model.simulate_mle()
Model.to_pytorch()
Model.to_stan()
Model.transformed_parameter_dict
Model.transformed_parameters
- Utility Methods
- Core Model Class
- Neural Network Module API Reference
- Model Results API Reference
- Maximum Likelihood Estimation Results API Reference
- Hamiltonian Monte Carlo Results API Reference
- Sample Results Analysis
SampleResults
SampleResults.calculate_diagnostics()
SampleResults.calculate_summaries()
SampleResults.diagnose()
SampleResults.evaluate_sample_stats()
SampleResults.evaluate_variable_diagnostic_stats()
SampleResults.from_disk()
SampleResults.identify_failed_diagnostics()
SampleResults.plot_sample_failure_quantile_traces()
SampleResults.plot_variable_failure_quantile_traces()
- Variable Failure Analyzer
- CSV to NetCDF Conversion
- Utility Functions
- Sample Results Analysis
- Stan Submodule API Reference
- Model Components API Reference
- Submodules
- Plotting Subpackage API Reference
- Plotting API Reference
- Prior Predictive
PriorPredictiveCheck
PriorPredictiveCheck.display()
PriorPredictiveCheck.get_ecdf_kwargs()
PriorPredictiveCheck.get_kde_kwargs()
PriorPredictiveCheck.get_relationship_kwargs()
PriorPredictiveCheck.get_violin_kwargs()
PriorPredictiveCheck.set_group_dim_options()
PriorPredictiveCheck.set_independent_var_options()
PriorPredictiveCheck.set_plot_type_options()
- Operations API Reference
- Utils API Reference
- Examples
Copyright and License¶
Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License.
Contributing¶
This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.microsoft.com.
When you submit a pull request, a CLA-bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact opencode@microsoft.com with any additional questions or comments.
Trademarks¶
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft’s Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party’s policies.