Maximum Likelihood Estimation Results API Reference

Maximum likelihood estimation results analysis and visualization for SciStanPy models.

This module provides analysis tools for maximum likelihood estimation results from SciStanPy models. It offers diagnostic plots, calibration checks, and posterior predictive analysis tools designed specifically for MLE-based inference workflows.

The module centers around three main classes: MLEParam for individual parameter estimates, MLE for complete model results, and MLEInferenceRes class, which wraps ArviZ InferenceData objects with specialized methods for MLE result analysis. Together, these classes provide the estimated parameter values and the fitted probability distributions resulting from MLE analysis, and allow for downstream analysis including uncertainty quantification and posterior predictive sampling. It provides both individual diagnostic tools and analysis workflows that combine multiple checks into unified reporting interfaces.

Key Features:
  • Individual parameter MLE estimates with associated distributions

  • Complete model MLE results with loss tracking and diagnostics

  • Posterior predictive checking workflows

  • Model calibration analysis with quantitative metrics

  • Interactive visualization with customizable display options

  • Integration with ArviZ for standardized Bayesian workflows

  • Memory-efficient handling of large posterior predictive samples

  • Flexible output formats for different analysis needs

Visualization Capabilities:
  • Posterior predictive sample plotting with confidence intervals

  • Calibration plots with deviation metrics

  • Quantile-quantile plots for model validation

  • Interactive layouts with customizable dimensions

Performance Considerations:
  • Batch sampling prevents memory overflow for large sample requests

  • GPU acceleration is preserved through PyTorch distribution objects

The module is designed to work with SciStanPy’s MLE estimation workflow, providing immediate access to model diagnostics and validation tools once MLE fitting is complete. The MLE results can be used for various purposes including model comparison, uncertainty quantification, and as initialization for more sophisticated inference procedures like MCMC sampling.

In combination, the following provide a comprehensive framework for analyzing maximum likelihood estimation results. Quickly navigate to specific sections using the links below:

Maximum Likelihood Estimation Results

class scistanpy.model.results.mle.MLE(
model: ssp_model.Model,
mle_estimate: dict[str, npt.NDArray],
distributions: dict[str, torch.distributions.Distribution],
losses: npt.NDArray,
data: dict[str, npt.NDArray],
)[source]

Bases: object

Complete maximum likelihood estimation results for a SciStanPy model.

This class encapsulates the full results of a call to Model.mle() for MLE parameter estimation, including parameter estimates, fitted distributions, optimization diagnostics, and utilities for further analysis. It provides a comprehensive interface for working with MLE results.

Parameters:
  • model (ssp_model.Model) – Original SciStanPy model

  • mle_estimate (dict[str, npt.NDArray]) – Dictionary of parameter names to their MLE values

  • distributions (dict[str, torch.distributions.Distribution]) – Dictionary of parameter names to fitted distributions

  • losses (npt.NDArray) – Array of loss values throughout optimization

  • data (dict[str, npt.NDArray]) – Observed data used for parameter estimation

Variables:
  • model – Reference to the original model

  • data – Observed data used for fitting

  • model_varname_to_mle – Mapping from parameter names to MLEParam objects

  • losses – DataFrame containing loss trajectory and diagnostics

Raises:
  • ValueError – If MLE estimate keys are not subset of distribution keys

  • ValueError – If parameter names conflict with existing attributes

The class automatically creates attributes for each parameter, allowing, e.g., direct access to a parameter named mu using the syntax mle_result.mu. It also exposes a method for bootstrapping samples from the fit model, providing a relatively cheap way to quantify uncertainty around MLE estimates.

Key Features:
  • Direct attribute access to individual parameter results

  • Comprehensive loss trajectory tracking and visualization

  • Efficient sampling from fitted parameter distributions

  • Integration with ArviZ for Bayesian workflow compatibility

  • Memory-efficient batch processing for large sample requests

Example:
# Run MLE fitting
mle_result = model.mle(data=observed_data)

# Access optimization diagnostics
loss_plot = mle_result.plot_loss_curve(logy=True)

# Sample from all fitted distributions
parameter_samples = mle_result.draw(n=1000, as_xarray=True)

# Sample from a specific parameter
mu_samples = mle_result.mu.draw(1000)

# Create inference object for detailed analysis
inference_obj = mle_result.get_inference_obj(n=2000)
draw(
n: custom_types.Integer,
*,
seed: custom_types.Integer | None,
as_xarray: Literal[True],
as_inference_data: Literal[False],
batch_size: custom_types.Integer | None = None,
) Dataset[source]
draw(
n: custom_types.Integer,
*,
seed: custom_types.Integer | None,
as_xarray: Literal[False],
batch_size: custom_types.Integer | None = None,
) dict[str, ndarray[tuple[int, ...], dtype[_ScalarType_co]]]

Generate samples from all fitted parameter distributions.

This method draws samples from the fitted distributions of all model parameters. It supports multiple output formats for integration with different analysis workflows.

Parameters:
  • n (custom_types.Integer) – Number of samples to draw from each parameter distribution

  • seed (Optional[custom_types.Integer]) – Random seed for reproducible sampling. Defaults to None.

  • as_xarray (bool) – Whether to return results as xarray Dataset. Defaults to False.

  • batch_size (Optional[custom_types.Integer]) – Batch size for memory-efficient sampling. Defaults to None.

Returns:

Sampled parameter values in requested format

Return type:

Union[dict[str, npt.NDArray], xr.Dataset]

Output Formats:
  • Dictionary (default): Keys are parameter names, values are sample arrays

  • xarray Dataset: Structured dataset with proper dimension labels and coordinates

This is particularly useful for:
  • Uncertainty propagation through model predictions

  • Bayesian model comparison and validation

  • Posterior predictive checking with MLE-based approximations

  • Sensitivity analysis of parameter estimates

Example:
>>> # Draw samples as dictionary
>>> samples = mle_result.draw(1000, seed=42)
>>> # Draw as structured xarray Dataset
>>> dataset = mle_result.draw(1000, as_xarray=True, batch_size=100)
get_inference_obj(
n: custom_types.Integer = 1000,
*,
seed: custom_types.Integer | None = None,
batch_size: custom_types.Integer | None = None,
) MLEInferenceRes[source]

Create ArviZ-compatible inference data object from MLE results.

This method constructs a comprehensive inference data structure that integrates MLE results with the ArviZ ecosystem for Bayesian analysis. Samples are bootstrapped from the fitted parameter distributions to approximate posterior distributions. It organizes parameter samples, observed data, and posterior predictive samples into a standardized format.

Parameters:
  • n (custom_types.Integer) – Number of samples to generate for the inference object. Defaults to 1000.

  • seed (Optional[custom_types.Integer]) – Random seed for reproducible sample generation. Defaults to None.

  • batch_size (Optional[custom_types.Integer]) – Batch size for memory-efficient sampling. Defaults to None.

Returns:

Structured inference data object with all MLE results

Return type:

results.MLEInferenceRes

The resulting inference object contains:
  • Posterior samples: Draws from fitted parameter distributions

  • Observed data: Original data used for parameter estimation

  • Posterior predictive: Samples from observable distributions

Data Organization:
  • Latent parameters are stored in the main posterior group

  • Observable parameters become posterior predictive samples

  • Observed data is stored separately for comparison

  • All data maintains proper dimensional structure and labeling

This enables:
  • Integration with ArviZ plotting and diagnostic functions

  • Model comparison

  • Posterior predictive checking workflows

  • Standardized reporting and visualization

Important

Samples are drawn using the optimized value of their parent parameters. For example, if a parameter y is defined in the model as y ~ Normal(mu, sigma), where mu and sigma are also parameters in the model, then samples of y will be drawn using the MLE values of mu and sigma. This means that uncertainty in mu and sigma is not propagated to y. This is a limitation of the MLE-based approach and should be considered when interpreting results.

Important

Related to the above, for root-level parameters with constant values for parent parameters, sampling from the fit distribution is identical to sampling from the prior distribution. For example, for a parameter, y defined in the model as y ~ Normal(mu = 0.0, sigma = 1.0), the values of mu and sigma will not change during fitting, so the distribution of y will remain Normal(0.0, 1.0).

Example:
>>> # Create inference object with default settings
>>> inference_obj = mle_result.get_inference_obj()
>>> # Generate larger sample with custom batch size
>>> inference_obj = mle_result.get_inference_obj(
...     n=5000, batch_size=500, seed=42
... )
plot_loss_curve(logy: bool = True)[source]

Generate interactive plot of the optimization loss trajectory.

This method creates a visualization of how the loss function evolved during the optimization process, providing insights into convergence behavior and optimization effectiveness.

Parameters:

logy (bool) – Whether to use logarithmic y-axis scaling. Defaults to True.

Returns:

Interactive HoloViews plot of the loss curve

The plot automatically handles:
  • Logarithmic scaling with proper handling of negative/zero values

  • Appropriate axis labels and titles based on scaling choice

  • Interactive features for detailed examination of convergence

  • Warning messages for problematic loss trajectories

For logarithmic scaling with non-positive loss values, the method automatically switches to a shifted logarithmic scale to maintain visualization quality while issuing appropriate warnings.

Example:
>>> # Standard logarithmic loss plot
>>> loss_plot = mle_result.plot_loss_curve()
>>> # Linear scale loss plot
>>> linear_plot = mle_result.plot_loss_curve(logy=False)
class scistanpy.model.results.mle.MLEParam(name: str, value: npt.NDArray | None, distribution: custom_types.SciStanPyDistribution)[source]

Bases: object

Container for maximum likelihood estimate of a single model parameter.

This class encapsulates the MLE result for an individual parameter, including the estimated value and the corresponding fitted probability distribution. It provides methods for sampling from the fitted distribution and accessing parameter properties.

Parameters:
  • name (str) – Name of the parameter in the model

  • value (Optional[npt.NDArray]) – Maximum likelihood estimate of the parameter value. Can be None for some distribution types.

  • distribution (custom_types.SciStanPyDistribution) – Fitted probability distribution object

Variables:
  • name – Parameter name identifier

  • mle – Stored maximum likelihood estimate

  • distribution – Fitted distribution for sampling and analysis

The class maintains both point estimates and distributional representations, enabling both point-based analysis and uncertainty quantification through sampling from the fitted distribution.

Example:
# Run MLE fitting
mle_result = model.mle(data=observed_data)

# Access a specific parameter (an instance of `MLEParam`) describing
# the MLE results for that parameter
mle_param = mle_result.mu
draw(
n: int,
*,
seed: custom_types.Integer | None = None,
batch_size: custom_types.Integer | None = None,
) npt.NDArray[source]

Sample from the fitted parameter distribution.

This method generates samples from the parameter’s fitted probability distribution using batch processing to handle large sample requests.

Parameters:
  • n (int) – Total number of samples to generate

  • seed (Optional[custom_types.Integer]) – Random seed for reproducible sampling. Defaults to None.

  • batch_size (Optional[custom_types.Integer]) – Size of batches for memory-efficient sampling. Defaults to None (uses n as batch size).

Returns:

Array of samples from the fitted distribution

Return type:

npt.NDArray

Batch processing prevents memory overflow when requesting large numbers of samples from complex distributions, particularly important when working with GPU-based computations.

Example:
>>> # Generate 10000 samples in batches of 1000
>>> samples = param.draw(10000, batch_size=1000, seed=42)
>>> print(f"Sample mean: {samples.mean()}")

Bootstrapped Parameter Values and Analysis

class scistanpy.model.results.mle.MLEInferenceRes(inference_obj: InferenceData | str)[source]

Bases: object

Analysis interface for bootstrapped samples from MLE instances.

This class provides tools for analyzing and visualizing MLE results from SciStanPy models. It wraps ArviZ InferenceData objects with specialized methods for posterior predictive checking, calibration analysis, and model validation.

Parameters:

inference_obj (Union[az.InferenceData, str]) – ArviZ InferenceData object or path to saved results

Variables:

inference_obj – Stored ArviZ InferenceData object with all results

Raises:
  • ValueError – If inference_obj is neither string nor InferenceData

  • ValueError – If required groups (posterior, posterior_predictive) are missing

The class expects the InferenceData object to contain:
  • posterior: Samples from fitted parameter distributions

  • posterior_predictive: Samples from observable distributions

  • observed_data: Original observed data used for fitting

Key Capabilities:
  • Posterior predictive checking with multiple visualization modes

  • Quantitative model calibration assessment

  • Interactive diagnostic dashboards

  • Summary statistics computation and caching

Example:
import scistanpy as ssp
import numpy as np

# Get MLE results
mle_result = model.mle(data=observed_data)

# Create inference analysis object
mle_analysis = mle_result.get_inference_obj()

# Run comprehensive posterior predictive checking
dashboard = mle_analysis.run_ppc()

# Save results for later analysis
mle_analysis.save_netcdf('mle_analysis.nc')
calculate_summaries(
var_names: list[str] | None = None,
filter_vars: Literal[None, 'like', 'regex'] = None,
kind: Literal['all', 'stats', 'diagnostics'] = 'stats',
round_to: custom_types.Integer = 2,
circ_var_names: list[str] | None = None,
stat_focus: str = 'mean',
stat_funcs: dict[str, callable] | callable | None = None,
extend: bool = True,
hdi_prob: custom_types.Float = 0.94,
skipna: bool = False,
) xr.Dataset[source]

Compute summary statistics for MLE results.

This method wraps ArviZ’s summary functionality while adding the computed statistics to the InferenceData object for persistence and reuse. See az.summary for detailed descriptions of arguments.

Parameters:
  • var_names (Optional[list[str]]) – Variable names to include in summary. Defaults to None (all variables).

  • filter_vars (Optional[Literal[None, "like", "regex"]]) – Variable filtering method. Defaults to None.

  • kind (Literal["all", "stats", "diagnostics"]) – Type of statistics to compute. Defaults to “stats”.

  • round_to (custom_types.Integer) – Number of decimal places for rounding. Defaults to 2.

  • circ_var_names (Optional[list[str]]) – Names of circular variables. Defaults to None.

  • stat_focus (str) – Primary statistic for focus. Defaults to “mean”.

  • stat_funcs (Optional[Union[dict[str, callable], callable]]) – Custom statistic functions. Defaults to None.

  • extend (bool) – Use functions provided by stat_funcs. Defaults to True. Only meaningful when stat_funcs is provided.

  • hdi_prob (custom_types.Float) – Probability for highest density interval. Defaults to 0.94.

  • skipna (bool) – Whether to skip NaN values. Defaults to False.

Returns:

Dataset containing computed summary statistics

Return type:

xr.Dataset

Raises:
  • ValueError – If diagnostics requested without chain dimension existing in self.inference_obj.posterior.dims

  • ValueError – If diagnostics requested with single chain.

The computed statistics are automatically added to the InferenceData object under the variable_summary_stats group for persistence.

Example:
>>> # Compute basic statistics
>>> stats = mle_analysis.calculate_summaries()
>>> # Compute diagnostics for multi-chain results
>>> diag = mle_analysis.calculate_summaries(kind="diagnostics")
check_calibration(
*,
return_deviance: Literal[False],
display: Literal[True],
width: custom_types.Integer,
height: custom_types.Integer,
) Layout[source]
check_calibration(
*,
return_deviance: Literal[False],
display: Literal[False],
width: custom_types.Integer,
height: custom_types.Integer,
) dict[str, Overlay]
check_calibration(
*,
return_deviance: Literal[True],
display: Literal[False],
width: custom_types.Integer,
height: custom_types.Integer,
) tuple[dict[str, Overlay], dict[str, float]]

Assess model calibration through posterior predictive quantile analysis.

This method evaluates how well the model’s posterior predictive distribution matches the observed data by analyzing the distribution of quantiles. Well- calibrated models should produce observed data that are uniformly distributed across the quantiles of the posterior predictive distribution.

Parameters:
  • return_deviance (bool) – Whether to return quantitative deviance metrics. Defaults to False.

  • display (bool) – Whether to return formatted layout for display. Defaults to True.

  • width (custom_types.Integer) – Width of individual plots in pixels. Defaults to 600.

  • height (custom_types.Integer) – Height of individual plots in pixels. Defaults to 600.

Returns:

Calibration plots and optionally deviance metrics

Return type:

Union[hv.Layout, dict[str, hv.Overlay], tuple[dict[str, hv.Overlay], dict[str, float]]]

Raises:

ValueError – If both display and return_deviance are True

Internally, this method is just a wrapper around ssp.plotting.plot_calibration. See that function for a detailed description of the calibration assessment method and returned plots.

Example:
>>> # Visual assessment
>>> cal_layout = mle_analysis.check_calibration()
>>> # Quantitative assessment
>>> plots, deviances = mle_analysis.check_calibration(
...     return_deviance=True, display=False
... )
>>> print(f"Mean deviance: {np.mean(list(deviances.values())):.3f}")
classmethod from_disk(path: str) MLEInferenceRes[source]

Load MLEInferenceRes object from saved NetCDF file.

Parameters:

path (str) – Path to NetCDF file containing saved InferenceData

Returns:

Reconstructed MLEInferenceRes object with all analysis capabilities

Return type:

MLEInferenceRes

This class method enables loading of previously saved analysis results, preserving all computed statistics and enabling continued analysis from where previous sessions left off.

Example:
>>> # Load previously saved results
>>> mle_analysis = MLEInferenceRes.from_disk('saved_results.nc')
>>> # Continue analysis with full functionality
>>> dashboard = mle_analysis.run_ppc()
plot_observed_quantiles(
*,
use_ranks: bool,
display: Literal[True],
width: custom_types.Integer,
height: custom_types.Integer,
windowsize: 'custom_types.Integer' | None,
) Layout[source]
plot_observed_quantiles(
*,
use_ranks: bool,
display: Literal[False],
width: custom_types.Integer,
height: custom_types.Integer,
windowsize: 'custom_types.Integer' | None,
) dict[str, Overlay]

Visualize systematic patterns in observed data quantiles.

This method creates hexagonal density plots showing the relationship between observed data values (or their ranks) and their corresponding quantiles within the posterior predictive distribution. A rolling mean overlay highlights systematic trends.

Parameters:
  • use_ranks (bool) – Whether to use ranks instead of raw values for x-axis. Defaults to True.

  • display (bool) – Whether to return formatted layout for display. Defaults to True.

  • width (custom_types.Integer) – Width of individual plots in pixels. Defaults to 600.

  • height (custom_types.Integer) – Height of individual plots in pixels. Defaults to 400.

  • windowsize (Optional[custom_types.Integer]) – Size of rolling window for trend line. Defaults to None (automatic).

Returns:

Quantile plots in requested format

Return type:

Union[hv.Layout, dict[str, hv.Overlay]]

Visualization Components:
  • Hexagonal binning showing density of (value, quantile) pairs

  • Rolling mean trend line highlighting systematic patterns

  • Colormap indicating point density for pattern identification

Pattern Interpretation:
  • Horizontal trend line around 0.5 with uniformly distributed points indicates good calibration

  • Systematic deviations suggest model bias or miscalibration

The hexagonal binning is particularly effective for visualizing large datasets where individual points would create overplotting issues.

Example:
>>> # Standard quantile analysis
>>> quant_layout = mle_analysis.plot_observed_quantiles()
>>> # Custom window size for trend analysis
>>> quant_plots = mle_analysis.plot_observed_quantiles(
...     windowsize=50, use_ranks=False, display=False
... )
plot_posterior_predictive_samples(
*,
quantiles: Sequence['custom_types.Float'],
use_ranks: bool,
logy: bool,
display: Literal[True],
width: custom_types.Integer,
height: custom_types.Integer,
) Layout[source]
plot_posterior_predictive_samples(
*,
quantiles: Sequence['custom_types.Float'],
use_ranks: bool,
logy: bool,
display: Literal[False],
width: custom_types.Integer,
height: custom_types.Integer,
) dict[str, Overlay]

Visualize observed data against posterior predictive uncertainty intervals.

This method creates plots showing how observed data relates to the uncertainty quantified by posterior predictive samples. The posterior predictive samples are displayed as confidence intervals, with observed data overlaid as points.

Parameters:
  • quantiles (Sequence[custom_types.Float]) – Quantiles defining confidence intervals. Defaults to (0.025, 0.25, 0.5). Note: quantiles are automatically symmetrized and median is always included.

  • use_ranks (bool) – Whether to use ranks instead of raw values for x-axis. Defaults to True.

  • logy (bool) – Whether to use logarithmic y-axis scaling. Defaults to False.

  • display (bool) – Whether to return formatted layout for display. Defaults to True.

  • width (custom_types.Integer) – Width of individual plots in pixels. Defaults to 600.

  • height (custom_types.Integer) – Height of individual plots in pixels. Defaults to 400.

Returns:

Posterior predictive plots in requested format

Return type:

Union[hv.Layout, dict[str, hv.Overlay]]

Visualization Features:
  • Confidence intervals shown as nested colored regions

  • Observed data displayed as scatter points

  • Optional rank transformation for better visualization of skewed data

  • Logarithmic scaling with automatic shifting for non-positive values

  • Interactive hover labels showing data point identifiers

The rank transformation is particularly useful when observed values have highly skewed distributions, as it emphasizes the ordering rather than the absolute magnitudes.

Example:
>>> # Standard posterior predictive plot
>>> pp_layout = mle_analysis.plot_posterior_predictive_samples()
>>> # Custom quantiles with logarithmic scaling
>>> pp_plots = mle_analysis.plot_posterior_predictive_samples(
...     quantiles=(0.05, 0.5, 0.95), logy=True, display=False
... )
run_ppc(
*,
use_ranks: bool,
display: Literal[True],
square_ecdf: bool,
windowsize: 'custom_types.Integer' | None,
quantiles: Sequence['custom_types.Float'],
logy_ppc_samples: bool,
subplot_width: custom_types.Integer,
subplot_height: custom_types.Integer,
) Column[source]
run_ppc(
*,
use_ranks: bool,
display: Literal[False],
square_ecdf: bool,
windowsize: 'custom_types.Integer' | None,
quantiles: Sequence['custom_types.Float'],
logy_ppc_samples: bool,
subplot_width: custom_types.Integer,
subplot_height: custom_types.Integer,
) list[dict[str, Overlay]]

Execute comprehensive posterior predictive checking analysis.

This method provides a complete posterior predictive checking workflow by combining multiple diagnostic approaches into a unified analysis. It runs the methods plot_posterior_predictive_samples(), plot_observed_quantiles(), and check_calibration(), combining their outputs into either an interactive dashboard or a list of individual plot dictionaries.

Parameters:
  • use_ranks (bool) – Whether to use ranks instead of raw values for x-axes. Defaults to True.

  • display (bool) – Whether to return interactive dashboard layout. Defaults to True.

  • square_ecdf (bool) – Whether to make ECDF plots square (width=height). Defaults to True.

  • windowsize (Optional[custom_types.Integer]) – Size of rolling window for trend analysis. Defaults to None (automatic).

  • quantiles (Sequence[custom_types.Float]) – Quantiles for confidence intervals. Defaults to (0.025, 0.25, 0.5).

  • logy_ppc_samples (bool) – Whether to use log scale for posterior predictive plots. Defaults to False.

  • subplot_width (custom_types.Integer) – Width of individual subplots in pixels. Defaults to 600.

  • subplot_height (custom_types.Integer) – Height of individual subplots in pixels. Defaults to 400.

Returns:

Interactive dashboard or list of plot dictionaries

Return type:

Union[pn.Column, list[dict[str, hv.Overlay]]]

Dashboard Features:
  • Interactive variable selection across all diagnostic types

  • Consistent formatting and scaling across related plots

  • Automatic layout optimization for comparison and analysis

  • Widget-based navigation for multi-variable models

Between the three plots generated, this method provides a holistic view of model performance in terms of:

  • Predictive accuracy: How well do predictions match observations?

  • Calibration quality: Are prediction intervals properly calibrated?

  • Systematic bias: Are there patterns indicating model inadequacy?

Example:
>>> # Complete interactive analysis
>>> dashboard = mle_analysis.run_ppc()
>>> dashboard  # Display in notebook
>>>
>>> # Programmatic access to individual components
>>> ppc_plots, quant_plots, cal_plots = mle_analysis.run_ppc(display=False)
save_netcdf(filename: str) None[source]

Save the ArviZ InferenceData object to NetCDF format.

Parameters:

filename (str) – Path where to save the NetCDF file

This method provides persistent storage of analysis results.

Example:
>>> mle_analysis.save_netcdf('my_mle_results.nc')
>>> # Later: reload with MLEInferenceRes('my_mle_results.nc')