In [1]:
import holoviews as hv
import numpy as np
import scistanpy as ssp

hv.extension('bokeh', inline=True)

Overview¶

This notebook contains an example workflow for modeling deep mutational scanning data (DMS). If you're familiar with DMS, continue to the next section ("Bayesian Analysis of DMS"). Otherwise, this cell provides an overview of the core concept:

DMS Background¶

When engineering a protein, how well it performs a desired task (e.g., catalysis, binding, etc.) is commonly referred to as its "fitness". Deep mutational scanning is a method for assigning fitness labels to different protein variants at scale. While the exact implementation varies, a typical DMS study follows the below workflow:

  1. We begin with a population of organisms (typically a model organism such as E. coli or S. cerevisiae). Within this population are subpopulations consisting of organisms carrying DNA that encode a specific protein variant--each subpopulation is defined by the variant DNA it carries.
  2. DNA is harvested from the full population and sequenced by next-generation sequencing. This results in a set number of reads (or "counts" in the context of DMS), with the number of reads pertaining to each subpopulation reflecting the overall frequency of that subpopulation within the full population.
  3. The population is then subjected to a selective pressure. Importantly, the ability of organisms to resist this selective pressure is tied to the fitness of the protein of interest. For example, if the goal were to engineer a protein critical for the production of a core element of cellular metabolism, the population would be grown in an environment absent that element, meaning that only subpopulations whose proteins were active could grow and propagate.
  4. Once selection is complete, next generation sequencing is performed again. Subpopulations expressing proteins with higher fitness should be enriched in the returned counts relative to their initial abundance and vice versa.
  5. "Fitness" for a protein is then defined as the ratio of normalized ending counts (where "normalization" is relative to the wild-type (unmutated) protein) to normalized beginning counts. The next section covers this in more detail.

Bayesian Analysis of DMS¶

Notably, DMS does not provide direct readouts of fitness. Instead, fitness values must be inferred. While, again, the exact procedure for doing this will vary with the DMS study, a typical approach is to take the relative ratio of normalized ending counts to normalized starting counts as below:

$$ \begin{align*} f_k = \frac{c_{k, f} / c_{wt, f}}{c_{k, i} / c_{wt, i}} \end{align*}, $$

where $c_{k, f}$ and $c_{k, i}$ are the final and starting counts for variant $k$, respectively, and $c_{wt, f}$ and $c_{wt, i}$ are the final and starting counts for the wild-type protein. Performing the above calculation for all variants gives us their fitnesses.

The value $f_k$ is also commonly called the "enrichment ratio" of the DMS study. Importantly, it approximates both (i) the abundance of each variant relative to other variants in the population at a given timepoint (hence division by the wild-type counts in both the numerator and denominator) and (ii) the abundance of a given variant to itself across timepoints (hence division of the end timepoint by the starting timepoint). The word "approximates" is stressed in the previous sentence, as the counts returned by NGS are an indirect correlate of true variant abundance.

As an analogy, if I had a bag full of marbles (variants) and drew 10 marbles from it, 4 being blue and 6 being green (e.g., "blue" being one variant and "green" being another), I cannot know with certainty what the true proportion of marbles is in that bag. Absent other information, a guess of 40% blue and 60% green seems most reasonable, but we intuitively know that it could be something else. 45% blue and 55% green might be the actual proportions. It's even possible, however unlikely, that there are only 4 blue marbles in the entire bag and we just so happened to pick them all, in which case, assuming the bag contains >>10 total marbles, the proportion of blue marbles is vanishingly small. The key point is that, because our marble counts are the result of a random process, there will always be some measure of uncertainty in our inference of the true proportions.

NGS can be thought of as drawing counts of specific variants from a population of unknown proportions. Just as with the marble example, absent additional information, the relative ratios of counts is the most probable reflection of true proportions, and this is, of course, what the enrichment ratio equation above is calculating. "Most probable" does not necessarily mean "probable", however, and the goal of Bayesian analysis is, in addition to identifying the most probable solution, to also quantify how probable a given solution may be.

Okay, all background and motivation out of the way, this brings us to building a Bayesian model of the above typical DMS scenario. One way to think about constructing a Bayesian model of any observable is to reconceptualize the data collecton process as a data generation process. The output of this process is our observables, which are a reflection of all underlying contributors of uncertainty (modeled as "parameters" SciStanPy) and their deterministic transformations (modeled as "transformations" in SciStanPy). For example, we have already established that one contributor of uncertainty is the sequencing process itself, which outputs discrete counts reflecting the true abundance of each variant in the selected population. A natural distribution for modeling this situation would be the Dirichlet distribution followed by the multinomial distribution

$$ \begin{align*} \mathbf{\theta_i} &\sim \text{Dirichlet}(\mathbf{\alpha}) \\ \mathbf{c}_i &\sim \text{Multinomial}(\mathbf{\theta_i}), \end{align*} $$

where $\alpha \in \mathbb{R}^K$ is a hyperparameter controlling our prior beliefs about the initial distribution of proportions, $\mathbf{c}_i \in \mathbb{W}^{K}$ defines the starting counts for $K$ total proteins and $\mathbf{\theta_i} \in \{\mathbb{R}^K | 0 \leq \theta_{i_k} \leq 1, \sum_{k=1}^K \theta_{i_k} = 1\}$ their proportions (i.e., the probability simplex).

The above is the first step of our "data generation" process. The next step is to apply the selective pressure. We will model the selection process as applying some multiplicative correction (i.e. an enrichment ratio) to the input proportions, not counts.

$$ \begin{align*} \mathbf{f} &\sim \text{Exponential}(\beta) \\ \mathbf{\theta_f} &= \frac{\mathbf{\theta_i} * \mathbf{f}}{\sum_{k=1}^K \theta_{i_k} * f_k}, \end{align*} $$

where $\mathbf{f} \in \{\mathbb{R}^K | R_k \geq 0\}$, $*$ indicates elementwise multiplication, and $\mathbf{\theta_f}$ gives the proportions after selection. Two things should be highlighted:

  1. We place an exponential prior on $\mathbf{f}$. This is to reflect the belief that variants are more likely to have low fitness than high fitness. The extent to which this belief is enforced is controlled by the hyperparameter $\beta$.
  2. $\mathbf{\theta_f}$ is not drawn from a distribution. Instead, it is modeled as the deterministic transformation of two parameters, $\mathbf{\theta_i}$ and $\mathbf{f}$. Implicitly, then, our uncertainty of the value of $\mathbf{\theta_f}$ results from our uncertainty in both the initial population proportions and the fitness of those proportions. In terms of data generation, $\mathbf{\theta_f}$ is found at a later step in the generation process than the other parameters (it is further along in the dependency graph).

From the transformed parameter $\mathbf{\theta_f}$, we can model our output counts using another multinomial distribution:

$$ \mathbf{c}_f \sim \text{Multinomial}(\mathbf{\theta_f}). $$

In all, this gives us the following model:

$$ \begin{align*} \mathbf{\theta_i} &\sim \text{Dirichlet}(\mathbf{\alpha}) \\ \mathbf{f} &\sim \text{Exponential}(\beta) \\ \mathbf{\theta_f} &= \frac{\mathbf{\theta_i} * \mathbf{f}}{\sum_{k=1}^K \theta_{i_k} * f_k} \\ \mathbf{c}_i &\sim \text{Multinomial}(\mathbf{\theta_i}) \\ \mathbf{c}_f &\sim \text{Multinomial}(\mathbf{\theta_f}) \end{align*} $$

SciStanPy is designed to allow models like the above to be encoded in a Pythonic way and, by extension, to allow us to infer values of unobservables ($\mathbf{f}$ in our example here) from observables, propagating uncertainty between different model components. In the case of DMS, then, fitting the above model in SciStanPy gives us a distribution of potential models, and so a distribution of potential fitness values, all informed by the collected data.

Fitting a DMS Model in SciStanPy¶

The above section covered the key concepts of DMS. It also covered the core differences between standard modeling of DMS data and Bayesian modeling of DMS data. This section demonstrates how to fit the above-described Bayesian model using SciStanPy:

To begin, let's just simulate some example data:

In [2]:
def sample_data():
    """
    Generate sample data for deep mutational scanning analysis.
    Returns:
        INPUT_COUNTS: Array of input counts for each variant.
        LOG_INPUT_FREQS: Log frequencies of input variants.
        LOG_OUTPUT_FREQS: Log frequencies of output variants after selection.
    """
    # Sample input counts
    rng = np.random.default_rng(1025)
    input_freqs = rng.dirichlet(np.ones(10))
    log_input_freqs = np.log(input_freqs)
    input_counts = np.stack([rng.multinomial(10000, input_freqs)
                             for _ in range(3)])

    # Sample enrichment factors
    log_enrichment_factors = np.log(rng.exponential(0.1, size=(10,)))

    # Generate output counts after selection
    log_output_freqs = log_input_freqs + log_enrichment_factors
    log_output_freqs -= np.log(np.sum(np.exp(log_output_freqs))) # Normalize
    output_counts = np.stack([rng.multinomial(10000, np.exp(log_output_freqs))
                               for _ in range(3)])

    return {
        "INPUT_COUNTS": input_counts,
        "LOG_INPUT_FREQS": log_input_freqs,
        "OUTPUT_COUNTS": output_counts,
        "LOG_OUTPUT_FREQS": log_output_freqs,
        "LOG_ENRICHMENT_FACTORS": log_enrichment_factors
    }

SAMPLE_DATA = sample_data()

The above simulates a deep mutational scanning experiment of 10 variants with sequencing performed in triplicate at both the beginning and end. Note that we are working in the log space for our implementation relative to the model described in the previous section. For 10 variants, this is not likely to be necessary, but as the size of the probability simplex grows, numerical precision will become a problem on the standard scale. The log scale is used here to both (i) demonstrate a unique feature of SciStanPy (log-simplexes are not currently natively supported in Stan or PyTorch) and (ii) demonstrate a model that would be more practical for the typical DMS experiment (which tend to have >>10 variants).

We can now model the process using SciStanPy:

In [3]:
# All model's inherit from `ssp.Model`
class DMSModel(ssp.Model):

    # Define the structure of the model in the `__init__` method
    def __init__(self, input_counts, output_counts):

        # We're going to register default data for the input and output counts.
        # This isn't necessary, but means you won't need to pass the observables
        # into later methods.
        super().__init__(
            default_data={"input_counts": input_counts, "output_counts": output_counts}
        )

        # We now define our priors. Let's assume that we expect our enrichment
        # ratios to follow an exponential distribution. The log-enrichment factors
        # will then follow a exponential-exponential (Gumbel) distribution.
        # Note: We define 10 independent log-enrichment factors using the "shape"
        # argument.
        # Note: The "beta" parameter here is the inverse of the scale parameter
        # in Numpy/Scipy.
        self.log_enrichment = ssp.parameters.ExpExponential(beta=10.0, shape=(10,))

        # We reason that the input and output counts are multinomially distributed
        # with some unknown frequency, which are the values we want to infer. To
        # handle potentially small values, we will use an Exp-Dirichlet prior to
        # model the log-frequencies.
        self.log_input_freqs = ssp.parameters.ExpDirichlet(alpha=1.0, shape=(10,))

        # From the log-input frequencies and log-enrichment factors, we can define
        # a transformation that takes us to the output frequencies. We're in log
        # space, so this is just addition followed by normalization. Note that,
        # currently, all reductions and normalizations are performed over the last
        # axis (this cannot be changed yet).
        self.log_output_freqs = ssp.operations.normalize_log(
            self.log_input_freqs + self.log_enrichment
        )

        # Finally, we can model our observed counts at both the beginning and end
        # as multinomially distributed. Note that the name of the observable must
        # match the name we used when registering default data. If not registering
        # default data, you will need to provide the observables as keyword arguments
        # in the relevant functions (again, with matching names).
        # Note: We are using an alternate parametrization of the multinomial distribution
        # here to keep in log space.
        # Note: Numpy broadcasting rules apply, so the below will use the same
        # 10 log-frequencies for all 3 replicates. Note that this is why we need
        # `keepdims=True` when summing the counts to get `N`: (shapes (3, 1) and
        # (10,) broadcast to (3 x 10), while (3,) and (10,) do not broadcast.
        self.input_counts = ssp.parameters.MultinomialLogTheta(
            log_theta=self.log_input_freqs,
            N=input_counts.sum(axis=-1, keepdims=True),
            shape=(3, 10),
        )
        self.output_counts = ssp.parameters.MultinomialLogTheta(
            log_theta=self.log_output_freqs,
            N=output_counts.sum(axis=-1, keepdims=True),
            shape=(3, 10),
        )

That's it! Model is defined. You'll note that the variable names in the above model correspond neatly to variables defined in the model definition from the previous section:

Name Previous Section Name in Model
$\ln{\mathbf{\theta_i}}$ log_input_freqs
$\ln{\mathbf{\theta_f}}$ log_output_freqs
$\ln{\mathbf{f}}$ log_enrichment
$\mathbf{c_i}$ input_counts
$\mathbf{c_i}$ output_counts

That is, SciStanPy models are designed to follow syntax that closely reflects standard probabilistic model definitions, with instance variables becoming model names and transformations automatically recorded.

Another note: SciStanPy models are, of course, Python classes, and can be extended in all the usual ways. This can allow for the construction of class hierarchies that greatly reduce the need for duplicated code.

Now, what can we do with our model? Let's create an instance of it and test out some SciStanPy operations. First up, let's do a prior predictive check:

In [4]:
# Build an instance
EXAMPLE_MODEL = DMSModel(
    input_counts=SAMPLE_DATA["INPUT_COUNTS"],
    output_counts=SAMPLE_DATA["OUTPUT_COUNTS"]
)

# Run a prior predictive check
EXAMPLE_MODEL.prior_predictive()
Out[4]:
BokehModel(combine_events=True, render_bundle={'docs_json': {'143225d7-76f4-456e-b0bc-57edc7246af7': {'version…

The prior_predictive function brings up an interactive dashboard that lets you test out the effects of different values for hyperparameters on model observables and parameters. By default, updating any parameters with the sliders (and subsequently clicking "update model") will also update their values in the model. This allows you to explore model hyperparameter values interactively before moving on to fitting a model.

Depending on the parameter selected, you may want to increase the value for "Number of Experiments"--this is the number of draws made from the model to build the figure. The default ECDF view flattens the displayed array before calculation, which obscures any relationships within or between variables. You can additionally choose values for "Group By", which will plot separate lines (or violins) for each grouping dimension; any constants with appropriate dimensionality can also be selected as "Independent Variable" to plot relationships between components.

Let's say we are happy with our hyperparameter selection. We're not yet ready to commit to full MCMC sampling, but we want to get an estimate for our parameter values. We can perform a maximum likelihood estimate using PyTorch as the backend:

In [5]:
MLE = EXAMPLE_MODEL.mle(lr=0.01)
Epochs:   7%|▋         | 6823/100000 [01:03<14:21, 108.17it/s, -log pdf/pmf=245.32] 

By default, maximum likelihood estiation will run for 100,000 steps or until the loss (negative log likelihood) has not decreased for 10 steps, whichever comes first. The output results object exposes a lot of additional functionality which is covered in great detail in the documentation. For the purposes of this example notebook, however, we'll look at just two: Extracting maximum likelihood estimates and bootstrapping observations:

In [6]:
MLE_ESTIMATES = {k: v.mle for k, v in MLE.model_varname_to_mle.items()}
MLE_ESTIMATES
Out[6]:
{'log_enrichment': array([-0.88175307, -3.0444229 , -2.52572081, -2.72127213, -2.63158634,
        -3.36346962, -4.35208336, -2.55653866, -0.82543532, -3.89821476]),
 'log_input_freqs': array([-2.31319812, -1.04716232, -2.52949171, -2.33462835, -2.90948524,
        -2.95035977, -2.54973785, -6.09674183, -2.15798691, -2.64728865]),
 'input_counts': None,
 'output_counts': None}

A dictionary of results linking model varnames to MLE estimates can be accessed with the model_to_varname_property. Note that observables do not have an MLE--observables are ground truth, so there is nothing to estimate.

Let's compare the estimates above to our known values from our simulated experiment:

In [7]:
hv.Scatter(
    data={
        "x": SAMPLE_DATA["LOG_ENRICHMENT_FACTORS"],
        "y": MLE_ESTIMATES["log_enrichment"],
    }
)
Out[7]:
In [8]:
hv.Scatter(
    data={"x": SAMPLE_DATA["LOG_INPUT_FREQS"], "y": MLE_ESTIMATES["log_input_freqs"]}
)
Out[8]:

As we'd expect considering we know the exact generative process, pretty good alignment between MLE and true values. Obviously, in practice, we do not know the generative process--our model defines what we believe it to be, then we fit the model and evaluate how well the fit model describes the data. One simple way to evaluate goodness of fit is to bootstrap samples from the MLE and perform a posterior predictive check on the results:

In [9]:
INFERENCE_OBJ = MLE.get_inference_obj() # Bootstrapping
INFERENCE_OBJ.run_ppc() # Checking fit
Out[9]:
BokehModel(combine_events=True, render_bundle={'docs_json': {'315e15f2-3cee-49ae-bc1b-8ab500f48833': {'version…

The above example is contrived, so the plots look a little too pristine compared to what you'd get in a real-world setting. Here's how to interpret them, though:

  1. The first plot shows the true rank (x-axis) of an observation against its values bootstrapped from the model. The x-axis does not mean anything in this case--it is shown by true rank to simplify data presentation when there are thousands or more observed points. The y-axis, however, shows the distribution of bootstrapped values (grey) and associated observed values (gold). A model that is fit well will have most gold dots falling within the shaded regions.
  2. The next plot has the same axis, but now shows the quantile of true observations relative to the bootstrapped distribution of observations. This figure is also designed to work with thousands or more data points, so it aggregates nearby points into hexagonal bins. It is not the most useful figure for data at the scale of this example; however, for larger datasets, what you would want to see is uniform distribution of probability across all quantiles at all values with the median line (grey) running down the center.
  3. The final plot is a quantile-quantile plot. Effectively, it is an ECDF over the quantile values for observables relative to bootstrapped samples. Because, by definition, each quantile should be uniformly represented in a sufficiently large sample from a perfectly calibrated, a perfectly calibrated (and fit) model will have a diagonal ECDF, indicated in the figure by the dashed line. The annotation "Absolute Deviance" gives the absolute area between the observed and ideal ECDF curves and can be used to measure how well calibrated and fit a given model is relative to another.

It should also be noted that the above plot is interactive! Use the dropdown to update the plot for different observables.

Now, one final note on the INFERENCE_OBJ variable. It exposes a special property inference_obj which is an ArviZ InferenceData instance holding all bootstrapped data. Use it to plug directly into the ArviZ ecosystem:

In [10]:
INFERENCE_OBJ.inference_obj  # ArviZ InferenceData instance
Out[10]:
arviz.InferenceData
    • <xarray.Dataset> Size: 160kB
      Dimensions:                (chain: 1, draw: 1000, a: 10)
      Coordinates:
          log_input_freqs.alpha  (a) float64 80B 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0
      Dimensions without coordinates: chain, draw, a
      Data variables:
          log_enrichment         (chain, draw, a) float64 80kB -3.052 ... -0.9929
          log_input_freqs        (chain, draw, a) float64 80kB -3.075 ... -2.559
      xarray.Dataset
        • chain: 1
        • draw: 1000
        • a: 10
        • log_input_freqs.alpha
          (a)
          float64
          1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0
          array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
        • log_enrichment
          (chain, draw, a)
          float64
          -3.052 -3.484 ... -4.59 -0.9929
          array([[[-3.05162068, -3.48388047, -2.45199896, ..., -2.11331152,
                   -2.77574977, -1.50744984],
                  [-3.54165014, -2.46428759, -2.49184172, ..., -2.16073524,
                   -0.56073761, -3.71667996],
                  [-3.08529589, -2.08587083, -3.46449429, ..., -3.97047433,
                   -2.21290193, -2.53005757],
                  ...,
                  [-3.01478083, -2.26081935, -3.46390801, ..., -2.97368984,
                   -4.6902851 , -1.10722984],
                  [-3.99862541, -1.6489174 , -1.80430865, ..., -1.89673003,
                   -3.35455209, -2.33587573],
                  [-3.92141236, -2.46620363, -4.53229531, ..., -2.70640658,
                   -4.58967283, -0.99287571]]], shape=(1, 1000, 10))
        • log_input_freqs
          (chain, draw, a)
          float64
          -3.075 -1.426 ... -1.771 -2.559
          array([[[-3.07481256, -1.42619118, -1.75907631, ..., -3.67515058,
                   -2.24584176, -1.95670032],
                  [-3.1103593 , -5.45017524, -2.43984753, ..., -2.78305232,
                   -1.32553558, -1.52997179],
                  [-2.44140875, -3.04075668, -4.79461183, ..., -4.18685807,
                   -1.74526816, -2.56184337],
                  ...,
                  [-2.37260952, -2.01023605, -1.21402788, ..., -2.81535244,
                   -3.32033473, -1.66475567],
                  [-5.16400392, -2.0192338 , -2.43027598, ..., -2.51321467,
                   -4.54908608, -1.86393857],
                  [-1.50167806, -4.53370862, -2.74664954, ..., -2.60190626,
                   -1.77074476, -2.5592491 ]]], shape=(1, 1000, 10))

      • <xarray.Dataset> Size: 480kB
        Dimensions:                (chain: 1, draw: 1000, b: 3, a: 10)
        Coordinates:
            input_counts.N         (b) int64 24B 10000 10000 10000
            log_input_freqs.alpha  (a) float64 80B 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0
            output_counts.N        (b) int64 24B 10000 10000 10000
        Dimensions without coordinates: chain, draw, b, a
        Data variables:
            input_counts           (chain, draw, b, a) float64 240kB 979.0 ... 716.0
            output_counts          (chain, draw, b, a) float64 240kB 3.211e+03 ... 122.0
        xarray.Dataset
          • chain: 1
          • draw: 1000
          • b: 3
          • a: 10
          • input_counts.N
            (b)
            int64
            10000 10000 10000
            array([10000, 10000, 10000])
          • log_input_freqs.alpha
            (a)
            float64
            1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0
            array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
          • output_counts.N
            (b)
            int64
            10000 10000 10000
            array([10000, 10000, 10000])
          • input_counts
            (chain, draw, b, a)
            float64
            979.0 3.523e+03 ... 1.149e+03 716.0
            array([[[[ 979., 3523.,  846., ...,   22., 1111.,  646.],
                     [1005., 3505.,  801., ...,   19., 1201.,  665.],
                     [ 984., 3508.,  783., ...,   25., 1230.,  683.]],
            
                    [[ 993., 3499.,  765., ...,   21., 1140.,  726.],
                     [ 990., 3578.,  777., ...,   20., 1110.,  736.],
                     [ 984., 3496.,  838., ...,   16., 1228.,  688.]],
            
                    [[ 967., 3471.,  804., ...,   22., 1214.,  701.],
                     [1006., 3531.,  746., ...,   24., 1150.,  692.],
                     [ 988., 3472.,  796., ...,   12., 1128.,  686.]],
            
                    ...,
            
                    [[1032., 3509.,  774., ...,   26., 1148.,  702.],
                     [ 985., 3586.,  815., ...,   29., 1158.,  694.],
                     [ 970., 3603.,  766., ...,   26., 1176.,  679.]],
            
                    [[ 985., 3514.,  788., ...,   22., 1190.,  711.],
                     [1046., 3545.,  795., ...,   21., 1120.,  672.],
                     [ 958., 3575.,  823., ...,   15., 1143.,  644.]],
            
                    [[ 958., 3528.,  793., ...,   19., 1227.,  772.],
                     [1014., 3499.,  801., ...,   19., 1148.,  713.],
                     [ 969., 3522.,  770., ...,   26., 1149.,  716.]]]],
                  shape=(1, 1000, 3, 10))
          • output_counts
            (chain, draw, b, a)
            float64
            3.211e+03 1.322e+03 ... 122.0
            array([[[[3211., 1322.,  472., ...,   11., 3837.,  131.],
                     [3129., 1326.,  491., ...,   17., 3868.,  103.],
                     [3179., 1328.,  493., ...,    5., 3898.,  118.]],
            
                    [[3294., 1353.,  503., ...,   17., 3727.,  101.],
                     [3181., 1271.,  478., ...,    6., 3968.,  103.],
                     [3154., 1310.,  508., ...,   20., 3884.,  114.]],
            
                    [[3120., 1308.,  479., ...,   15., 3977.,  107.],
                     [3256., 1241.,  465., ...,    8., 3894.,   94.],
                     [3195., 1293.,  481., ...,    9., 3894.,  100.]],
            
                    ...,
            
                    [[3186., 1300.,  506., ...,   14., 3924.,   99.],
                     [3150., 1292.,  505., ...,   13., 3941.,   94.],
                     [3154., 1301.,  507., ...,   18., 3859.,  125.]],
            
                    [[3156., 1283.,  505., ...,   22., 3939.,   90.],
                     [3168., 1246.,  491., ...,   11., 3946.,  105.],
                     [3194., 1273.,  434., ...,   14., 3924.,  121.]],
            
                    [[3200., 1296.,  478., ...,   13., 3921.,  126.],
                     [3211., 1223.,  505., ...,   10., 3939.,  100.],
                     [3158., 1329.,  464., ...,   15., 3891.,  122.]]]],
                  shape=(1, 1000, 3, 10))

        • <xarray.Dataset> Size: 480B
          Dimensions:        (b: 3, a: 10)
          Dimensions without coordinates: b, a
          Data variables:
              input_counts   (b, a) int64 240B 965 3504 777 963 545 ... 755 21 1125 691
              output_counts  (b, a) int64 240B 3172 1273 481 471 312 ... 86 14 3936 108
          xarray.Dataset
            • b: 3
            • a: 10
              • input_counts
                (b, a)
                int64
                965 3504 777 963 ... 21 1125 691
                array([[ 965, 3504,  777,  963,  545,  538,  799,   21, 1173,  715],
                       [ 994, 3476,  831,  970,  535,  492,  790,   25, 1166,  721],
                       [1006, 3551,  783,  973,  555,  540,  755,   21, 1125,  691]])
              • output_counts
                (b, a)
                int64
                3172 1273 481 471 ... 14 3936 108
                array([[3172, 1273,  481,  471,  312,  151,   72,   10, 3953,  105],
                       [3213, 1309,  483,  511,  297,  128,   74,   16, 3850,  119],
                       [3115, 1292,  514,  495,  300,  140,   86,   14, 3936,  108]])

          Alright, we're satisfied with our maximum likelihood estimation and ready to move on to full monte carlo sampling with Stan. Just call the below and SciStanPy will:

          1. Convert your SciStanPy model into Stan code.
          2. Compile that code.
          3. Run that code.
          4. Organize and return the results.
          In [11]:
          HMC_RES = EXAMPLE_MODEL.mcmc(iter_warmup=2000, iter_sampling=1000)
          
          13:16:34 - cmdstanpy - INFO - compiling stan file /tmp/tmpmmet_h7o/model.stan to exe file /tmp/tmpmmet_h7o/model
          
          13:17:03 - cmdstanpy - INFO - compiled model executable: /tmp/tmpmmet_h7o/model
          13:17:03 - cmdstanpy - WARNING - Stan compiler has produced 2 warnings:
          13:17:03 - cmdstanpy - WARNING - 
          --- Translating Stan model to C++ code ---
          bin/stanc --filename-in-msg=model.stan --warn-pedantic --O1 --include-paths=/home/bwittmann/micromamba/envs/ssp_test/lib/python3.12/site-packages/scistanpy/model/stan --o=/tmp/tmpmmet_h7o/model.hpp /tmp/tmpmmet_h7o/model.stan
          Warning: The parameter log_input_freqs_raw has no priors. This means either
              no prior is provided, or the prior(s) depend on data variables. In the
              later case, this may be a false positive.
          Warning: The parameter log_enrichment has no priors. This means either no
              prior is provided, or the prior(s) depend on data variables. In the later
              case, this may be a false positive.
          
          --- Compiling C++ code ---
          g++ -std=c++17 -pthread -D_REENTRANT -Wno-sign-compare -Wno-ignored-attributes -Wno-class-memaccess     -DSTAN_THREADS -I stan/lib/stan_math/lib/tbb_2020.3/include    -O3 -I src -I stan/src -I stan/lib/rapidjson_1.1.0/ -I lib/CLI11-1.9.1/ -I stan/lib/stan_math/ -I stan/lib/stan_math/lib/eigen_3.4.0 -I stan/lib/stan_math/lib/boost_1.87.0 -I stan/lib/stan_math/lib/sundials_6.1.1/include -I stan/lib/stan_math/lib/sundials_6.1.1/src/sundials    -DBOOST_DISABLE_ASSERTS          -c -Wno-ignored-attributes   -x c++ -o /tmp/tmpmmet_h7o/model.o /tmp/tmpmmet_h7o/model.hpp
          
          --- Linking model ---
          g++ -std=c++17 -pthread -D_REENTRANT -Wno-sign-compare -Wno-ignored-attributes -Wno-class-memaccess     -DSTAN_THREADS -I stan/lib/stan_math/lib/tbb_2020.3/include    -O3 -I src -I stan/src -I stan/lib/rapidjson_1.1.0/ -I lib/CLI11-1.9.1/ -I stan/lib/stan_math/ -I stan/lib/stan_math/lib/eigen_3.4.0 -I stan/lib/stan_math/lib/boost_1.87.0 -I stan/lib/stan_math/lib/sundials_6.1.1/include -I stan/lib/stan_math/lib/sundials_6.1.1/src/sundials    -DBOOST_DISABLE_ASSERTS               -Wl,-L,"/home/bwittmann/.cmdstan/cmdstan-2.37.0/stan/lib/stan_math/lib/tbb"   -Wl,-rpath,"/home/bwittmann/.cmdstan/cmdstan-2.37.0/stan/lib/stan_math/lib/tbb"      /tmp/tmpmmet_h7o/model.o src/cmdstan/main_threads.o       -ltbb   stan/lib/stan_math/lib/sundials_6.1.1/lib/libsundials_nvecserial.a stan/lib/stan_math/lib/sundials_6.1.1/lib/libsundials_cvodes.a stan/lib/stan_math/lib/sundials_6.1.1/lib/libsundials_idas.a stan/lib/stan_math/lib/sundials_6.1.1/lib/libsundials_kinsol.a  stan/lib/stan_math/lib/tbb/libtbb.so.2 -o /tmp/tmpmmet_h7o/model
          rm /tmp/tmpmmet_h7o/model.hpp /tmp/tmpmmet_h7o/model.o
          
          13:17:03 - cmdstanpy - INFO - CmdStan start processing
          
          chain 1 |          | 00:00 Status
          chain 2 |          | 00:00 Status
          chain 3 |          | 00:00 Status
          chain 4 |          | 00:00 Status
                                                                                                                                                                                                                                                                                                                                          
          13:17:05 - cmdstanpy - INFO - CmdStan done processing.
          
          
          
          Converting CSV to NetCDF: 100%|██████████| 4/4 [00:27<00:00,  6.75s/it]
          

          You can ignore warnings about parameters not having priors. This is a result of how data is passed to the model: that is, priors depend on data variables.

          Now that we have the results, lets run some diagnostics:

          In [12]:
          _ = HMC_RES.diagnose()
          
          Sample diagnostic tests results' summaries:
          -------------------------------------------
          0 of 4000 (0.00%) samples had a low energy.
          0 of 4000 (0.00%) samples reached the maximum tree depth.
          0 of 4000 (0.00%) samples diverged.
          
          R_hat diagnostic tests results' summaries:
          ------------------------------------------
          0 of 10 (0.00%) r_hats tests failed for log_enrichment.
          0 of 10 (0.00%) r_hats tests failed for log_input_freqs.
          0 of 10 (0.00%) r_hats tests failed for log_output_freqs.
          
          Ess_bulk diagnostic tests results' summaries:
          ---------------------------------------------
          0 of 10 (0.00%) ess_bulks tests failed for log_enrichment.
          0 of 10 (0.00%) ess_bulks tests failed for log_input_freqs.
          0 of 10 (0.00%) ess_bulks tests failed for log_output_freqs.
          
          Ess_tail diagnostic tests results' summaries:
          ---------------------------------------------
          0 of 10 (0.00%) ess_tails tests failed for log_enrichment.
          0 of 10 (0.00%) ess_tails tests failed for log_input_freqs.
          0 of 10 (0.00%) ess_tails tests failed for log_output_freqs.
          

          The diagnostic tests and their meanings are described in greater detail in the full documentation. For our purposes here, what matters is that they all passed!

          We stored the output of the function in a throw away variable, as we don't need it. However, it contains indices of failed samples and variables, as relevant. Also note that the SampleResults object has additional functionality for helping to diagnose failed samples, when they're present, via the plot_sample_failure_quantile_traces and plot_variable_failure_quantile_traces methods. See the full documentation for details on these methods.

          Finally, we're going to want to do a posterior predictive check on our Stan samples. Note that, unlike the MLE example, these samples should be representative of the full posterior, not just the MLE. The same workflow as for evaluating MLE applies here, however:

          In [13]:
          HMC_RES.run_ppc()
          
          Out[13]:
          BokehModel(combine_events=True, render_bundle={'docs_json': {'1dffcae8-2c6d-4ab4-b57c-21e451fb8dc8': {'version…

          As before, the fit looks reasonable. Also as before, note that we can access an underlying ArviZ InferenceData object for further analysis. It will have some additional fields compared to the MLE-associated one, reflecting the richer information content of HMC samples:

          In [14]:
          HMC_RES.inference_obj
          
          Out[14]:
          arviz.InferenceData
            • <xarray.Dataset> Size: 480kB
              Dimensions:           (chain: 4, draw: 1000, a: 10)
              Dimensions without coordinates: chain, draw, a
              Data variables:
                  log_enrichment    (chain, draw, a) float32 160kB -1.069 -3.236 ... -3.867
                  log_input_freqs   (chain, draw, a) float32 160kB -2.328 -1.053 ... -2.638
                  log_output_freqs  (chain, draw, a) float32 160kB -1.142 -2.034 ... -4.514
              xarray.Dataset
                • chain: 4
                • draw: 1000
                • a: 10
                  • log_enrichment
                    (chain, draw, a)
                    float32
                    -1.069 -3.236 ... -0.735 -3.867
                    array([[[-1.068688, -3.236369, ..., -1.05469 , -4.08567 ],
                            [-1.262433, -3.427649, ..., -1.232311, -4.283817],
                            ...,
                            [-0.932845, -3.147924, ..., -0.896691, -4.035187],
                            [-0.77195 , -2.941099, ..., -0.727211, -3.724041]],
                    
                           [[-1.046526, -3.162458, ..., -0.94579 , -4.045295],
                            [-0.956545, -3.104215, ..., -0.877397, -4.097172],
                            ...,
                            [-0.64909 , -2.812824, ..., -0.626214, -3.604606],
                            [-0.882263, -3.056537, ..., -0.805593, -3.905529]],
                    
                           [[-1.333376, -3.492797, ..., -1.334005, -4.496884],
                            [-1.137693, -3.253041, ..., -1.087074, -4.189535],
                            ...,
                            [-1.194294, -3.381477, ..., -1.101621, -4.325435],
                            [-1.203106, -3.426734, ..., -1.159328, -4.311948]],
                    
                           [[-1.623296, -3.79377 , ..., -1.538482, -4.637006],
                            [-1.534369, -3.703756, ..., -1.514616, -4.493636],
                            ...,
                            [-0.645705, -2.858608, ..., -0.678251, -3.628458],
                            [-0.844686, -2.973438, ..., -0.735004, -3.867251]]],
                          shape=(4, 1000, 10), dtype=float32)
                  • log_input_freqs
                    (chain, draw, a)
                    float32
                    -2.328 -1.053 ... -2.189 -2.638
                    array([[[-2.328192, -1.052654, ..., -2.146743, -2.650253],
                            [-2.30359 , -1.053144, ..., -2.144945, -2.63379 ],
                            ...,
                            [-2.328521, -1.047004, ..., -2.156443, -2.659002],
                            [-2.331172, -1.044762, ..., -2.165393, -2.639492]],
                    
                           [[-2.285582, -1.0619  , ..., -2.161284, -2.598546],
                            [-2.297102, -1.060655, ..., -2.149428, -2.589555],
                            ...,
                            [-2.320424, -1.054399, ..., -2.120126, -2.644726],
                            [-2.313679, -1.037961, ..., -2.181865, -2.686515]],
                    
                           [[-2.337826, -1.047284, ..., -2.141611, -2.589049],
                            [-2.306165, -1.057289, ..., -2.148264, -2.649077],
                            ...,
                            [-2.322836, -1.0529  , ..., -2.184392, -2.63638 ],
                            [-2.337645, -1.03459 , ..., -2.165001, -2.632983]],
                    
                           [[-2.320177, -1.038026, ..., -2.196003, -2.647448],
                            [-2.314257, -1.045154, ..., -2.130461, -2.659276],
                            ...,
                            [-2.336945, -1.042501, ..., -2.12371 , -2.692032],
                            [-2.302911, -1.04821 , ..., -2.189054, -2.637959]]],
                          shape=(4, 1000, 10), dtype=float32)
                  • log_output_freqs
                    (chain, draw, a)
                    float32
                    -1.142 -2.034 ... -0.9327 -4.514
                    array([[[-1.141578, -2.033722, ..., -0.946132, -4.480623],
                            [-1.138317, -2.053086, ..., -0.94955 , -4.489902],
                            ...,
                            [-1.140021, -2.073582, ..., -0.931789, -4.572843],
                            [-1.154388, -2.037127, ..., -0.94387 , -4.414799]],
                    
                           [[-1.152483, -2.044734, ..., -0.92745 , -4.464217],
                            [-1.157013, -2.068236, ..., -0.93019 , -4.590092],
                            ...,
                            [-1.152963, -2.050672, ..., -0.92979 , -4.43278 ],
                            [-1.150977, -2.049534, ..., -0.942494, -4.547079]],
                    
                           [[-1.144643, -2.013521, ..., -0.949056, -4.559373],
                            [-1.156536, -2.023007, ..., -0.948016, -4.55129 ],
                            ...,
                            [-1.154477, -2.071724, ..., -0.923361, -4.599162],
                            [-1.152518, -2.073091, ..., -0.936096, -4.556698]],
                    
                           [[-1.150744, -2.039067, ..., -0.941757, -4.491725],
                            [-1.148419, -2.048703, ..., -0.944871, -4.452706],
                            ...,
                            [-1.132241, -2.0507  , ..., -0.951551, -4.47008 ],
                            [-1.156215, -2.030266, ..., -0.932676, -4.513828]]],
                          shape=(4, 1000, 10), dtype=float32)

                • <xarray.Dataset> Size: 960kB
                  Dimensions:        (chain: 4, draw: 1000, b: 3, a: 10)
                  Dimensions without coordinates: chain, draw, b, a
                  Data variables:
                      input_counts   (chain, draw, b, a) int32 480kB 1020 3470 792 ... 24 1168 681
                      output_counts  (chain, draw, b, a) int32 480kB 3149 1251 454 ... 8 3920 119
                  xarray.Dataset
                    • chain: 4
                    • draw: 1000
                    • b: 3
                    • a: 10
                      • input_counts
                        (chain, draw, b, a)
                        int32
                        1020 3470 792 1027 ... 24 1168 681
                        array([[[[1020, ...,  712],
                                 ...,
                                 [ 992, ...,  679]],
                        
                                ...,
                        
                                [[ 924, ...,  705],
                                 ...,
                                 [ 963, ...,  751]]],
                        
                        
                               ...,
                        
                        
                               [[[1024, ...,  743],
                                 ...,
                                 [1018, ...,  738]],
                        
                                ...,
                        
                                [[ 966, ...,  704],
                                 ...,
                                 [ 954, ...,  681]]]], shape=(4, 1000, 3, 10), dtype=int32)
                      • output_counts
                        (chain, draw, b, a)
                        int32
                        3149 1251 454 502 ... 71 8 3920 119
                        array([[[[3149, ...,   97],
                                 ...,
                                 [3232, ...,  116]],
                        
                                ...,
                        
                                [[3141, ...,  102],
                                 ...,
                                 [3077, ...,  129]]],
                        
                        
                               ...,
                        
                        
                               [[[3123, ...,  117],
                                 ...,
                                 [3154, ...,  118]],
                        
                                ...,
                        
                                [[3184, ...,  109],
                                 ...,
                                 [3151, ...,  119]]]], shape=(4, 1000, 3, 10), dtype=int32)

                    • <xarray.Dataset> Size: 112kB
                      Dimensions:        (chain: 4, draw: 1000)
                      Dimensions without coordinates: chain, draw
                      Data variables:
                          lp__           (chain, draw) float32 16kB ...
                          accept_stat__  (chain, draw) float32 16kB ...
                          stepsize__     (chain, draw) float32 16kB ...
                          treedepth__    (chain, draw) int32 16kB 7 7 8 7 7 7 7 7 ... 7 7 7 7 7 7 8 7
                          n_leapfrog__   (chain, draw) int32 16kB ...
                          divergent__    (chain, draw) int32 16kB 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0
                          energy__       (chain, draw) float32 16kB 270.7 269.4 275.2 ... 284.3 285.1
                      xarray.Dataset
                        • chain: 4
                        • draw: 1000
                          • lp__
                            (chain, draw)
                            float32
                            ...
                            [4000 values with dtype=float32]
                          • accept_stat__
                            (chain, draw)
                            float32
                            ...
                            [4000 values with dtype=float32]
                          • stepsize__
                            (chain, draw)
                            float32
                            ...
                            [4000 values with dtype=float32]
                          • treedepth__
                            (chain, draw)
                            int32
                            7 7 8 7 7 7 7 7 ... 7 7 7 7 7 7 8 7
                            array([[7, 7, 8, ..., 6, 7, 7],
                                   [7, 7, 7, ..., 7, 7, 7],
                                   [6, 7, 7, ..., 7, 6, 7],
                                   [7, 7, 8, ..., 7, 8, 7]], shape=(4, 1000), dtype=int32)
                          • n_leapfrog__
                            (chain, draw)
                            int32
                            ...
                            [4000 values with dtype=int32]
                          • divergent__
                            (chain, draw)
                            int32
                            0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0
                            array([[0, 0, 0, ..., 0, 0, 0],
                                   [0, 0, 0, ..., 0, 0, 0],
                                   [0, 0, 0, ..., 0, 0, 0],
                                   [0, 0, 0, ..., 0, 0, 0]], shape=(4, 1000), dtype=int32)
                          • energy__
                            (chain, draw)
                            float32
                            270.7 269.4 275.2 ... 284.3 285.1
                            array([[270.7353 , 269.35608, 275.17273, ..., 277.98468, 278.795  , 276.4616 ],
                                   [283.55234, 284.7645 , 282.1037 , ..., 281.1396 , 280.9188 , 280.6225 ],
                                   [288.111  , 288.29886, 281.03134, ..., 280.14832, 286.42044, 286.70654],
                                   [279.701  , 278.34268, 272.34003, ..., 281.4624 , 284.348  , 285.05832]],
                                  shape=(4, 1000), dtype=float32)

                        • <xarray.Dataset> Size: 240B
                          Dimensions:        (b: 3, a: 10)
                          Dimensions without coordinates: b, a
                          Data variables:
                              input_counts   (b, a) int32 120B 965 3504 777 963 545 ... 755 21 1125 691
                              output_counts  (b, a) int32 120B 3172 1273 481 471 312 ... 86 14 3936 108
                          xarray.Dataset
                            • b: 3
                            • a: 10
                              • input_counts
                                (b, a)
                                int32
                                965 3504 777 963 ... 21 1125 691
                                array([[ 965, 3504,  777,  963,  545,  538,  799,   21, 1173,  715],
                                       [ 994, 3476,  831,  970,  535,  492,  790,   25, 1166,  721],
                                       [1006, 3551,  783,  973,  555,  540,  755,   21, 1125,  691]],
                                      dtype=int32)
                              • output_counts
                                (b, a)
                                int32
                                3172 1273 481 471 ... 14 3936 108
                                array([[3172, 1273,  481,  471,  312,  151,   72,   10, 3953,  105],
                                       [3213, 1309,  483,  511,  297,  128,   74,   16, 3850,  119],
                                       [3115, 1292,  514,  495,  300,  140,   86,   14, 3936,  108]],
                                      dtype=int32)

                            • <xarray.Dataset> Size: 1kB
                              Dimensions:           (metric: 5, a: 10)
                              Coordinates:
                                * metric            (metric) <U9 180B 'mcse_mean' 'mcse_sd' ... 'r_hat'
                              Dimensions without coordinates: a
                              Data variables:
                                  log_enrichment    (metric, a) float64 400B 0.01 0.01 0.01 ... 1.0 1.0 1.0
                                  log_input_freqs   (metric, a) float64 400B 0.0 0.0 0.0 0.0 ... 1.0 1.0 1.0
                                  log_output_freqs  (metric, a) float64 400B 0.0 0.0 0.0 0.0 ... 1.0 1.0 1.0
                              xarray.Dataset
                                • metric: 5
                                • a: 10
                                • metric
                                  (metric)
                                  <U9
                                  'mcse_mean' 'mcse_sd' ... 'r_hat'
                                  array(['mcse_mean', 'mcse_sd', 'ess_bulk', 'ess_tail', 'r_hat'], dtype='<U9')
                                • log_enrichment
                                  (metric, a)
                                  float64
                                  0.01 0.01 0.01 0.01 ... 1.0 1.0 1.0
                                  array([[1.00000e-02, 1.00000e-02, 1.00000e-02, 1.00000e-02, 1.00000e-02,
                                          1.00000e-02, 1.00000e-02, 1.00000e-02, 1.00000e-02, 1.00000e-02],
                                         [1.00000e-02, 1.00000e-02, 1.00000e-02, 1.00000e-02, 1.00000e-02,
                                          1.00000e-02, 1.00000e-02, 1.00000e-02, 1.00000e-02, 1.00000e-02],
                                         [5.99790e+02, 6.00140e+02, 5.97170e+02, 6.02250e+02, 6.03920e+02,
                                          6.04780e+02, 6.15870e+02, 7.26230e+02, 5.99890e+02, 6.12910e+02],
                                         [7.49200e+02, 7.85320e+02, 7.46800e+02, 8.37950e+02, 7.42870e+02,
                                          7.85930e+02, 8.38430e+02, 1.04279e+03, 7.93680e+02, 8.03140e+02],
                                         [1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00,
                                          1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00]])
                                • log_input_freqs
                                  (metric, a)
                                  float64
                                  0.0 0.0 0.0 0.0 ... 1.0 1.0 1.0 1.0
                                  array([[0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00,
                                          0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00],
                                         [0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00,
                                          0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00],
                                         [4.14929e+03, 5.99540e+03, 4.78792e+03, 5.00634e+03, 4.49178e+03,
                                          4.28000e+03, 5.19919e+03, 2.24705e+03, 5.62300e+03, 5.54843e+03],
                                         [2.81655e+03, 3.46402e+03, 3.07469e+03, 2.76553e+03, 3.24866e+03,
                                          2.85078e+03, 2.66078e+03, 2.28510e+03, 3.36446e+03, 3.00511e+03],
                                         [1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00,
                                          1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00]])
                                • log_output_freqs
                                  (metric, a)
                                  float64
                                  0.0 0.0 0.0 0.0 ... 1.0 1.0 1.0 1.0
                                  array([[0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00,
                                          0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00],
                                         [0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00,
                                          0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00],
                                         [3.94222e+03, 3.90681e+03, 4.39845e+03, 4.19513e+03, 3.91825e+03,
                                          3.79242e+03, 4.10996e+03, 4.43777e+03, 4.16450e+03, 3.82122e+03],
                                         [3.39500e+03, 2.56051e+03, 3.32536e+03, 3.47363e+03, 2.86068e+03,
                                          3.11030e+03, 2.46409e+03, 2.67234e+03, 3.00552e+03, 2.65292e+03],
                                         [1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00,
                                          1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00]])
                                • metric
                                  PandasIndex
                                  PandasIndex(Index(['mcse_mean', 'mcse_sd', 'ess_bulk', 'ess_tail', 'r_hat'], dtype='object', name='metric'))

                            • <xarray.Dataset> Size: 1kB
                              Dimensions:           (metric: 5, a: 10)
                              Coordinates:
                                * metric            (metric) <U9 180B 'ess_bulk' 'ess_tail' ... 'r_hat'
                              Dimensions without coordinates: a
                              Data variables:
                                  log_enrichment    (metric, a) float64 400B 599.8 600.1 597.2 ... 1.0 1.0 1.0
                                  log_input_freqs   (metric, a) float64 400B 4.149e+03 5.995e+03 ... 1.0 1.0
                                  log_output_freqs  (metric, a) float64 400B 3.942e+03 3.907e+03 ... 1.0 1.0
                              xarray.Dataset
                                • metric: 5
                                • a: 10
                                • metric
                                  (metric)
                                  <U9
                                  'ess_bulk' 'ess_tail' ... 'r_hat'
                                  array(['ess_bulk', 'ess_tail', 'mcse_mean', 'mcse_sd', 'r_hat'], dtype='<U9')
                                • log_enrichment
                                  (metric, a)
                                  float64
                                  599.8 600.1 597.2 ... 1.0 1.0 1.0
                                  array([[5.99790e+02, 6.00140e+02, 5.97170e+02, 6.02250e+02, 6.03920e+02,
                                          6.04780e+02, 6.15870e+02, 7.26230e+02, 5.99890e+02, 6.12910e+02],
                                         [7.49200e+02, 7.85320e+02, 7.46800e+02, 8.37950e+02, 7.42870e+02,
                                          7.85930e+02, 8.38430e+02, 1.04279e+03, 7.93680e+02, 8.03140e+02],
                                         [1.00000e-02, 1.00000e-02, 1.00000e-02, 1.00000e-02, 1.00000e-02,
                                          1.00000e-02, 1.00000e-02, 1.00000e-02, 1.00000e-02, 1.00000e-02],
                                         [1.00000e-02, 1.00000e-02, 1.00000e-02, 1.00000e-02, 1.00000e-02,
                                          1.00000e-02, 1.00000e-02, 1.00000e-02, 1.00000e-02, 1.00000e-02],
                                         [1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00,
                                          1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00]])
                                • log_input_freqs
                                  (metric, a)
                                  float64
                                  4.149e+03 5.995e+03 ... 1.0 1.0
                                  array([[4.14929e+03, 5.99540e+03, 4.78792e+03, 5.00634e+03, 4.49178e+03,
                                          4.28000e+03, 5.19919e+03, 2.24705e+03, 5.62300e+03, 5.54843e+03],
                                         [2.81655e+03, 3.46402e+03, 3.07469e+03, 2.76553e+03, 3.24866e+03,
                                          2.85078e+03, 2.66078e+03, 2.28510e+03, 3.36446e+03, 3.00511e+03],
                                         [0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00,
                                          0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00],
                                         [0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00,
                                          0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00],
                                         [1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00,
                                          1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00]])
                                • log_output_freqs
                                  (metric, a)
                                  float64
                                  3.942e+03 3.907e+03 ... 1.0 1.0
                                  array([[3.94222e+03, 3.90681e+03, 4.39845e+03, 4.19513e+03, 3.91825e+03,
                                          3.79242e+03, 4.10996e+03, 4.43777e+03, 4.16450e+03, 3.82122e+03],
                                         [3.39500e+03, 2.56051e+03, 3.32536e+03, 3.47363e+03, 2.86068e+03,
                                          3.11030e+03, 2.46409e+03, 2.67234e+03, 3.00552e+03, 2.65292e+03],
                                         [0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00,
                                          0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00],
                                         [0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00,
                                          0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00],
                                         [1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00,
                                          1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00, 1.00000e+00]])
                                • metric
                                  PandasIndex
                                  PandasIndex(Index(['ess_bulk', 'ess_tail', 'mcse_mean', 'mcse_sd', 'r_hat'], dtype='object', name='metric'))

                            • <xarray.Dataset> Size: 12kB
                              Dimensions:                 (chain: 4, draw: 1000)
                              Dimensions without coordinates: chain, draw
                              Data variables:
                                  low_ebfmi               (chain, draw) bool 4kB False False ... False False
                                  max_tree_depth_reached  (chain, draw) bool 4kB False False ... False False
                                  diverged                (chain, draw) bool 4kB False False ... False False
                              xarray.Dataset
                                • chain: 4
                                • draw: 1000
                                  • low_ebfmi
                                    (chain, draw)
                                    bool
                                    False False False ... False False
                                    array([[False, False, False, ..., False, False, False],
                                           [False, False, False, ..., False, False, False],
                                           [False, False, False, ..., False, False, False],
                                           [False, False, False, ..., False, False, False]], shape=(4, 1000))
                                  • max_tree_depth_reached
                                    (chain, draw)
                                    bool
                                    False False False ... False False
                                    array([[False, False, False, ..., False, False, False],
                                           [False, False, False, ..., False, False, False],
                                           [False, False, False, ..., False, False, False],
                                           [False, False, False, ..., False, False, False]], shape=(4, 1000))
                                  • diverged
                                    (chain, draw)
                                    bool
                                    False False False ... False False
                                    array([[False, False, False, ..., False, False, False],
                                           [False, False, False, ..., False, False, False],
                                           [False, False, False, ..., False, False, False],
                                           [False, False, False, ..., False, False, False]], shape=(4, 1000))

                                • <xarray.Dataset> Size: 198B
                                  Dimensions:           (metric: 3, a: 10)
                                  Coordinates:
                                    * metric            (metric) <U9 108B 'r_hat' 'ess_bulk' 'ess_tail'
                                  Dimensions without coordinates: a
                                  Data variables:
                                      log_enrichment    (metric, a) bool 30B False False False ... False False
                                      log_input_freqs   (metric, a) bool 30B False False False ... False False
                                      log_output_freqs  (metric, a) bool 30B False False False ... False False
                                  xarray.Dataset
                                    • metric: 3
                                    • a: 10
                                    • metric
                                      (metric)
                                      <U9
                                      'r_hat' 'ess_bulk' 'ess_tail'
                                      array(['r_hat', 'ess_bulk', 'ess_tail'], dtype='<U9')
                                    • log_enrichment
                                      (metric, a)
                                      bool
                                      False False False ... False False
                                      array([[False, False, False, False, False, False, False, False, False,
                                              False],
                                             [False, False, False, False, False, False, False, False, False,
                                              False],
                                             [False, False, False, False, False, False, False, False, False,
                                              False]])
                                    • log_input_freqs
                                      (metric, a)
                                      bool
                                      False False False ... False False
                                      array([[False, False, False, False, False, False, False, False, False,
                                              False],
                                             [False, False, False, False, False, False, False, False, False,
                                              False],
                                             [False, False, False, False, False, False, False, False, False,
                                              False]])
                                    • log_output_freqs
                                      (metric, a)
                                      bool
                                      False False False ... False False
                                      array([[False, False, False, False, False, False, False, False, False,
                                              False],
                                             [False, False, False, False, False, False, False, False, False,
                                              False],
                                             [False, False, False, False, False, False, False, False, False,
                                              False]])
                                    • metric
                                      PandasIndex
                                      PandasIndex(Index(['r_hat', 'ess_bulk', 'ess_tail'], dtype='object', name='metric'))

                              In [ ]: