Neural Network Module API Reference

PyTorch integration utilities for SciStanPy models.

This module provides integration between SciStanPy probabilistic models and PyTorch’s automatic differentiation and optimization framework. It enables maximum likelihood estimation, variational inference, and other gradient-based learning procedures on SciStanPy models.

The module’s core functionality centers around converting SciStanPy models into PyTorch nn.Module instances that preserve the probabilistic structure while enabling efficient gradient computation and optimization. This allows users to leverage PyTorch’s ecosystem of optimizers, learning rate schedulers, and other training utilities.

Key Features:
  • Automatic conversion of SciStanPy models to PyTorch modules

  • Gradient-based parameter optimization with various optimizers

  • Mixed precision training support for improved performance

  • Early stopping and convergence monitoring

  • GPU acceleration and device management

Core PyTorch Integration

class scistanpy.model.nn_module.PyTorchModel(*args: Any, **kwargs: Any)[source]

Bases: Module

PyTorch-trainable version of a SciStanPy Model.

This class converts SciStanPy probabilistic models into PyTorch nn.Module instances that can be optimized using standard PyTorch training procedures. It preserves the probabilistic structure while enabling gradient-based parameter estimation and other machine learning techniques.

Parameters:
  • model (ssp_model.Model) – SciStanPy model to convert to PyTorch

  • seed (Optional[custom_types.Integer]) – Random seed for reproducible parameter initialization. Defaults to None.

Variables:
  • model – Reference to the original SciStanPy model

  • learnable_params – PyTorch ParameterList containing optimizable parameters

The conversion process:
  • Initializes all model parameters for PyTorch optimization

  • Sets up proper gradient computation graphs

  • Configures device placement and memory management

  • Preserves probabilistic model structure and relationships

The resulting PyTorch model can be treated like any other nn.Module.

Example:
>>> pytorch_model = model.to_pytorch(seed=42)
>>> optimizer = torch.optim.Adam(pytorch_model.parameters(), lr=0.01)
>>> loss = -pytorch_model(**observed_data)
>>> loss.backward()
>>> optimizer.step()

Note

This class should not be instantiated directly. Instead, use the to_pytorch() method on a SciStanPy Model instance.

cpu(*args, **kwargs)[source]

Move model to CPU device.

This method transfers the entire model (including SciStanPy constants) to CPU memory, which is useful for inference or when GPU memory is limited.

Parameters:
  • args – Arguments passed to torch.nn.Module.cpu()

  • kwargs – Keyword arguments passed to torch.nn.Module.cpu()

Returns:

Self reference for method chaining

Return type:

PyTorchModel

Example:
>>> pytorch_model = pytorch_model.cpu()  # Move to CPU
cuda(*args, **kwargs)[source]

Move model to CUDA device.

This method transfers the entire model (including SciStanPy constants) to a CUDA-enabled GPU device for accelerated computation.

Parameters:
  • args – Arguments passed to torch.nn.Module.cuda()

  • kwargs – Keyword arguments passed to torch.nn.Module.cuda()

Returns:

Self reference for method chaining

Return type:

PyTorchModel

Example:
>>> pytorch_model = pytorch_model.cuda()  # Move to default GPU
>>> pytorch_model = pytorch_model.cuda(1)  # Move to GPU 1
export_distributions() dict[str, torch.distributions.Distribution][source]

Export fitted probability distributions for all model components.

This method returns the complete set of probability distributions from the fitted model, including both parameter distributions (priors) and observable distributions (likelihoods) with their current parameter values.

Returns:

Dictionary mapping component names to their distribution objects

Return type:

dict[str, torch.distributions.Distribution]

The exported distributions include:
  • Parameter distributions with updated hyperparameter values

  • Observable distributions with fitted parameter values

  • All distributions in their PyTorch format for further computation

This is useful for:
  • Posterior predictive sampling

  • Model diagnostics and validation

  • Uncertainty quantification

  • Distribution comparison and analysis

Example:
>>> distributions = pytorch_model.export_distributions()
>>> fitted_normal = distributions['mu']  # torch.distributions.Normal
>>> samples = fitted_normal.sample((1000,))  # Sample from fit distribution
export_params() dict[str, torch.Tensor][source]

Export optimized parameter values from the fitted model.

This method extracts the current parameter values after optimization, providing access to the maximum likelihood estimates or other fitted parameter values. It excludes observable parameters (which represent data) and focuses on the learnable model parameters.

Returns:

Dictionary mapping parameter names to their current tensor values

Return type:

dict[str, torch.Tensor]

Excluded from export:
  • Observable parameters (representing data, not learnable parameters)

  • Unnamed parameters

  • Intermediate computational results from transformations

This is typically used after model fitting to extract the estimated parameter values for further analysis or model comparison.

Example:
>>> fitted_params = pytorch_model.export_params()
>>> mu_estimate = fitted_params['mu']
>>> sigma_estimate = fitted_params['sigma']
fit(
*,
epochs: custom_types.Integer = 100000,
early_stop: custom_types.Integer = 10,
lr: custom_types.Float = 0.001,
data: dict[str, torch.Tensor | ndarray[tuple[int, ...], dtype[_ScalarType_co]] | custom_types.Float | custom_types.Integer],
mixed_precision: bool = False,
) torch.Tensor[source]

Optimize model parameters using gradient-based maximum likelihood estimation.

This method performs complete model training using the Adam optimizer with configurable early stopping, learning rate, and mixed precision support. It automatically handles device placement, gradient computation, and convergence monitoring.

Parameters:
  • epochs (custom_types.Integer) – Maximum number of training epochs. Defaults to 100000.

  • early_stop (custom_types.Integer) – Epochs without improvement before stopping. Defaults to 10.

  • lr (custom_types.Float) – Learning rate for Adam optimizer. Defaults to 0.001.

  • data (dict[str, Union[torch.Tensor, npt.NDArray, custom_types.Float, custom_types.Integer]]) – Observed data for model observables

  • mixed_precision (bool) – Whether to use automatic mixed precision. Defaults to False.

Returns:

Tensor containing loss trajectory throughout training

Return type:

torch.Tensor

Raises:

UserWarning – If early stopping is not triggered within epoch limit

The training loop:
  1. Converts input data to appropriate tensor format

  2. Validates data compatibility with model observables

  3. Iteratively optimizes parameters using gradient descent

  4. Monitors convergence and applies early stopping

  5. Returns complete loss trajectory for analysis

Example:
>>> loss_history = pytorch_model.fit(
...     data={'y': observed_data},
...     epochs=5000,
...     lr=0.01,
...     early_stop=50,
...     mixed_precision=True
... )
>>> final_loss = loss_history[-1]
forward(**data: torch.Tensor) torch.Tensor[source]

Compute log probability of observed data given current parameters.

This method calculates the total log probability (log-likelihood) of the observed data under the current model parameters. It forms the core objective function for maximum likelihood estimation and other gradient-based inference procedures.

Parameters:

data (dict[str, torch.Tensor]) – Observed data tensors keyed by observable parameter names

Returns:

Total log probability of the observed data

Return type:

torch.Tensor

Important

This returns log probability, not log loss (negative log probability). For optimization, negate the result to get the loss function.

Example:
>>> log_prob = pytorch_model(y=observed_y, x=observed_x)
>>> loss = -log_prob  # Negative for minimization
>>> loss.backward()
to(*args, **kwargs)[source]

Move model to specified device or data type.

This method provides flexible device and dtype conversion for the entire model, including both PyTorch parameters and SciStanPy constant tensors.

Parameters:
  • args – Arguments passed to torch.nn.Module.to()

  • kwargs – Keyword arguments passed to torch.nn.Module.to()

Returns:

Self reference for method chaining

Return type:

PyTorchModel

Example:
>>> pytorch_model = pytorch_model.to('cuda:0')  # Move to specific GPU
>>> pytorch_model = pytorch_model.to(torch.float64)  # Change precision
>>> pytorch_model = pytorch_model.to('cpu', dtype=torch.float32)