# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Interactive prior predictive check functionality for SciStanPy models.
This module provides interactive widgets and visualizations for conducting
prior predictive checks on SciStanPy models. Prior predictive checks allow
users to examine the behavior of their models before fitting to data by
sampling from prior distributions and visualizing the resulting predictions.
The module centers around the :py:class:`~scistanpy.plotting.prior_predictive.PriorPredictiveCheck`
class, which creates an interactive dashboard with sliders for model hyperparameters
and dropdown menus for selecting visualization options. This enables rapid
exploration of how different prior specifications affect model behavior.
The interface automatically adapts to the structure of the provided model,
exposing only relevant parameters and visualization options based on the
data dimensions and available coordinates.
.. note::
Instances of the :py:class:`~scistanpy.plotting.prior_predictive.PriorPredictiveCheck`
class will not typically be accessed directly by users. Instead, they are accessed
via the :py:meth:`Model.prior_predictive() <~scistanpy.model.Model.prior_predictive>`
method.
"""
from __future__ import annotations
import itertools
import re
from copy import deepcopy
from typing import Optional, TYPE_CHECKING
import holoviews as hv
import hvplot.pandas # pylint: disable=unused-import
import numpy as np
import pandas as pd
import panel as pn
import panel.widgets as pnw
import xarray as xr
from param.parameterized import Event
from scistanpy.model.components import constants
if TYPE_CHECKING:
from scistanpy import model as ssp_model
# We need a regular expression for separating the variable name from its indices
_INDEX_EXTRACTOR = re.compile(r"([A-Za-z0-9_\.]+)\[?([0-9, ]*)\]?")
[docs]
class PriorPredictiveCheck:
"""Interactive dashboard for conducting prior predictive checks.
This class creates a comprehensive interface for exploring model behavior
through prior predictive sampling. It automatically generates appropriate
widgets based on the model structure and provides multiple visualization
modes for examining parameter distributions and relationships.
:param model: SciStanPy model to analyze
:type model: ssp_model.Model
:param copy_model: Whether to create a copy of the model to avoid modifying the original.
Defaults to False, meaning updates to model parameters in the
interactive plot will be reflected in the original model object.
:type copy_model: bool
:ivar model: The model being analyzed (copy or reference)
:ivar float_sliders: Dictionary of parameter adjustment sliders
:ivar target_dropdown: Widget for selecting which parameter to visualize
:ivar group_dim_dropdown: Widget for selecting grouping dimension
:ivar independent_var_dropdown: Widget for selecting independent variable
:ivar plot_type_dropdown: Widget for selecting visualization type
:ivar draw_seed_entry: Widget for setting random seed
:ivar draw_entry: Widget for setting number of experiments
:ivar update_model_button: Button to update model and redraw data
:ivar update_plot_button: Button to update plot without redrawing data
:ivar fig: HoloViews pane containing the current plot
"""
def __init__(self, model: "ssp_model.Model", copy_model: bool = False):
# Copy the model if requested. If we don't copy, then we can modify our
# values on the model directly.
self.model = deepcopy(model) if copy_model else model
# Initialize widgets.
self.float_sliders = self._init_float_sliders()
self.target_dropdown = pnw.Select(
name="Viewed Parameter",
options=[
k
for k, v in self.model.named_model_components_dict.items()
if not isinstance(v, constants.Constant)
],
value=self.model.observables[0].model_varname,
)
self.group_dim_dropdown = pnw.Select(name="Group By", options=[], value="")
self.independent_var_dropdown = pnw.Select(
name="Independent Variable", options=[], value=""
)
self.plot_type_dropdown = pnw.Select(name="Plot Type", options=[], value="ECDF")
self.draw_seed_entry = pnw.IntInput(name="Seed", value=1025)
self.draw_entry = pnw.IntInput(name="Number of Experiments", value=1)
self.update_model_button = pnw.Button(
name="Update Model", button_type="primary"
)
self.update_plot_button = pnw.Button(name="Update Plot", button_type="primary")
# We need additional components for the plotting data
self.fig = pn.pane.HoloViews(
hv.Curve([]),
name="Plot",
align="center",
sizing_mode="stretch_both",
)
# The update model button will run the full pipeline, including drawing new
# data and updating the plot. The update plot button will only update the
# plot, without redrawing the data.
self.update_model_button.on_click(self._full_pipeline)
self.update_plot_button.on_click(self._update_plot)
# We need to store the data and the plotting data
self._xarray_data: xr.Dataset = xr.Dataset()
self._processed_data: pd.DataFrame = pd.DataFrame()
# Draw initial data. We need this for setting up the remaining reactive
# components. This will set the `_xarray_data` attribute.
self._draw_data()
# Create reactive components whose values depend on the data. These components
# have no effect on the data that is drawn, only what is shown.
self.target_dropdown.param.watch(self.set_group_dim_options, "value")
self.group_dim_dropdown.param.watch(self.set_independent_var_options, "value")
self.group_dim_dropdown.param.watch(self.set_plot_type_options, "value")
self.independent_var_dropdown.param.watch(self.set_plot_type_options, "value")
# Set initial values for the components relevant to showing data
self.target_dropdown.param.trigger("value")
self.group_dim_dropdown.param.trigger("value")
self.independent_var_dropdown.param.trigger("value")
# Run the full pipeline to set up the plotting data
self._full_pipeline()
def _init_float_sliders(self) -> dict[str, pnw.EditableFloatSlider]:
"""Initialize float sliders for toggleable model parameters.
This method scans the model for constant parameters marked as toggleable
and creates appropriate slider widgets for interactive adjustment. It
handles both scalar and multi-dimensional parameters appropriately.
:returns: Dictionary mapping parameter names to slider widgets
:rtype: dict[str, pnw.EditableFloatSlider]
For scalar parameters or those with enforced uniformity, a single slider
is created. For multi-dimensional parameters, individual sliders are
created for each array element with indexed naming conventions.
Slider ranges and step sizes are determined by the parameter's configured
slider properties (slider_start, slider_end, slider_step_size).
"""
# Each togglable parameter gets its own float slider
sliders = {}
# Process all constants in the model
for (
hyperparam_name,
hyperparam_val,
) in self.model.all_model_components_dict.items():
# Skip non-constants and non-togglable parameters
if (
not isinstance(hyperparam_val, constants.Constant)
or not hyperparam_val.is_togglable
):
continue
# If no dimensions OR if we are forcing uniformity across a multidimensional
# array, just create a slider
if hyperparam_val.ndim == 0 or hyperparam_val.enforce_uniformity:
sliders[hyperparam_name] = pnw.EditableFloatSlider(
name=hyperparam_name,
value=np.unique(hyperparam_val.value).item(),
start=hyperparam_val.slider_start,
end=hyperparam_val.slider_end,
step=hyperparam_val.slider_step_size,
)
continue
# Otherwise, create a slider for each entry
for arr_ind in np.ndindex(hyperparam_val.shape):
# Build the slider name
name = f"{hyperparam_name}[{', '.join(map(str, arr_ind))}]"
# Get the slider value
slider_val = hyperparam_val.value[arr_ind]
# Add the slider
sliders[name] = pnw.EditableFloatSlider(
name=name,
value=slider_val,
start=hyperparam_val.slider_start,
end=hyperparam_val.slider_end,
step=hyperparam_val.slider_step_size,
)
return sliders
def _update_model(self) -> None:
"""Update the underlying model with current slider values.
This method reads the current values from all parameter sliders and
updates the corresponding constant values in the model. It handles
both scalar parameters and indexed multi-dimensional parameters.
For indexed parameters, the method parses the slider name to extract
the parameter name and array indices, then updates the appropriate
array element.
"""
for paramname, slider in self.float_sliders.items():
# Get the parameter name and the indices
paramname, indices = _INDEX_EXTRACTOR.match(paramname).groups()
if indices:
tuple(map(int, indices.split(",")))
else:
indices = ...
# Update the value of the constant
self.model[paramname].value[indices] = slider.value
def _draw_data(self) -> None:
"""Draw new data from the model using current parameter values.
This method generates samples from the model's prior predictive
distribution using the current slider values and stores the results
as an xarray Dataset for efficient manipulation and visualization.
The number of samples and random seed are controlled by the
corresponding widget values (draw_entry.value and draw_seed_entry.value).
"""
self._xarray_data = self.model.draw(
n=self.draw_entry.value,
named_only=True,
as_xarray=True,
seed=self.draw_seed_entry.value,
)
def _process_data(self) -> None:
"""Process raw xarray data into DataFrame format appropriate for plotting.
This method transforms the xarray Dataset into a pandas DataFrame
with appropriate structure for the selected visualization type. It
handles data reshaping, grouping, and computation of derived quantities
like cumulative probabilities for ECDF plots.
The processing logic adapts based on the selected plot type:
- ECDF plots: Computes cumulative probabilities and sorts appropriately
- Relationship plots: Sorts by independent variable and adds separators
- Other plots: Basic reshaping and filtering
The method respects the current widget selections for target parameter,
grouping dimension, and independent variable.
"""
# We need to define aggregation functions for the different plot types
def build_ecdfs(group):
# We need to record the original values of the dependent variable
new_df = group[[self.target_dropdown.value]]
# Rank the target variable
new_df["Cumulative Probability"] = group[[self.target_dropdown.value]].rank(
method="max", pct=True
)
# Add the independent variable values
sort_keys = ["Cumulative Probability"]
if self._independent_label is not None:
sort_keys.insert(0, self._independent_label)
new_df[self._independent_label] = group[self._independent_label]
# Sort and return
return new_df.sort_values(by=sort_keys)
def build_relations(group):
# Sort the data by the independent variable and add a NaN row to separate
# the data by the appropriate dependent variable
return pd.concat(
[
group.sort_values(by=self.independent_var_dropdown.value),
pd.DataFrame({self.independent_var_dropdown.value: [np.nan]}),
]
)
# Gather the target data
selected_data = self._xarray_data[
[self.target_dropdown.value]
+ (
[]
if self.independent_var_dropdown.value == ""
else [self.independent_var_dropdown.value]
)
]
# Reshape the data as appropriate and convert the extracted data to a DataFrame.
# We keep the grouping dimension separate from the stacked results.
df = (
selected_data.stack(
stacked=[
dim
for dim in selected_data.dims
if dim != self.group_dim_dropdown.value
],
create_index=False,
)
.to_dataframe()
.reset_index()
)
# Filter to just the columns needed. These are the grouping dimensions and
# independent variables, if any.
target_cols = [self.target_dropdown.value]
if self.group_dim_dropdown.value != "":
target_cols.append(self.group_dim_dropdown.value)
if self.independent_var_dropdown.value != "":
target_cols.extend([self.independent_var_dropdown.value, "stacked"])
df = df[target_cols]
# Final processing for certain plots.
if self.plot_type_dropdown.value == "ECDF":
if self.group_dim_dropdown.value != "":
df = df.groupby(self.group_dim_dropdown.value, group_keys=False).apply(
build_ecdfs
)
else:
df = build_ecdfs(df)
elif self.plot_type_dropdown.value == "Relationship":
df = df.groupby("stacked").apply(build_relations)
# Store the processed data
self._processed_data = df
def _update_plot( # pylint: disable=unused-argument
self, event: Optional[Event] = None
) -> None:
"""Update the plot display using current data and widget settings.
This method reprocesses the current xarray data according to the
selected visualization options and updates the plot display. It does
not redraw data from the model, making it efficient for exploring
different visualization modes.
:param event: Panel event object (unused, for callback compatibility).
Defaults to None.
:type event: Optional[Event]
The method automatically selects appropriate plotting functions and
styling based on the plot type dropdown selection, handling both
hvplot-based plots and specialized HoloViews elements like violin plots.
"""
# Update plot button to loading
self.update_plot_button.loading = True
# Reformat the data
self._process_data()
# Update the plot kwargs
plot_kwargs = {
"ECDF": self.get_ecdf_kwargs,
"KDE": self.get_kde_kwargs,
"Violin": self.get_violin_kwargs,
"Relationship": self.get_relationship_kwargs,
}[self.plot_type_dropdown.value]()
# Update the plot
if self.plot_type_dropdown.value == "Violin":
self.fig.object = hv.Violin(plot_kwargs.pop("args"), **plot_kwargs).opts(
show_legend=False
)
else:
self.fig.object = self._processed_data.hvplot(**plot_kwargs)
# Update plot button to not be loading
self.update_plot_button.loading = False
def _full_pipeline( # pylint: disable=unused-argument
self, event: Optional[Event] = None
) -> None:
"""Execute complete model update and visualization pipeline.
This method performs the full sequence of operations: updating model
parameters from sliders, drawing new data from the updated model,
and refreshing the plot display. It provides a complete refresh of
the analysis when model parameters change.
:param event: Panel event object (unused, for callback compatibility).
Defaults to None.
:type event: Optional[Event]
The method sets loading states on relevant buttons to provide user
feedback during potentially time-consuming operations like data
generation for complex models.
"""
# Buttons to loading mode
self.update_model_button.loading = True
self.update_plot_button.loading = True
# Update the model
self._update_model()
# Draw new data
self._draw_data()
# Update the plot
self._update_plot()
# Buttons to not be loading
self.update_model_button.loading = False
self.update_plot_button.loading = False
[docs]
def set_group_dim_options(self, event: Event) -> None:
"""Update grouping dimension options based on selected target parameter.
This method configures the group dimension dropdown with appropriate
options based on the dimensionality of the currently selected target
parameter. It ensures that grouping options are only available for
multi-dimensional parameters.
:param event: Panel event containing the new target parameter selection
:type event: Event
The method updates both the dropdown options and the descriptive text
showing dimension sizes to help users understand the data structure.
If the previously selected grouping dimension is no longer valid,
it automatically resets to a sensible default.
"""
# Get the dimensions of the target variable
partial_opts = list(self._xarray_data[event.new].dims[1:])
# Grouping can only be performed when we have more than one dimension
target_dim_opts = [""]
if len(partial_opts) > 1:
target_dim_opts += partial_opts
# If the previous dependent dimension is not in the options, reset it
if self.group_dim_dropdown.value not in target_dim_opts:
self.group_dim_dropdown.value = target_dim_opts[-1]
# Update the description of the grouping dimension
description = ", ".join(
f"[{dim}: {self._xarray_data.sizes[dim]}]" for dim in target_dim_opts[1:]
)
self.group_dim_dropdown.name = f"Group By ({description})"
# Update the dropdown options
self.group_dim_dropdown.options = target_dim_opts
[docs]
def set_independent_var_options(self, event: Event) -> None:
"""Update independent variable options based on grouping dimension.
This method configures the independent variable dropdown with coordinates
and data variables that are compatible with the selected grouping
dimension. Only variables that vary along the grouping dimension are
included as options.
:param event: Panel event containing the new grouping dimension selection
:type event: Event
The method scans both coordinates and data variables in the xarray
Dataset to find suitable independent variables, ensuring compatibility
with the selected visualization approach.
"""
# The independent variable must be a coordinate or data variable that contains
# the `Group By` dimension.
independent_var_opts = [""] + [
varname
for varname, arr in itertools.chain(
self._xarray_data.coords.items(), self._xarray_data.data_vars.items()
)
if arr.sizes.get(event.new, 0) > 1
]
# If the previous independent variable is not in the options, reset it
if self.independent_var_dropdown.value not in independent_var_opts:
self.independent_var_dropdown.value = ""
# Update the dropdown options
self.independent_var_dropdown.options = independent_var_opts
[docs]
def set_plot_type_options( # pylint: disable=unused-argument
self, event: Event
) -> None:
"""Update available plot types based on current dimension selections.
This method determines which visualization types are appropriate
given the current selections for target parameter, grouping dimension,
and independent variable. It enables more sophisticated plot types
as more structure is specified.
:param event: Panel event (used for callback compatibility is not used)
:type event: Event
Plot Type Logic:
- ECDF and KDE: Always available for any parameter
- Violin: Available when grouping dimension is selected
- Relationship: Available when both grouping and independent variable are selected
The method automatically selects the most sophisticated available plot
type as the default when options change.
"""
# We can always have an ECDF and KDE plot
plot_type_opts = ["ECDF", "KDE"]
default_plot = "ECDF"
# If a grouping dimension is set, then we can also have violin plots
if self.group_dim_dropdown.value != "":
plot_type_opts.append("Violin")
default_plot = "Violin"
# If an independent variable is set, then we can also have relationship
# plots
if self.independent_var_dropdown.value != "":
plot_type_opts.append("Relationship")
default_plot = "Relationship"
# If the previous plot type is not in the options, reset it
if self.plot_type_dropdown.value not in plot_type_opts:
self.plot_type_dropdown.value = default_plot
# Update to the plot type options
self.plot_type_dropdown.options = plot_type_opts
[docs]
def display(self) -> pn.Row:
"""Create and return the complete interactive dashboard layout.
This method assembles all widgets and the plot display into a
comprehensive dashboard layout suitable for display in Jupyter
notebooks, Panel applications, or web interfaces.
:returns: Panel layout containing all interface elements
:rtype: pn.Row
The layout consists of:
- Left panel: Model hyperparameter sliders and viewing options
- Right panel: Interactive plot display that updates based on selections
Example:
>>> check = PriorPredictiveCheck(model)
>>> dashboard = check.display() # For Jupyter notebook
>>> dashboard.servable() # For web deployment
"""
# Organize widgets and plot
return pn.Row(
pn.WidgetBox(
pn.WidgetBox(
"# Model Hyperparameters",
*self.float_sliders.values(),
self.draw_seed_entry,
self.draw_entry,
self.update_model_button,
),
pn.WidgetBox(
"# Viewing Options",
self.target_dropdown,
self.group_dim_dropdown,
self.independent_var_dropdown,
self.plot_type_dropdown,
self.update_plot_button,
),
),
self.fig,
)
[docs]
def get_kde_kwargs(self) -> dict:
"""Generate keyword arguments for kernel density estimate plots.
This method constructs the parameter dictionary needed for creating
KDE visualizations using hvplot, including appropriate grouping
and styling options.
:returns: Dictionary of hvplot parameters for KDE plotting
:rtype: dict
The returned dictionary includes:
- 'kind': Set to 'kde' for kernel density estimation
- 'x': `None` for KDE plots.
- 'y': Target parameter name for the y-axis
- 'by': Independent label for grouping (if applicable)
- 'datashade': Disabled (False) for KDE plots to maintain clarity
"""
return {
"kind": "kde",
"x": None,
"y": self.target_dropdown.value,
"by": self._independent_label,
"datashade": False,
}
[docs]
def get_ecdf_kwargs(self) -> dict:
"""Generate keyword arguments for empirical CDF plots.
This method constructs the parameter dictionary needed for creating
ECDF visualizations using hvplot, including cumulative probability
calculations and appropriate hover interactions.
:returns: Dictionary of hvplot parameters for ECDF plotting
:rtype: dict
The returned dictionary includes:
- 'kind': Set to 'line' for step-like ECDF appearance
- 'x': Target parameter values
- 'y': 'Cumulative Probability' (computed during data processing)
- 'by': Independent label for grouping multiple ECDFs
- 'datashade': Disabled (False) for ECDF plots to maintain clarity
- 'hover': Set to 'hline' for horizontal hover lines
"""
return {
"kind": "line",
"x": self.target_dropdown.value,
"y": "Cumulative Probability",
"by": self._independent_label,
"datashade": False,
"hover": "hline",
}
[docs]
def get_violin_kwargs(self) -> dict:
"""Generate arguments for violin plot creation.
This method constructs the arguments needed for creating violin plots
using HoloViews, handling complex grouping scenarios and determining
appropriate categorization based on data structure.
:returns: Dictionary containing plot arguments and dimensions
:rtype: dict
The method handles multiple grouping scenarios:
- Single grouping by dimension index
- Grouping by both dimension and independent variable
- Automatic determination of primary vs. secondary grouping
The returned dictionary includes:
- 'args': Tuple of (groups..., values) for HoloViews Violin constructor
- 'kdims': List of key dimension names
- 'vdims': Value dimension name (target parameter)
"""
# This is only an option if we have a grouping dimension
group_indices = self._processed_data[self.group_dim_dropdown.value].to_numpy()
# If the independent label is provided, then it is the grouping variable
# and the group index is the category variable IF there are more unique
# combinations of indices and independent labels than there are group indices
if self.independent_var_dropdown.value != "":
independent_labels = self._processed_data[
self.independent_var_dropdown.value
].to_numpy()
if len(np.unique(group_indices)) < len(
self._processed_data[
[self.group_dim_dropdown.value, self.independent_var_dropdown.value]
].drop_duplicates()
):
groups = [group_indices, independent_labels]
kdims = [
self.independent_var_dropdown.value,
self.group_dim_dropdown.value,
]
else:
groups = [independent_labels]
kdims = [self.independent_var_dropdown.value]
# Otherwise, we just have the group indices
else:
groups = [group_indices]
kdims = [self.group_dim_dropdown.value]
return {
"args": tuple(
groups + [self._processed_data[self.target_dropdown.value].to_numpy()]
),
"kdims": kdims,
"vdims": self.target_dropdown.value,
}
[docs]
def get_relationship_kwargs(self) -> dict:
"""Generate keyword arguments for relationship plots.
This method constructs the parameter dictionary for creating
relationship visualizations that show how parameters vary with
respect to independent variables, using datashading for performance.
:returns: Dictionary of hvplot parameters for relationship plotting
:rtype: dict
The returned dictionary includes:
- kind: Set to 'line' for continuous relationships
- x: Independent variable name
- y: Target parameter name
- datashade: Enabled (True) for performance with large datasets
- dynamic: Disabled (False), resulting in all data being embedded in output
- aggregator: Set to 'count' for density-based coloring
- cmap: 'inferno' colormap
"""
return {
"kind": "line",
"x": self.independent_var_dropdown.value,
"y": self.target_dropdown.value,
"by": None,
"datashade": True,
"dynamic": False,
"aggregator": "count",
"cmap": "inferno",
}
@property
def _independent_label(self) -> Optional[str]:
"""Determine the effective independent label for plotting.
This property provides a unified interface for determining which
variable should be used as the independent label in plots, following
a priority hierarchy based on user selections.
:returns: Name of the independent variable, or None if not applicable
:rtype: Optional[str]
Priority order:
1. Independent variable dropdown selection (if not empty)
2. Group dimension dropdown selection (if not empty)
3. None (for simple univariate plots)
This property enables consistent handling of grouping and independent
variables across different visualization types.
"""
if self.independent_var_dropdown.value != "":
return self.independent_var_dropdown.value
elif self.group_dim_dropdown.value != "":
return self.group_dim_dropdown.value