11. Assumptions, diagnostics, and model evaluation¶

11.1 Assumptions of regression analysis¶

In decreasing order of importance:

  1. Validity: Data analysing should map to research question. Ideally output variable should reflect phenomenon of interest, model include all relevant predictors and model should generalise to the cases to which it will be applied. Not all models can be generalised to all cases. In practice all these criteria will be perfectly upheld.
  2. Representativeness: regression model fit to data and used to make inferences about a larger population, hence implicit assumption is that the sample is representative of the population. If not, then inferences may be biased and not generalisable. Key assumption is that data are representative of the outcome y given the predictors x. E.g., regression of earnings on height and sex, it is OK to have more females and tall people (overrepresented), but not OK to have people with higher earnings. The model will not see true relationship between earnings and height and sex, particularly for lower earnings. More predictors allows to control for more factors. Example: You fit a regression on past elections, then use it to predict the next one. Assumption 1: 'superpopulation' - past elections and future are drawn from the same imaginary pool of elections. Assumption 2: 'random errors' - regression captures the relationship, every election deviates by some error. Errors are from same normal distribution. These are saying the same thing, but in different ways.
  3. Additivity and linearity: most important assumption of linear regression is that the deterministic component of the model is a linear function of the separate predictors. If violated, then consider transforming predictors, adding interaction terms, or using a different type of model (e.g., generalized additive model).
  4. Independence of errors: errors are independent of each other. Violated in timeseries, spatial and multilevel data.
  5. Equal variance of errors (homoscedasticity): errors are how far real points fall from the regression line. Homoscedasticity = errors are equally spread out, heteroscedasticity = errors are more spread out for some values of predictors than others. Does it matter? If just understanding the relationship, then no. If want to make inferences about the relationship, then yes. The fix is to use weighted least squares regression.
  6. Normality of errors: errors are normally distributed. Less of an issue than we think. 1. Estimating the regression line - normality barely matters. 2. Predicting data with uncertainty - normality matters more, because shape of the errors shape prediction intervals.
Failures of the assumptions¶

What to do if assumptions are violated?

Extend the model. E.g., measurement error models to address issues with validity, selection models to address non-representativeness, nonadditive and nonlinear models, correlated errors or latent variables to address independence.

Change data or model so assumptions are more reasonable. E.g., obtain cleaner data, add predictors, add interactions, transform predictors.

Alternatively, change or restrict research question to align with data.

Causal inference¶

A regression coefficient is not automatically a causal effect. Regression represents the average relationship between the outcome and predictor, holding other predictors constant. E.g., 'on average, two people who differ by one inch in height differ about 1,000 in earnings'. This is a comparison across people, not a comparison within the same person.

The key question for causal claims: what is the treatment? To say something causes something, we have to specify how it would happen.

Regression tells us patterns of comparison in existing data. If we want to say what would happen if we intervened, then we need to make extra assumptions and a clear specification of the treatment.

11.2 Plotting the data and fitted model¶

Displaying a regression line as a function of one input variable¶
In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import arviz as az
import pymc as pm 
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.robust.scale import mad
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/arviz/__init__.py:50: FutureWarning: 
ArviZ is undergoing a major refactor to improve flexibility and extensibility while maintaining a user-friendly interface.
Some upcoming changes may be backward incompatible.
For details and migration guidance, visit: https://python.arviz.org/en/latest/user_guide/migration_guide.html
  warn(
In [2]:
def fit_and_plot_lm(data, predictors, outcome, add_constant=True, show_plot=True, scatter_kws=None, line_kws=None):
    """
    Fit a linear model using statsmodels, print summary, plot, and show formula.
    Args:
        data: pandas DataFrame
        predictors: list of predictor column names (str)
        outcome: outcome column name (str)
        add_constant: whether to add intercept (default True)
        show_plot: whether to plot (default True)
        scatter_kws: dict, kwargs for scatterplot
        line_kws: dict, kwargs for regression line
    """
    X = data[predictors].copy()
    if add_constant:
        X = sm.add_constant(X, prepend=False)
    y = data[outcome]
    model = sm.OLS(y, X)
    results = model.fit()
    print(results.summary())
    # Print formula
    params = results.params
    formula = f"{outcome} = " + " + ".join([f"{params[name]:.2f}*{name}" for name in predictors])
    if add_constant:
        formula = f"{outcome} = {params['const']:.2f} + " + " + ".join([f"{params[name]:.2f}*{name}" for name in predictors])
    print("Formula:", formula)
    # Print residual standard deviation and its uncertainty
    sigma = np.sqrt(results.mse_resid)
    sigma_se = sigma / np.sqrt(2 * results.df_resid)
    print(f"Residual std dev (σ): {sigma:.2f} ± {sigma_se:.2f}")
    # Print median absolute deviation of residuals
    print("MAD of residuals:", round(mad(results.resid), 2))
    # Plot if only one predictor
    if show_plot and len(predictors) == 1:
        x_name = predictors[0]
        ax = sns.scatterplot(data=data, x=x_name, y=outcome, **(scatter_kws or {}))
        x_vals = np.linspace(data[x_name].min(), data[x_name].max(), 100)
        y_vals = params['const'] + params[x_name] * x_vals if add_constant else params[x_name] * x_vals
        ax.plot(x_vals, y_vals, color='red', **(line_kws or {}))
        ax.set_title('Linear Regression Fit')
        plt.show()
In [3]:
def fit_and_plot_bayes(data, predictors, outcome,
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=True, show_forest=True,
                       show_posterior=True, show_regression=True,
                       show_residuals=False,
                       n_regression_lines=100):
    """
    Fit a Bayesian linear regression with PyMC and produce diagnostic plots.

    The model is:
        y ~ Normal(mu, sigma)
        mu = intercept + sum_k (slope_k * x_k)
        intercept ~ Normal(intercept_mu, intercept_sigma)
        slope_k   ~ Normal(slope_mu, slope_sigma)     for each predictor k
        sigma     ~ HalfNormal(sigma_sigma)

    Supports one or multiple predictors. After sampling, prints a posterior
    summary and the mean regression formula, then optionally shows trace,
    forest, posterior density, regression-line, and residual plots.

    Args:
        data (pd.DataFrame): DataFrame containing predictor and outcome columns.
        predictors (str | list[str]): Predictor column name(s). A single string
            is treated as a one-predictor model; a list allows multiple predictors.
        outcome (str): Outcome column name.
        intercept_mu (float): Prior mean for the intercept Normal prior.
        intercept_sigma (float): Prior std dev for the intercept Normal prior.
        slope_mu (float): Prior mean for each slope's Normal prior.
        slope_sigma (float): Prior std dev for each slope's Normal prior.
        sigma_sigma (float): Scale for the HalfNormal prior on residual noise.
        samples (int): Number of posterior draws per chain.
        tune (int): Number of NUTS tuning (warm-up) steps.
        hdi_prob (float): Probability mass used for HDI summaries and plots.
        show_trace (bool): If True, plot MCMC traces and marginal densities.
        show_forest (bool): If True, plot a forest plot of posterior means + HDIs.
        show_posterior (bool): If True, plot posterior density for each parameter.
        show_regression (bool): If True, plot data with overlaid posterior
            regression lines (one subplot per predictor; other predictors held at mean).
        show_residuals (bool): If True, plot residuals (y - y_hat using posterior
            mean coefficients) vs fitted values and vs each predictor.
        n_regression_lines (int): Number of posterior draws to overlay on the
            regression plot when show_regression is True.

    Returns:
        arviz.InferenceData: The PyMC trace (posterior samples and sampling stats).
    """
    if isinstance(predictors, str):  # Allow a single predictor to be passed as a bare string.
        predictors = [predictors]  # Normalise to a list so downstream code can iterate uniformly.
    y = data[outcome].values  # Extract outcome values as a NumPy array for PyMC.

    with pm.Model() as model:  # Open a PyMC model context; variables below are registered to it.
        intercept = pm.Normal("intercept", mu=intercept_mu, sigma=intercept_sigma)  # Normal prior on the intercept.
        slopes = []  # Collect slope RVs (not strictly needed, but handy if you want to inspect them).
        mu = intercept  # Start building the linear predictor; will accumulate slope*x terms below.
        for pred in predictors:  # Loop over each predictor column name.
            s = pm.Normal(f"slope_{pred}", mu=slope_mu, sigma=slope_sigma)  # Normal prior on this predictor's slope.
            slopes.append(s)  # Keep a reference to the slope RV.
            mu = mu + s * data[pred].values  # Add this predictor's contribution to the linear predictor.
        sigma = pm.HalfNormal("sigma", sigma=sigma_sigma)  # HalfNormal prior on residual std dev (must be positive).
        likelihood = pm.Normal("y", mu=mu, sigma=sigma, observed=y)  # Likelihood: observed y conditional on mu and sigma.
        trace = pm.sample(samples, tune=tune, idata_kwargs={"log_likelihood": True})  # Run NUTS; store log-likelihood for LOO.

    summary = pm.summary(trace, hdi_prob=hdi_prob)  # Compute posterior summary stats (mean, sd, HDI, r_hat, ess).
    print(summary)  # Print the summary table for the user.

    # Build and print the mean-posterior regression formula for quick inspection.
    posterior = trace.posterior  # Shortcut to the posterior group of the InferenceData.
    intercept_mean = posterior["intercept"].values.flatten().mean()  # Posterior mean of the intercept (flatten chains+draws).
    formula = f"{outcome} = {intercept_mean:.2f}"  # Start the formula string with the intercept.
    for pred in predictors:  # Append a "+ slope*predictor" term for each predictor.
        slope_mean = posterior[f"slope_{pred}"].values.flatten().mean()  # Posterior mean slope for this predictor.
        formula += f" + {slope_mean:.2f}*{pred}"  # Append this term to the formula string.
    print(f"\nRegression formula: {formula}")  # Print the assembled formula.

    sigma_draws = posterior["sigma"].values.flatten()
    print(f"Residual std dev (σ): {sigma_draws.mean():.2f} ± {sigma_draws.std():.2f}")

    intercept_draws = posterior["intercept"].values.flatten()
    X_mat = data[predictors].values
    slope_mat = np.column_stack([posterior[f"slope_{pred}"].values.flatten() for pred in predictors])
    y_hat_all = intercept_draws[:, None] + slope_mat @ X_mat.T
    bayes_r2 = az.r2_score(y, y_hat_all)
    print(f"Bayesian R²: {bayes_r2['r2']:.3f} ± {bayes_r2['r2_std']:.3f}")

    loo = az.loo(trace, pointwise=True)  # PSIS-LOO; pointwise=True gives per-obs Pareto-k diagnostics.
    n_obs = len(y)
    print(f"LOO-ELPD: {loo.elpd_loo:.2f} ± {loo.se:.2f}  (p_loo={loo.p_loo:.1f})")
    print(f"LOO log score (per obs): {loo.elpd_loo / n_obs:.3f} ± {loo.se / n_obs:.3f}")
    n_bad_k = int((loo.pareto_k.values > 0.7).sum())
    if n_bad_k > 0:
        print(f"  Warning: {n_bad_k} observations with Pareto-k > 0.7 (unreliable LOO estimates)")

    if show_trace:  # Optional: MCMC trace diagnostics.
        az.plot_trace(trace)  # Plot chain traces and marginal densities per parameter.
        plt.tight_layout()  # Tidy up subplot spacing.
        plt.show()  # Render the figure.

    if show_forest:  # Optional: forest plot of parameter posteriors.
        az.plot_forest(trace, hdi_prob=hdi_prob)  # Show posterior means and HDIs for all parameters.
        plt.show()  # Render the figure.

    if show_posterior:  # Optional: posterior density plots.
        az.plot_posterior(trace, hdi_prob=hdi_prob)  # Density plot per parameter, annotated with HDI.
        plt.show()  # Render the figure.

    if show_regression:  # Optional: overlay posterior regression lines on data.
        a_samples = posterior["intercept"].values.flatten()  # Flattened posterior draws of the intercept.
        slope_samples = {pred: posterior[f"slope_{pred}"].values.flatten() for pred in predictors}  # Flattened slope draws per predictor.
        idx = np.random.choice(len(a_samples), n_regression_lines, replace=False)  # Random subset of draw indices to plot.

        fig, axes = plt.subplots(1, len(predictors), figsize=(6 * len(predictors), 5))  # One subplot per predictor.
        if len(predictors) == 1:  # plt.subplots returns a single Axes when ncols=1; wrap it for uniform handling.
            axes = [axes]  # Make axes iterable in the single-predictor case.

        for ax, pred in zip(axes, predictors):  # Plot each predictor's partial regression view.
            x = data[pred].values  # Observed values of this predictor.
            ax.scatter(x, y, alpha=0.5)  # Scatter of y vs this predictor.
            x_grid = np.linspace(x.min(), x.max(), 100)  # Dense grid across this predictor for smooth lines.

            # For each posterior draw, compute the contribution of *other* predictors
            # held at their sample means so the plotted line isolates this predictor's effect.
            other_contribution = np.zeros(len(a_samples))  # Per-draw constant offset from held-fixed predictors.
            for other_pred in predictors:  # Sum contributions across all other predictors.
                if other_pred != pred:  # Skip the predictor currently on the x-axis.
                    other_contribution += slope_samples[other_pred] * data[other_pred].mean()  # slope_draw * mean(x_other).

            for i in idx:  # Overlay a thin gray line per sampled posterior draw.
                y_line = a_samples[i] + other_contribution[i] + slope_samples[pred][i] * x_grid  # Predicted y on the grid.
                ax.plot(x_grid, y_line, alpha=0.05, color="gray")  # Low alpha so overlap shades posterior uncertainty.

            # Posterior mean line for emphasis.
            mean_other = sum(slope_samples[op].mean() * data[op].mean() for op in predictors if op != pred)  # Mean offset from other predictors.
            y_mean = a_samples.mean() + mean_other + slope_samples[pred].mean() * x_grid  # Mean posterior prediction on the grid.
            ax.plot(x_grid, y_mean, color="red")  # Plot the mean line in red.
            ax.set_xlabel(pred)  # Label x-axis with the predictor name.
            ax.set_ylabel(outcome)  # Label y-axis with the outcome name.
            ax.set_title(f"{outcome} vs {pred} (others at mean)")  # Title notes that other predictors are held at mean.

        plt.tight_layout()  # Tidy spacing across subplots.
        plt.show()  # Render the figure.

    if show_residuals:  # Optional: residual diagnostics using posterior mean coefficients.
        a_mean = posterior["intercept"].values.flatten().mean()  # Posterior mean intercept (point estimate).
        slope_means = {pred: posterior[f"slope_{pred}"].values.flatten().mean() for pred in predictors}  # Posterior mean slope per predictor.
        y_hat = a_mean + sum(slope_means[pred] * data[pred].values for pred in predictors)  # Fitted values using mean coefficients.
        residuals = y - y_hat  # Residuals: observed minus fitted.

        fig, axes = plt.subplots(1, len(predictors) + 1, figsize=(6 * (len(predictors) + 1), 5))  # One subplot for fitted + one per predictor.

        axes[0].scatter(y_hat, residuals, alpha=0.5)  # Residuals vs fitted values — should look like noise around 0.
        axes[0].axhline(0, color="red", linestyle="--")  # Zero reference line to judge bias/structure.
        axes[0].set_xlabel("Fitted values")  # x-axis label.
        axes[0].set_ylabel("Residuals")  # y-axis label.
        axes[0].set_title("Residuals vs Fitted")  # Subplot title.

        for ax, pred in zip(axes[1:], predictors):  # Residuals vs each predictor — check for unmodeled structure.
            ax.scatter(data[pred].values, residuals, alpha=0.5)  # Scatter residuals against this predictor.
            ax.axhline(0, color="red", linestyle="--")  # Zero reference line.
            ax.set_xlabel(pred)  # x-axis label is the predictor name.
            ax.set_ylabel("Residuals")  # y-axis label.
            ax.set_title(f"Residuals vs {pred}")  # Subplot title.

        plt.tight_layout()  # Tidy subplot spacing.
        plt.show()  # Render the figure.

    return trace  # Return the InferenceData so the caller can do further analysis.
In [4]:
# Simualte some fake data

n = 100
x = np.random.uniform(0, 1, n)                  
z = np.random.choice([0, 1], size=n, replace=True)
a = 1
b = 2
theta = 5
sigma = 2
y = a + b * x + theta * z + np.random.normal(0, sigma, n)
data = pd.DataFrame({"x": x, "z": z, "y": y})
display(data.head())
z0 = data[data["z"] == 0]
z1 = data[data["z"] == 1]

# fit_and_plot_lm(z0, predictors=["x"], outcome="y")
# fit_and_plot_lm(z1, predictors=["x"], outcome="y")

fit_and_plot_bayes(data = z0, predictors = ['x'], outcome = 'y',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=False, show_forest=False,
                       show_posterior=False, show_regression=True,
                       show_residuals=True,
                       n_regression_lines=100)

fit_and_plot_bayes(data = z1, predictors = ['x'], outcome = 'y',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=False, show_forest=False,
                       show_posterior=False, show_regression=True,
                       show_residuals=True,
                       n_regression_lines=100)
x z y
0 0.454181 0 -0.878259
1 0.045388 1 8.343305
2 0.869509 1 8.853543
3 0.608016 0 1.662860
4 0.876591 0 1.602928
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_x, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 1 seconds.
            mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  ess_bulk  \
intercept  0.071  0.541    -0.999      1.110      0.009    0.007    3304.0   
slope_x    2.925  0.945     1.104      4.782      0.016    0.013    3362.0   
sigma      1.891  0.195     1.540      2.297      0.003    0.003    4535.0   

           ess_tail  r_hat  
intercept    3671.0    1.0  
slope_x      3751.0    1.0  
sigma        4044.0    1.0  

Regression formula: y = 0.07 + 2.93*x
Residual std dev (σ): 1.89 ± 0.19
Bayesian R²: 0.173 ± 0.082
LOO-ELPD: -105.72 ± 5.67  (p_loo=3.2)
LOO log score (per obs): -2.073 ± 0.111
No description has been provided for this image
No description has been provided for this image
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_x, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 1 seconds.
            mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  ess_bulk  \
intercept  6.536  0.642     5.259      7.788      0.011    0.008    3609.0   
slope_x    0.982  1.154    -1.272      3.199      0.020    0.015    3443.0   
sigma      2.376  0.254     1.912      2.882      0.004    0.004    4570.0   

           ess_tail  r_hat  
intercept    3834.0    1.0  
slope_x      3757.0    1.0  
sigma        4454.0    1.0  

Regression formula: y = 6.54 + 0.98*x
Residual std dev (σ): 2.38 ± 0.25
Bayesian R²: 0.033 ± 0.039
LOO-ELPD: -112.48 ± 4.29  (p_loo=2.6)
LOO log score (per obs): -2.295 ± 0.088
No description has been provided for this image
No description has been provided for this image
Out[4]:
arviz.InferenceData
    • <xarray.Dataset> Size: 208kB
      Dimensions:    (chain: 4, draw: 2000)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 16kB 0 1 2 3 4 5 6 ... 1994 1995 1996 1997 1998 1999
      Data variables:
          intercept  (chain, draw) float64 64kB 6.727 5.65 7.266 ... 6.756 6.207 5.803
          slope_x    (chain, draw) float64 64kB 1.087 2.203 -0.1839 ... 1.712 3.025
          sigma      (chain, draw) float64 64kB 2.374 2.181 2.424 ... 1.984 2.467
      Attributes:
          created_at:                 2026-04-29T07:00:51.381770+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
          sampling_time:              0.9912471771240234
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • intercept
          (chain, draw)
          float64
          6.727 5.65 7.266 ... 6.207 5.803
          array([[6.72710349, 5.64955217, 7.26568056, ..., 6.12894035, 7.07360699,
                  7.12477433],
                 [6.81643398, 7.10963166, 6.04177393, ..., 7.41239817, 6.95230818,
                  8.12770863],
                 [6.34811953, 6.73884944, 7.08485316, ..., 6.76842204, 6.57492001,
                  6.86952576],
                 [7.4640708 , 6.18458982, 7.28452013, ..., 6.75622576, 6.20667112,
                  5.80268361]], shape=(4, 2000))
        • slope_x
          (chain, draw)
          float64
          1.087 2.203 -0.1839 ... 1.712 3.025
          array([[ 1.08741335,  2.20305061, -0.18386786, ...,  1.53466826,
                   0.55354542,  0.1551512 ],
                 [ 0.8599489 , -0.43373484,  2.29078828, ...,  0.01050248,
                  -0.20784181, -1.174153  ],
                 [ 1.59363173, -0.65205742,  0.0836463 , ...,  0.3859668 ,
                   1.33591482,  1.82016785],
                 [ 0.74533642,  1.39676882, -0.7548674 , ...,  0.38888366,
                   1.71223275,  3.02465847]], shape=(4, 2000))
        • sigma
          (chain, draw)
          float64
          2.374 2.181 2.424 ... 1.984 2.467
          array([[2.37382028, 2.18125645, 2.42402279, ..., 2.18736949, 2.56960411,
                  2.05295925],
                 [2.47676545, 2.42243586, 2.32148306, ..., 2.23274638, 2.99191026,
                  2.78261971],
                 [2.10729856, 2.64678418, 2.31919387, ..., 2.04663682, 2.15710605,
                  2.3389439 ],
                 [2.49878051, 2.66511234, 2.04948719, ..., 2.671922  , 1.98437969,
                  2.46705802]], shape=(4, 2000))
      • created_at :
        2026-04-29T07:00:51.381770+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2
        sampling_time :
        0.9912471771240234
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 3MB
      Dimensions:  (chain: 4, draw: 2000, y_dim_0: 49)
      Coordinates:
        * chain    (chain) int64 32B 0 1 2 3
        * draw     (draw) int64 16kB 0 1 2 3 4 5 6 ... 1994 1995 1996 1997 1998 1999
        * y_dim_0  (y_dim_0) int64 392B 0 1 2 3 4 5 6 7 8 ... 41 42 43 44 45 46 47 48
      Data variables:
          y        (chain, draw, y_dim_0) float64 3MB -2.001 -1.907 ... -2.576 -4.547
      Attributes:
          created_at:                 2026-04-29T07:00:51.569730+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • y_dim_0: 49
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • y_dim_0
          (y_dim_0)
          int64
          0 1 2 3 4 5 6 ... 43 44 45 46 47 48
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48])
        • y
          (chain, draw, y_dim_0)
          float64
          -2.001 -1.907 ... -2.576 -4.547
          array([[[-2.00127376, -1.9071813 , -1.78680414, ..., -2.00860865,
                   -2.70110159, -3.79207455],
                  [-2.40583448, -1.87328913, -1.70083008, ..., -1.77598902,
                   -2.38928641, -4.10978162],
                  [-1.90472047, -2.06429253, -1.80550863, ..., -2.06276634,
                   -2.71446034, -3.1878915 ],
                  ...,
                  [-2.18232528, -1.90360344, -1.70278335, ..., -1.8354776 ,
                   -2.51397298, -3.91722429],
                  [-1.97998521, -1.99039417, -1.86527225, ..., -2.10158193,
                   -2.71713703, -3.44660038],
                  [-1.81234065, -1.93959878, -1.63832966, ..., -1.98326498,
                   -2.89789803, -3.75938284]],
          
                 [[-2.00632355, -1.96139813, -1.82731514, ..., -2.0378881 ,
                   -2.67027496, -3.56638993],
                  [-1.93756174, -2.18703646, -1.82011608, ..., -1.99801285,
                   -2.58037826, -2.92502777],
                  [-2.20918708, -1.82351368, -1.77024561, ..., -1.91217426,
                   -2.58994558, -4.33708736],
          ...
                  [-1.92465015, -2.00050051, -1.64158725, ..., -1.87572781,
                   -2.70237279, -3.64332328],
                  [-2.00108933, -1.8217844 , -1.69237929, ..., -1.93494551,
                   -2.75959997, -4.21825684],
                  [-1.94552174, -1.78336141, -1.83108685, ..., -2.11961729,
                   -2.97786976, -4.66459015]],
          
                 [[-1.89197393, -1.87875763, -1.87434138, ..., -2.23128701,
                   -3.00486962, -3.96246428],
                  [-2.20824209, -2.048099  , -1.90058705, ..., -1.99149783,
                   -2.44761832, -3.33928584],
                  [-1.77874709, -2.22603761, -1.66376989, ..., -1.93622532,
                   -2.75970684, -3.0807886 ],
                  ...,
                  [-2.07424318, -2.11847839, -1.90586252, ..., -2.04065848,
                   -2.5232757 , -3.07465395],
                  [-2.14251335, -1.77453508, -1.60531154, ..., -1.80703509,
                   -2.69687065, -4.60187734],
                  [-2.29646996, -1.83651777, -1.84812398, ..., -1.95266553,
                   -2.57580734, -4.54748609]]], shape=(4, 2000, 49))
      • created_at :
        2026-04-29T07:00:51.569730+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2

    • <xarray.Dataset> Size: 1MB
      Dimensions:                (chain: 4, draw: 2000)
      Coordinates:
        * chain                  (chain) int64 32B 0 1 2 3
        * draw                   (draw) int64 16kB 0 1 2 3 4 ... 1996 1997 1998 1999
      Data variables: (12/18)
          step_size_bar          (chain, draw) float64 64kB 0.5248 0.5248 ... 0.5089
          acceptance_rate        (chain, draw) float64 64kB 0.4822 0.9832 ... 0.3597
          energy_error           (chain, draw) float64 64kB 0.0 -0.004262 ... 0.7436
          step_size              (chain, draw) float64 64kB 0.7636 0.7636 ... 0.4776
          tree_depth             (chain, draw) int64 64kB 3 4 3 3 4 3 ... 3 2 3 3 3 3
          reached_max_treedepth  (chain, draw) bool 8kB False False ... False False
          ...                     ...
          smallest_eigval        (chain, draw) float64 64kB nan nan nan ... nan nan
          index_in_trajectory    (chain, draw) int64 64kB 0 -5 -3 -2 9 ... -5 -2 -6 4
          lp                     (chain, draw) float64 64kB -122.8 -123.8 ... -124.5
          divergences            (chain, draw) int64 64kB 0 0 0 0 0 0 ... 0 0 0 0 0 0
          diverging              (chain, draw) bool 8kB False False ... False False
          energy                 (chain, draw) float64 64kB 126.5 124.9 ... 130.3
      Attributes:
          created_at:                 2026-04-29T07:00:51.395403+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
          sampling_time:              0.9912471771240234
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • step_size_bar
          (chain, draw)
          float64
          0.5248 0.5248 ... 0.5089 0.5089
          array([[0.52478406, 0.52478406, 0.52478406, ..., 0.52478406, 0.52478406,
                  0.52478406],
                 [0.49747689, 0.49747689, 0.49747689, ..., 0.49747689, 0.49747689,
                  0.49747689],
                 [0.5362839 , 0.5362839 , 0.5362839 , ..., 0.5362839 , 0.5362839 ,
                  0.5362839 ],
                 [0.50894367, 0.50894367, 0.50894367, ..., 0.50894367, 0.50894367,
                  0.50894367]], shape=(4, 2000))
        • acceptance_rate
          (chain, draw)
          float64
          0.4822 0.9832 ... 0.9986 0.3597
          array([[0.48218082, 0.98317366, 0.99446019, ..., 0.48653159, 0.44894383,
                  0.94803415],
                 [0.94345332, 0.99365965, 0.65835871, ..., 0.96278292, 0.89338694,
                  0.89024013],
                 [0.98123739, 0.54250468, 0.74554748, ..., 0.9951409 , 0.80011238,
                  0.4797377 ],
                 [0.72995395, 0.92669897, 0.97939769, ..., 0.81472165, 0.99859878,
                  0.35968311]], shape=(4, 2000))
        • energy_error
          (chain, draw)
          float64
          0.0 -0.004262 ... 0.003065 0.7436
          array([[ 0.        , -0.00426156, -0.07091429, ...,  0.07495619,
                   0.46578558, -0.18731174],
                 [-0.14385017,  0.0150046 , -0.08935785, ...,  0.03352991,
                   0.05171789, -0.02200433],
                 [-0.10602541,  0.94946951, -1.0800202 , ...,  0.02179526,
                   0.14726186,  1.17598283],
                 [ 0.48819529, -0.82842012,  0.12725916, ...,  0.05098099,
                   0.00306504,  0.74359128]], shape=(4, 2000))
        • step_size
          (chain, draw)
          float64
          0.7636 0.7636 ... 0.4776 0.4776
          array([[0.76355337, 0.76355337, 0.76355337, ..., 0.76355337, 0.76355337,
                  0.76355337],
                 [0.54566522, 0.54566522, 0.54566522, ..., 0.54566522, 0.54566522,
                  0.54566522],
                 [0.47421957, 0.47421957, 0.47421957, ..., 0.47421957, 0.47421957,
                  0.47421957],
                 [0.47760963, 0.47760963, 0.47760963, ..., 0.47760963, 0.47760963,
                  0.47760963]], shape=(4, 2000))
        • tree_depth
          (chain, draw)
          int64
          3 4 3 3 4 3 3 3 ... 3 3 3 2 3 3 3 3
          array([[3, 4, 3, ..., 2, 3, 3],
                 [3, 3, 4, ..., 3, 3, 3],
                 [2, 3, 3, ..., 4, 3, 2],
                 [2, 4, 3, ..., 3, 3, 3]], shape=(4, 2000))
        • reached_max_treedepth
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 2000))
        • max_energy_error
          (chain, draw)
          float64
          2.365 -0.1313 ... -0.09614 2.0
          array([[ 2.36466824, -0.13126155, -0.13131518, ...,  1.40022442,
                   1.74131118,  0.18842019],
                 [ 0.15177779, -0.06515655,  1.08260718, ..., -0.22764418,
                   0.22429157, -0.4376875 ],
                 [-0.17501824,  0.94946951,  1.22790658, ..., -0.05271714,
                   0.38649523,  2.0348698 ],
                 [ 0.55142396, -1.06955136,  0.12725916, ...,  0.62476027,
                  -0.09614241,  2.00012327]], shape=(4, 2000))
        • n_steps
          (chain, draw)
          float64
          7.0 15.0 7.0 7.0 ... 7.0 7.0 7.0
          array([[ 7., 15.,  7., ...,  3.,  7.,  7.],
                 [ 7.,  7., 15., ...,  7.,  7.,  7.],
                 [ 3.,  7.,  7., ..., 11.,  7.,  3.],
                 [ 3., 11.,  7., ...,  7.,  7.,  7.]], shape=(4, 2000))
        • largest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]], shape=(4, 2000))
        • process_time_diff
          (chain, draw)
          float64
          0.000245 0.00046 ... 0.000182
          array([[2.45e-04, 4.60e-04, 2.07e-04, ..., 9.10e-05, 1.75e-04, 2.11e-04],
                 [1.86e-04, 1.88e-04, 3.79e-04, ..., 1.87e-04, 1.91e-04, 1.87e-04],
                 [1.20e-04, 2.30e-04, 2.39e-04, ..., 2.72e-04, 1.75e-04, 8.70e-05],
                 [9.30e-05, 2.75e-04, 1.86e-04, ..., 1.77e-04, 1.83e-04, 1.82e-04]],
                shape=(4, 2000))
        • perf_counter_start
          (chain, draw)
          float64
          216.5 216.5 216.5 ... 217.1 217.1
          array([[216.46579413, 216.46609567, 216.46660958, ..., 216.96442971,
                  216.96455442, 216.96484875],
                 [216.47920287, 216.47943083, 216.47966046, ..., 217.01169387,
                  217.01192108, 217.01215246],
                 [216.48796838, 216.48814367, 216.48842863, ..., 217.04845263,
                  217.04875979, 217.04896829],
                 [216.5171585 , 216.51729058, 216.5176055 , ..., 217.10147196,
                  217.10168287, 217.10190092]], shape=(4, 2000))
        • perf_counter_diff
          (chain, draw)
          float64
          0.0002453 0.000461 ... 0.000181
          array([[2.45333e-04, 4.60959e-04, 2.07833e-04, ..., 9.00420e-05,
                  1.74834e-04, 2.11875e-04],
                 [1.85333e-04, 1.89250e-04, 3.80292e-04, ..., 1.86708e-04,
                  1.89958e-04, 1.86250e-04],
                 [1.27416e-04, 2.35917e-04, 2.44000e-04, ..., 2.71666e-04,
                  1.74542e-04, 8.59590e-05],
                 [9.26250e-05, 2.74958e-04, 1.85833e-04, ..., 1.75667e-04,
                  1.83416e-04, 1.80959e-04]], shape=(4, 2000))
        • smallest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]], shape=(4, 2000))
        • index_in_trajectory
          (chain, draw)
          int64
          0 -5 -3 -2 9 1 ... -5 -3 -5 -2 -6 4
          array([[ 0, -5, -3, ...,  2, -2,  5],
                 [-2,  2, -8, ..., -4,  4, -3],
                 [-3,  3,  4, ..., -1, -2,  1],
                 [-1, -4, -5, ..., -2, -6,  4]], shape=(4, 2000))
        • lp
          (chain, draw)
          float64
          -122.8 -123.8 ... -123.8 -124.5
          array([[-122.80455344, -123.79043381, -123.22315952, ..., -122.84520064,
                  -123.55298145, -123.59353941],
                 [-122.97500916, -123.36239739, -123.26842175, ..., -123.70492136,
                  -125.80271196, -126.35335343],
                 [-123.06490026, -125.35208924, -122.84605708, ..., -123.27958818,
                  -122.93161193, -125.16548171],
                 [-125.44501645, -123.66662008, -124.56088137, ..., -123.65563891,
                  -123.84508362, -124.45630705]], shape=(4, 2000))
        • divergences
          (chain, draw)
          int64
          0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0
          array([[0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0]], shape=(4, 2000))
        • diverging
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 2000))
        • energy
          (chain, draw)
          float64
          126.5 124.9 125.3 ... 124.0 130.3
          array([[126.45886604, 124.89474881, 125.3229483 , ..., 124.68103099,
                  126.70125879, 125.15592127],
                 [123.74277195, 123.75493186, 124.65889162, ..., 124.91896272,
                  126.59229143, 128.19712774],
                 [123.29940048, 126.160353  , 126.01393673, ..., 123.62301632,
                  123.94233779, 126.30846934],
                 [127.48632603, 126.24808291, 124.69813294, ..., 124.33315731,
                  124.04403975, 130.27658564]], shape=(4, 2000))
      • created_at :
        2026-04-29T07:00:51.395403+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2
        sampling_time :
        0.9912471771240234
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 784B
      Dimensions:  (y_dim_0: 49)
      Coordinates:
        * y_dim_0  (y_dim_0) int64 392B 0 1 2 3 4 5 6 7 8 ... 41 42 43 44 45 46 47 48
      Data variables:
          y        (y_dim_0) float64 392B 8.343 8.854 7.259 ... 5.467 3.925 3.051
      Attributes:
          created_at:                 2026-04-29T07:00:51.399353+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
      xarray.Dataset
        • y_dim_0: 49
        • y_dim_0
          (y_dim_0)
          int64
          0 1 2 3 4 5 6 ... 43 44 45 46 47 48
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48])
        • y
          (y_dim_0)
          float64
          8.343 8.854 7.259 ... 3.925 3.051
          array([ 8.34330526,  8.85354305,  7.25869525,  7.29690135,  7.38499755,
                  7.20958198,  6.86486283, 10.74456448,  2.3092309 , 10.62321524,
                  8.23060168,  9.56501182,  7.69919064,  7.25552857,  6.82056785,
                  1.99970726,  6.33792861,  4.52109585,  3.98643413,  6.61431601,
                  7.27946944,  9.30840993, 10.58716248,  5.59163934,  6.23922678,
                  3.61317323, 10.33649293,  6.60014275,  6.15650261,  7.24625047,
                  3.88831746,  8.8306032 ,  8.55006783,  3.70976237,  5.97295254,
                  5.31937005,  7.72663442,  8.84736953, 11.38541811, 10.26911096,
                  6.78846401,  4.92328308,  7.87518768,  6.17003853,  9.67523265,
                  8.06296313,  5.46682215,  3.92530565,  3.05060448])
      • created_at :
        2026-04-29T07:00:51.399353+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2

Forming a linear predictor from a multiple regression¶
In [5]:
# simulate fake data

N = 100                                              # Number of observations (sample size).
K = 10                                               # Number of predictors (features).
X = np.random.uniform(0, 1, size=(N, K))             # N×K design matrix of Uniform(0,1) draws.
# print(X[:5])                                         # Show the first 5 rows of the design matrix
z = np.random.choice([0, 1], size=N, replace=True)   # Binary treatment indicator per observation.
a = 1                                                # Intercept (constant term).
b = np.arange(1, K + 1)                              # Length-K slope vector: [1, 2, ..., K].
theta = 5                                            # Treatment effect coefficient on z.
sigma = 2                                            # Residual standard deviation.
# X @ b is matrix multiplication: for each row of X, multiply each value by the matching entry in b and sum — one number per observation (same as np.dot(X, b)).
y = a + X @ b + theta * z + np.random.normal(0, sigma, N)  # Outcome generated by the linear model with noise.
data = pd.DataFrame(X, columns=[f"x{k+1}" for k in range(K)])  # Create DataFrame with predictor columns named x1, x2, ..., xK.
data["z"] = z  # Add treatment indicator to the DataFrame.
data["y"] = y  # Add outcome variable to the DataFrame.
display(data.head())  # Show the first few rows of the simulated dataset.
# describe the data
# print(data.describe())

z0 = data[data["z"] == 0]  # Subset of data where treatment indicator z is 0.
z1 = data[data["z"] == 1]  # Subset of data where

fit_and_plot_bayes(data=z0, predictors=[f"x{k+1}" for k in range(K)], outcome="y",
                   intercept_mu=0, intercept_sigma=50,
                   slope_mu=0, slope_sigma=50,
                   sigma_sigma=50,
                   samples=2000, tune=1000, hdi_prob=0.95,
                   show_trace=False, show_forest=False,
                   show_posterior=False, show_regression=True,
                   show_residuals=True,
                   n_regression_lines=100)
x1 x2 x3 x4 x5 x6 x7 x8 x9 x10 z y
0 0.613026 0.339508 0.610554 0.524063 0.060038 0.373551 0.778760 0.970273 0.416569 0.192414 1 32.956398
1 0.568241 0.511727 0.355487 0.502818 0.973214 0.458990 0.618661 0.919096 0.955292 0.107091 0 34.749450
2 0.707802 0.010413 0.983991 0.519175 0.924139 0.742563 0.089703 0.706747 0.554307 0.875053 0 38.297788
3 0.139362 0.989462 0.740087 0.051541 0.601359 0.139781 0.604497 0.090807 0.892681 0.931586 0 27.729313
4 0.648994 0.508998 0.785647 0.091857 0.470251 0.622525 0.060170 0.631203 0.295915 0.295320 0 21.876481
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_x1, slope_x2, slope_x3, slope_x4, slope_x5, slope_x6, slope_x7, slope_x8, slope_x9, slope_x10, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 4 seconds.
             mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  ess_bulk  \
intercept  -0.199  1.785    -3.669      3.410      0.023    0.022    5839.0   
slope_x1    2.423  1.062     0.408      4.601      0.012    0.013    8016.0   
slope_x2    1.574  1.109    -0.716      3.690      0.013    0.013    7408.0   
slope_x3    1.640  1.116    -0.515      3.865      0.012    0.013    8343.0   
slope_x4    3.203  1.147     0.986      5.466      0.013    0.013    8356.0   
slope_x5    4.763  1.155     2.578      7.088      0.013    0.015    8040.0   
slope_x6    7.626  1.322     5.015     10.197      0.016    0.015    7081.0   
slope_x7    6.541  1.203     4.150      8.878      0.014    0.013    7215.0   
slope_x8    7.775  1.222     5.420     10.226      0.014    0.014    7868.0   
slope_x9   10.043  1.144     7.871     12.399      0.013    0.014    7501.0   
slope_x10  10.637  1.151     8.450     12.939      0.013    0.013    7698.0   
sigma       2.155  0.248     1.717      2.668      0.003    0.003    5885.0   

           ess_tail  r_hat  
intercept    4644.0    1.0  
slope_x1     5844.0    1.0  
slope_x2     5625.0    1.0  
slope_x3     5223.0    1.0  
slope_x4     5638.0    1.0  
slope_x5     5801.0    1.0  
slope_x6     5623.0    1.0  
slope_x7     6250.0    1.0  
slope_x8     6195.0    1.0  
slope_x9     5290.0    1.0  
slope_x10    5908.0    1.0  
sigma        5380.0    1.0  

Regression formula: y = -0.20 + 2.42*x1 + 1.57*x2 + 1.64*x3 + 3.20*x4 + 4.76*x5 + 7.63*x6 + 6.54*x7 + 7.77*x8 + 10.04*x9 + 10.64*x10
Residual std dev (σ): 2.15 ± 0.25
Bayesian R²: 0.884 ± 0.015
LOO-ELPD: -122.75 ± 5.44  (p_loo=12.3)
LOO log score (per obs): -2.316 ± 0.103
  Warning: 1 observations with Pareto-k > 0.7 (unreliable LOO estimates)
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/arviz/stats/stats.py:782: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
No description has been provided for this image
No description has been provided for this image
Out[5]:
arviz.InferenceData
    • <xarray.Dataset> Size: 784kB
      Dimensions:    (chain: 4, draw: 2000)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 16kB 0 1 2 3 4 5 6 ... 1994 1995 1996 1997 1998 1999
      Data variables:
          intercept  (chain, draw) float64 64kB 0.6593 1.075 ... -1.733 -0.5384
          slope_x1   (chain, draw) float64 64kB 2.437 0.7655 4.487 ... 4.868 4.52 2.59
          slope_x2   (chain, draw) float64 64kB 2.77 0.7728 2.016 ... 2.61 2.483 2.557
          slope_x3   (chain, draw) float64 64kB 2.857 0.04978 2.969 ... 2.047 0.1386
          slope_x4   (chain, draw) float64 64kB 4.503 2.468 3.083 ... 3.275 2.759 3.24
          slope_x5   (chain, draw) float64 64kB 3.212 7.074 3.169 ... 1.914 4.987
          slope_x6   (chain, draw) float64 64kB 5.582 8.671 5.807 ... 9.141 5.452
          slope_x7   (chain, draw) float64 64kB 6.996 5.081 8.09 ... 6.613 6.227 8.688
          slope_x8   (chain, draw) float64 64kB 8.627 7.771 8.292 ... 9.32 8.912 8.456
          slope_x9   (chain, draw) float64 64kB 8.754 8.907 10.01 ... 9.699 9.822
          slope_x10  (chain, draw) float64 64kB 8.06 11.32 8.755 ... 12.24 11.8 10.75
          sigma      (chain, draw) float64 64kB 2.283 1.931 1.922 ... 2.021 2.174 2.14
      Attributes:
          created_at:                 2026-04-29T07:00:57.001070+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
          sampling_time:              3.8593673706054688
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • intercept
          (chain, draw)
          float64
          0.6593 1.075 ... -1.733 -0.5384
          array([[ 0.659261  ,  1.07534163, -0.6314563 , ...,  2.54743372,
                  -0.03963473,  0.56361164],
                 [-2.56738798,  1.11078748, -1.06883906, ..., -0.93649978,
                   1.70669729,  1.35469173],
                 [ 1.42755547,  1.98204049,  1.11902833, ...,  0.2567874 ,
                   0.29159445,  1.0119711 ],
                 [-2.93372553,  1.0749503 , -2.56214487, ..., -2.33643222,
                  -1.73280783, -0.53836776]], shape=(4, 2000))
        • slope_x1
          (chain, draw)
          float64
          2.437 0.7655 4.487 ... 4.52 2.59
          array([[ 2.43718408,  0.76551819,  4.48685264, ...,  1.87906378,
                   2.42579195,  2.56241793],
                 [ 3.2038377 ,  2.92683358,  3.17884942, ...,  3.71344339,
                   0.32203344,  1.3269384 ],
                 [-0.78298398,  0.05443411,  1.66003421, ...,  2.54947986,
                   2.95515566,  3.11891663],
                 [ 1.23122215,  3.38929187,  1.31793544, ...,  4.8681612 ,
                   4.52011339,  2.59042843]], shape=(4, 2000))
        • slope_x2
          (chain, draw)
          float64
          2.77 0.7728 2.016 ... 2.483 2.557
          array([[ 2.76963758,  0.77279238,  2.0159312 , ...,  1.76034474,
                   1.68680731,  1.59233773],
                 [ 2.34997156,  0.80317456,  2.8164883 , ...,  2.77213417,
                  -0.10017218,  0.59150421],
                 [-1.30129232, -0.79679087,  1.09352213, ...,  0.39981337,
                   0.20003431,  3.0678798 ],
                 [ 2.09999103,  1.74366577,  1.36863481, ...,  2.61005484,
                   2.4831952 ,  2.55709024]], shape=(4, 2000))
        • slope_x3
          (chain, draw)
          float64
          2.857 0.04978 ... 2.047 0.1386
          array([[2.8572581 , 0.04977967, 2.96859956, ..., 3.70509529, 3.303281  ,
                  3.03002672],
                 [1.01452107, 1.70474038, 0.8540556 , ..., 3.44341512, 0.77098033,
                  0.14132172],
                 [2.25063124, 2.22378959, 0.33605347, ..., 3.41655869, 3.30421625,
                  0.04137156],
                 [2.34138699, 1.64334864, 1.71249974, ..., 1.9714886 , 2.04671014,
                  0.13864106]], shape=(4, 2000))
        • slope_x4
          (chain, draw)
          float64
          4.503 2.468 3.083 ... 2.759 3.24
          array([[4.50333033, 2.46813411, 3.08306086, ..., 3.55950627, 3.64311525,
                  3.58809367],
                 [3.29558134, 1.75183416, 3.65133803, ..., 3.74297463, 3.79236032,
                  3.55942515],
                 [3.2788629 , 2.97314115, 3.01638939, ..., 3.0801868 , 4.93135567,
                  2.39828052],
                 [4.77753305, 2.3105622 , 4.18860809, ..., 3.27450343, 2.75934961,
                  3.24022081]], shape=(4, 2000))
        • slope_x5
          (chain, draw)
          float64
          3.212 7.074 3.169 ... 1.914 4.987
          array([[3.21200626, 7.07431316, 3.16871989, ..., 2.45195685, 3.54173787,
                  3.79917675],
                 [6.65405292, 4.88533144, 4.44419083, ..., 3.47126476, 4.86954263,
                  5.24884691],
                 [4.72753698, 5.43329502, 4.75343868, ..., 6.59053687, 5.63644379,
                  3.1125412 ],
                 [5.81546208, 4.42819008, 6.41849206, ..., 1.76064011, 1.91388464,
                  4.98719957]], shape=(4, 2000))
        • slope_x6
          (chain, draw)
          float64
          5.582 8.671 5.807 ... 9.141 5.452
          array([[5.58219259, 8.67092779, 5.80723942, ..., 4.43171395, 6.8483098 ,
                  6.78165632],
                 [7.81685445, 6.93450716, 6.32515696, ..., 7.82527263, 7.53216941,
                  8.45392934],
                 [8.56534622, 7.44353423, 8.11618569, ..., 6.06155288, 6.19297692,
                  7.86904929],
                 [8.93363882, 5.22536537, 9.54305122, ..., 7.91721726, 9.14087209,
                  5.45221323]], shape=(4, 2000))
        • slope_x7
          (chain, draw)
          float64
          6.996 5.081 8.09 ... 6.227 8.688
          array([[6.99587379, 5.08148873, 8.09023345, ..., 4.44935157, 6.86025112,
                  6.51501715],
                 [6.84316776, 7.04174192, 5.63366306, ..., 8.05390426, 5.1155853 ,
                  4.48206068],
                 [7.17232418, 6.72251539, 5.98521582, ..., 4.29986576, 4.52647164,
                  6.78024787],
                 [8.5166784 , 5.69385051, 7.88050226, ..., 6.61264356, 6.22702328,
                  8.68794309]], shape=(4, 2000))
        • slope_x8
          (chain, draw)
          float64
          8.627 7.771 8.292 ... 8.912 8.456
          array([[8.62656259, 7.77097636, 8.29187001, ..., 6.55445315, 7.0095811 ,
                  6.20162959],
                 [8.94687865, 8.50139337, 8.84151299, ..., 7.00040071, 6.47351667,
                  6.28508497],
                 [6.2778288 , 6.78918205, 5.92511815, ..., 8.1739672 , 7.43141102,
                  7.70019117],
                 [6.8662634 , 9.09661673, 6.71583167, ..., 9.3203462 , 8.91223992,
                  8.45585872]], shape=(4, 2000))
        • slope_x9
          (chain, draw)
          float64
          8.754 8.907 10.01 ... 9.699 9.822
          array([[ 8.75428543,  8.90670926, 10.01493593, ..., 10.29094165,
                   9.74537788,  9.6475015 ],
                 [ 9.4449613 ,  9.46351762,  9.71064847, ...,  7.64243376,
                  11.9679877 , 11.15227661],
                 [11.23023037, 11.15961033, 11.16395484, ..., 10.58585809,
                  10.18501563,  9.59929172],
                 [ 9.87136934, 10.05544499,  9.55785226, ...,  9.8180702 ,
                   9.69876377,  9.82202319]], shape=(4, 2000))
        • slope_x10
          (chain, draw)
          float64
          8.06 11.32 8.755 ... 11.8 10.75
          array([[ 8.05998131, 11.31974595,  8.75511389, ..., 10.93820408,
                  11.00767922, 11.71808541],
                 [11.10288526,  8.8571721 , 12.00455524, ...,  9.86941003,
                  12.12563407, 12.08934835],
                 [10.2310801 , 10.96715908, 11.98602528, ..., 10.86453362,
                  10.30075105, 10.72508103],
                 [11.32006525, 10.34504304, 11.82041192, ..., 12.23801559,
                  11.80001765, 10.75468675]], shape=(4, 2000))
        • sigma
          (chain, draw)
          float64
          2.283 1.931 1.922 ... 2.174 2.14
          array([[2.28328763, 1.93061658, 1.92208488, ..., 2.10564546, 2.14925892,
                  2.16773919],
                 [2.31773208, 1.82595457, 2.36431408, ..., 2.03187503, 2.19591056,
                  2.49223659],
                 [2.12975474, 2.2380275 , 1.79197433, ..., 1.8873989 , 2.02892107,
                  2.26490924],
                 [2.49820918, 1.76424457, 2.21928606, ..., 2.02059665, 2.17438394,
                  2.13969505]], shape=(4, 2000))
      • created_at :
        2026-04-29T07:00:57.001070+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2
        sampling_time :
        3.8593673706054688
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 3MB
      Dimensions:  (chain: 4, draw: 2000, y_dim_0: 53)
      Coordinates:
        * chain    (chain) int64 32B 0 1 2 3
        * draw     (draw) int64 16kB 0 1 2 3 4 5 6 ... 1994 1995 1996 1997 1998 1999
        * y_dim_0  (y_dim_0) int64 424B 0 1 2 3 4 5 6 7 8 ... 45 46 47 48 49 50 51 52
      Data variables:
          y        (chain, draw, y_dim_0) float64 3MB -1.812 -4.134 ... -1.717 -2.156
      Attributes:
          created_at:                 2026-04-29T07:00:57.393594+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • y_dim_0: 53
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • y_dim_0
          (y_dim_0)
          int64
          0 1 2 3 4 5 6 ... 47 48 49 50 51 52
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52])
        • y
          (chain, draw, y_dim_0)
          float64
          -1.812 -4.134 ... -1.717 -2.156
          array([[[-1.81174632, -4.13438484, -1.93369282, ..., -2.70568747,
                   -1.8744191 , -2.06027064],
                  [-1.64507097, -1.90974177, -2.1803751 , ..., -1.58385889,
                   -1.6291338 , -1.92721713],
                  [-1.58568892, -3.92760758, -2.15349147, ..., -2.92830977,
                   -1.95582772, -1.73077388],
                  ...,
                  [-2.62978454, -2.70859685, -3.94975287, ..., -1.91808879,
                   -2.15619936, -1.83978594],
                  [-2.02549296, -2.3600765 , -3.16218893, ..., -2.08894765,
                   -1.95253319, -1.84525975],
                  [-2.12763299, -2.09463347, -3.80631013, ..., -1.92009584,
                   -1.91532508, -1.73549082]],
          
                 [[-1.77888242, -2.1400485 , -2.21559034, ..., -1.82959716,
                   -2.04399191, -1.87079158],
                  [-1.52109379, -3.07578757, -1.89115636, ..., -2.13966222,
                   -1.58428728, -1.8117346 ],
                  [-1.87416555, -2.53477731, -2.51557504, ..., -1.80672161,
                   -2.18626874, -1.8744403 ],
          ...
                  [-1.59958954, -1.65743149, -3.46404857, ..., -1.58524147,
                   -2.34821584, -1.6152586 ],
                  [-1.63291201, -1.62651344, -2.22822612, ..., -2.4008845 ,
                   -2.83781428, -1.76440048],
                  [-1.82316797, -3.52572588, -2.75784014, ..., -1.77298499,
                   -1.74497359, -1.73730648]],
          
                 [[-1.87290125, -2.35616072, -2.78300924, ..., -2.03754029,
                   -1.97618482, -2.25024093],
                  [-1.49274626, -2.4514186 , -3.08411994, ..., -1.50822241,
                   -1.73330229, -1.5493688 ],
                  [-1.83376183, -2.09838675, -2.53544082, ..., -1.99619787,
                   -1.95800386, -2.22210051],
                  ...,
                  [-2.08431642, -2.50444261, -2.11090924, ..., -2.25682666,
                   -2.44292042, -1.64091632],
                  [-2.07640124, -2.2279681 , -2.09148663, ..., -2.27041802,
                   -2.07460067, -1.69842458],
                  [-1.68435457, -4.29006097, -2.9876584 , ..., -1.68136699,
                   -1.71667757, -2.15612335]]], shape=(4, 2000, 53))
      • created_at :
        2026-04-29T07:00:57.393594+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2

    • <xarray.Dataset> Size: 1MB
      Dimensions:                (chain: 4, draw: 2000)
      Coordinates:
        * chain                  (chain) int64 32B 0 1 2 3
        * draw                   (draw) int64 16kB 0 1 2 3 4 ... 1996 1997 1998 1999
      Data variables: (12/18)
          step_size_bar          (chain, draw) float64 64kB 0.1588 0.1588 ... 0.1556
          acceptance_rate        (chain, draw) float64 64kB 0.6502 0.9646 ... 0.9687
          energy_error           (chain, draw) float64 64kB 0.3776 0.02622 ... 0.05557
          step_size              (chain, draw) float64 64kB 0.1279 0.1279 ... 0.1416
          tree_depth             (chain, draw) int64 64kB 5 5 5 5 5 5 ... 5 5 5 5 5 4
          reached_max_treedepth  (chain, draw) bool 8kB False False ... False False
          ...                     ...
          smallest_eigval        (chain, draw) float64 64kB nan nan nan ... nan nan
          index_in_trajectory    (chain, draw) int64 64kB 10 12 13 -4 14 ... 9 8 -3 13
          lp                     (chain, draw) float64 64kB -172.6 -173.4 ... -171.9
          divergences            (chain, draw) int64 64kB 0 0 0 0 0 0 ... 0 0 0 0 0 0
          diverging              (chain, draw) bool 8kB False False ... False False
          energy                 (chain, draw) float64 64kB 174.4 183.2 ... 180.6
      Attributes:
          created_at:                 2026-04-29T07:00:57.016038+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
          sampling_time:              3.8593673706054688
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • step_size_bar
          (chain, draw)
          float64
          0.1588 0.1588 ... 0.1556 0.1556
          array([[0.15876393, 0.15876393, 0.15876393, ..., 0.15876393, 0.15876393,
                  0.15876393],
                 [0.15897783, 0.15897783, 0.15897783, ..., 0.15897783, 0.15897783,
                  0.15897783],
                 [0.15052133, 0.15052133, 0.15052133, ..., 0.15052133, 0.15052133,
                  0.15052133],
                 [0.15557044, 0.15557044, 0.15557044, ..., 0.15557044, 0.15557044,
                  0.15557044]], shape=(4, 2000))
        • acceptance_rate
          (chain, draw)
          float64
          0.6502 0.9646 ... 0.9987 0.9687
          array([[0.65023366, 0.96456717, 0.98960313, ..., 0.96875983, 1.        ,
                  0.9094316 ],
                 [0.99849222, 0.87105247, 1.        , ..., 0.68183881, 0.91750653,
                  0.86712844],
                 [0.91510152, 0.94239205, 0.89240487, ..., 0.93679464, 0.53167259,
                  0.8840302 ],
                 [0.996665  , 0.65487302, 0.60753608, ..., 0.57937089, 0.99866595,
                  0.96872736]], shape=(4, 2000))
        • energy_error
          (chain, draw)
          float64
          0.3776 0.02622 ... -0.02699 0.05557
          array([[ 0.37760059,  0.02621513, -0.19460971, ...,  0.12635546,
                  -0.51144457,  0.15564271],
                 [-0.22765804,  0.37744282, -0.26145516, ..., -0.78557444,
                  -0.0454037 ,  0.01905106],
                 [-0.10917734, -0.22113269, -1.32555079, ...,  0.44750679,
                  -0.55502553,  0.07123601],
                 [-0.1377945 , -1.05524952,  0.36128586, ..., -0.08363901,
                  -0.02698775,  0.05556591]], shape=(4, 2000))
        • step_size
          (chain, draw)
          float64
          0.1279 0.1279 ... 0.1416 0.1416
          array([[0.127863  , 0.127863  , 0.127863  , ..., 0.127863  , 0.127863  ,
                  0.127863  ],
                 [0.14036138, 0.14036138, 0.14036138, ..., 0.14036138, 0.14036138,
                  0.14036138],
                 [0.14377983, 0.14377983, 0.14377983, ..., 0.14377983, 0.14377983,
                  0.14377983],
                 [0.14162905, 0.14162905, 0.14162905, ..., 0.14162905, 0.14162905,
                  0.14162905]], shape=(4, 2000))
        • tree_depth
          (chain, draw)
          int64
          5 5 5 5 5 5 5 5 ... 5 5 5 5 5 5 5 4
          array([[5, 5, 5, ..., 5, 4, 5],
                 [5, 5, 5, ..., 5, 5, 5],
                 [5, 4, 4, ..., 5, 4, 5],
                 [5, 5, 5, ..., 5, 5, 4]], shape=(4, 2000))
        • reached_max_treedepth
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 2000))
        • max_energy_error
          (chain, draw)
          float64
          1.201 -0.4797 ... -0.07767 0.07997
          array([[ 1.20145446, -0.47966153, -0.69869497, ..., -0.44193219,
                  -0.51144457,  0.26564209],
                 [-0.33276664,  0.46921659, -0.62842738, ...,  1.74303428,
                   0.29410643,  0.51302955],
                 [-1.08773105, -1.07922356, -1.32555079, ..., -0.70665903,
                   2.44079829,  0.32255459],
                 [-0.275816  ,  1.8726262 ,  1.14330892, ...,  1.51314346,
                  -0.07766506,  0.07996967]], shape=(4, 2000))
        • n_steps
          (chain, draw)
          float64
          31.0 31.0 31.0 ... 31.0 31.0 15.0
          array([[31., 31., 31., ..., 31., 15., 31.],
                 [31., 31., 31., ..., 31., 31., 31.],
                 [31., 15., 15., ..., 31., 15., 31.],
                 [31., 31., 31., ..., 31., 31., 15.]], shape=(4, 2000))
        • largest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]], shape=(4, 2000))
        • process_time_diff
          (chain, draw)
          float64
          0.001643 0.001264 ... 0.000637
          array([[0.001643, 0.001264, 0.001363, ..., 0.001485, 0.000727, 0.001403],
                 [0.001206, 0.001196, 0.001226, ..., 0.001209, 0.001205, 0.001197],
                 [0.001256, 0.000637, 0.000611, ..., 0.001219, 0.000601, 0.001193],
                 [0.001211, 0.001209, 0.001226, ..., 0.001178, 0.001339, 0.000637]],
                shape=(4, 2000))
        • perf_counter_start
          (chain, draw)
          float64
          220.2 220.2 220.2 ... 222.7 222.7
          array([[220.21731862, 220.21960012, 220.22101492, ..., 222.6002825 ,
                  222.60184625, 222.60265258],
                 [220.23135321, 220.23261508, 220.23386717, ..., 222.52058467,
                  222.52184142, 222.523094  ],
                 [220.26989125, 220.27123004, 220.27192567, ..., 222.71365779,
                  222.71493242, 222.71558304],
                 [220.240736  , 220.24200158, 220.24326296, ..., 222.65061879,
                  222.65184504, 222.65329129]], shape=(4, 2000))
        • perf_counter_diff
          (chain, draw)
          float64
          0.002196 0.001263 ... 0.0006368
          array([[0.00219625, 0.00126329, 0.00155337, ..., 0.00148538, 0.00073629,
                  0.00140175],
                 [0.00120629, 0.00119687, 0.00122587, ..., 0.00120796, 0.00120471,
                  0.001198  ],
                 [0.00125488, 0.000638  , 0.00060996, ..., 0.00121904, 0.00060092,
                  0.00119371],
                 [0.00121096, 0.00120833, 0.00122571, ..., 0.00117737, 0.0013485 ,
                  0.00063679]], shape=(4, 2000))
        • smallest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]], shape=(4, 2000))
        • index_in_trajectory
          (chain, draw)
          int64
          10 12 13 -4 14 -13 ... 17 9 8 -3 13
          array([[ 10,  12,  13, ...,  -9,   7,   3],
                 [-21, -15, -16, ...,  17,  15,  -5],
                 [ 22,   3,   7, ...,  23,   6, -22],
                 [ 27, -19,  22, ...,   8,  -3,  13]], shape=(4, 2000))
        • lp
          (chain, draw)
          float64
          -172.6 -173.4 ... -175.2 -171.9
          array([[-172.64789753, -173.3603423 , -172.87524265, ..., -173.5907686 ,
                  -167.48955694, -168.46578515],
                 [-169.11931229, -169.1235382 , -170.20477703, ..., -172.44109193,
                  -171.49351231, -171.69369033],
                 [-176.76035819, -173.66105713, -169.09697012, ..., -176.21895382,
                  -172.78900327, -170.91999052],
                 [-171.76734389, -169.35833998, -170.44215135, ..., -175.36288627,
                  -175.18787136, -171.91389883]], shape=(4, 2000))
        • divergences
          (chain, draw)
          int64
          0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0
          array([[0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0]], shape=(4, 2000))
        • diverging
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 2000))
        • energy
          (chain, draw)
          float64
          174.4 183.2 178.9 ... 180.7 180.6
          array([[174.40356532, 183.24514741, 178.90949434, ..., 177.53307613,
                  175.83502973, 171.12485157],
                 [174.4772862 , 172.69955422, 173.68023009, ..., 184.71957662,
                  177.22215308, 175.50914841],
                 [185.05158817, 183.08054606, 178.40489589, ..., 184.08539082,
                  182.86408077, 178.42796745],
                 [178.30405151, 174.50438879, 174.23929253, ..., 180.63053007,
                  180.74612917, 180.60774525]], shape=(4, 2000))
      • created_at :
        2026-04-29T07:00:57.016038+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2
        sampling_time :
        3.8593673706054688
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 848B
      Dimensions:  (y_dim_0: 53)
      Coordinates:
        * y_dim_0  (y_dim_0) int64 424B 0 1 2 3 4 5 6 7 8 ... 45 46 47 48 49 50 51 52
      Data variables:
          y        (y_dim_0) float64 424B 34.75 38.3 27.73 21.88 ... 33.73 20.73 24.72
      Attributes:
          created_at:                 2026-04-29T07:00:57.020720+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
      xarray.Dataset
        • y_dim_0: 53
        • y_dim_0
          (y_dim_0)
          int64
          0 1 2 3 4 5 6 ... 47 48 49 50 51 52
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52])
        • y
          (y_dim_0)
          float64
          34.75 38.3 27.73 ... 20.73 24.72
          array([34.74944994, 38.29778823, 27.72931262, 21.87648085, 23.31941818,
                 30.400501  , 42.11592338, 23.08651503, 21.93162637, 29.94131778,
                 35.97276041, 19.20847615, 24.14998291, 20.57070352, 33.69674373,
                 29.35925003, 17.84673987, 24.49960502, 28.60466993, 25.33150995,
                 33.39921976, 30.09022634, 45.21079098, 24.34062382, 32.78113225,
                 27.83488103, 28.97746014, 36.7791798 , 26.96766555, 23.05292286,
                 24.79557661, 30.46409438, 25.80993696, 26.60660173, 30.91371479,
                 30.79109767, 23.29275854, 33.41135546, 18.26669972, 29.02664593,
                 33.77670179, 35.89068873, 24.01032875, 23.19579217, 30.70931753,
                 27.23213401, 29.81131977, 25.0022173 , 39.70223606, 24.16532033,
                 33.72826622, 20.73242716, 24.71603269])
      • created_at :
        2026-04-29T07:00:57.020720+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2

11.3 Residual plots¶

We can evaluate fits by looking at the difference between data and their expectations: the residuals.

$r_i = y_i - X_i\hat{\beta}$

If model is correct, then residuals should be random noise.

A confusing choice: plot residuals vs. predicted values, or residuals vs. observed values?¶
In [6]:
# Import gradesW4315.dat from /ros_data
grades = pd.read_csv('../ros_data/gradesW4315.dat', 
                     sep=r'\s+')

# Display the first few rows of the dataset
display(grades.head(10))

fit_and_plot_bayes(data=grades, predictors=["midterm"], outcome="final",
                     intercept_mu=0, intercept_sigma=50,
                     slope_mu=0, slope_sigma=50,
                     sigma_sigma=50,
                     samples=2000, tune=1000, hdi_prob=0.95,
                     show_trace=False, show_forest=False,
                     show_posterior=False, show_regression=True,
                     show_residuals=True,
                     n_regression_lines=100)
hw1 hw2 hw3 hw4 midterm hw5 hw6 hw7 final
0 95 88 100 95 80 96 99 0 103
1 0 74 74 0 53 83 97 0 79
2 100 0 105 100 91 96 100 96 122
3 0 90 76 100 63 91 95 0 78
4 100 96 99 100 91 93 100 92 135
5 90 83 95 100 73 89 100 90 117
6 95 98 100 100 59 98 98 94 135
7 80 100 97 100 69 94 98 101 123
8 95 90 98 90 78 95 99 100 109
9 90 94 95 98 91 94 100 89 126
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_midterm, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 2 seconds.
                 mean      sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  \
intercept      57.736  16.722    26.616     90.677      0.349    0.251   
slope_midterm   0.789   0.211     0.387      1.196      0.004    0.003   
sigma          15.103   1.535    12.253     18.174      0.026    0.021   

               ess_bulk  ess_tail  r_hat  
intercept        2302.0    3062.0    1.0  
slope_midterm    2323.0    3050.0    1.0  
sigma            3379.0    3274.0    1.0  

Regression formula: final = 57.74 + 0.79*midterm
Residual std dev (σ): 15.10 ± 1.54
Bayesian R²: 0.211 ± 0.080
LOO-ELPD: -216.22 ± 5.44  (p_loo=3.4)
LOO log score (per obs): -4.158 ± 0.105
No description has been provided for this image
No description has been provided for this image
Out[6]:
arviz.InferenceData
    • <xarray.Dataset> Size: 208kB
      Dimensions:        (chain: 4, draw: 2000)
      Coordinates:
        * chain          (chain) int64 32B 0 1 2 3
        * draw           (draw) int64 16kB 0 1 2 3 4 5 ... 1995 1996 1997 1998 1999
      Data variables:
          intercept      (chain, draw) float64 64kB 54.06 49.2 57.57 ... 52.49 50.71
          slope_midterm  (chain, draw) float64 64kB 0.8117 0.8722 ... 0.9173 0.8182
          sigma          (chain, draw) float64 64kB 15.48 13.51 13.52 ... 15.54 15.64
      Attributes:
          created_at:                 2026-04-29T07:01:01.821555+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
          sampling_time:              2.4053969383239746
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • intercept
          (chain, draw)
          float64
          54.06 49.2 57.57 ... 52.49 50.71
          array([[54.06140119, 49.19791635, 57.57081119, ..., 56.03097253,
                  44.30725101, 36.20968289],
                 [77.68917924, 75.08995137, 75.40488869, ..., 31.42233148,
                  43.55833196, 43.55833196],
                 [69.15300694, 71.11707757, 82.00498459, ..., 38.96762794,
                  39.44146605, 62.67034526],
                 [57.09143096, 58.39663951, 65.5706168 , ..., 48.08757355,
                  52.48867068, 50.71341153]], shape=(4, 2000))
        • slope_midterm
          (chain, draw)
          float64
          0.8117 0.8722 ... 0.9173 0.8182
          array([[0.8117044 , 0.87222632, 0.81376283, ..., 0.80819193, 0.9618803 ,
                  1.07709491],
                 [0.56655731, 0.54933781, 0.56478379, ..., 1.11027455, 0.98584288,
                  0.98584288],
                 [0.60612526, 0.62413849, 0.49098806, ..., 1.02569847, 1.01724878,
                  0.71006307],
                 [0.80773335, 0.82447506, 0.69045279, ..., 0.84647492, 0.91734581,
                  0.81822213]], shape=(4, 2000))
        • sigma
          (chain, draw)
          float64
          15.48 13.51 13.52 ... 15.54 15.64
          array([[15.47931605, 13.50742957, 13.52086122, ..., 17.12125826,
                  13.24439912, 12.78969066],
                 [16.19796407, 17.12971685, 17.91941082, ..., 14.5457252 ,
                  17.45960705, 17.45960705],
                 [16.31049468, 16.31528358, 14.40357201, ..., 17.56871734,
                  16.68001146, 15.63670752],
                 [13.91153561, 13.88263641, 14.45372326, ..., 15.39162652,
                  15.53906548, 15.63557629]], shape=(4, 2000))
      • created_at :
        2026-04-29T07:01:01.821555+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2
        sampling_time :
        2.4053969383239746
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 3MB
      Dimensions:  (chain: 4, draw: 2000, y_dim_0: 52)
      Coordinates:
        * chain    (chain) int64 32B 0 1 2 3
        * draw     (draw) int64 16kB 0 1 2 3 4 5 6 ... 1994 1995 1996 1997 1998 1999
        * y_dim_0  (y_dim_0) int64 416B 0 1 2 3 4 5 6 7 8 ... 44 45 46 47 48 49 50 51
      Data variables:
          y        (chain, draw, y_dim_0) float64 3MB -4.192 -4.341 ... -5.052 -4.018
      Attributes:
          created_at:                 2026-04-29T07:01:02.037775+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • y_dim_0: 52
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • y_dim_0
          (y_dim_0)
          int64
          0 1 2 3 4 5 6 ... 46 47 48 49 50 51
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51])
        • y
          (chain, draw, y_dim_0)
          float64
          -4.192 -4.341 ... -5.052 -4.018
          array([[[-4.19249634, -4.34069809, -3.73173635, ..., -4.37747147,
                   -4.77380429, -4.18859621],
                  [-4.2216368 , -4.26158536, -3.64048881, ..., -4.44589524,
                   -5.06748383, -4.19037224],
                  [-4.58157553, -4.81109683, -3.77645348, ..., -4.87481429,
                   -4.55951978, -4.57418367],
                  ...,
                  [-4.29280794, -4.43236331, -3.85716979, ..., -4.45950855,
                   -4.54007285, -4.29033952],
                  [-4.45267258, -4.25861933, -3.77841268, ..., -4.68548475,
                   -4.92776938, -4.37029481],
                  [-4.61529636, -4.09226365, -3.92442451, ..., -4.83438118,
                   -5.00043134, -4.45704012]],
          
                 [[-4.4671442 , -5.27533975, -3.8038778 , ..., -4.73964856,
                   -4.23223525, -4.55877494],
                  [-4.19799593, -4.84227861, -3.77591491, ..., -4.4003413 ,
                   -4.47281669, -4.26917104],
                  [-4.28647723, -4.88501842, -3.84070234, ..., -4.48454951,
                   -4.37051443, -4.35304735],
          ...
                  [-4.31127974, -4.11768802, -3.95712088, ..., -4.42955319,
                   -4.65947501, -4.24744688],
                  [-4.30391737, -4.10350881, -3.91326109, ..., -4.43553273,
                   -4.71307433, -4.23645923],
                  [-4.22363364, -4.5966473 , -3.72570069, ..., -4.43384946,
                   -4.62487536, -4.25434789]],
          
                 [[-4.45608085, -4.68032542, -3.74252286, ..., -4.72235436,
                   -4.62338346, -4.4523465 ],
                  [-4.73264836, -4.93319978, -3.88815064, ..., -5.0291596 ,
                   -4.37990974, -4.71912776],
                  [-4.34879005, -4.87417347, -3.68797896, ..., -4.61894606,
                   -4.55661508, -4.39619539],
                  ...,
                  [-3.99885976, -4.06352932, -3.67326511, ..., -4.14464262,
                   -5.15277378, -3.98641289],
                  [-4.74595591, -4.67438556, -4.06625272, ..., -4.96939351,
                   -4.28183375, -4.69108163],
                  [-4.023294  , -4.13353554, -3.68906078, ..., -4.17367739,
                   -5.05223737, -4.01840229]]], shape=(4, 2000, 52))
      • created_at :
        2026-04-29T07:01:02.037775+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2

    • <xarray.Dataset> Size: 1MB
      Dimensions:                (chain: 4, draw: 2000)
      Coordinates:
        * chain                  (chain) int64 32B 0 1 2 3
        * draw                   (draw) int64 16kB 0 1 2 3 4 ... 1996 1997 1998 1999
      Data variables: (12/18)
          step_size_bar          (chain, draw) float64 64kB 0.1243 0.1243 ... 0.1162
          acceptance_rate        (chain, draw) float64 64kB 0.1184 0.9801 ... 0.9403
          energy_error           (chain, draw) float64 64kB 0.0 0.08551 ... -0.1477
          step_size              (chain, draw) float64 64kB 0.1157 0.1157 ... 0.131
          tree_depth             (chain, draw) int64 64kB 1 5 5 5 5 5 ... 4 4 5 5 3 4
          reached_max_treedepth  (chain, draw) bool 8kB False False ... False False
          ...                     ...
          smallest_eigval        (chain, draw) float64 64kB nan nan nan ... nan nan
          index_in_trajectory    (chain, draw) int64 64kB 0 -8 12 -20 8 ... 6 20 7 2 2
          lp                     (chain, draw) float64 64kB -225.2 -225.7 ... -227.2
          divergences            (chain, draw) int64 64kB 0 0 0 0 0 0 ... 0 0 0 0 0 0
          diverging              (chain, draw) bool 8kB False False ... False False
          energy                 (chain, draw) float64 64kB 228.0 226.4 ... 228.1
      Attributes:
          created_at:                 2026-04-29T07:01:01.839898+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
          sampling_time:              2.4053969383239746
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • step_size_bar
          (chain, draw)
          float64
          0.1243 0.1243 ... 0.1162 0.1162
          array([[0.12434117, 0.12434117, 0.12434117, ..., 0.12434117, 0.12434117,
                  0.12434117],
                 [0.12428381, 0.12428381, 0.12428381, ..., 0.12428381, 0.12428381,
                  0.12428381],
                 [0.12284527, 0.12284527, 0.12284527, ..., 0.12284527, 0.12284527,
                  0.12284527],
                 [0.11622674, 0.11622674, 0.11622674, ..., 0.11622674, 0.11622674,
                  0.11622674]], shape=(4, 2000))
        • acceptance_rate
          (chain, draw)
          float64
          0.1184 0.9801 0.64 ... 1.0 0.9403
          array([[0.11840362, 0.98005072, 0.63996002, ..., 0.72378318, 0.73975165,
                  0.75105456],
                 [0.65315679, 0.996179  , 0.96513164, ..., 0.99989967, 0.98012689,
                  0.27574276],
                 [0.99496016, 0.98691321, 0.98888783, ..., 0.72110549, 0.94554398,
                  0.80552803],
                 [0.14357896, 0.40182469, 0.98840607, ..., 0.58085816, 1.        ,
                  0.94029411]], shape=(4, 2000))
        • energy_error
          (chain, draw)
          float64
          0.0 0.08551 ... -0.07972 -0.1477
          array([[ 0.00000000e+00,  8.55142207e-02,  3.38458804e-03, ...,
                   1.78658902e-02, -4.85554049e-01,  1.44671301e-01],
                 [ 4.05956311e-01, -1.02448354e-01, -5.01368541e-02, ...,
                  -2.08095814e-01,  4.42546064e-02,  0.00000000e+00],
                 [-1.19643706e-03, -3.38656948e-01,  1.73795062e-02, ...,
                   1.90430915e-01, -1.36696451e-02,  1.21374003e-02],
                 [ 0.00000000e+00,  9.11739382e-01, -8.87318753e-01, ...,
                   1.31991142e+00, -7.97210167e-02, -1.47701111e-01]],
                shape=(4, 2000))
        • step_size
          (chain, draw)
          float64
          0.1157 0.1157 ... 0.131 0.131
          array([[0.11574222, 0.11574222, 0.11574222, ..., 0.11574222, 0.11574222,
                  0.11574222],
                 [0.13938968, 0.13938968, 0.13938968, ..., 0.13938968, 0.13938968,
                  0.13938968],
                 [0.16130634, 0.16130634, 0.16130634, ..., 0.16130634, 0.16130634,
                  0.16130634],
                 [0.13096899, 0.13096899, 0.13096899, ..., 0.13096899, 0.13096899,
                  0.13096899]], shape=(4, 2000))
        • tree_depth
          (chain, draw)
          int64
          1 5 5 5 5 5 5 3 ... 5 5 4 4 5 5 3 4
          array([[1, 5, 5, ..., 5, 5, 4],
                 [5, 4, 5, ..., 5, 5, 2],
                 [5, 2, 5, ..., 5, 4, 6],
                 [1, 1, 5, ..., 5, 3, 4]], shape=(4, 2000))
        • reached_max_treedepth
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 2000))
        • max_energy_error
          (chain, draw)
          float64
          2.134 -0.255 1.41 ... -1.337 -1.313
          array([[ 2.13365602, -0.25495097,  1.4104675 , ...,  0.95357606,
                   1.02554703,  0.59345248],
                 [ 0.88422827, -0.22137917, -0.15194501, ..., -0.25576827,
                   0.06597192,  1.5528476 ],
                 [-0.40599127, -0.33865695,  0.04267365, ...,  0.70190611,
                   0.12040347,  0.67319792],
                 [ 1.94087013,  0.91173938, -0.94726767, ...,  1.38624232,
                  -1.33652195, -1.31288658]], shape=(4, 2000))
        • n_steps
          (chain, draw)
          float64
          1.0 17.0 31.0 ... 19.0 7.0 15.0
          array([[ 1., 17., 31., ..., 31., 31., 15.],
                 [31., 15., 31., ..., 31., 31.,  3.],
                 [31.,  3., 31., ..., 31., 15., 47.],
                 [ 1.,  1., 31., ..., 19.,  7., 15.]], shape=(4, 2000))
        • largest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]], shape=(4, 2000))
        • process_time_diff
          (chain, draw)
          float64
          4.9e-05 0.00049 ... 0.000366
          array([[4.900e-05, 4.900e-04, 7.230e-04, ..., 7.370e-04, 7.400e-04,
                  3.690e-04],
                 [7.140e-04, 3.660e-04, 8.290e-04, ..., 8.250e-04, 7.650e-04,
                  9.100e-05],
                 [7.780e-04, 1.010e-04, 7.750e-04, ..., 7.670e-04, 3.770e-04,
                  1.878e-03],
                 [4.700e-05, 4.500e-05, 8.170e-04, ..., 4.590e-04, 1.860e-04,
                  3.660e-04]], shape=(4, 2000))
        • perf_counter_start
          (chain, draw)
          float64
          226.0 226.0 226.0 ... 227.5 227.5
          array([[225.97933542, 225.97953079, 225.98011096, ..., 227.53929063,
                  227.54006929, 227.54084879],
                 [225.93518233, 225.93593833, 225.93634279, ..., 227.34164083,
                  227.34263404, 227.34344412],
                 [225.97125996, 225.97208767, 225.97223067, ..., 227.52950067,
                  227.53031429, 227.53101396],
                 [225.95602512, 225.95611546, 225.9562585 , ..., 227.52954637,
                  227.53004487, 227.53027258]], shape=(4, 2000))
        • perf_counter_diff
          (chain, draw)
          float64
          4.888e-05 0.0005233 ... 0.0003672
          array([[4.887500e-05, 5.232920e-04, 7.238750e-04, ..., 7.375410e-04,
                  7.400000e-04, 3.690840e-04],
                 [7.136250e-04, 3.664170e-04, 8.365000e-04, ..., 9.263750e-04,
                  7.645420e-04, 9.162500e-05],
                 [7.772080e-04, 1.016250e-04, 7.749590e-04, ..., 7.673340e-04,
                  3.770840e-04, 1.957625e-03],
                 [4.754100e-05, 4.545800e-05, 8.171250e-04, ..., 4.591250e-04,
                  1.863330e-04, 3.672080e-04]], shape=(4, 2000))
        • smallest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]], shape=(4, 2000))
        • index_in_trajectory
          (chain, draw)
          int64
          0 -8 12 -20 8 -20 ... 3 6 20 7 2 2
          array([[  0,  -8,  12, ...,  12,  22,  10],
                 [ 20,   7,   5, ..., -27, -12,   0],
                 [  9,   1,  -9, ...,  20,   5,   9],
                 [  0,   1,  -8, ...,   7,   2,   2]], shape=(4, 2000))
        • lp
          (chain, draw)
          float64
          -225.2 -225.7 ... -227.5 -227.2
          array([[-225.16529992, -225.66063281, -225.43843924, ..., -225.78117096,
                  -225.6910413 , -227.31540968],
                 [-226.35439073, -226.59642119, -227.05124733, ..., -226.08468075,
                  -226.47486369, -226.47486369],
                 [-226.4003464 , -225.5542709 , -225.82968845, ..., -226.54791356,
                  -225.88330276, -225.08892583],
                 [-224.86754612, -226.42599904, -224.74926   , ..., -227.64527605,
                  -227.53536551, -227.19140806]], shape=(4, 2000))
        • divergences
          (chain, draw)
          int64
          0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0
          array([[0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0]], shape=(4, 2000))
        • diverging
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 2000))
        • energy
          (chain, draw)
          float64
          228.0 226.4 227.1 ... 227.7 228.1
          array([[227.98845339, 226.40383196, 227.11593265, ..., 227.15836049,
                  226.86492647, 228.11607301],
                 [227.24639337, 226.91437031, 227.42072247, ..., 226.99657392,
                  227.18777349, 231.09341303],
                 [228.08072027, 226.15333004, 226.68983539, ..., 227.27078002,
                  226.86418622, 227.98886932],
                 [226.54017161, 226.5628077 , 225.97520043, ..., 227.89988843,
                  227.73228783, 228.1159298 ]], shape=(4, 2000))
      • created_at :
        2026-04-29T07:01:01.839898+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2
        sampling_time :
        2.4053969383239746
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 832B
      Dimensions:  (y_dim_0: 52)
      Coordinates:
        * y_dim_0  (y_dim_0) int64 416B 0 1 2 3 4 5 6 7 8 ... 44 45 46 47 48 49 50 51
      Data variables:
          y        (y_dim_0) float64 416B 103.0 79.0 122.0 78.0 ... 98.0 134.0 99.0
      Attributes:
          created_at:                 2026-04-29T07:01:01.844280+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
      xarray.Dataset
        • y_dim_0: 52
        • y_dim_0
          (y_dim_0)
          int64
          0 1 2 3 4 5 6 ... 46 47 48 49 50 51
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51])
        • y
          (y_dim_0)
          float64
          103.0 79.0 122.0 ... 134.0 99.0
          array([103.,  79., 122.,  78., 135., 117., 135., 123., 109., 126., 124.,
                 126., 130., 117., 133., 118., 127., 118., 138., 143., 135., 129.,
                  91.,  95., 121., 131., 144., 112., 132., 121., 121., 126., 116.,
                 112., 114., 121., 126., 122., 118., 123.,  87., 136., 120., 140.,
                 125.,  86., 145., 128., 121.,  98., 134.,  99.])
      • created_at :
        2026-04-29T07:01:01.844280+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2

11.4 Comparing data to replications from a fitted model¶

Posterior predictive checks: simulate data from the fitted model and compare to observed data. If model is a good fit, then simulated data should look like observed data.

Example: simulation-based checking of a fitted normal distribution¶

Demonstrate how lack of fit can be seen using predictive replications.

Inappropriately fit normal distribution to the data, by fitting a model with no predictors. Implicit model for linear regression is normally distributed errors.

Simon Newcomb’s measurements for estimating the speed of light, from Stigler (1977). The data represent the amount of time required for light to travel a distance of 7442 meters and are recorded as deviations from 24 800 nanoseconds.

In [7]:
# read csv
newcomb = pd.read_csv('../ros_data/newcomb.txt', sep=' ')
display(newcomb.head())

#plot histogram of the data
sns.histplot(newcomb['y'], bins=30, kde=True)
plt.title('Newcomb\'s Speed of Light Measurements')
plt.xlabel('Time (microseconds)')
plt.ylabel('Frequency')
plt.show()

# using pymc to fit a normal distribution to the data, with no predictors (just an intercept) - dont use fit_and_plot_bayes since it expects a predictor, so we will write a custom pymc model for this
with pm.Model() as model:
    mu = pm.Normal('mu', mu=0, sigma=50)  # Prior for the mean (intercept).
    sigma = pm.HalfNormal('sigma', sigma=50)  # Prior for the standard deviation (must be positive).
    likelihood = pm.Normal('y', mu=mu, sigma=sigma, observed=newcomb['y'].values)  # Likelihood of observed data given mu and sigma.
    trace = pm.sample(2000, tune=1000, chains=4)  # Sample from the posterior distribution using NUTS.

az.plot_trace(trace)  # Plot the MCMC traces and marginal densities for mu and sigma.
plt.tight_layout()  # Adjust subplot spacing.
plt.show()  # Render the trace plot.

# Stack the posterior samples of mu and sigma into a 2-column matrix.
# Each row is one posterior draw: [mu_draw, sigma_draw].
# This mirrors R's `sims <- as.matrix(fit)`.
sims = np.column_stack([
    trace.posterior["mu"].values.ravel(),       # flatten chains x draws into 1D
    trace.posterior["sigma"].values.ravel(),
])

# Total number of posterior draws we have available.
n_sims = sims.shape[0]

# Build fake ("replicated") datasets from the fitted model.
# Seeded RNG so results are reproducible.
rng = np.random.default_rng(0)

# Pre-allocate an (n_sims x n) array; each row will be one simulated dataset of size n.
y_rep = np.empty((n_sims, n))

# For each posterior draw, simulate n new observations from Normal(mu, sigma).
for s in range(n_sims):
    y_rep[s] = rng.normal(sims[s, 0], sims[s, 1], size=n)

# Reset the RNG and pick 20 random draws to visualize (no repeats).
rng = np.random.default_rng(0)
sample_idx = rng.choice(n_sims, size=20, replace=False)

# Create a 5-row x 4-column grid of subplots (20 panels total).
fig, axes = plt.subplots(5, 4, figsize=(12, 10))

# Plot a histogram of one replicated dataset in each panel.
for ax, s in zip(axes.flat, sample_idx):       # axes.flat iterates all 20 subplots
    ax.hist(y_rep[s, :], bins=15, edgecolor="black")
    ax.set_title(f"draw {s}", fontsize=8)      # label which posterior draw it came from

plt.tight_layout()   # stop titles/axes from overlapping
plt.show()           # render the figure
y
0 28
1 26
2 33
3 24
4 34
No description has been provided for this image
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 4 seconds.
No description has been provided for this image
No description has been provided for this image

Very low values are not replicated in the simulated data, which suggests that the model is not a good fit to the data.

Checking model fit using a numerical data summary¶
In [8]:
# Define a test statistic: here, the minimum of a dataset.
def test(y):
    return np.min(y)

# Apply the test statistic to each row of y_rep.
# axis=1 means: collapse columns, one result per row (one per replicated dataset).
# Equivalent to R's apply(y_rep, 1, test).
test_rep = np.apply_along_axis(test, 1, y_rep)
# Simpler / faster alternative: test_rep = y_rep.min(axis=1)

# Observed test statistic (minimum of the real data).
t_obs = test(newcomb["y"].values)

# xlim spans both the observed value and the full range of replicated minima
# so the observed-data line is always visible on the axis.
lo = min(t_obs, test_rep.min())
hi = max(t_obs, test_rep.max())

fig, ax = plt.subplots(figsize=(8, 5))
ax.hist(test_rep, bins=30, edgecolor="black", range=(lo, hi))

# Vertical line marking the observed test statistic.
# R: lines(rep(test(y),2), c(0,n))
ax.axvline(t_obs, color="red", linewidth=2)

ax.set_xlabel("min(y_rep)")
ax.set_ylabel("Frequency")
ax.set_title("Posterior predictive check: minimum")
plt.show()
No description has been provided for this image

11.5 Example: predictive simulation to check the fit of a time-series model¶

Predictive simulation more complex for time-series data, which are setup so that the distribution for each point depends on earlier data.

Fitting a first-order autoregression to the unemployment series¶
In [9]:
# read csv
unemp = pd.read_csv('../ros_data/unemp.txt', sep=' ')
display(unemp.head())

# plot year vs y
sns.lineplot(data=unemp, x='year', y='y')
plt.title('Unemployment Rate Over Time')
plt.xlabel('Year')
plt.ylabel('Unemployment Rate (%)')
plt.show()

unemp['y_lag'] = unemp['y'].shift(1)

unemp_model = fit_and_plot_bayes(data=unemp.dropna(), predictors=['y_lag'], outcome='y',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=False, show_forest=False,
                       show_posterior=False, show_regression=True,
                       show_residuals=True,
                       n_regression_lines=100)
year y
0 1947 3.9
1 1948 3.8
2 1949 5.9
3 1950 5.3
4 1951 3.3
No description has been provided for this image
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_y_lag, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 1 seconds.
              mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  ess_bulk  \
intercept    1.365  0.472     0.414      2.250      0.009    0.007    2951.0   
slope_y_lag  0.767  0.078     0.618      0.923      0.001    0.001    2994.0   
sigma        1.043  0.092     0.879      1.238      0.002    0.001    3693.0   

             ess_tail  r_hat  
intercept      2811.0    1.0  
slope_y_lag    2563.0    1.0  
sigma          3801.0    1.0  

Regression formula: y = 1.37 + 0.77*y_lag
Residual std dev (σ): 1.04 ± 0.09
Bayesian R²: 0.592 ± 0.051
LOO-ELPD: -102.03 ± 8.29  (p_loo=3.7)
LOO log score (per obs): -1.479 ± 0.120
No description has been provided for this image
No description has been provided for this image
In [10]:
# print(unemp_model)
posterior = unemp_model.posterior
# print(posterior)

intercept_sims = posterior["intercept"].values.ravel()
slope_sims = posterior["slope_y_lag"].values.ravel()
sigma_sims = posterior["sigma"].values.ravel()

print("Posterior mean of intercept:", np.mean(intercept_sims))
print("Posterior mean of slope:", np.mean(slope_sims))
print("Posterior mean of sigma:", np.mean(sigma_sims))

# len of sims
print("Number of posterior draws:", len(intercept_sims))

n = len(unemp) # number of time points in the original data
n_sims = len(intercept_sims) # number of posterior draws
rng = np.random.default_rng(0) # seeded RNG for reproducibility

y_rep = np.full((n_sims, n), np.nan) # Start with an array of NaNs; we'll fill in the first column with the initial value and then iteratively fill in the rest.
# print(y_rep)

for s in range(n_sims):
    y_rep[s, 0] = unemp['y'].iloc[0] # Set the first value of each simulated series to the same initial value as the real data.
    for t in range(1, n):
        y_rep[s, t] = intercept_sims[s] + slope_sims[s] * y_rep[s, t-1] + rng.normal(0, sigma_sims[s]) # Simulate the next value based on the previous one using the AR(1) model with noise.
        
print(y_rep.shape) # Should be (n_sims, n)

# plot a random 10 simulated series, each as sub plots in a 5x2 grid and plot the original data on top of each subplot for comparison
rng = np.random.default_rng(0)
sample_idx = rng.choice(n_sims, size=10, replace=False) # Randomly select 10 posterior draws to visualize.
fig, axes = plt.subplots(5, 2, figsize=(12, 15)) # Create a 5-row x 2-column grid of subplots.
for ax, s in zip(axes.flat, sample_idx):
    ax.plot(unemp['year'], y_rep[s, :], label='Simulated', alpha=0.7) # Plot the simulated series for this draw.
    ax.plot(unemp['year'], unemp['y'], label='Observed', color='red') # Overlay the original data in red for comparison.
    ax.set_title(f"Draw {s}") # Title indicating which posterior draw it is.
    ax.set_xlabel('Year') # x-axis label.
    ax.set_ylabel('Unemployment Rate (%)') # y-axis label.
    ax.legend() # Show legend to differentiate simulated vs observed.
plt.tight_layout() # Adjust subplot spacing.
plt.show() # Render the figure.
Posterior mean of intercept: 1.365171392702648
Posterior mean of slope: 0.7671765233022904
Posterior mean of sigma: 1.042686294763032
Number of posterior draws: 8000
(8000, 70)
No description has been provided for this image
Visual and numerical comparisons of replicated to actual data¶

Data appears much more jagged than the simulated data. We quantify this below.

In [11]:
# Test statistic: count the number of turning points in a series.
# A turning point is where the direction changes (rising to falling or vice versa).
def test(y):
    n = len(y)
    y_lag = np.roll(y, 1)       # y[t-1]
    y_lag[0] = np.nan
    y_lag_2 = np.roll(y, 2)     # y[t-2]
    y_lag_2[:2] = np.nan
    diff1 = np.sign(y - y_lag)          # direction of current move
    diff2 = np.sign(y_lag - y_lag_2)    # direction of previous move
    return np.nansum(diff1 != diff2)    # count where direction changed

# Observed test statistic
test_y = test(unemp['y'].values)

# Apply to each simulated series (each row of y_rep)
test_rep = np.apply_along_axis(test, 1, y_rep)

# Plot
lo = min(test_y, test_rep.min())
hi = max(test_y, test_rep.max())

fig, ax = plt.subplots(figsize=(8, 5))
ax.hist(test_rep, bins=30, edgecolor="black", range=(lo, hi))
ax.axvline(test_y, color="red", linewidth=2, label=f"Observed = {test_y}")
ax.set_xlabel("Number of turning points")
ax.set_ylabel("Frequency")
ax.set_title("Posterior predictive check: turning points")
ax.legend()
plt.show()
No description has been provided for this image

11.6 Residual standard deviation σ and explained variance R²¶

Residual standard deviation \sigma is the standard deviation of the residuals. It is a measure of how far the data points are from the fitted regression line, on average.

For test score example, \sigma = 15. This means that the typical error in predicting a student’s test score is about 15 points.

The size of \sigma only means something when compared with how much the data vary. If the data vary a lot, then a large \sigma might be acceptable. If the data do not vary much, then a large \sigma might indicate a poor fit.

Generally small \sigma is better, but it depends on the variance of the data.

Model fit summarised by \sigma and by $R^2$ (coefficient of determination). $R^2$ is the proportion of variance in the data that is explained by the model. The unexplained variance is $\sigma^2$ and if we label $s_{y}$ as the standard deviation of the data, then $R^2 = 1 - \frac{\sigma^2}{s_y^2}$.

When fit using least squares $R^2 = V_{i=1}^n \frac{\hat{y}_i}{s_y})^2$.

Where $\hat{y}_i = X_i\hat{\beta}$ and we are using the nottion V for the sample variance: $V_{i=1}^n z_i = \frac{1}{n-1} \sum_{i=1}^n (z_i - \bar{z})^2$ for any vector z of length n.

$R^2$ is between 0 and 1, with higher values indicating better fit.

$R^2$ does not care about units - we can rescale the data and $R^2$ will be the same.

With one predictor, $R^2$ is the square of the correlation between the predictor and the outcome.

Degrees of freedom (n-k): number of data points minus number of parameters.

Difficulties in interpreting residual standard deviation and explained variance¶

We are generally more interested in the deterministic component of the model, $X \beta$, than the variation $\epsilon$. But when look at standard deviation, we are interested in the unexplained variation in the data (how much scatter remains), or for precision (bigger \sigma means less precision).

$R^2$ can be misleading because it is a fraction: unexplained variance divided by total variance. Top and bottom part of the fraction can change independently. Therefore, $R^2$ can shift even if the model fit does not change.

Bayesian R²¶

Bayesian methods are concerned with uncertainty, so the regular $R^2$ formulas do not work for two reasons: 1. they ignore uncertainty and 2. they can give values greater than 1 and less than 0. This can arise if we have strong prior belief, the model might produce fitted values that vary more than the actual data, giving a $R^2$ greater than 1.

Bayesian $R^2$ is defined as: variance of predictions divided by variance of predictions plus variance of residuals. This asks of all the variation in the data, what fraction comes from the models predictions vs leftover noise?

Bayesian $R^2$ gives us $R^2$ values for each draw, so we get a distribution of $R^2$ values - telling us "how well does the model fit?" and "how certain are we about that fit?".

11.7 External validation: checking fitted model on new data¶

Most fundamental way to test a model is to make predictions and compare to actual data it has never seen before.

  1. Fit model on original (old) data
  2. Feed new cases into the model and make predictions
  3. Compare predictions to actual data

Plot actual vs predicted values for new data. If model is a good fit, then points should be close to the line of equality (y = x).

Plot residuals (actual - predicted) vs predicted values. If model is a good fit, then residuals should be randomly scattered around zero, with no clear pattern.

11.8 Cross validation¶

Split data into training and test (hold-out) sets. Fit model on training set, evaluate on test set. Hold-out set can be used as a proxy for new data. If no prediction for future, this can be seen as a way to check how well model generalises to new data.

We can:

  1. Hold out individual data points (leave-one-out cross validation)
  2. Hold out groups of data points (leave-one-group-out cross validation)
  3. Use past data to predict future data (leave-future-out cross validation)

Leave-one-group-out and leave-future-out cross validation is useful for timeseries data, where we want to predict future data from past data.

Cross-validation removes overfitting arising from using same data to fit and evaluate the model.

Leave-one-out cross validation (LOO)¶

Naive approach: fit model n times, once for each held-out data point. This is computationally expensive.

In [12]:
# simulated dataset

n = 20
# array 1:n
x = np.arange(1, n + 1)
a = 0.2
b = 0.3
sigma = 1
# set seed for reproducibility
rng = np.random.default_rng(2141)
y = a + b * x + rng.normal(0, sigma, n)
data = pd.DataFrame({"x": x, "y": y})
display(data.head())

fit_and_plot_bayes(data, 'x', 'y',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=False, show_forest=False,
                       show_posterior=False, show_regression=True,
                       show_residuals=True,
                       n_regression_lines=100)

# remove 15th observation
data_dropped = data.drop(index=15).reset_index(drop=True)
display(data_dropped.head())

new_model = fit_and_plot_bayes(data_dropped, 'x', 'y',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=False, show_forest=False,
                       show_posterior=False, show_regression=True,
                       show_residuals=True,
                       n_regression_lines=100)

# predict the 15 observation using the new model
posterior = new_model.posterior
intercept_sims = posterior["intercept"].values.ravel()
slope_sims = posterior["slope_x"].values.ravel()
sigma_sims = posterior["sigma"].values.ravel()
x_new = 15
y_pred_sims = intercept_sims + slope_sims * x_new + rng.normal(0, sigma_sims)
y_pred_mean = np.mean(y_pred_sims)
y_pred_hdi = az.hdi(y_pred_sims, hdi_prob=0.95)
print(f"Predicted value for x=15: {y_pred_mean:.2f} (95% HDI: [{y_pred_hdi[0]:.2f}, {y_pred_hdi[1]:.2f}])")

# plot distribution of predicted values for x=15
sns.histplot(y_pred_sims, bins=30, kde=True)
plt.axvline(y_pred_mean, color='red', label='Predicted mean')
plt.axvline(y_pred_hdi[0], color='blue', linestyle='--', label='95% HDI')
plt.axvline(y_pred_hdi[1], color='blue', linestyle='--')
plt.title('Posterior predictive distribution for x=15')
plt.xlabel('Predicted y')
plt.ylabel('Frequency')
plt.legend()
plt.show()
x y
0 1 0.640604
1 2 2.168083
2 3 2.893252
3 4 1.091838
4 5 2.274745
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_x, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 1 seconds.
            mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  ess_bulk  \
intercept  0.305  0.541    -0.740      1.408      0.010    0.008    3184.0   
slope_x    0.300  0.045     0.211      0.387      0.001    0.001    3181.0   
sigma      1.156  0.217     0.774      1.576      0.004    0.004    3585.0   

           ess_tail  r_hat  
intercept    3740.0    1.0  
slope_x      3513.0    1.0  
sigma        3090.0    1.0  

Regression formula: y = 0.30 + 0.30*x
Residual std dev (σ): 1.16 ± 0.22
Bayesian R²: 0.722 ± 0.070
LOO-ELPD: -31.89 ± 2.56  (p_loo=2.4)
LOO log score (per obs): -1.595 ± 0.128
No description has been provided for this image
No description has been provided for this image
x y
0 1 0.640604
1 2 2.168083
2 3 2.893252
3 4 1.091838
4 5 2.274745
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_x, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 1 seconds.
            mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  ess_bulk  \
intercept  0.395  0.491    -0.548      1.390      0.009    0.007    2801.0   
slope_x    0.281  0.041     0.201      0.362      0.001    0.001    2946.0   
sigma      1.054  0.196     0.706      1.443      0.004    0.003    2888.0   

           ess_tail  r_hat  
intercept    3453.0    1.0  
slope_x      3219.0    1.0  
sigma        2493.0    1.0  

Regression formula: y = 0.39 + 0.28*x
Residual std dev (σ): 1.05 ± 0.20
Bayesian R²: 0.735 ± 0.070
LOO-ELPD: -28.49 ± 2.42  (p_loo=2.4)
LOO log score (per obs): -1.499 ± 0.128
No description has been provided for this image
No description has been provided for this image
Predicted value for x=15: 4.60 (95% HDI: [2.44, 6.85])
No description has been provided for this image

Model predicts y=4.6 for x=15, which is far from the original value of 7.5.

Fast leave-one-out cross validation¶

In Bayes, the posterior is proportional to the likelihood times the prior. The likelihood is a product of n terms, one for each data point. If we want to leave out one data point, we can just divide out the likelihood term for that data point, which is equivalent to leaving out that data point from the posterior. This allows us to compute the leave-one-out posterior without refitting the model n times. The idea is that when we leave one data point out, if the data point was very likely it gets downweighted, which mimics the posterior we would get if we refit the model without that data point. If the data point was very unlikely, then it gets upweighted.

Reweighted samples then approximate the predictive distribution for the held-out data point.

Smoothing exists to prevent extreme weights from dominating the approximation. This is called Pareto smoothed importance sampling (PSIS-LOO). It tames the extreme weights.

Key note: dividing by a likelihood factor is equivalent to never having observed the data point.

Summarizing prediction error using the log score and deviance¶

Various methods for assessing accuracy of model predictions. Residual standard deviation and $R^2$ are interpretable for linear but not for logistic and other discrete models. Residual standard deviation and $R^2$ ignore uncertainty.

Log score is more suitable for probability models - evaluate how well a probability model predicts the outcome. Take the log of the probability assigned to the observed outcome by the model. Higher log score indicates better predictive performance.

The log score rewards models that make tight predictions (small variance) that land close to the actual data (small residuals). Smaller variance improves one part of the log score, but if the predictions are too tight and miss the data, then the log score will be worse. The log score balances these two aspects of predictive performance.

Expected log predictive density (ELPD) is the sum of the log scores for each data point. It is a measure of how well the model predicts new data. Higher ELPD indicates better predictive performance.

Deviance is a measure of model fit that is based on the log score. It is defined as -2 times the log score. Lower deviance indicates better model fit.

Log score goes beyond linear regression and can be used for any model that produces a probability distribution for the outcome. It is a more general measure of predictive performance than residual standard deviation and $R^2$. It also evaluates the entire predictive distribution, not just the mean prediction. This is important because a model that makes accurate predictions with high uncertainty may be more useful than a model that makes inaccurate predictions with low uncertainty.

Overfitting and AIC¶

Models log score on data used to fit the model will be higher than log score on new data. Cross validation solved this by evaluating each point as if it was new. AIC (Akaike Information Criterion) takes a different approach.

AIC inflates the within-sample log score by roughly one unit. So when fitting k parameters, just subtract k from the log score. On deviance scale, this is equivalent to adding 2k to the deviance.

AIC connects to degrees of freedom, which is the number of parameters in the model. The more parameters, the more flexible the model, and the better it can fit the data. However, a more flexible model is also more likely to overfit the data, which is why AIC penalizes models with more parameters.

Cross validation makes fewer assumptions than AIC, but is more computationally expensive. For small samples or complex models, cross validation is better.

Interpreting differences in log scores¶

Log scores are only for comparing two models.

Adding useless (noise) predictors makes training score look 0.5 better, but makes the honest LOO score 0.5 worse - this is overfitting in action.

In practice, if adding a predictor and LOO score drops by 0.5, predictor is noise. If LOO score improves or remains flat, predictor is useful. This scales linearly, so adding 10 noise predictors would drop LOO score by 5.

Small fluctuation between models, however, could be random. So must pay attention to the uncertainty in the comparison.

Demonstration of adding pure noise predictors to a model¶

When adding predictors, even if pure noise, fit to data and within sample predictive measures will generally improve. We want to fit to data, but with finite data, model will also fit to noise. We should understand the amount of noise in data.

In [13]:
kidiq = pd.read_csv('../ros_data/kidiq.csv', skiprows=0)
display(kidiq.head())

# fit_and_plot_lm(kidiq, ['mom_hs'], 'kid_score', add_constant=True, show_plot=True, scatter_kws=None, line_kws=None)

fit1 = fit_and_plot_bayes(kidiq, ['mom_hs', 'mom_iq'], 'kid_score',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=False, show_forest=False,
                       show_posterior=False, show_regression=False,
                       n_regression_lines=100)

# add 5 pure noise predictors to the data — matches R: array(rnorm(5*n), c(n,5))
n = len(kidiq)
noise = rng.normal(0, 1, size=5 * n).reshape(n, 5, order='F')
for i in range(5):
    kidiq[f'noise_{i+1}'] = noise[:, i]

fit2 = fit_and_plot_bayes(kidiq, ['mom_hs', 'mom_iq', 'noise_1', 'noise_2', 'noise_3', 'noise_4', 'noise_5'], 'kid_score',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=False, show_forest=False,
                       show_posterior=False, show_regression=False,
                       n_regression_lines=100)

fit3 = fit_and_plot_bayes(kidiq, ['mom_hs'], 'kid_score',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=False, show_forest=False,
                       show_posterior=False, show_regression=False,
                       n_regression_lines=100)

az.compare({'model1': fit1, 'model2': fit2, 'model3': fit3}, ic='loo')
kid_score mom_hs mom_iq mom_work mom_age
0 65 1 121.117529 4 27
1 98 1 89.361882 4 25
2 85 1 115.443165 4 27
3 83 1 99.449639 3 25
4 115 1 92.745710 4 27
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_mom_hs, slope_mom_iq, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 3 seconds.
                mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  \
intercept     25.315  5.943    13.465     36.531      0.093    0.075   
slope_mom_hs   5.895  2.196     1.707     10.264      0.031    0.027   
slope_mom_iq   0.568  0.061     0.445      0.685      0.001    0.001   
sigma         18.186  0.623    17.039     19.471      0.009    0.008   

              ess_bulk  ess_tail  r_hat  
intercept       4126.0    4585.0    1.0  
slope_mom_hs    5167.0    4453.0    1.0  
slope_mom_iq    4061.0    4277.0    1.0  
sigma           5310.0    4574.0    1.0  

Regression formula: kid_score = 25.31 + 5.89*mom_hs + 0.57*mom_iq
Residual std dev (σ): 18.19 ± 0.62
Bayesian R²: 0.217 ± 0.031
LOO-ELPD: -1876.06 ± 14.19  (p_loo=4.0)
LOO log score (per obs): -4.323 ± 0.033
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_mom_hs, slope_mom_iq, slope_noise_1, slope_noise_2, slope_noise_3, slope_noise_4, slope_noise_5, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 5 seconds.
                 mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  \
intercept      26.137  5.918    14.888     37.916      0.079    0.067   
slope_mom_hs    5.948  2.245     1.509     10.193      0.022    0.029   
slope_mom_iq    0.560  0.061     0.439      0.680      0.001    0.001   
slope_noise_1   0.880  0.857    -0.802      2.558      0.008    0.010   
slope_noise_2  -0.360  0.885    -2.090      1.423      0.009    0.012   
slope_noise_3  -0.878  0.856    -2.545      0.796      0.009    0.010   
slope_noise_4  -0.491  0.891    -2.326      1.202      0.009    0.010   
slope_noise_5   0.404  0.861    -1.271      2.101      0.008    0.011   
sigma          18.234  0.613    17.078     19.463      0.006    0.007   

               ess_bulk  ess_tail  r_hat  
intercept        5609.0    5608.0    1.0  
slope_mom_hs    10261.0    4910.0    1.0  
slope_mom_iq     5538.0    5411.0    1.0  
slope_noise_1   11330.0    5827.0    1.0  
slope_noise_2   10340.0    5061.0    1.0  
slope_noise_3    9593.0    5846.0    1.0  
slope_noise_4   10568.0    5944.0    1.0  
slope_noise_5   10876.0    5628.0    1.0  
sigma           10033.0    5774.0    1.0  

Regression formula: kid_score = 26.14 + 5.95*mom_hs + 0.56*mom_iq + 0.88*noise_1 + -0.36*noise_2 + -0.88*noise_3 + -0.49*noise_4 + 0.40*noise_5
Residual std dev (σ): 18.23 ± 0.61
Bayesian R²: 0.226 ± 0.030
Initializing NUTS using jitter+adapt_diag...
LOO-ELPD: -1879.66 ± 14.09  (p_loo=8.9)
LOO log score (per obs): -4.331 ± 0.032
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_mom_hs, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 1 seconds.
                mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  \
intercept     77.439  2.057    73.581     81.559      0.034    0.026   
slope_mom_hs  11.863  2.306     7.352     16.294      0.038    0.028   
sigma         19.908  0.664    18.643     21.226      0.009    0.008   

              ess_bulk  ess_tail  r_hat  
intercept       3560.0    4081.0    1.0  
slope_mom_hs    3631.0    3936.0    1.0  
sigma           5234.0    4375.0    1.0  

Regression formula: kid_score = 77.44 + 11.86*mom_hs
Residual std dev (σ): 19.91 ± 0.66
Bayesian R²: 0.058 ± 0.021
LOO-ELPD: -1914.77 ± 13.75  (p_loo=3.0)
LOO log score (per obs): -4.412 ± 0.032
Out[13]:
rank elpd_loo p_loo elpd_diff weight se dse warning scale
model1 0 -1876.059170 4.007582 0.000000 9.843872e-01 14.189357 0.000000 False log
model2 1 -1879.663074 8.926097 3.603904 9.466809e-16 14.088293 1.788628 False log
model3 2 -1914.768343 3.019895 38.709173 1.561281e-02 13.754367 8.401360 False log

Interpretting table above: model 1 is best (rank 0 and highest elpd_loo). Model 2 is close second. Model 3 is clearly worse than model 1 and 2.

If difference in elpd_diff is greater than 4, then we can say that the model with the higher elpd_loo is better. If difference in elpd_diff is less than 4 it is basically noise. Additionally, difference should be greater than 4 times the standard error of the difference to be considered significant.

In example above, difference between 1 and 2 is 3.8, so we cannot say that model 1 is better than model 2. Difference between 1 and 3 is 38.7, and the DSE is 8.4

K-fold cross validation¶

LOO is computationally expensive. It takes data removes one point, fits model, makes a prediction for that point, and repeats this for each point.

LOO uses a clever approximation where we fit model once then mathematically estimates what would happen if we leave each one out. Does this using weights.

This doesnt work when a data point is very unusual, which means its very unlikely under the model.

K-fold cross validation is an alternative to LOO that is less computationally expensive. It:

  1. Randomly splits data into K groups (folds)
  2. For each fold, fits model on remaining K-1 folds and evaluates on held-out fold
  3. Only need K model fits, which is much less than n for LOO

K=10 is a common choice for K. It provides a good balance between bias and computational efficiency.

Demonstration of K-fold cross validation¶

I cannot find a suitable way to demonstrate K-fold cross validation using a Python package.

Concerns about model selection¶

Dont simply pick the model with the best cross validation score:

  1. Averaging across models can be better than picking the single best. If we want to predict new data, combining predictions from multiple models can often give better performance than picking the single best model.
  2. Sample might not represent population. If we apply the model to a different population, cross validation score might not be a good indicator of performance. Errors might need to be reweighted to reflect the new population - called post-stratification.
  3. Build a bigger model. Rather than picking a subset of predictors, we can build a bigger model that includes all predictors and use regularization to prevent overfitting - this pulls coefficients toward zero. This is called continuous model expansion.
  4. In causal inference, predictive power isnt the goal. Good prediction and good causal estimation are different goals, and variables good for one are not necessarily good for the other. CV scores reward good prediction, but not necessarily good causal estimation. Generally for causal inference, we dont include post-treatment variables.

Bigger picture: CV tells us about prediction. But we might have other goals, and each goal might need a different strategy for model selection.

11.9 Bibliographic note¶

11.10 Exercises¶