Source code for scistanpy.model.results.hmc

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Hamiltonian Monte Carlo (HMC) sampling results analysis and diagnostics.

This module provides tools for analyzing and diagnosing HMC sampling results from
Stan models. It offers specialized classes and functions for processing
MCMC output, conducting diagnostic tests, and creating interactive visualizations
for model validation and troubleshooting.

The module centers around the :py:class:`~scistanpy.model.results.hmc.SampleResults`
class, which extends :py:class:`~scistanpy.model.results.MLEInferenceRes` to provide
HMC-specific functionality including convergence diagnostics, sample quality assessment,
and specialized visualization tools for identifying problematic parameters and sampling
behavior.

Key Features:
    - MCMC diagnostic test suites
    - Interactive visualization tools for failed diagnostics
    - Efficient CSV to NetCDF conversion for large datasets
    - Dask-enabled processing for memory-intensive operations
    - Specialized trace plot analysis for problematic variables
    - Automated detection and reporting of sampling issues

Diagnostic Capabilities:
    - R-hat convergence assessment
    - Effective sample size (ESS) evaluation
    - Energy fraction of missing information (E-BFMI) analysis
    - Divergence detection and analysis
    - Tree depth saturation monitoring
    - Variable-specific failure pattern identification

The module is designed to handle both small-scale interactive analysis and
large-scale batch processing of MCMC results, with particular attention to
memory efficiency and computational performance for complex models.
"""

from __future__ import annotations

import itertools
import os.path
import re
import warnings

from glob import glob
from typing import (
    Any,
    Generator,
    Literal,
    Optional,
    overload,
    Sequence,
    TYPE_CHECKING,
    Union,
)

import arviz as az
import dask
import holoviews as hv
import h5netcdf
import numpy as np
import numpy.typing as npt
import panel as pn
import xarray as xr

from cmdstanpy.cmdstan_args import CmdStanArgs, SamplerArgs
from cmdstanpy.stanfit import CmdStanMCMC, RunSet
from cmdstanpy.utils import check_sampler_csv, scan_config
from tqdm import tqdm

from scistanpy import plotting, utils
from scistanpy.defaults import (
    DEFAULT_EBFMI_THRESH,
    DEFAULT_ESS_THRESH,
    DEFAULT_RHAT_THRESH,
)
from scistanpy.model.components import parameters
from scistanpy.model.components.transformations import transformed_parameters
from scistanpy.model.results import mle

if TYPE_CHECKING:
    from scistanpy import custom_types, Model

# pylint: disable=too-many-lines

# Maps between the precision of the data and the numpy types
_NP_TYPE_MAP = {
    "double": {"float": np.float64, "int": np.int64},
    "single": {"float": np.float32, "int": np.int32},
    "half": {"float": np.float16, "int": np.int16},
}


def _symmetrize_quantiles(
    quantiles: Sequence[custom_types.Float],
) -> list[custom_types.Float]:
    """Symmetrize and validate quantile sequences for plotting.

    This utility function takes a sequence of quantiles and creates a symmetric
    set by adding complementary quantiles and ensuring the median is included.
    It also validates that all quantiles are properly bounded.

    :param quantiles: Sequence of quantile values between 0 and 1
    :type quantiles: Sequence[custom_types.Float]

    :returns: Symmetrized and sorted list of quantiles including median
    :rtype: list[custom_types.Float]

    :raises ValueError: If quantiles are not between 0 and 1
    :raises AssertionError: If result doesn't have odd length or include median

    The function ensures:
        - All quantiles are between 0 and 1 (exclusive)
        - Complementary quantiles are added (e.g., 0.1 → 0.1, 0.9)
        - Median (0.5) is always included
        - Result has odd length for symmetric confidence intervals

    Example:
        >>> quantiles = _symmetrize_quantiles([0.1, 0.2])
        >>> # Returns [0.1, 0.2, 0.5, 0.8, 0.9]
    """
    # Get the quantiles
    quantiles = sorted(set(quantiles) | {1 - q for q in quantiles} | {0.5})

    # Check that the quantiles are between 0 and 1
    if not all(0 < q < 1 for q in quantiles):
        raise ValueError(
            "Quantiles must be between 0 and 1. Please provide a valid list of "
            "quantiles."
        )

    # Check that the quantiles are odd in number and include 0.5
    assert len(quantiles) % 2 == 1, "Quantiles must be odd in number"
    median_ind = len(quantiles) // 2
    assert quantiles[median_ind] == 0.5, "Quantiles must include 0.5"

    return quantiles


[docs] class VariableAnalyzer: """Interactive analysis tool for variables that fail MCMC diagnostic tests. This class provides an interactive interface for analyzing individual variables that have failed diagnostic tests during MCMC sampling. It creates a dashboard with widgets for selecting variables, metrics, and specific array indices, along with trace plots showing the problematic sampling behavior. :param sample_results: SampleResults object containing MCMC diagnostics :type sample_results: SampleResults :param plot_width: Width of plots in pixels. Defaults to 800. :type plot_width: custom_types.Integer :param plot_height: Height of plots in pixels. Defaults to 400. :type plot_height: custom_types.Integer :param plot_quantiles: Whether to plot quantiles vs raw values. Defaults to False. :type plot_quantiles: bool :ivar sample_results: Reference to source sampling results :ivar plot_quantiles: Flag controlling plot content type :ivar n_chains: Number of MCMC chains in the results :ivar x: Array of step indices for x-axis :ivar failed_vars: Dictionary mapping variable names to failure information :ivar varchoice: Widget for selecting variables to analyze :ivar metricchoice: Widget for selecting diagnostic metrics :ivar indexchoice: Widget for selecting array indices :ivar plot_width: Recorded width of plots :ivar plot_height: Recorded height of plots :ivar fig: HoloViews pane containing the current plot :ivar layout: Panel layout containing all interface elements The analyzer automatically identifies variables that have failed diagnostic tests and organizes them by failure type. It provides trace plots that can show either raw parameter values or their quantiles relative to passing samples, helping identify the nature of sampling problems. Key Features: - Automatic identification of failed variables and metrics - Interactive widget-based navigation - Trace plots with chain-specific coloring - Quantile-based analysis for identifying sampling bias - Real-time plot updates based on widget selections .. note:: This class should not be instantiated directly. Use the :py:meth:`~scistanpy.model.results.hmc.SampleResults.plot_variable_failure_quantile_traces` method of :py:class:`~scistanpy.model.results.hmc.SampleResults` instead. """ # pylint: disable=attribute-defined-outside-init def __init__( self, sample_results: "SampleResults", plot_width: custom_types.Integer = 800, plot_height: custom_types.Integer = 400, plot_quantiles: bool = False, ): # Hold a reference to the sample results self.sample_results = sample_results self.plot_quantiles = plot_quantiles # Some placeholders for whether or not we both updating the plot self._previous_vals = None # Record the number of chains and an array for the steps self.n_chains = sample_results.inference_obj.posterior.sizes["chain"] self.x = np.arange(sample_results.inference_obj.posterior.sizes["draw"]) # Identify failed variables self.failed_vars = {} self._identify_failed_vars() # Set up the holoviews plot self.plot_width = plot_width self.plot_height = plot_height self.fig = pn.pane.HoloViews( hv.Curve([]).opts(width=self.plot_width, height=self.plot_height), name="Plot", align="center", ) # Set up widgets self.varchoice = pn.widgets.Select( name="Variable", options=list(self.failed_vars.keys()) ) self.metricchoice = pn.widgets.Select(name="Metric", options=[]) self.indexchoice = pn.widgets.Select(name="Index", options=[]) self.varchoice.param.watch(self._get_var_data, "value") self.varchoice.param.watch(self._update_metric_selector, "value") self.varchoice.param.watch(self._get_metric_data, "value") self.varchoice.param.watch(self._update_index_selector, "value") self.varchoice.param.watch(self._update_plot, "value") self.metricchoice.param.watch(self._get_metric_data, "value") self.metricchoice.param.watch(self._update_index_selector, "value") self.metricchoice.param.watch(self._update_plot, "value") self.indexchoice.param.watch(self._update_plot, "value") # Package widgets and figure into a layout self.layout = pn.Column( self.varchoice, self.metricchoice, self.indexchoice, self.fig, ) # Trigger the initial data retrieval and plotting self.varchoice.param.trigger("value") def _identify_failed_vars(self): """Identify variables that failed diagnostic tests. This method analyzes the variable diagnostic tests in the SampleResults object to identify which variables failed which tests, organizing the information for easy access by the widget interface. The method populates the failed_vars dictionary with variable names as keys and tuples containing dimension information and failure details as values. """ # Identify both the variables that fail and their indices for ( varname, vartests, ) in self.sample_results.inference_obj.variable_diagnostic_tests.items(): # The first variable name is the metric assert vartests.dims[0] == "metric" # Process each metric for this variable metric_test_summaries = {} for metric in vartests.metric: # Get the test results for the metric metrictests = vartests.sel(metric=metric).to_numpy() # Get the indices of the failing tests if metrictests.ndim > 0: failing_inds = [ ".".join(map(str, indices)) for indices in zip(*np.nonzero(metrictests)) ] else: failing_inds = [""] if metrictests else [] # Just variable name # Record the failing tests if len(failing_inds) > 0: metric_test_summaries[metric.item()] = failing_inds # If there are any failing tests, add them to the dictionary if len(metric_test_summaries) > 0: self.failed_vars[varname] = (vartests.dims[1:], metric_test_summaries) def _update_metric_selector(self, event): # pylint: disable=unused-argument """Update metric selection options based on selected variable. :param event: Panel event object (required for callback interface) This callback method updates the available metric options when a new variable is selected, ensuring only relevant diagnostic metrics are shown for the current variable. """ # Get the current variable name current_name = self.metricchoice.value # Update the metric choice options self.metricchoice.options = list( self.failed_vars[self.varchoice.value][1].keys() ) # If the currently selected metric is not in the new options, set it to # the first one if current_name not in self.metricchoice.options: self.metricchoice.value = self.metricchoice.options[0] def _update_index_selector(self, event): # pylint: disable=unused-argument """Update index selection options based on selected variable and metric. :param event: Panel event object (required for callback interface) This callback method updates the available index options when a new variable or metric is selected, showing only the array indices that failed the selected diagnostic test. """ # Get the current variable name current_name = self.indexchoice.value # Update the index choice options opts = self.failed_vars[self.varchoice.value][1] if len(opts) == 0: self.indexchoice.options = [] self.indexchoice.value = None return self.indexchoice.options = opts[self.metricchoice.value] # If the currently selected index is not in the new options, set it to # the first one if current_name not in self.indexchoice.options: self.indexchoice.value = self.indexchoice.options[0] def _get_var_data(self, event): # pylint: disable=unused-argument """Retrieve and prepare data for the selected variable. :param event: Panel event object (required for callback interface) This callback method loads the sample data for the currently selected variable and prepares it for analysis and visualization, including dimension validation and array reshaping. """ # Get the samples for the selected variable self._samples = getattr( self.sample_results.inference_obj.posterior, self.varchoice.value ) # Check the dimensions of the variable assert self._samples.dims[:2] == ("chain", "draw") assert self._samples.dims[2:] == self.failed_vars[self.varchoice.value][0] # We also want the samples as a numpy array with the draw and chain dimensions # last self._np_samples = np.moveaxis(self._samples.to_numpy(), [0, 1], [-2, -1]) def _get_metric_data(self, event): # pylint: disable=unused-argument """Retrieve and prepare diagnostic test data for the selected metric. :param event: Panel event object (required for callback interface) This callback method processes the diagnostic test results for the currently selected metric, separating failing and passing samples for comparative analysis. """ # Get the tests for the selected variable and metric tests = self.sample_results.inference_obj.variable_diagnostic_tests[ self.varchoice.value ].sel(metric=self.metricchoice.value) # Make sure the dimensions are correct assert tests.dims == self._samples.dims[2:] # Tests as numpy array tests = tests.to_numpy() # Map between index name and failing index. Separate the failing and passing # samples. Note the different approach for scalar variables (no indices) if tests.ndim == 0: self._index_map = {"": ...} # No indices, no map if tests: self._failing_samples = self._np_samples.copy() self._passing_samples = np.array([], dtype=self._np_samples.dtype) else: self._failing_samples = np.array([], dtype=self._np_samples.dtype) self._passing_samples = self._np_samples.copy() else: self._index_map = { ".".join(map(str, indices)): i for i, indices in enumerate(zip(*np.nonzero(tests))) } self._failing_samples, self._passing_samples = ( self._np_samples[tests], self._np_samples[~tests], ) def _update_plot(self, event): # pylint: disable=unused-argument """Update the trace plot based on current widget selections. :param event: Panel event object (required for callback interface) This callback method generates new trace plots when widget selections change, creating overlays that show sampling traces for each chain with appropriate styling and hover information. """ # Skip the update if the values haven't changed if self._previous_vals == ( new_vals := ( self.varchoice.value, self.metricchoice.value, self.indexchoice.value, ) ): return self._previous_vals = new_vals # Get the variable name varname = self.varchoice.value if self.indexchoice.value is not None: varname += "." + self.indexchoice.value # Calculate the relative quantiles for the selected failing index relative # to the passing samples if self.indexchoice.value == "": # For scalars failing_samples = self._failing_samples else: failing_samples = self._failing_samples[ self._index_map[self.indexchoice.value] ] n_failing, n_passing = len(self._failing_samples), len(self._passing_samples) # We always calculate quantiles. If there are no reference samples, however, # quantiles are undefined. We raise a warning if there are more failing # samples than passing samples and the user is trying to plot quantiles. if n_passing == 0: if self.plot_quantiles: raise ValueError( f"No passing samples found for {self.varchoice.value}. Cannot " "calculate quantiles." ) failing_quantiles = np.full_like(failing_samples, np.nan) else: if n_failing > n_passing and self.plot_quantiles: warnings.warn( "There are more failing samples than passing samples for " f"{self.varchoice.values}. Consider plotting true values instead." ) failing_quantiles = plotting.calculate_relative_quantiles( reference=self._passing_samples, observed=failing_samples[None], )[0] # We should have shape (n_chains, n_draws) assert ( failing_samples.shape == failing_quantiles.shape == (self.n_chains, self.x.size) ) # Build an overlay for the failing quantiles and use it to update the plot overlay_dict = {} for i in range(self.n_chains): if self.plot_quantiles: order = (failing_quantiles[i], failing_samples[i]) order_vdim = ["Sample Quantile", "Sample Value"] else: order = (failing_samples[i], failing_quantiles[i]) order_vdim = ["Sample Value", "Sample Quantile"] overlay_dict[i] = hv.Curve( (self.x, *order, n_passing), kdims=["Step"], vdims=[*order_vdim, "N Reference Variables"], ).opts(line_alpha=0.5, **({"ylim": (0, 1)} if self.plot_quantiles else {})) # Create the overlay self.fig.object = hv.NdOverlay(overlay_dict, kdims="Chain").opts( title=f"{self.metricchoice.value}: {varname}", tools=["hover"], width=self.plot_width, height=self.plot_height, )
[docs] def display(self): """Display the complete interactive analysis interface. :returns: Panel layout containing all widgets and plots :rtype: pn.Layout This method returns the complete interactive interface for display in Jupyter notebooks or Panel applications. """ return self.layout
[docs] class CmdStanMCMCToNetCDFConverter: """Object responsible for converting CmdStan CSV output to NetCDF format. This class is used internally by the :py:func:`~scistanpy.model.results.hmc.cmdstan_csv_to_netcdf` function and should not be instantiated directly in most use cases. This class handles the conversion of CmdStan CSV output files to NetCDF format, providing efficient storage and access for large MCMC datasets. It properly organizes data into appropriate groups and handles dimension naming and chunking strategies. :param fit: CmdStanMCMC object or path to CSV files :type fit: Union[CmdStanMCMC, str, list[str], os.PathLike] :param model: SciStanPy model object for metadata extraction :type model: Model :param data: Optional observed data dictionary. Defaults to None. :type data: Optional[dict[str, Any]] :ivar fit: CmdStanMCMC object containing sampling results :ivar model: Reference to the original SciStanPy model :ivar data: Observed data used for model fitting :ivar config: Configuration dictionary from Stan sampling :ivar num_draws: Total number of draws including warmup if saved :ivar varname_to_column_order: Mapping from variables to csv column indices The converter handles: - Automatic detection of variable types and dimensions - Proper NetCDF group organization - Chunking strategies for large datasets - Data type optimization based on precision requirements """ def __init__( self, fit: CmdStanMCMC | str | list[str] | os.PathLike, model: "Model", data: dict[str, Any] | None = None, ): """ Initialization involves collecting information about the different variables in the fit object. This includes the names of the variables, their shapes, and their types. This information is used to create the HDF5 file. """ # If `fit` is a string, we assume we need to load it from disk if isinstance(fit, str): fit = fit_from_csv_noload(fit) # The fit and model are stored as attributes self.fit = fit self.model = model self.data = data # Record the config object self.config = fit.metadata.cmdstan_config # The number of chains is per thread. We want the number of chains total self.config["total_chains"] = len(fit.runset.csv_files) # How many samples are we expecting? self.num_draws = self.config["num_samples"] + ( self.config["num_warmup"] if self.config["save_warmup"] else 0 ) # Argsort the columns for each variable such that the columns are in row-major # order. This is important for efficiently saving to the HDF5 file. self.varname_to_column_order = self._get_c_order() def _get_c_order(self) -> dict[str, npt.NDArray[np.int64]]: """Determine optimal column ordering for efficient NetCDF storage. :returns: Dictionary mapping variable names to column order arrays :rtype: dict[str, npt.NDArray[np.int64]] This method analyzes the CSV column structure to determine the optimal ordering for writing multi-dimensional arrays to NetCDF format, ensuring efficient memory access patterns and proper array reconstruction. """ # We need a regular expression for parsing indices out of variable names ind_re = re.compile(r"\[([0-9,]+)\]") # Get the indices of each column in the csv files. If there is no match, # then there are no indices and we return an empty tuple. column_indices = [ ( tuple(int(ind) - 1 for ind in match_obj.group(1).split(",")) if (match_obj := ind_re.search(col)) else () ) for col in self.fit.column_names ] # Now we assign the row-major argsort of indices to each variable varname_to_column_order = {} for varname, var in itertools.chain( self.fit.metadata.method_vars.items(), self.fit.metadata.stan_vars.items() ): # Slice out the indices for this variable var_inds = column_indices[var.start_idx : var.end_idx] # All indices must be unique assert len(set(var_inds)) == len(var_inds) # All indices should have the appropriate number of dimensions assert all(len(ind) == len(var.dimensions) for ind in var_inds) # All indices should fit within the dimensions of the variable for dimind, dimsize in enumerate(var.dimensions): assert all(ind[dimind] < dimsize for ind in var_inds) # Argsort the indices such that the last dimension changes fastest (c-major) varname_to_column_order[varname] = np.array( sorted(range(len(var_inds)), key=var_inds.__getitem__) ) return varname_to_column_order
[docs] def write_netcdf( self, filename: str | None = None, precision: Literal["double", "single", "half"] = "single", mib_per_chunk: custom_types.Integer | None = None, ) -> str: """Write the converted data to NetCDF format. :param filename: Output filename. Auto-generated if None. Defaults to None. :type filename: Optional[str] :param precision: Numerical precision for arrays. Defaults to "single". :type precision: Literal["double", "single", "half"] :param mib_per_chunk: Memory limit per chunk in MiB. Defaults to None, meaning use Dask default. :type mib_per_chunk: Optional[custom_types.Integer] :returns: Path to the created NetCDF file :rtype: str This method orchestrates the complete conversion process: 1. Creates NetCDF file with appropriate structure 2. Sets up dimensions based on model and data characteristics 3. Creates variables with optimal chunking strategies 4. Populates data from CSV files with progress tracking The resulting NetCDF file contains properly organized groups for posterior samples, posterior predictive samples, sample statistics, and observed data. """ # If no filename is provided, we create one based on the csv files filename = ( filename or os.path.commonprefix(self.fit.runset.csv_files).rstrip("_") + ".nc" ) # Get the data types for the method and stan variables method_var_dtypes = { "lp__": _NP_TYPE_MAP[precision]["float"], "accept_stat__": _NP_TYPE_MAP[precision]["float"], "stepsize__": _NP_TYPE_MAP[precision]["float"], "treedepth__": _NP_TYPE_MAP[precision]["int"], "n_leapfrog__": _NP_TYPE_MAP[precision]["int"], "divergent__": _NP_TYPE_MAP[precision]["int"], "energy__": _NP_TYPE_MAP[precision]["float"], } stan_var_dtypes, stan_var_dimnames = self._get_stan_var_dtypes_dimnames( precision ) assert not set(stan_var_dtypes.keys()).intersection( set(method_var_dtypes.keys()) ), "Stan variable names should not overlap with method variable names." # Create the HDF5 file with h5netcdf.File(filename, "w") as netcdf_file: # Write attributes to the file for attr in ( "stan_version_major", "stan_version_minor", "stan_version_patch", "model", "start_datetime", "method", "num_samples", "num_warmup", "save_warmup", "max_depth", "num_chains", "data_file", "diagnostic_file", "seed", "sig_figs", "num_threads", "stanc_version", ): netcdf_file.attrs[attr] = self.fit.metadata.cmdstan_config[attr] # Set dimensions netcdf_file.dimensions = { "chain": self.config["total_chains"], "draw": self.num_draws, **{ dimname: dimsize for varinfo in filter( lambda x: len(x) > 0, stan_var_dimnames.values() ) for dimname, dimsize in varinfo }, } # We need a group for metadata, samples, posterior predictive checks, # observations, and transformed parameters. metadata_group = netcdf_file.create_group("sample_stats") sample_group = netcdf_file.create_group("posterior") ppc_group = netcdf_file.create_group("posterior_predictive") observed_group = netcdf_file.create_group("observed_data") # Create variables for each of the method variables. Build a mapping # from the variable name to the dataset object. We store all of the # method variables with a single chunk. varname_to_dset = { varname: metadata_group.create_variable( name=varname, dimensions=("chain", "draw"), dtype=method_var_dtypes[varname], chunks=(self.config["total_chains"], self.num_draws), ) for varname in self.fit.metadata.method_vars.keys() } # Now we can create a dataset for each stan variable. We update the # mapping from the variable name to the dataset object for varname, stan_dtype in stan_var_dtypes.items(): # Get the shape of the variable if len(shape_info := stan_var_dimnames[varname]) == 0: named_shape, true_shape = (), () else: named_shape, true_shape = zip(*shape_info) # Calculate the chunk shape. We always hold the first two dimensions # frozen. This is because the first two dimensions are what we # are typically performing operations over. chunk_shape = utils.get_chunk_shape( array_shape=( self.config["total_chains"], self.num_draws, *true_shape, ), array_precision=precision, mib_per_chunk=mib_per_chunk, frozen_dims=(0, 1), ) # We record without the '_ppc' suffix recorded_varname = varname.removesuffix("_ppc") # Build the group group = ppc_group if varname.endswith("_ppc") else sample_group varname_to_dset[varname] = group.create_variable( name=recorded_varname, dimensions=("chain", "draw", *named_shape), dtype=stan_dtype, chunks=chunk_shape, ) # If an observable, also create a dataset in the observed group # and populate it with the data if varname.endswith("_ppc") and self.data is not None: observed_group.create_variable( name=recorded_varname, data=self.data[recorded_varname].squeeze(), dimensions=named_shape, dtype=stan_dtype, chunks=chunk_shape[2:], ) # Now we populate the datasets with the data from the csv files for chain_ind, csv_file in enumerate( tqdm(sorted(self.fit.runset.csv_files), desc="Converting CSV to NetCDF") ): for draw_ind, draw in enumerate( tqdm( self._parse_csv( filename=csv_file, method_var_dtypes=method_var_dtypes, stan_var_dtypes=stan_var_dtypes, ), total=self.num_draws, desc=f"Processing chain {chain_ind + 1}", leave=False, position=1, ) ): for varname, varvals in draw.items(): varname_to_dset[varname][ chain_ind, draw_ind ] = varvals.squeeze() # We must have all the draws for this chain assert draw_ind == self.num_draws - 1 # pylint: disable=W0631 return filename
def _get_stan_var_dtypes_dimnames( self, precision: Literal["double", "single", "half"] ) -> tuple[ dict[str, Union[type[np.floating], type[np.integer]]], dict[str, tuple[tuple[str, int], ...]], ]: """Determine data types and dimension names for Stan variables. :param precision: Numerical precision specification :type precision: Literal["double", "single", "half"] :returns: Tuple of (data_types_dict, dimension_names_dict) :rtype: tuple[dict[str, Union[type[np.floating], type[np.integer]]], dict[str, tuple[tuple[str, int], ...]]] This method analyzes the SciStanPy model to determine appropriate NumPy data types and dimension naming schemes for all variables that will be stored in the NetCDF file. """ def get_dimname() -> tuple[tuple[str, int], ...] | tuple[()]: """Retrieves the dimension names for the current component.""" # Get the name of the dimensions named_shape = [] for dimind, dimsize in enumerate(component.shape[::-1]): # See if we can get the name of the dimension. If we cannot, this must # be a singleton dimension if (dimname := dim_map.get((dimind, dimsize))) is None: assert dimsize == 1 continue # If we have a name, record named_shape.append((dimname, dimsize)) # If we have no dimensions, we return an empty tuple if len(named_shape) == 0: return () # We have our named shape return tuple(named_shape[::-1]) # We will need the map from dimension depth and size to dimension name dim_map = self.model.get_dimname_map() # Datatypes for the stan variables stan_var_dtypes = {} stan_var_dimnames = {} for varname, component in self.model.named_model_components_dict.items(): # We only take parameters and transformed parameters if not isinstance( component, (parameters.Parameter, transformed_parameters.TransformedParameter), ): continue # Update the varname if needed if isinstance(component, parameters.Parameter) and component.observable: varname = f"{varname}_ppc" # Record the datatype stan_var_dtypes[varname] = _NP_TYPE_MAP[precision][ ( "int" if isinstance(component, parameters.DiscreteDistribution) else "float" ) ] # Record the dimension names stan_var_dimnames[varname] = get_dimname() return stan_var_dtypes, stan_var_dimnames def _parse_csv( self, filename: str, method_var_dtypes: dict[str, Union[type[np.floating], type[np.integer]]], stan_var_dtypes: dict[str, Union[type[np.floating], type[np.integer]]], ) -> Generator[dict[str, npt.NDArray], None, None]: """Parse CSV file and yield properly formatted arrays. :param filename: Path to CSV file to parse :type filename: str :param method_var_dtypes: Data types for method variables :type method_var_dtypes: dict[str, Union[type[np.floating], type[np.integer]]] :param stan_var_dtypes: Data types for Stan variables :type stan_var_dtypes: dict[str, Union[type[np.floating], type[np.integer]]] :yields: Dictionary of variable names to properly shaped arrays for each draw :rtype: Generator[dict[str, npt.NDArray], None, None] This generator function parses CSV files line by line, converting each row into properly typed and shaped NumPy arrays according to the variable specifications determined during initialization. """ # Start parsing the file line by line with open(filename, "r", encoding="utf-8") as csv_file: for line in csv_file: # Skip the header information. This was parsed by the fit object. if line.startswith("#") or line.startswith("lp__"): continue # Split the components of the line vals = line.strip().split(",") # Build the arrays for each variable processed_vals = {} for varname, dtype in itertools.chain( method_var_dtypes.items(), stan_var_dtypes.items() ): # Get the variable object from the metadata var = getattr( self.fit.metadata, "stan_vars" if varname in stan_var_dtypes else "method_vars", )[varname] # Using that variable object, slice out the data, convert to # an appropriately typed numpy array, reorder it such that it # is a flattened row-major array, then reshape it to the final # shape. Note that we convert to a float first followed by a # conversion to the final type. This is because some numbers # are stored in the CSV as scientific notation. processed_val = np.array( [float(val) for val in vals[var.start_idx : var.end_idx]] )[self.varname_to_column_order[varname]].reshape(var.dimensions) if issubclass(dtype, np.integer): processed_val = np.rint(processed_val) processed_val = processed_val.astype(dtype, order="C") # Confirm that the array is c-contiguous and record assert processed_val.flags["C_CONTIGUOUS"] processed_vals[varname] = processed_val # Yield the processed values for this line yield processed_vals
[docs] def cmdstan_csv_to_netcdf( path: str | list[str] | os.PathLike | CmdStanMCMC, model: "Model", data: dict[str, Any] | None = None, output_filename: str | None = None, precision: Literal["double", "single", "half"] = "single", mib_per_chunk: custom_types.Integer | None = None, ) -> str: """Convert CmdStan CSV output to NetCDF format. This function provides a high-level interface for converting CmdStan sampling results from CSV format to NetCDF, enabling efficient storage and processing of large MCMC datasets. :param path: Path to CSV files or CmdStanMCMC object :type path: Union[str, list[str], os.PathLike, CmdStanMCMC] :param model: SciStanPy model used for sampling :type model: Model :param data: Observed data dictionary. Uses model default if None. Defaults to None. :type data: Optional[dict[str, Any]] :param output_filename: Output NetCDF filename. Auto-generated if None. Defaults to None. :type output_filename: Optional[str] :param precision: Numerical precision for stored arrays. Defaults to "single". :type precision: Literal["double", "single", "half"] :param mib_per_chunk: Memory limit per chunk in MiB. Defaults to None, meaning use Dask default. :type mib_per_chunk: Optional[custom_types.Integer] :returns: Path to created NetCDF file :rtype: str The conversion process: 1. Analyzes model structure to determine optimal storage layout 2. Creates NetCDF file with appropriate groups and dimensions 3. Converts CSV data with proper chunking for memory efficiency 4. Organizes results into ArviZ-compatible structure Benefits of NetCDF format: - Significantly faster loading compared to CSV - Memory-efficient access with chunking support - Metadata preservation and self-describing format - Integration with scientific Python ecosystem Example: >>> netcdf_path = cmdstan_csv_to_netcdf( ... 'model_output*.csv', model, precision='single' ... ) >>> results = SampleResults.from_disk(netcdf_path) """ # If no data, check for default data in the model. Otherwise, data provided # takes priority if data is None and model.has_default_data: data = model.default_data # Build the converter converter = CmdStanMCMCToNetCDFConverter(fit=path, model=model, data=data) # Run conversion return converter.write_netcdf( filename=output_filename, precision=precision, mib_per_chunk=mib_per_chunk, )
[docs] def dask_enabled_summary_stats(inference_obj: az.InferenceData) -> xr.Dataset: """Compute summary statistics using Dask for memory efficiency. This is used inside the :py:meth:`SampleResults.calculate_summaries() <scistanpy.model.results.hmc.SampleResults.calculate_summaries>` method when Dask is enabled. :param inference_obj: ArviZ InferenceData object containing posterior samples :type inference_obj: az.InferenceData :returns: Dataset containing computed summary statistics :rtype: xr.Dataset This function computes basic summary statistics (mean, standard deviation, and highest density intervals) using Dask for memory-efficient computation on large datasets that might not fit in memory. The function leverages Dask's lazy evaluation to: - Queue multiple computations for efficient execution - Minimize memory usage through chunked processing - Provide progress tracking for long-running computations Computed Statistics: - Mean across chains and draws - Standard deviation across chains and draws - 94% highest density intervals Example: >>> stats = dask_enabled_summary_stats(inference_data) >>> print(stats.sel(metric='mean')) """ # Queue up the delayed computations with utils.az_dask(): delayed_summaries = [ inference_obj.posterior.mean(dim=("chain", "draw")), inference_obj.posterior.std(dim=("chain", "draw")), az.hdi( inference_obj, hdi_prob=0.94, dask_gufunc_kwargs={"output_sizes": {"hdi": 2}}, ), ] # Compute the results mean, std, hdi = dask.compute(*delayed_summaries) # Concatenate the results return xr.concat( [ mean.assign_coords(metric=["mean"]), std.assign_coords(metric=["sd"]), hdi.assign_coords(hdi=["hdi_3%", "hdi_97%"]).rename(hdi="metric"), ], dim="metric", )
[docs] def dask_enabled_diagnostics(inference_obj: az.InferenceData) -> xr.Dataset: """Compute MCMC diagnostics using Dask for memory efficiency. This is used inside the :py:meth:`SampleResults.calculate_summaries() <scistanpy.model.results.hmc.SampleResults.calculate_summaries>` method when Dask is enabled. :param inference_obj: ArviZ InferenceData object containing posterior samples :type inference_obj: az.InferenceData :returns: Dataset containing computed diagnostic metrics :rtype: xr.Dataset This function computes comprehensive MCMC diagnostic metrics using Dask for memory-efficient computation on large datasets. All diagnostics are computed simultaneously to maximize efficiency. Computed Diagnostics: - Monte Carlo standard errors (mean and sd methods) - Effective sample sizes (bulk and tail) - R-hat convergence diagnostic The Dask implementation enables: - Parallel computation across available cores - Memory-efficient processing of large datasets - Automatic load balancing and optimization Example: >>> diagnostics = dask_enabled_diagnostics(inference_data) >>> print(diagnostics.sel(metric='r_hat')) """ # Run computations with utils.az_dask(): diagnostics = dask.compute( az.mcse(inference_obj.posterior, method="mean"), az.mcse(inference_obj.posterior, method="sd"), az.ess(inference_obj.posterior, method="bulk"), az.ess(inference_obj.posterior, method="tail"), az.rhat(inference_obj.posterior), ) # Concatenate the results and return return xr.concat( [ dset.assign_coords(metric=[metric]) for metric, dset in zip( ["mcse_mean", "mcse_sd", "ess_bulk", "ess_tail", "r_hat"], diagnostics ) ], dim="metric", )
[docs] class SampleResults(mle.MLEInferenceRes): """Comprehensive analysis interface for HMC sampling results. This class should never be instantiated directly. Instead, use the `from_disk` method to load the appropriate results object from disk. This class extends MLEInferenceRes to provide specialized functionality for analyzing Hamiltonian Monte Carlo sampling results from Stan. It offers comprehensive diagnostic capabilities, interactive visualization tools, and efficient data management for large MCMC datasets. :param model: SciStanPy model used for sampling. Defaults to None. :type model: Optional[Model] :param fit: CmdStanMCMC object or path to CSV files. Defaults to None. :type fit: Optional[Union[str, list[str], os.PathLike, CmdStanMCMC]] :param data: Observed data dictionary. Defaults to None. :type data: Optional[dict[str, npt.NDArray]] :param precision: Numerical precision for arrays. Defaults to "single". :type precision: Literal["double", "single", "half"] :param inference_obj: Pre-existing InferenceData or NetCDF path. Defaults to None. :type inference_obj: Optional[Union[az.InferenceData, str]] :param mib_per_chunk: Memory limit per chunk in MiB. Defaults to None. :type mib_per_chunk: Optional[custom_types.Integer] :param use_dask: Whether to use Dask for computation. Defaults to False. :type use_dask: bool :ivar fit: CmdStanMCMC object containing sampling metadata :ivar use_dask: Flag controlling Dask usage for computation The class provides comprehensive functionality for: - MCMC convergence diagnostics and reporting - Sample quality assessment and visualization - Interactive analysis of problematic variables - Efficient handling of large datasets with Dask integration - Automated detection and reporting of sampling issues Key Diagnostic Features: - R-hat convergence assessment - Effective sample size evaluation - Energy-based diagnostics (E-BFMI) - Divergence detection and analysis - Tree depth saturation monitoring The class automatically handles NetCDF conversion for efficient storage and supports both in-memory and out-of-core computation depending on dataset size and available memory. Example: .. code-block:: python import scistanpy as ssp # Get MCMC results mcmc_results = model.mcmc(data=observed_data, chains=4, iter_sampling=2000) # Run full diagnostics diagnostics = mcmc_results.diagnose() # Posterior predictive check (interactive in notebook) mcmc_results.run_ppc() # Evaluate problematic samples (interactive in notebook) mcmc_results.plot_sample_failure_quantile_traces() # Evaluate problematic variables (interactive in notebook) mcmc_results.plot_variable_failure_quantile_traces() """ def __init__( self, model: Union["Model", None] = None, fit: str | list[str] | os.PathLike | CmdStanMCMC | None = None, data: dict[str, npt.NDArray] | None = None, precision: Literal["double", "single", "half"] = "single", inference_obj: Optional[az.InferenceData | str] = None, mib_per_chunk: custom_types.Integer | None = None, use_dask: bool = False, ): # Store the CmdStanMCMC object self.fit = fit # Note whether we are using dask self.use_dask = use_dask # If the inference object is None, we assume that we need to create a NETCDF # file from the CmdStanMCMC object. if inference_obj is None: # Compile results to a NetCDF file inference_obj = cmdstan_csv_to_netcdf( path=fit, model=model, data=data, precision=precision, mib_per_chunk=mib_per_chunk, ) # If the inference object is a string, we assume that it is a NetCDF file # to be loaded from disk if isinstance(inference_obj, str): # Load the inference object. Ignore warnings about chunking. with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=UserWarning, message="The specified chunks separate the stored chunks along dimension", ) inference_obj = az.from_netcdf( filename=inference_obj, engine="h5netcdf", group_kwargs={ k: {"chunks": "auto" if use_dask else None} for k in ("posterior", "posterior_predictive", "sample_stats") }, ) # Initialize the parent class super().__init__(inference_obj)
[docs] def calculate_summaries( self, var_names: list[str] | None = None, filter_vars: Literal[None, "like", "regex"] = None, kind: Literal["all", "stats", "diagnostics"] = "all", round_to: custom_types.Integer = 2, circ_var_names: list[str] | None = None, stat_focus: str = "mean", stat_funcs: Optional[Union[dict[str, callable], callable]] = None, extend: bool = True, hdi_prob: custom_types.Float = 0.94, skipna: bool = False, diagnostic_varnames: Sequence[str] = ( "mcse_mean", "mcse_sd", "ess_bulk", "ess_tail", "r_hat", ), ) -> xr.Dataset: """Compute comprehensive summary statistics and diagnostics for MCMC results. This method extends the parent class functionality to provide HMC-specific diagnostic capabilities, including automatic separation of statistics and diagnostics into appropriate InferenceData groups. See ``az.summary`` for more detail on arguments. :param var_names: Variable names to include. Defaults to None (all variables). :type var_names: Optional[list[str]] :param filter_vars: Variable filtering method. Defaults to None. :type filter_vars: Optional[Literal[None, "like", "regex"]] :param kind: Type of computations to perform. Defaults to "all". :type kind: Literal["all", "stats", "diagnostics"] :param round_to: Decimal places for rounding. Defaults to 2. :type round_to: custom_types.Integer :param circ_var_names: Names of circular variables. Defaults to None. :type circ_var_names: Optional[list[str]] :param stat_focus: Primary statistic for focus. Defaults to "mean". :type stat_focus: str :param stat_funcs: Custom statistic functions. Defaults to None. :type stat_funcs: Optional[Union[dict[str, callable], callable]] :param extend: Whether to include extended statistics. Defaults to True. Only meaningful if `stat_funcs` is not `None`. :type extend: bool :param hdi_prob: Probability for highest density interval. Defaults to 0.94. :type hdi_prob: custom_types.Float :param skipna: Whether to skip NaN values. Defaults to False. :type skipna: bool :param diagnostic_varnames: Names of diagnostic metrics. Defaults to ("mcse_mean", "mcse_sd", "ess_bulk", "ess_tail", "r_hat"). :type diagnostic_varnames: Sequence[str] :returns: Combined dataset with all computed metrics :rtype: xr.Dataset Enhanced Features: - Automatic Dask acceleration for large datasets - Separation of statistics and diagnostics into appropriate groups - Memory-efficient computation strategies The method automatically updates the InferenceData object with new groups: - variable_summary_stats: Basic summary statistics - variable_diagnostic_stats: MCMC diagnostic metrics """ # We use custom functions if we are using dask if self.use_dask: # Calculate the two datasets summary_stats = dask_enabled_summary_stats(self.inference_obj) diagnostics = dask_enabled_diagnostics(self.inference_obj) # Combine datasets to get the summaries if kind == "all": summaries = xr.concat([summary_stats, diagnostics], dim="metric") elif kind == "stats": summaries = summary_stats elif kind == "diagnostics": summaries = diagnostics # Otherwise, we use the default ArviZ functions else: # Run the inherited method to get the summary statistics summaries = super().calculate_summaries( var_names=var_names, filter_vars=filter_vars, kind=kind, round_to=round_to, circ_var_names=circ_var_names, stat_focus=stat_focus, stat_funcs=stat_funcs, extend=extend, hdi_prob=hdi_prob, skipna=skipna, ) # Identify the diagnostic and summary statistics noted_diagnostics = set(diagnostic_varnames) calculated_metrics = set(summaries.metric.values.tolist()) diagnostic_metrics = list(noted_diagnostics & calculated_metrics) stat_metrics = list(calculated_metrics - noted_diagnostics) summary_stats = summaries.sel(metric=stat_metrics) diagnostics = summaries.sel(metric=diagnostic_metrics) # Update the groups if kind == "all" or kind == "diagnostics": self._update_group("variable_diagnostic_stats", diagnostics) if kind == "all" or kind == "stats": self._update_group("variable_summary_stats", summary_stats) return summaries
[docs] def calculate_diagnostics(self) -> xr.Dataset: """Shortcut to running :py:meth:`~scistanpy.model.results.mle.MLEInferenceRes.calculate_summaries` with ``kind="diagnostics"`` and no other arguments. :returns: Dataset containing diagnostic metrics :rtype: xr.Dataset The method is designed as a simple interface for users who only need diagnostic information without summary statistics. """ return self.calculate_summaries(kind="diagnostics")
[docs] def evaluate_sample_stats( self, max_tree_depth: custom_types.Integer | None = None, ebfmi_thresh: custom_types.Float = DEFAULT_EBFMI_THRESH, ) -> xr.Dataset: """Evaluate sample-level diagnostic statistics for MCMC quality assessment. :param max_tree_depth: Maximum tree depth threshold. Uses model default if None. Defaults to None. :type max_tree_depth: Optional[custom_types.Integer] :param ebfmi_thresh: E-BFMI threshold for energy diagnostics. Defaults to 0.2. :type ebfmi_thresh: custom_types.Float :returns: Dataset with boolean arrays indicating test failures :rtype: xr.Dataset This method evaluates sample-level diagnostic statistics to identify problematic samples in the MCMC chains. Tests are considered failures when samples exhibit the following characteristics: - **Tree Depth**: Sample reached maximum tree depth (saturation) - **E-BFMI**: Energy-based fraction of missing information below threshold - **Divergence**: Sample diverged during Hamiltonian dynamics The resulting boolean arrays have ``True`` values indicating failed samples and ``False`` values indicating successful samples. This information is stored in the 'sample_diagnostic_tests' group of the InferenceData object. Example: >>> sample_tests = results.evaluate_sample_stats(ebfmi_thresh=0.15) >>> n_diverged = sample_tests.diverged.sum().item() >>> print(f"Number of divergent samples: {n_diverged}") """ # If not provided, extract the maximum tree depth from the attributes if max_tree_depth is None: max_tree_depth = self.inference_obj.attrs["max_depth"].item() # Run all tests and build a dataset sample_tests = xr.Dataset( { "low_ebfmi": self.inference_obj.sample_stats.energy__ < ebfmi_thresh, "max_tree_depth_reached": self.inference_obj.sample_stats.treedepth__ == max_tree_depth, "diverged": self.inference_obj.sample_stats.divergent__ == 1, } ) # pylint: enable=no-member # Add the new group to the ArviZ object self._update_group("sample_diagnostic_tests", sample_tests) return sample_tests
[docs] def evaluate_variable_diagnostic_stats( self, r_hat_thresh: custom_types.Float = DEFAULT_RHAT_THRESH, ess_thresh=DEFAULT_ESS_THRESH, ) -> xr.Dataset: """Evaluate variable-level diagnostic statistics for convergence assessment. :param r_hat_thresh: R-hat threshold for convergence. Defaults to 1.01. :type r_hat_thresh: custom_types.Float :param ess_thresh: ESS threshold *per chain*. Defaults to 100. :type ess_thresh: custom_types.Integer :returns: Dataset with boolean arrays indicating variable-level test failures :rtype: xr.Dataset :raises ValueError: If variable_diagnostic_stats group doesn't exist :raises ValueError: If required metrics are missing This method evaluates variable-level diagnostic statistics to identify parameters that exhibit poor sampling behavior. Tests are considered failures when variables meet the following criteria: Failure Conditions: - **R-hat**: Split R-hat statistic >= threshold (poor convergence) - **ESS Bulk**: Bulk effective sample size / n_chains <= threshold per chain - **ESS Tail**: Tail effective sample size / n_chains <= threshold per chain Results are stored in the 'variable_diagnostic_tests' group with boolean arrays indicating which variables failed which tests. Example: >>> var_tests = results.evaluate_variable_diagnostic_stats(r_hat_thresh=1.02) >>> failed_convergence = var_tests.sel(metric='r_hat').sum() >>> print(f"Variables with poor convergence: {failed_convergence.sum().item()}") """ # We need to check if the `variable_diagnostic_stats` group exists. If it doesn't, # we need to run `calculate_diagnostics` first. if not hasattr(self.inference_obj, "variable_diagnostic_stats"): raise ValueError( "The `variable_diagnostic_stats` group does not exist. Please run " "`calculate_diagnostics` first." ) # All metrics should be present in the `variable_diagnostic_stats` group. # pylint: disable=no-member if missing_metrics := ( {"r_hat", "ess_bulk", "ess_tail"} - set(self.inference_obj.variable_diagnostic_stats.metric.values.tolist()) ): raise ValueError( "The following metrics are missing from the `variable_diagnostic_stats` " f"group: {missing_metrics}." ) # Update the ess threshold based on the number of chains ess_thresh *= self.inference_obj.posterior.sizes["chain"] # Run all tests and build a dataset variable_tests = xr.concat( [ self.inference_obj.variable_diagnostic_stats.sel(metric="r_hat") >= r_hat_thresh, self.inference_obj.variable_diagnostic_stats.sel(metric="ess_bulk") <= ess_thresh, self.inference_obj.variable_diagnostic_stats.sel(metric="ess_tail") <= ess_thresh, ], dim="metric", ) # pylint: enable=no-member # Add the new group to the ArviZ object self._update_group("variable_diagnostic_tests", variable_tests) return variable_tests
[docs] def identify_failed_diagnostics(self, silent: bool = False) -> tuple[ "custom_types.StrippedTestRes", dict[str, "custom_types.StrippedTestRes"], ]: """Identify and report diagnostic test failures with comprehensive summary. :param silent: Whether to suppress printed output. Defaults to False. :type silent: bool :returns: Tuple of (sample_failures, variable_failures) dictionaries :rtype: tuple[custom_types.StrippedTestRes, dict[str, custom_types.StrippedTestRes]] This method analyzes the results of diagnostic tests and provides both programmatic access to failure information and human-readable summaries. It requires that diagnostic evaluation methods have been run previously. Return Structure: - **sample_failures**: Dictionary mapping test names to arrays of failed sample indices - **variable_failures**: Dictionary mapping metric names to dictionaries of failed variables The method processes test results to extract: - Indices of samples that failed each diagnostic test - Names of variables that failed each diagnostic metric - Summary statistics showing failure rates and percentages When not silent, provides detailed reporting including: - Failure counts and percentages for each test type - Variable-specific failure information organized by metric - Clear categorization of sample vs. variable-level issues """ def process_test_results( test_res_dataarray: xr.Dataset, ) -> "custom_types.ProcessedTestRes": """ Process the test results from a DataArray into a dictionary of test results. Parameters ---------- test_res_dataarray : xr.Dataset The DataArray containing the test results. Returns ------- custom_types.ProcessedTestRes A dictionary where the keys are the variable names and the values are tuples containing the indices of the failed tests and the total number of tests. """ return { varname: (np.atleast_1d(tests.values).nonzero(), tests.values.size) for varname, tests in test_res_dataarray.items() } def strip_totals( processed_test_results: "custom_types.ProcessedTestRes", ) -> "custom_types.StrippedTestRes": """ Strip the totals from the test results. Parameters ---------- processed_test_results : custom_types.ProcessedTestRes The processed test results from the `process_test_results` function. Returns ------- custom_types.StrippedTestRes The processed test results with the totals stripped. """ return {k: v[0] for k, v in processed_test_results.items()} def report_test_summary( processed_test_results: "custom_types.ProcessedTestRes", type_: str, prepend_newline: bool = True, ) -> None: """ Report the summary of the test results. Parameters ---------- processed_test_results : dict[str, tuple[tuple[npt.NDArray, ...], int]] The processed test results from the `process_test_results` function. type_ : str The type of test results (e.g., "sample", "variable"). prepend_newline : bool, optional Whether to prepend a newline before the summary, by default True. """ if prepend_newline: print() header = f"{type_.capitalize()} diagnostic tests results' summaries:" print(header) print("-" * len(header)) for varname, ( failed_indices, total_tests, ) in processed_test_results.items(): n_failures = len(failed_indices[0]) assert all(n_failures == len(failed) for failed in failed_indices[1:]) print( f"{n_failures} of {total_tests} ({n_failures / total_tests:.2%}) {type_}s " f"{message_map.get(varname, f'tests failed for {varname}')}." ) # Different messages for different test types message_map = { "low_ebfmi": "had a low energy", "max_tree_depth_reached": "reached the maximum tree depth", "diverged": "diverged", } # Get the indices of the sampling tests that failed and the total number of tests # performed # pylint: disable=no-member sample_test_failures = process_test_results( self.inference_obj.sample_diagnostic_tests ) # Get the indices of the variable diagnostic tests that failed and the total number # of tests performed variable_test_failures = { metric.item(): process_test_results( self.inference_obj.variable_diagnostic_tests.sel(metric=metric.item()) ) for metric in self.inference_obj.variable_diagnostic_tests.metric } # pylint: enable=no-member # Strip the totals from the test results and package as return values res = ( strip_totals(sample_test_failures), { metric: strip_totals(failures) for metric, failures in variable_test_failures.items() }, ) # If silent, return the test results now if silent: return res # Report sample test failures report_test_summary(sample_test_failures, "sample", prepend_newline=False) # Report variable test failures for metric, failures in variable_test_failures.items(): report_test_summary(failures, metric) return res
[docs] def diagnose( self, max_tree_depth: custom_types.Integer | None = None, ebfmi_thresh: custom_types.Float = DEFAULT_EBFMI_THRESH, r_hat_thresh: custom_types.Float = DEFAULT_RHAT_THRESH, ess_thresh: custom_types.Float = DEFAULT_ESS_THRESH, silent: bool = False, ) -> tuple[ "custom_types.StrippedTestRes", dict[str, "custom_types.StrippedTestRes"] ]: """Runs the complete MCMC diagnostic pipeline. This involves running, in order: 1. :py:meth:`~scistanpy.model.results.hmc.SampleResults.calculate_diagnostics` 2. :py:meth:`~scistanpy.model.results.hmc.SampleResults.evaluate_sample_stats` 3. :py:meth:`~scistanpy.model.results.hmc.SampleResults.evaluate_variable_diagnostic_stats` 4. :py:meth:`~scistanpy.model.results.hmc.SampleResults.identify_failed_diagnostics` Typically, users will want to use this method rather than calling the individual methods themselves. :param max_tree_depth: Maximum tree depth threshold. Uses model default if None. Defaults to None. :type max_tree_depth: Optional[custom_types.Integer] :param ebfmi_thresh: E-BFMI threshold for energy diagnostics. Defaults to 0.2. :type ebfmi_thresh: custom_types.Float :param r_hat_thresh: R-hat threshold for convergence assessment. Defaults to 1.01. :type r_hat_thresh: custom_types.Float :param ess_thresh: ESS threshold per chain. Defaults to 100. :type ess_thresh: custom_types.Float :param silent: Whether to suppress diagnostic output. Defaults to False. :type silent: bool :returns: Tuple of (sample_failures, variable_failures) as returned by identify_failed_diagnostics :rtype: tuple[custom_types.StrippedTestRes, dict[str, custom_types.StrippedTestRes]] The method provides comprehensive assessment of MCMC sampling quality, identifying both immediate issues (e.g., divergences, energy problems) and convergence concerns (e.g., R-hat, effective sample size). All intermediate results are stored in the ``inference_obj`` attribute for later access and further analysis. """ # Run the diagnostics self.calculate_diagnostics() self.evaluate_sample_stats( max_tree_depth=max_tree_depth, ebfmi_thresh=ebfmi_thresh ) self.evaluate_variable_diagnostic_stats( r_hat_thresh=r_hat_thresh, ess_thresh=ess_thresh ) # Identify the failed diagnostics return self.identify_failed_diagnostics(silent=silent)
@overload def plot_sample_failure_quantile_traces( self, display: Literal[True], width: custom_types.Integer, height: custom_types.Integer, ) -> hv.HoloMap: ... @overload def plot_sample_failure_quantile_traces( self, display: Literal[False], width: custom_types.Integer, height: custom_types.Integer, ) -> dict[str, hv.Overlay]: ...
[docs] def plot_sample_failure_quantile_traces( self, *, display=True, width=600, height=600 ): """Visualize quantile traces for samples that failed diagnostic tests. :param display: Whether to return formatted layout for display. Defaults to True. :type display: bool :param width: Width of plots in pixels. Defaults to 600. :type width: custom_types.Integer :param height: Height of plots in pixels. Defaults to 600. :type height: custom_types.Integer :returns: Quantile trace plots in requested format :rtype: Union[hv.HoloMap, dict[str, hv.Overlay]] :raises ValueError: If no samples failed diagnostic tests This method creates specialized trace plots showing how samples that failed diagnostic tests compare to those that passed. The visualization helps identify systematic patterns in sampling failures. Plot Structure: - **X-axis**: Cumulative fraction of parameters (0 to 1, sorted by typical quantile of failed samples) - **Y-axis**: Quantiles of failed samples relative to passing samples - **Individual traces**: Semi-transparent lines for each failed sample - **Typical trace**: Bold line showing median behavior across failures - **Reference line**: Diagonal indicating perfect calibration The plots reveal: - Whether failures are systematic across parameters - Patterns in how failed samples deviate from typical behavior - The severity and consistency of sampling problems Example: >>> # Display interactive traces >>> results.plot_sample_failure_quantile_traces() """ # x-axis labels are meaningless, so we will use a hook to hide them def hook(plot, element): # pylint: disable=unused-argument plot.state.xaxis.major_tick_line_color = None plot.state.xaxis.minor_tick_line_color = None plot.state.xaxis.major_label_text_font_size = "0pt" # If there are no failed samples, raise an error # pylint: disable=no-member if not any( self.inference_obj.sample_diagnostic_tests.apply(lambda x: x.any()).values() ): raise ValueError( "No samples failed the diagnostic tests. This error is a good thing!" ) # First, we need to get all samples and the diagnostic tests. We will reshape # both to be 2D, with the first dimension being the samples and the second # dimension being the parameter values or the test results, respectively. sample_arr = ( self.inference_obj.posterior.to_stacked_array( new_dim="vals", sample_dims=["chain", "draw"], variable_dim="parameter" ) .stack(samples=("chain", "draw")) .T ) sample_test_arr = self.inference_obj.sample_diagnostic_tests.apply( lambda x: x.stack(samples=("chain", "draw")) ) # pylint: enable=no-member # Now we need some metadata about the samples, such as the chain and draw # indices and the parameter names varnames = np.array( [ ".".join([str(i) for i in name_tuple if isinstance(i, (str, int))]) for name_tuple in sample_arr.vals.to_numpy() ] ) chains = sample_arr.coords["chain"].to_numpy() draws = sample_arr.coords["draw"].to_numpy() # We will need the sample array as a numpy array from here sample_arr = sample_arr.to_numpy() # x-values are just incrementally increasing from 0 to 1 for the number of # parameters x = np.linspace(0, 1, sample_arr.shape[1]) # Now we process each of the diagnostic tests plots = {} for testname, testmask in sample_test_arr.items(): # Get the failed samples. If there are no failed samples, skip this # test testmask = testmask.to_numpy() if not testmask.any(): continue failed_samples, failed_chains, failed_draws, passed_samples = ( sample_arr[testmask], chains[testmask], draws[testmask], sample_arr[~testmask], ) # Get the quantiles of the failed samples relative to the passed ones failed_quantiles = plotting.calculate_relative_quantiles( passed_samples, failed_samples ) # Get the typical quantiles of the failed samples typical_failed_quantiles = np.median(failed_quantiles, axis=0) # Sort samples by the values of the typical failed samples sorted_inds = np.argsort(typical_failed_quantiles) ( failed_samples, failed_quantiles, typical_failed_quantiles, resorted_varnames, ) = ( failed_samples[:, sorted_inds], failed_quantiles[:, sorted_inds], typical_failed_quantiles[sorted_inds], varnames[sorted_inds], ) # Build the traces plots[testname] = hv.Overlay( [ hv.Curve( ( x, quantile, resorted_varnames, failed_chains[i], failed_draws[i], failed_samples[i], ), kdims=["Fraction of Parameters"], vdims=["Quantile", "Parameter", "Chain", "Draw", "Value"], ).opts(line_color="blue", line_alpha=0.1, tools=["hover"]) for i, quantile in enumerate(failed_quantiles) ] + [ hv.Curve( (x, typical_failed_quantiles), kdims=["Fraction of Parameters"], vdims=["Quantile"], label="Typical Failed Quantiles", ).opts(line_color="black", line_width=1), hv.Curve( ((0, 1), (0, 1)), kdims=["Fraction of Parameters"], vdims=["Quantile"], label="Idealized Quantiles", ).opts(line_color="black", line_width=1, line_dash="dashed"), ] ).opts( hooks=[hook], title=f"Quantiles of Samples Failing: {testname}", width=width, height=height, ) # If requested, display the plots if display: return hv.Layout(plots.values()).cols(1).opts(shared_axes=False) return plots
@overload def plot_variable_failure_quantile_traces( self, *, display: Literal[True], width: custom_types.Integer, height: custom_types.Integer, plot_quantiles: bool, ) -> VariableAnalyzer: ... @overload def plot_variable_failure_quantile_traces( self, *, display: Literal[False], width: custom_types.Integer, height: custom_types.Integer, plot_quantiles: bool, ) -> pn.pane.HoloViews: ...
[docs] def plot_variable_failure_quantile_traces( self, display=True, width=800, height=400, plot_quantiles=False, ): """Create interactive analyzer for variables that failed diagnostic tests. :param display: Whether to return display-ready analyzer. Defaults to True. :type display: bool :param width: Width of plots in pixels. Defaults to 800. :type width: custom_types.Integer :param height: Height of plots in pixels. Defaults to 400. :type height: custom_types.Integer :param plot_quantiles: Whether to plot quantiles vs raw values. Defaults to False. :type plot_quantiles: bool :returns: Interactive analyzer or Panel layout :rtype: Union[VariableAnalyzer, pn.pane.HoloViews] This method creates an interactive analysis tool for examining individual variables that failed diagnostic tests. The analyzer provides widgets for selecting specific variables, diagnostic metrics, and array indices. Interactive Features: - **Variable Selection**: Choose from variables that failed any test - **Metric Selection**: Focus on specific diagnostic failures - **Index Selection**: Examine individual array elements for multi-dimensional parameters The resulting trace plots show: - Sample trajectories across MCMC chains with distinct colors - Quantile analysis relative to parameters that passed tests - Hover information with detailed sample metadata - Chain-specific behavior identification This tool is particularly valuable for: - Understanding the nature of convergence problems - Identifying problematic parameter regions - Diagnosing systematic vs. sporadic sampling issues - Planning model reparameterization strategies Example: >>> # Interactive analysis in notebook >>> analyzer = results.plot_variable_failure_quantile_traces() >>> analyzer # Display widget interface """ # Build the analyzer object analyzer = VariableAnalyzer( self, plot_width=width, plot_height=height, plot_quantiles=plot_quantiles ) # Return the analyzer object if not displaying if not display: return analyzer # Otherwise, display the plots return analyzer.display()
[docs] @classmethod def from_disk( cls, path: str, csv_files: list[str] | str | None = None, skip_fit: bool = False, use_dask: bool = False, ) -> "SampleResults": """Load SampleResults from saved NetCDF file with optional CSV metadata. :param path: Path to NetCDF file containing inference data :type path: str :param csv_files: Paths to CSV files output by Stan. Can also be a glob pattern in place of a list. Defaults to None (auto-detect based on ``path`` value). :type csv_files: Optional[Union[list[str], str]] :param skip_fit: Whether to skip loading CSV metadata. Defaults to False. :type skip_fit: bool :param use_dask: Whether to enable Dask for computation. Defaults to False. :type use_dask: bool :returns: Loaded SampleResults object ready for analysis :rtype: SampleResults :raises FileNotFoundError: If the specified NetCDF file doesn't exist This class method enables loading of previously saved MCMC results from NetCDF format, with optional access to original CSV metadata for complete functionality. Loading Modes: - **Full loading**: NetCDF + CSV metadata (complete functionality) - **NetCDF only**: Fast loading without CSV metadata (limited functionality) - **Auto-detection**: Automatically finds CSV files based on NetCDF path When use_dask=True, the loaded data supports out-of-core computation for memory-efficient analysis of large datasets. Management of Dask happens internally, so users do not need to be familiar with Dask to take advantage of it. Example: >>> # Load with auto-detected CSV files (csvs must have same basename) >>> results = SampleResults.from_disk('model_results.nc') >>> >>> # Load with explicit CSV files >>> results = SampleResults.from_disk( ... 'results.nc', csv_files=['chain_1.csv', 'chain_2.csv'] ... ) >>> >>> # Fast loading without CSV metadata >>> results = SampleResults.from_disk('results.nc', skip_fit=True) """ # The path to the netcdf file must exist if not os.path.exists(path): raise FileNotFoundError( f"The file {path} does not exist. Please provide a valid path." ) # If csv files are not provided and we are not skipping the fit, we need # to get them from the netcdf file if skip_fit: csv_files = None elif csv_files is None: # If the path to the netcdf file does not end with ".nc", raise a warning # that we cannot automatically find the csv files. If it does, find # the csv files if path.endswith(".nc"): csv_files = list(glob(path.removesuffix(".nc") + "*.csv")) else: warnings.warn( "Could not identify csv files automatically. Loading without." "To be auto-detected, csv files must be named according to the" "following pattern: <extensionless_netcdf_filename>*.csv" ) # Initialize the object return cls( model=None, fit=None if csv_files is None else fit_from_csv_noload(csv_files), inference_obj=path, use_dask=use_dask, )
[docs] def fit_from_csv_noload(path: str | list[str] | os.PathLike) -> CmdStanMCMC: """Create CmdStanMCMC object from CSV files without loading data into memory. This function is adapted from ``cmdstanpy.from_csv``. :param path: Path specification for CSV files (single file, list, or glob pattern) :type path: Union[str, list[str], os.PathLike] :returns: CmdStanMCMC object with metadata but no loaded sample data :rtype: CmdStanMCMC :raises ValueError: If path specification is invalid or no CSV files found :raises ValueError: If CSV files are not valid Stan output This function provides a memory-efficient way to create CmdStanMCMC objects by parsing only the metadata from CSV files without loading the actual sample data. This is particularly useful for large datasets where memory usage is a concern. Path Specifications: - **Single file**: Direct path to one CSV file - **File list**: List of paths to multiple CSV files - **Glob pattern**: Wildcard pattern for automatic file discovery - **Directory**: Directory containing CSV files (loads all .csv files) The function performs validation to ensure: - All specified files exist and are readable - Files contain valid Stan CSV output - Sampling method is compatible (only 'sample' method supported) - Configuration is consistent across files This approach enables efficient processing workflows where sample data is converted to more efficient formats (like NetCDF) without requiring full memory loading of the original CSV files. Example: >>> # Load from glob pattern >>> fit = fit_from_csv_noload('model_output_*.csv') >>> >>> # Load from explicit list >>> fit = fit_from_csv_noload(['chain1.csv', 'chain2.csv']) >>> >>> # Use for conversion without memory loading >>> netcdf_path = cmdstan_csv_to_netcdf(fit, model) """ def identify_files() -> list[str]: """Identifies CSV files from the given path.""" csvfiles = [] if isinstance(path, list): csvfiles = path elif isinstance(path, str) and "*" in path: splits = os.path.split(path) if splits[0] is not None: if not (os.path.exists(splits[0]) and os.path.isdir(splits[0])): raise ValueError( f"Invalid path specification, {path} unknown directory: {splits[0]}" ) csvfiles = glob(path) elif isinstance(path, (str, os.PathLike)): if os.path.exists(path) and os.path.isdir(path): for file in os.listdir(path): if os.path.splitext(file)[1] == ".csv": csvfiles.append(os.path.join(path, file)) elif os.path.exists(path): csvfiles.append(str(path)) else: raise ValueError(f"Invalid path specification: {path}") else: raise ValueError(f"Invalid path specification: {path}") if len(csvfiles) == 0: raise ValueError(f"No CSV files found in directory {path}") for file in csvfiles: if not (os.path.exists(file) and os.path.splitext(file)[1] == ".csv"): raise ValueError( f"Bad CSV file path spec, includes non-csv file: {file}" ) return csvfiles def get_config_dict() -> dict[str, Any]: """Reads the first CSV file and returns the configuration dictionary.""" config_dict: dict[str, Any] = {} try: with open(csvfiles[0], "r", encoding="utf-8") as fd: scan_config(fd, config_dict, 0) except (IOError, OSError, PermissionError) as e: raise ValueError(f"Cannot read CSV file: {csvfiles[0]}") from e if "model" not in config_dict or "method" not in config_dict: raise ValueError(f"File {csvfiles[0]} is not a Stan CSV file.") if config_dict["method"] != "sample": raise ValueError( "Expecting Stan CSV output files from method sample, " f" found outputs from method {config_dict["method"]}" ) return config_dict def build_sampler_args() -> SamplerArgs: """Builds the sampler arguments""" sampler_args = SamplerArgs( iter_sampling=config_dict["num_samples"], iter_warmup=config_dict["num_warmup"], thin=config_dict["thin"], save_warmup=config_dict["save_warmup"], ) # bugfix 425, check for fixed_params output try: check_sampler_csv( csvfiles[0], iter_sampling=config_dict["num_samples"], iter_warmup=config_dict["num_warmup"], thin=config_dict["thin"], save_warmup=config_dict["save_warmup"], ) except ValueError: try: check_sampler_csv( csvfiles[0], is_fixed_param=True, iter_sampling=config_dict["num_samples"], iter_warmup=config_dict["num_warmup"], thin=config_dict["thin"], save_warmup=config_dict["save_warmup"], ) sampler_args = SamplerArgs( iter_sampling=config_dict["num_samples"], iter_warmup=config_dict["num_warmup"], thin=config_dict["thin"], save_warmup=config_dict["save_warmup"], fixed_param=True, ) except ValueError as e: raise ValueError("Invalid or corrupt Stan CSV output file, ") from e return sampler_args def build_fit() -> CmdStanMCMC: """Builds the CmdStanMCMC object""" chains = len(csvfiles) cmdstan_args = CmdStanArgs( model_name=config_dict["model"], model_exe=config_dict["model"], chain_ids=[x + 1 for x in range(chains)], method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=chains) # pylint: disable=protected-access runset._csv_files = csvfiles for i in range(len(runset._retcodes)): runset._set_retcode(i, 0) # pylint: enable=protected-access fit = CmdStanMCMC(runset) return fit # Run the functions to parse the CSV files csvfiles = identify_files() config_dict = get_config_dict() sampler_args = build_sampler_args() return build_fit()