# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""PyTorch integration utilities for SciStanPy models.
This module provides integration between SciStanPy probabilistic models
and PyTorch's automatic differentiation and optimization framework. It enables
maximum likelihood estimation, variational inference, and other gradient-based
learning procedures on SciStanPy models.
The module's core functionality centers around converting SciStanPy models into
PyTorch ``nn.Module`` instances that preserve the probabilistic structure while
enabling efficient gradient computation and optimization. This allows users to
leverage PyTorch's ecosystem of optimizers, learning rate schedulers, and other
training utilities.
Key Features:
- Automatic conversion of SciStanPy models to PyTorch modules
- Gradient-based parameter optimization with various optimizers
- Mixed precision training support for improved performance
- Early stopping and convergence monitoring
- GPU acceleration and device management
"""
import warnings
from typing import Optional, TYPE_CHECKING, Union
import itertools
import numpy.typing as npt
import torch
import torch.nn as nn
from tqdm import tqdm
from scistanpy.defaults import DEFAULT_EARLY_STOP, DEFAULT_LR, DEFAULT_N_EPOCHS
from scistanpy.model.components import constants, parameters
if TYPE_CHECKING:
from scistanpy import custom_types
from scistanpy import model as ssp_model
def check_observable_data(model: "ssp_model.Model", data: dict[str, torch.Tensor]):
"""Validate that provided data matches model observable specifications.
This function performs comprehensive validation to ensure that the observed
data dictionary contains exactly the expected observables with correct
shapes and types. It prevents common errors during model fitting by
catching data mismatches early.
:param model: SciStanPy model containing observable specifications
:type model: ssp_model.Model
:param data: Dictionary mapping observable names to their tensor data
:type data: dict[str, torch.Tensor]
:raises ValueError: If observable names don't match expected set
:raises ValueError: If data shapes don't match observable shapes
The validation checks:
- Perfect correspondence between provided and expected observable names
- Exact shape matching between data tensors and observable specifications
- Proper tensor formatting for PyTorch computation
Example:
>>> data = {'y': torch.randn(100), 'x': torch.randn(100, 5)}
>>> check_observable_data(model, data) # Validates or raises error
"""
# There must be perfect overlap between the keys of the provided data and the
# expected observations
expected_set = set(model.observable_dict.keys())
provided_set = set(data.keys())
missing = expected_set - provided_set
extra = provided_set - expected_set
# If there are missing or extra, raise an error
if missing:
raise ValueError(
"The provided data must match the observable distribution names."
f"The following observables are missing: {', '.join(missing)}"
)
if extra:
raise ValueError(
"The provided data must match the observable distribution names. The "
"following observables were provided in addition to the expected: "
f"{', '.join(extra)}"
)
# Shapes must match
for name, param in model.observable_dict.items():
if data[name].shape != param.shape:
raise ValueError(
f"The shape of the provided data for observable {name} does not match "
f"the expected shape. Expected: {param.shape}, provided: "
f"{data[name].shape}"
)
[docs]
class PyTorchModel(nn.Module):
"""PyTorch-trainable version of a SciStanPy Model.
This class converts SciStanPy probabilistic models into PyTorch nn.Module
instances that can be optimized using standard PyTorch training procedures.
It preserves the probabilistic structure while enabling gradient-based
parameter estimation and other machine learning techniques.
:param model: SciStanPy model to convert to PyTorch
:type model: ssp_model.Model
:param seed: Random seed for reproducible parameter initialization. Defaults to None.
:type seed: Optional[custom_types.Integer]
:ivar model: Reference to the original SciStanPy model
:ivar learnable_params: PyTorch ParameterList containing optimizable parameters
The conversion process:
- Initializes all model parameters for PyTorch optimization
- Sets up proper gradient computation graphs
- Configures device placement and memory management
- Preserves probabilistic model structure and relationships
The resulting PyTorch model can be treated like any other nn.Module.
Example:
>>> pytorch_model = model.to_pytorch(seed=42)
>>> optimizer = torch.optim.Adam(pytorch_model.parameters(), lr=0.01)
>>> loss = -pytorch_model(**observed_data)
>>> loss.backward()
>>> optimizer.step()
.. note::
This class should not be instantiated directly. Instead, use the
`to_pytorch()` method on a SciStanPy Model instance.
"""
def __init__(
self, model: "ssp_model.Model", seed: Optional["custom_types.Integer"] = None
):
"""Initialize the PyTorch model from a SciStanPy model."""
super().__init__()
# Record the model
self.model = model
# Initialize all parameters for pytorch optimization
learnable_params = []
for param_num, param in enumerate(self.model.parameters):
param.init_pytorch(seed=None if seed is None else seed + param_num)
learnable_params.append(param._torch_parametrization)
# Record learnable parameters such that they can be recognized by PyTorch
self.learnable_params = nn.ParameterList(learnable_params)
[docs]
def forward(self, **data: torch.Tensor) -> torch.Tensor:
"""Compute log probability of observed data given current parameters.
This method calculates the total log probability (log-likelihood) of
the observed data under the current model parameters. It forms the
core objective function for maximum likelihood estimation and other
gradient-based inference procedures.
:param data: Observed data tensors keyed by observable parameter names
:type data: dict[str, torch.Tensor]
:returns: Total log probability of the observed data
:rtype: torch.Tensor
.. important::
This returns log probability, *not* log loss (negative log probability).
For optimization, negate the result to get the loss function.
Example:
>>> log_prob = pytorch_model(y=observed_y, x=observed_x)
>>> loss = -log_prob # Negative for minimization
>>> loss.backward()
"""
# Sum the log-probs of the observables and parameters
log_prob = 0.0
for name, param in itertools.chain(
self.model.parameter_dict.items(), self.model.observable_dict.items()
):
# Calculate the log probability of the observed data given the parameters
temp_log_prob = param.get_torch_logprob(observed=data.get(name))
# Log probability should be 0-dimensional if anything but a Multinomial
assert temp_log_prob.ndim == 0 or isinstance(param, parameters.Multinomial)
# Add to the total log probability
log_prob += temp_log_prob.sum()
return log_prob
[docs]
def fit(
self,
*,
epochs: "custom_types.Integer" = DEFAULT_N_EPOCHS,
early_stop: "custom_types.Integer" = DEFAULT_EARLY_STOP,
lr: "custom_types.Float" = DEFAULT_LR,
data: dict[
str,
Union[
torch.Tensor, npt.NDArray, "custom_types.Float", "custom_types.Integer"
],
],
mixed_precision: bool = False,
) -> torch.Tensor:
"""Optimize model parameters using gradient-based maximum likelihood estimation.
This method performs complete model training using the Adam optimizer
with configurable early stopping, learning rate, and mixed precision
support. It automatically handles device placement, gradient computation,
and convergence monitoring.
:param epochs: Maximum number of training epochs. Defaults to 100000.
:type epochs: custom_types.Integer
:param early_stop: Epochs without improvement before stopping. Defaults to 10.
:type early_stop: custom_types.Integer
:param lr: Learning rate for Adam optimizer. Defaults to 0.001.
:type lr: custom_types.Float
:param data: Observed data for model observables
:type data: dict[str, Union[torch.Tensor, npt.NDArray, custom_types.Float,
custom_types.Integer]]
:param mixed_precision: Whether to use automatic mixed precision. Defaults to False.
:type mixed_precision: bool
:returns: Tensor containing loss trajectory throughout training
:rtype: torch.Tensor
:raises UserWarning: If early stopping is not triggered within epoch limit
The training loop:
1. Converts input data to appropriate tensor format
2. Validates data compatibility with model observables
3. Iteratively optimizes parameters using gradient descent
4. Monitors convergence and applies early stopping
5. Returns complete loss trajectory for analysis
Example:
>>> loss_history = pytorch_model.fit(
... data={'y': observed_data},
... epochs=5000,
... lr=0.01,
... early_stop=50,
... mixed_precision=True
... )
>>> final_loss = loss_history[-1]
"""
# Any observed data that is not a tensor is converted to a tensor
data = {
k: torch.tensor(v) if not isinstance(v, torch.Tensor) else v
for k, v in data.items()
}
# Note the device
device = self.learnable_params[0].device
# Check the observed data
check_observable_data(self.model, data)
# Train mode. This should be a null-op.
self.train()
# Build the optimizer
optim = torch.optim.Adam(self.parameters(), lr=lr)
# If using mixed precision, we also need a scaler
if mixed_precision:
scaler = torch.amp.GradScaler()
# Set up for optimization
best_loss = float("inf") # Records the best loss
loss_trajectory = [None] * (epochs + 1) # Records all losses
n_without_improvement = 0 # Epochs without improvement
# Run optimization
with tqdm(total=epochs, desc="Epochs", postfix={"-log pdf/pmf": "N/A"}) as pbar:
for epoch in range(epochs):
# Get the loss
with torch.autocast(device_type=device.type, enabled=mixed_precision):
log_loss = -1 * self(**data)
# Step the optimizer
optim.zero_grad()
if mixed_precision:
scaler.scale(log_loss).backward()
scaler.step(optim)
scaler.update()
else:
log_loss.backward()
optim.step()
# Record loss
log_loss = log_loss.item()
loss_trajectory[epoch] = log_loss
# Update best loss
if log_loss < best_loss:
n_without_improvement = 0
best_loss = log_loss
else:
n_without_improvement += 1
# Update progress bar
pbar.update(1)
pbar.set_postfix({"-log pdf/pmf": f"{log_loss:.2f}"})
# Check for early stopping
if early_stop > 0 and n_without_improvement >= early_stop:
break
# Note that early stopping was not triggered if the loop completes
else:
if early_stop > 0:
warnings.warn("Early stopping not triggered.")
# Back to eval mode
self.eval()
# Get a final loss
with torch.no_grad():
loss_trajectory[epoch + 1] = -1 * self(**data).item()
# Trim off the None values of the loss trajectory and convert to a tensor
return torch.tensor(loss_trajectory[: epoch + 2], dtype=torch.float32)
[docs]
def export_params(self) -> dict[str, torch.Tensor]:
"""Export optimized parameter values from the fitted model.
This method extracts the current parameter values after optimization,
providing access to the maximum likelihood estimates or other fitted
parameter values. It excludes observable parameters (which represent
data) and focuses on the learnable model parameters.
:returns: Dictionary mapping parameter names to their current tensor values
:rtype: dict[str, torch.Tensor]
Excluded from export:
- Observable parameters (representing data, not learnable parameters)
- Unnamed parameters
- Intermediate computational results from transformations
This is typically used after model fitting to extract the estimated
parameter values for further analysis or model comparison.
Example:
>>> fitted_params = pytorch_model.export_params()
>>> mu_estimate = fitted_params['mu']
>>> sigma_estimate = fitted_params['sigma']
"""
return {
name: param.torch_parametrization
for name, param in self.model.parameter_dict.items()
}
[docs]
def export_distributions(self) -> dict[str, torch.distributions.Distribution]:
"""Export fitted probability distributions for all model components.
This method returns the complete set of probability distributions
from the fitted model, including both parameter distributions (priors)
and observable distributions (likelihoods) with their current
parameter values.
:returns: Dictionary mapping component names to their distribution objects
:rtype: dict[str, torch.distributions.Distribution]
The exported distributions include:
- Parameter distributions with updated hyperparameter values
- Observable distributions with fitted parameter values
- All distributions in their PyTorch format for further computation
This is useful for:
- Posterior predictive sampling
- Model diagnostics and validation
- Uncertainty quantification
- Distribution comparison and analysis
Example:
>>> distributions = pytorch_model.export_distributions()
>>> fitted_normal = distributions['mu'] # torch.distributions.Normal
>>> samples = fitted_normal.sample((1000,)) # Sample from fit distribution
"""
return {
name: param.torch_dist_instance
for name, param in itertools.chain(
self.model.parameter_dict.items(), self.model.observable_dict.items()
)
}
def _move_model(self, funcname: str, *args, **kwargs):
"""Internal method for device placement operations.
This method handles the task of moving both PyTorch parameters
and SciStanPy constant tensors to different devices or data types.
It ensures that all model components remain synchronized during
device transfers.
:param funcname: Name of the PyTorch method to apply ('cuda', 'cpu', 'to')
:type funcname: str
:param args: Positional arguments for the device operation
:param kwargs: Keyword arguments for the device operation
:returns: Self reference for method chaining
:rtype: PyTorchModel
"""
# Apply to the model
getattr(super(), funcname)(*args, **kwargs)
# Apply to additional torch tensors in the model (i.e., the ones that are
# constants and not parameters)
# pylint: disable=protected-access
for constant in filter(
lambda x: isinstance(x, constants.Constant),
self.model.all_model_components,
):
constant._torch_parametrization = getattr(
constant._torch_parametrization, funcname
)(*args, **kwargs)
return self
[docs]
def cuda(self, *args, **kwargs):
"""Move model to CUDA device.
This method transfers the entire model (including SciStanPy constants)
to a CUDA-enabled GPU device for accelerated computation.
:param args: Arguments passed to torch.nn.Module.cuda()
:param kwargs: Keyword arguments passed to torch.nn.Module.cuda()
:returns: Self reference for method chaining
:rtype: PyTorchModel
Example:
>>> pytorch_model = pytorch_model.cuda() # Move to default GPU
>>> pytorch_model = pytorch_model.cuda(1) # Move to GPU 1
"""
return self._move_model("cuda", *args, **kwargs)
[docs]
def cpu(self, *args, **kwargs):
"""Move model to CPU device.
This method transfers the entire model (including SciStanPy constants)
to CPU memory, which is useful for inference or when GPU memory is limited.
:param args: Arguments passed to torch.nn.Module.cpu()
:param kwargs: Keyword arguments passed to torch.nn.Module.cpu()
:returns: Self reference for method chaining
:rtype: PyTorchModel
Example:
>>> pytorch_model = pytorch_model.cpu() # Move to CPU
"""
return self._move_model("cpu", *args, **kwargs)
[docs]
def to(self, *args, **kwargs):
"""Move model to specified device or data type.
This method provides flexible device and dtype conversion for the
entire model, including both PyTorch parameters and SciStanPy
constant tensors.
:param args: Arguments passed to torch.nn.Module.to()
:param kwargs: Keyword arguments passed to torch.nn.Module.to()
:returns: Self reference for method chaining
:rtype: PyTorchModel
Example:
>>> pytorch_model = pytorch_model.to('cuda:0') # Move to specific GPU
>>> pytorch_model = pytorch_model.to(torch.float64) # Change precision
>>> pytorch_model = pytorch_model.to('cpu', dtype=torch.float32)
"""
return self._move_model("to", *args, **kwargs)