Maximum Likelihood Estimation Results API Reference¶
Maximum likelihood estimation results analysis and visualization for SciStanPy models.
This module provides analysis tools for maximum likelihood estimation results from SciStanPy models. It offers diagnostic plots, calibration checks, and posterior predictive analysis tools designed specifically for MLE-based inference workflows.
The module centers around three main classes: MLEParam for individual parameter estimates, MLE for complete model results, and MLEInferenceRes class, which wraps ArviZ InferenceData objects with specialized methods for MLE result analysis. Together, these classes provide the estimated parameter values and the fitted probability distributions resulting from MLE analysis, and allow for downstream analysis including uncertainty quantification and posterior predictive sampling. It provides both individual diagnostic tools and analysis workflows that combine multiple checks into unified reporting interfaces.
- Key Features:
Individual parameter MLE estimates with associated distributions
Complete model MLE results with loss tracking and diagnostics
Posterior predictive checking workflows
Model calibration analysis with quantitative metrics
Interactive visualization with customizable display options
Integration with ArviZ for standardized Bayesian workflows
Memory-efficient handling of large posterior predictive samples
Flexible output formats for different analysis needs
- Visualization Capabilities:
Posterior predictive sample plotting with confidence intervals
Calibration plots with deviation metrics
Quantile-quantile plots for model validation
Interactive layouts with customizable dimensions
- Performance Considerations:
Batch sampling prevents memory overflow for large sample requests
GPU acceleration is preserved through PyTorch distribution objects
The module is designed to work with SciStanPy’s MLE estimation workflow, providing immediate access to model diagnostics and validation tools once MLE fitting is complete. The MLE results can be used for various purposes including model comparison, uncertainty quantification, and as initialization for more sophisticated inference procedures like MCMC sampling.
In combination, the following provide a comprehensive framework for analyzing maximum likelihood estimation results. Quickly navigate to specific sections using the links below:
Maximum Likelihood Estimation Results¶
- class scistanpy.model.results.mle.MLE(
- model: ssp_model.Model,
- mle_estimate: dict[str, npt.NDArray],
- distributions: dict[str, torch.distributions.Distribution],
- losses: npt.NDArray,
- data: dict[str, npt.NDArray],
Bases:
object
Complete maximum likelihood estimation results for a SciStanPy model.
This class encapsulates the full results of a call to
Model.mle()
for MLE parameter estimation, including parameter estimates, fitted distributions, optimization diagnostics, and utilities for further analysis. It provides a comprehensive interface for working with MLE results.- Parameters:
model (ssp_model.Model) – Original SciStanPy model
mle_estimate (dict[str, npt.NDArray]) – Dictionary of parameter names to their MLE values
distributions (dict[str, torch.distributions.Distribution]) – Dictionary of parameter names to fitted distributions
losses (npt.NDArray) – Array of loss values throughout optimization
data (dict[str, npt.NDArray]) – Observed data used for parameter estimation
- Variables:
model – Reference to the original model
data – Observed data used for fitting
model_varname_to_mle – Mapping from parameter names to MLEParam objects
losses – DataFrame containing loss trajectory and diagnostics
- Raises:
ValueError – If MLE estimate keys are not subset of distribution keys
ValueError – If parameter names conflict with existing attributes
The class automatically creates attributes for each parameter, allowing, e.g., direct access to a parameter named
mu
using the syntaxmle_result.mu
. It also exposes amethod for bootstrapping
samples from the fit model, providing a relatively cheap way to quantify uncertainty around MLE estimates.- Key Features:
Direct attribute access to individual parameter results
Comprehensive loss trajectory tracking and visualization
Efficient sampling from fitted parameter distributions
Integration with ArviZ for Bayesian workflow compatibility
Memory-efficient batch processing for large sample requests
- Example:
# Run MLE fitting mle_result = model.mle(data=observed_data) # Access optimization diagnostics loss_plot = mle_result.plot_loss_curve(logy=True) # Sample from all fitted distributions parameter_samples = mle_result.draw(n=1000, as_xarray=True) # Sample from a specific parameter mu_samples = mle_result.mu.draw(1000) # Create inference object for detailed analysis inference_obj = mle_result.get_inference_obj(n=2000)
- draw(
- n: custom_types.Integer,
- *,
- seed: custom_types.Integer | None,
- as_xarray: Literal[True],
- as_inference_data: Literal[False],
- batch_size: custom_types.Integer | None = None,
- draw(
- n: custom_types.Integer,
- *,
- seed: custom_types.Integer | None,
- as_xarray: Literal[False],
- batch_size: custom_types.Integer | None = None,
Generate samples from all fitted parameter distributions.
This method draws samples from the fitted distributions of all model parameters. It supports multiple output formats for integration with different analysis workflows.
- Parameters:
n (custom_types.Integer) – Number of samples to draw from each parameter distribution
seed (Optional[custom_types.Integer]) – Random seed for reproducible sampling. Defaults to None.
as_xarray (bool) – Whether to return results as xarray Dataset. Defaults to False.
batch_size (Optional[custom_types.Integer]) – Batch size for memory-efficient sampling. Defaults to None.
- Returns:
Sampled parameter values in requested format
- Return type:
Union[dict[str, npt.NDArray], xr.Dataset]
- Output Formats:
Dictionary (default): Keys are parameter names, values are sample arrays
xarray Dataset: Structured dataset with proper dimension labels and coordinates
- This is particularly useful for:
Uncertainty propagation through model predictions
Bayesian model comparison and validation
Posterior predictive checking with MLE-based approximations
Sensitivity analysis of parameter estimates
- Example:
>>> # Draw samples as dictionary >>> samples = mle_result.draw(1000, seed=42) >>> # Draw as structured xarray Dataset >>> dataset = mle_result.draw(1000, as_xarray=True, batch_size=100)
- get_inference_obj(
- n: custom_types.Integer = 1000,
- *,
- seed: custom_types.Integer | None = None,
- batch_size: custom_types.Integer | None = None,
Create ArviZ-compatible inference data object from MLE results.
This method constructs a comprehensive inference data structure that integrates MLE results with the ArviZ ecosystem for Bayesian analysis. Samples are bootstrapped from the fitted parameter distributions to approximate posterior distributions. It organizes parameter samples, observed data, and posterior predictive samples into a standardized format.
- Parameters:
n (custom_types.Integer) – Number of samples to generate for the inference object. Defaults to 1000.
seed (Optional[custom_types.Integer]) – Random seed for reproducible sample generation. Defaults to None.
batch_size (Optional[custom_types.Integer]) – Batch size for memory-efficient sampling. Defaults to None.
- Returns:
Structured inference data object with all MLE results
- Return type:
results.MLEInferenceRes
- The resulting inference object contains:
Posterior samples: Draws from fitted parameter distributions
Observed data: Original data used for parameter estimation
Posterior predictive: Samples from observable distributions
- Data Organization:
Latent parameters are stored in the main posterior group
Observable parameters become posterior predictive samples
Observed data is stored separately for comparison
All data maintains proper dimensional structure and labeling
- This enables:
Integration with ArviZ plotting and diagnostic functions
Model comparison
Posterior predictive checking workflows
Standardized reporting and visualization
Important
Samples are drawn using the optimized value of their parent parameters. For example, if a parameter
y
is defined in the model asy ~ Normal(mu, sigma)
, wheremu
andsigma
are also parameters in the model, then samples ofy
will be drawn using the MLE values ofmu
andsigma
. This means that uncertainty inmu
andsigma
is not propagated toy
. This is a limitation of the MLE-based approach and should be considered when interpreting results.Important
Related to the above, for root-level parameters with constant values for parent parameters, sampling from the fit distribution is identical to sampling from the prior distribution. For example, for a parameter,
y
defined in the model asy ~ Normal(mu = 0.0, sigma = 1.0)
, the values ofmu
andsigma
will not change during fitting, so the distribution ofy
will remainNormal(0.0, 1.0)
.- Example:
>>> # Create inference object with default settings >>> inference_obj = mle_result.get_inference_obj() >>> # Generate larger sample with custom batch size >>> inference_obj = mle_result.get_inference_obj( ... n=5000, batch_size=500, seed=42 ... )
- plot_loss_curve(logy: bool = True)[source]¶
Generate interactive plot of the optimization loss trajectory.
This method creates a visualization of how the loss function evolved during the optimization process, providing insights into convergence behavior and optimization effectiveness.
- Parameters:
logy (bool) – Whether to use logarithmic y-axis scaling. Defaults to True.
- Returns:
Interactive HoloViews plot of the loss curve
- The plot automatically handles:
Logarithmic scaling with proper handling of negative/zero values
Appropriate axis labels and titles based on scaling choice
Interactive features for detailed examination of convergence
Warning messages for problematic loss trajectories
For logarithmic scaling with non-positive loss values, the method automatically switches to a shifted logarithmic scale to maintain visualization quality while issuing appropriate warnings.
- Example:
>>> # Standard logarithmic loss plot >>> loss_plot = mle_result.plot_loss_curve() >>> # Linear scale loss plot >>> linear_plot = mle_result.plot_loss_curve(logy=False)
- class scistanpy.model.results.mle.MLEParam(name: str, value: npt.NDArray | None, distribution: custom_types.SciStanPyDistribution)[source]¶
Bases:
object
Container for maximum likelihood estimate of a single model parameter.
This class encapsulates the MLE result for an individual parameter, including the estimated value and the corresponding fitted probability distribution. It provides methods for sampling from the fitted distribution and accessing parameter properties.
- Parameters:
name (str) – Name of the parameter in the model
value (Optional[npt.NDArray]) – Maximum likelihood estimate of the parameter value. Can be None for some distribution types.
distribution (custom_types.SciStanPyDistribution) – Fitted probability distribution object
- Variables:
name – Parameter name identifier
mle – Stored maximum likelihood estimate
distribution – Fitted distribution for sampling and analysis
The class maintains both point estimates and distributional representations, enabling both point-based analysis and uncertainty quantification through sampling from the fitted distribution.
- Example:
# Run MLE fitting mle_result = model.mle(data=observed_data) # Access a specific parameter (an instance of `MLEParam`) describing # the MLE results for that parameter mle_param = mle_result.mu
- draw(
- n: int,
- *,
- seed: custom_types.Integer | None = None,
- batch_size: custom_types.Integer | None = None,
Sample from the fitted parameter distribution.
This method generates samples from the parameter’s fitted probability distribution using batch processing to handle large sample requests.
- Parameters:
n (int) – Total number of samples to generate
seed (Optional[custom_types.Integer]) – Random seed for reproducible sampling. Defaults to None.
batch_size (Optional[custom_types.Integer]) – Size of batches for memory-efficient sampling. Defaults to None (uses n as batch size).
- Returns:
Array of samples from the fitted distribution
- Return type:
npt.NDArray
Batch processing prevents memory overflow when requesting large numbers of samples from complex distributions, particularly important when working with GPU-based computations.
- Example:
>>> # Generate 10000 samples in batches of 1000 >>> samples = param.draw(10000, batch_size=1000, seed=42) >>> print(f"Sample mean: {samples.mean()}")
Bootstrapped Parameter Values and Analysis¶
- class scistanpy.model.results.mle.MLEInferenceRes(inference_obj: InferenceData | str)[source]¶
Bases:
object
Analysis interface for bootstrapped samples from
MLE
instances.This class provides tools for analyzing and visualizing MLE results from SciStanPy models. It wraps ArviZ InferenceData objects with specialized methods for posterior predictive checking, calibration analysis, and model validation.
- Parameters:
inference_obj (Union[az.InferenceData, str]) – ArviZ InferenceData object or path to saved results
- Variables:
inference_obj – Stored ArviZ InferenceData object with all results
- Raises:
ValueError – If inference_obj is neither string nor InferenceData
ValueError – If required groups (posterior, posterior_predictive) are missing
- The class expects the InferenceData object to contain:
posterior: Samples from fitted parameter distributions
posterior_predictive: Samples from observable distributions
observed_data: Original observed data used for fitting
- Key Capabilities:
Posterior predictive checking with multiple visualization modes
Quantitative model calibration assessment
Interactive diagnostic dashboards
Summary statistics computation and caching
- Example:
import scistanpy as ssp import numpy as np # Get MLE results mle_result = model.mle(data=observed_data) # Create inference analysis object mle_analysis = mle_result.get_inference_obj() # Run comprehensive posterior predictive checking dashboard = mle_analysis.run_ppc() # Save results for later analysis mle_analysis.save_netcdf('mle_analysis.nc')
- calculate_summaries(
- var_names: list[str] | None = None,
- filter_vars: Literal[None, 'like', 'regex'] = None,
- kind: Literal['all', 'stats', 'diagnostics'] = 'stats',
- round_to: custom_types.Integer = 2,
- circ_var_names: list[str] | None = None,
- stat_focus: str = 'mean',
- stat_funcs: dict[str, callable] | callable | None = None,
- extend: bool = True,
- hdi_prob: custom_types.Float = 0.94,
- skipna: bool = False,
Compute summary statistics for MLE results.
This method wraps ArviZ’s summary functionality while adding the computed statistics to the InferenceData object for persistence and reuse. See az.summary for detailed descriptions of arguments.
- Parameters:
var_names (Optional[list[str]]) – Variable names to include in summary. Defaults to None (all variables).
filter_vars (Optional[Literal[None, "like", "regex"]]) – Variable filtering method. Defaults to None.
kind (Literal["all", "stats", "diagnostics"]) – Type of statistics to compute. Defaults to “stats”.
round_to (custom_types.Integer) – Number of decimal places for rounding. Defaults to 2.
circ_var_names (Optional[list[str]]) – Names of circular variables. Defaults to None.
stat_focus (str) – Primary statistic for focus. Defaults to “mean”.
stat_funcs (Optional[Union[dict[str, callable], callable]]) – Custom statistic functions. Defaults to None.
extend (bool) – Use functions provided by stat_funcs. Defaults to True. Only meaningful when stat_funcs is provided.
hdi_prob (custom_types.Float) – Probability for highest density interval. Defaults to 0.94.
skipna (bool) – Whether to skip NaN values. Defaults to False.
- Returns:
Dataset containing computed summary statistics
- Return type:
xr.Dataset
- Raises:
ValueError – If diagnostics requested without chain dimension existing in self.inference_obj.posterior.dims
ValueError – If diagnostics requested with single chain.
The computed statistics are automatically added to the InferenceData object under the
variable_summary_stats
group for persistence.- Example:
>>> # Compute basic statistics >>> stats = mle_analysis.calculate_summaries() >>> # Compute diagnostics for multi-chain results >>> diag = mle_analysis.calculate_summaries(kind="diagnostics")
- check_calibration(
- *,
- return_deviance: Literal[False],
- display: Literal[True],
- width: custom_types.Integer,
- height: custom_types.Integer,
- check_calibration(
- *,
- return_deviance: Literal[False],
- display: Literal[False],
- width: custom_types.Integer,
- height: custom_types.Integer,
- check_calibration(
- *,
- return_deviance: Literal[True],
- display: Literal[False],
- width: custom_types.Integer,
- height: custom_types.Integer,
Assess model calibration through posterior predictive quantile analysis.
This method evaluates how well the model’s posterior predictive distribution matches the observed data by analyzing the distribution of quantiles. Well- calibrated models should produce observed data that are uniformly distributed across the quantiles of the posterior predictive distribution.
- Parameters:
return_deviance (bool) – Whether to return quantitative deviance metrics. Defaults to False.
display (bool) – Whether to return formatted layout for display. Defaults to True.
width (custom_types.Integer) – Width of individual plots in pixels. Defaults to 600.
height (custom_types.Integer) – Height of individual plots in pixels. Defaults to 600.
- Returns:
Calibration plots and optionally deviance metrics
- Return type:
Union[hv.Layout, dict[str, hv.Overlay], tuple[dict[str, hv.Overlay], dict[str, float]]]
- Raises:
ValueError – If both display and return_deviance are True
Internally, this method is just a wrapper around
ssp.plotting.plot_calibration
. See that function for a detailed description of the calibration assessment method and returned plots.- Example:
>>> # Visual assessment >>> cal_layout = mle_analysis.check_calibration() >>> # Quantitative assessment >>> plots, deviances = mle_analysis.check_calibration( ... return_deviance=True, display=False ... ) >>> print(f"Mean deviance: {np.mean(list(deviances.values())):.3f}")
- classmethod from_disk(path: str) MLEInferenceRes [source]¶
Load
MLEInferenceRes
object from saved NetCDF file.- Parameters:
path (str) – Path to NetCDF file containing saved InferenceData
- Returns:
Reconstructed
MLEInferenceRes
object with all analysis capabilities- Return type:
This class method enables loading of previously saved analysis results, preserving all computed statistics and enabling continued analysis from where previous sessions left off.
- Example:
>>> # Load previously saved results >>> mle_analysis = MLEInferenceRes.from_disk('saved_results.nc') >>> # Continue analysis with full functionality >>> dashboard = mle_analysis.run_ppc()
- plot_observed_quantiles(
- *,
- use_ranks: bool,
- display: Literal[True],
- width: custom_types.Integer,
- height: custom_types.Integer,
- windowsize: 'custom_types.Integer' | None,
- plot_observed_quantiles(
- *,
- use_ranks: bool,
- display: Literal[False],
- width: custom_types.Integer,
- height: custom_types.Integer,
- windowsize: 'custom_types.Integer' | None,
Visualize systematic patterns in observed data quantiles.
This method creates hexagonal density plots showing the relationship between observed data values (or their ranks) and their corresponding quantiles within the posterior predictive distribution. A rolling mean overlay highlights systematic trends.
- Parameters:
use_ranks (bool) – Whether to use ranks instead of raw values for x-axis. Defaults to True.
display (bool) – Whether to return formatted layout for display. Defaults to True.
width (custom_types.Integer) – Width of individual plots in pixels. Defaults to 600.
height (custom_types.Integer) – Height of individual plots in pixels. Defaults to 400.
windowsize (Optional[custom_types.Integer]) – Size of rolling window for trend line. Defaults to None (automatic).
- Returns:
Quantile plots in requested format
- Return type:
Union[hv.Layout, dict[str, hv.Overlay]]
- Visualization Components:
Hexagonal binning showing density of (value, quantile) pairs
Rolling mean trend line highlighting systematic patterns
Colormap indicating point density for pattern identification
- Pattern Interpretation:
Horizontal trend line around 0.5 with uniformly distributed points indicates good calibration
Systematic deviations suggest model bias or miscalibration
The hexagonal binning is particularly effective for visualizing large datasets where individual points would create overplotting issues.
- Example:
>>> # Standard quantile analysis >>> quant_layout = mle_analysis.plot_observed_quantiles() >>> # Custom window size for trend analysis >>> quant_plots = mle_analysis.plot_observed_quantiles( ... windowsize=50, use_ranks=False, display=False ... )
- plot_posterior_predictive_samples(
- *,
- quantiles: Sequence['custom_types.Float'],
- use_ranks: bool,
- logy: bool,
- display: Literal[True],
- width: custom_types.Integer,
- height: custom_types.Integer,
- plot_posterior_predictive_samples(
- *,
- quantiles: Sequence['custom_types.Float'],
- use_ranks: bool,
- logy: bool,
- display: Literal[False],
- width: custom_types.Integer,
- height: custom_types.Integer,
Visualize observed data against posterior predictive uncertainty intervals.
This method creates plots showing how observed data relates to the uncertainty quantified by posterior predictive samples. The posterior predictive samples are displayed as confidence intervals, with observed data overlaid as points.
- Parameters:
quantiles (Sequence[custom_types.Float]) – Quantiles defining confidence intervals. Defaults to (0.025, 0.25, 0.5). Note: quantiles are automatically symmetrized and median is always included.
use_ranks (bool) – Whether to use ranks instead of raw values for x-axis. Defaults to True.
logy (bool) – Whether to use logarithmic y-axis scaling. Defaults to False.
display (bool) – Whether to return formatted layout for display. Defaults to True.
width (custom_types.Integer) – Width of individual plots in pixels. Defaults to 600.
height (custom_types.Integer) – Height of individual plots in pixels. Defaults to 400.
- Returns:
Posterior predictive plots in requested format
- Return type:
Union[hv.Layout, dict[str, hv.Overlay]]
- Visualization Features:
Confidence intervals shown as nested colored regions
Observed data displayed as scatter points
Optional rank transformation for better visualization of skewed data
Logarithmic scaling with automatic shifting for non-positive values
Interactive hover labels showing data point identifiers
The rank transformation is particularly useful when observed values have highly skewed distributions, as it emphasizes the ordering rather than the absolute magnitudes.
- Example:
>>> # Standard posterior predictive plot >>> pp_layout = mle_analysis.plot_posterior_predictive_samples() >>> # Custom quantiles with logarithmic scaling >>> pp_plots = mle_analysis.plot_posterior_predictive_samples( ... quantiles=(0.05, 0.5, 0.95), logy=True, display=False ... )
- run_ppc(
- *,
- use_ranks: bool,
- display: Literal[True],
- square_ecdf: bool,
- windowsize: 'custom_types.Integer' | None,
- quantiles: Sequence['custom_types.Float'],
- logy_ppc_samples: bool,
- subplot_width: custom_types.Integer,
- subplot_height: custom_types.Integer,
- run_ppc(
- *,
- use_ranks: bool,
- display: Literal[False],
- square_ecdf: bool,
- windowsize: 'custom_types.Integer' | None,
- quantiles: Sequence['custom_types.Float'],
- logy_ppc_samples: bool,
- subplot_width: custom_types.Integer,
- subplot_height: custom_types.Integer,
Execute comprehensive posterior predictive checking analysis.
This method provides a complete posterior predictive checking workflow by combining multiple diagnostic approaches into a unified analysis. It runs the methods
plot_posterior_predictive_samples()
,plot_observed_quantiles()
, andcheck_calibration()
, combining their outputs into either an interactive dashboard or a list of individual plot dictionaries.- Parameters:
use_ranks (bool) – Whether to use ranks instead of raw values for x-axes. Defaults to True.
display (bool) – Whether to return interactive dashboard layout. Defaults to True.
square_ecdf (bool) – Whether to make ECDF plots square (width=height). Defaults to True.
windowsize (Optional[custom_types.Integer]) – Size of rolling window for trend analysis. Defaults to None (automatic).
quantiles (Sequence[custom_types.Float]) – Quantiles for confidence intervals. Defaults to (0.025, 0.25, 0.5).
logy_ppc_samples (bool) – Whether to use log scale for posterior predictive plots. Defaults to False.
subplot_width (custom_types.Integer) – Width of individual subplots in pixels. Defaults to 600.
subplot_height (custom_types.Integer) – Height of individual subplots in pixels. Defaults to 400.
- Returns:
Interactive dashboard or list of plot dictionaries
- Return type:
Union[pn.Column, list[dict[str, hv.Overlay]]]
- Dashboard Features:
Interactive variable selection across all diagnostic types
Consistent formatting and scaling across related plots
Automatic layout optimization for comparison and analysis
Widget-based navigation for multi-variable models
Between the three plots generated, this method provides a holistic view of model performance in terms of:
Predictive accuracy: How well do predictions match observations?
Calibration quality: Are prediction intervals properly calibrated?
Systematic bias: Are there patterns indicating model inadequacy?
- Example:
>>> # Complete interactive analysis >>> dashboard = mle_analysis.run_ppc() >>> dashboard # Display in notebook >>> >>> # Programmatic access to individual components >>> ppc_plots, quant_plots, cal_plots = mle_analysis.run_ppc(display=False)
- save_netcdf(filename: str) None [source]¶
Save the ArviZ InferenceData object to NetCDF format.
- Parameters:
filename (str) – Path where to save the NetCDF file
This method provides persistent storage of analysis results.
- Example:
>>> mle_analysis.save_netcdf('my_mle_results.nc') >>> # Later: reload with MLEInferenceRes('my_mle_results.nc')