Neural Network Module API Reference¶
PyTorch integration utilities for SciStanPy models.
This module provides integration between SciStanPy probabilistic models and PyTorch’s automatic differentiation and optimization framework. It enables maximum likelihood estimation, variational inference, and other gradient-based learning procedures on SciStanPy models.
The module’s core functionality centers around converting SciStanPy models into
PyTorch nn.Module
instances that preserve the probabilistic structure while
enabling efficient gradient computation and optimization. This allows users to
leverage PyTorch’s ecosystem of optimizers, learning rate schedulers, and other
training utilities.
- Key Features:
Automatic conversion of SciStanPy models to PyTorch modules
Gradient-based parameter optimization with various optimizers
Mixed precision training support for improved performance
Early stopping and convergence monitoring
GPU acceleration and device management
Core PyTorch Integration¶
- class scistanpy.model.nn_module.PyTorchModel(*args: Any, **kwargs: Any)[source]¶
Bases:
Module
PyTorch-trainable version of a SciStanPy Model.
This class converts SciStanPy probabilistic models into PyTorch nn.Module instances that can be optimized using standard PyTorch training procedures. It preserves the probabilistic structure while enabling gradient-based parameter estimation and other machine learning techniques.
- Parameters:
model (ssp_model.Model) – SciStanPy model to convert to PyTorch
seed (Optional[custom_types.Integer]) – Random seed for reproducible parameter initialization. Defaults to None.
- Variables:
model – Reference to the original SciStanPy model
learnable_params – PyTorch ParameterList containing optimizable parameters
- The conversion process:
Initializes all model parameters for PyTorch optimization
Sets up proper gradient computation graphs
Configures device placement and memory management
Preserves probabilistic model structure and relationships
The resulting PyTorch model can be treated like any other nn.Module.
- Example:
>>> pytorch_model = model.to_pytorch(seed=42) >>> optimizer = torch.optim.Adam(pytorch_model.parameters(), lr=0.01) >>> loss = -pytorch_model(**observed_data) >>> loss.backward() >>> optimizer.step()
Note
This class should not be instantiated directly. Instead, use the to_pytorch() method on a SciStanPy Model instance.
- cpu(*args, **kwargs)[source]¶
Move model to CPU device.
This method transfers the entire model (including SciStanPy constants) to CPU memory, which is useful for inference or when GPU memory is limited.
- Parameters:
args – Arguments passed to torch.nn.Module.cpu()
kwargs – Keyword arguments passed to torch.nn.Module.cpu()
- Returns:
Self reference for method chaining
- Return type:
- Example:
>>> pytorch_model = pytorch_model.cpu() # Move to CPU
- cuda(*args, **kwargs)[source]¶
Move model to CUDA device.
This method transfers the entire model (including SciStanPy constants) to a CUDA-enabled GPU device for accelerated computation.
- Parameters:
args – Arguments passed to torch.nn.Module.cuda()
kwargs – Keyword arguments passed to torch.nn.Module.cuda()
- Returns:
Self reference for method chaining
- Return type:
- Example:
>>> pytorch_model = pytorch_model.cuda() # Move to default GPU >>> pytorch_model = pytorch_model.cuda(1) # Move to GPU 1
- export_distributions() dict[str, torch.distributions.Distribution] [source]¶
Export fitted probability distributions for all model components.
This method returns the complete set of probability distributions from the fitted model, including both parameter distributions (priors) and observable distributions (likelihoods) with their current parameter values.
- Returns:
Dictionary mapping component names to their distribution objects
- Return type:
dict[str, torch.distributions.Distribution]
- The exported distributions include:
Parameter distributions with updated hyperparameter values
Observable distributions with fitted parameter values
All distributions in their PyTorch format for further computation
- This is useful for:
Posterior predictive sampling
Model diagnostics and validation
Uncertainty quantification
Distribution comparison and analysis
- Example:
>>> distributions = pytorch_model.export_distributions() >>> fitted_normal = distributions['mu'] # torch.distributions.Normal >>> samples = fitted_normal.sample((1000,)) # Sample from fit distribution
- export_params() dict[str, torch.Tensor] [source]¶
Export optimized parameter values from the fitted model.
This method extracts the current parameter values after optimization, providing access to the maximum likelihood estimates or other fitted parameter values. It excludes observable parameters (which represent data) and focuses on the learnable model parameters.
- Returns:
Dictionary mapping parameter names to their current tensor values
- Return type:
dict[str, torch.Tensor]
- Excluded from export:
Observable parameters (representing data, not learnable parameters)
Unnamed parameters
Intermediate computational results from transformations
This is typically used after model fitting to extract the estimated parameter values for further analysis or model comparison.
- Example:
>>> fitted_params = pytorch_model.export_params() >>> mu_estimate = fitted_params['mu'] >>> sigma_estimate = fitted_params['sigma']
- fit(
- *,
- epochs: custom_types.Integer = 100000,
- early_stop: custom_types.Integer = 10,
- lr: custom_types.Float = 0.001,
- data: dict[str, torch.Tensor | ndarray[tuple[int, ...], dtype[_ScalarType_co]] | custom_types.Float | custom_types.Integer],
- mixed_precision: bool = False,
Optimize model parameters using gradient-based maximum likelihood estimation.
This method performs complete model training using the Adam optimizer with configurable early stopping, learning rate, and mixed precision support. It automatically handles device placement, gradient computation, and convergence monitoring.
- Parameters:
epochs (custom_types.Integer) – Maximum number of training epochs. Defaults to 100000.
early_stop (custom_types.Integer) – Epochs without improvement before stopping. Defaults to 10.
lr (custom_types.Float) – Learning rate for Adam optimizer. Defaults to 0.001.
data (dict[str, Union[torch.Tensor, npt.NDArray, custom_types.Float, custom_types.Integer]]) – Observed data for model observables
mixed_precision (bool) – Whether to use automatic mixed precision. Defaults to False.
- Returns:
Tensor containing loss trajectory throughout training
- Return type:
torch.Tensor
- Raises:
UserWarning – If early stopping is not triggered within epoch limit
- The training loop:
Converts input data to appropriate tensor format
Validates data compatibility with model observables
Iteratively optimizes parameters using gradient descent
Monitors convergence and applies early stopping
Returns complete loss trajectory for analysis
- Example:
>>> loss_history = pytorch_model.fit( ... data={'y': observed_data}, ... epochs=5000, ... lr=0.01, ... early_stop=50, ... mixed_precision=True ... ) >>> final_loss = loss_history[-1]
- forward(**data: torch.Tensor) torch.Tensor [source]¶
Compute log probability of observed data given current parameters.
This method calculates the total log probability (log-likelihood) of the observed data under the current model parameters. It forms the core objective function for maximum likelihood estimation and other gradient-based inference procedures.
- Parameters:
data (dict[str, torch.Tensor]) – Observed data tensors keyed by observable parameter names
- Returns:
Total log probability of the observed data
- Return type:
torch.Tensor
Important
This returns log probability, not log loss (negative log probability). For optimization, negate the result to get the loss function.
- Example:
>>> log_prob = pytorch_model(y=observed_y, x=observed_x) >>> loss = -log_prob # Negative for minimization >>> loss.backward()
- to(*args, **kwargs)[source]¶
Move model to specified device or data type.
This method provides flexible device and dtype conversion for the entire model, including both PyTorch parameters and SciStanPy constant tensors.
- Parameters:
args – Arguments passed to torch.nn.Module.to()
kwargs – Keyword arguments passed to torch.nn.Module.to()
- Returns:
Self reference for method chaining
- Return type:
- Example:
>>> pytorch_model = pytorch_model.to('cuda:0') # Move to specific GPU >>> pytorch_model = pytorch_model.to(torch.float64) # Change precision >>> pytorch_model = pytorch_model.to('cpu', dtype=torch.float32)