10. Linear regression with multiple predictors¶

When moving from a simple model $y = a + bx + \epsilon$ to a multiple regression model $y = \beta_0 + beta_1 x_1 + beta_2 x_2 + ... + beta_p x_p + \epsilon$, it becomes more complex:

  1. What predictors should be included in the model?
  2. How do we interpret the coefficients $\beta_1, \beta_2, ...,
  3. Interactions between predictors
  4. Construction of new predictors (e.g., polynomial terms, transformations)

10.1 Adding predictors to a model¶

Regression coefficients are more complicated with multiple predictors because they are in part contingent on other variables in the model.

This attempts to say whilst holding other predictors constant, the expected change in the response variable for a one unit change in the predictor of interest.

Starting with a binary predictor¶

Model childrens test scores given an indicator of whether mother graduated from high school or not (0 = no, 1 = yes).

In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import arviz as az
import pymc as pm 
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.robust.scale import mad
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/arviz/__init__.py:50: FutureWarning: 
ArviZ is undergoing a major refactor to improve flexibility and extensibility while maintaining a user-friendly interface.
Some upcoming changes may be backward incompatible.
For details and migration guidance, visit: https://python.arviz.org/en/latest/user_guide/migration_guide.html
  warn(
In [2]:
def fit_and_plot_lm(data, predictors, outcome, add_constant=True, show_plot=True, scatter_kws=None, line_kws=None):
    """
    Fit a linear model using statsmodels, print summary, plot, and show formula.
    Args:
        data: pandas DataFrame
        predictors: list of predictor column names (str)
        outcome: outcome column name (str)
        add_constant: whether to add intercept (default True)
        show_plot: whether to plot (default True)
        scatter_kws: dict, kwargs for scatterplot
        line_kws: dict, kwargs for regression line
    """
    X = data[predictors].copy()
    if add_constant:
        X = sm.add_constant(X, prepend=False)
    y = data[outcome]
    model = sm.OLS(y, X)
    results = model.fit()
    print(results.summary())
    # Print formula
    params = results.params
    formula = f"{outcome} = " + " + ".join([f"{params[name]:.2f}*{name}" for name in predictors])
    if add_constant:
        formula = f"{outcome} = {params['const']:.2f} + " + " + ".join([f"{params[name]:.2f}*{name}" for name in predictors])
    print("Formula:", formula)
    # Print residual standard deviation and its uncertainty
    sigma = np.sqrt(results.mse_resid)
    sigma_se = sigma / np.sqrt(2 * results.df_resid)
    print(f"Residual std dev (σ): {sigma:.2f} ± {sigma_se:.2f}")
    # Print median absolute deviation of residuals
    print("MAD of residuals:", round(mad(results.resid), 2))
    # Plot if only one predictor
    if show_plot and len(predictors) == 1:
        x_name = predictors[0]
        ax = sns.scatterplot(data=data, x=x_name, y=outcome, **(scatter_kws or {}))
        x_vals = np.linspace(data[x_name].min(), data[x_name].max(), 100)
        y_vals = params['const'] + params[x_name] * x_vals if add_constant else params[x_name] * x_vals
        ax.plot(x_vals, y_vals, color='red', **(line_kws or {}))
        ax.set_title('Linear Regression Fit')
        plt.show()
In [3]:
def fit_and_plot_bayes(data, *args,
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=True, show_forest=True,
                       show_posterior=True, show_regression=True,
                       n_regression_lines=100):
    """
    Fit a Bayesian linear regression using PyMC and optionally plot diagnostics.
    Supports single or multiple predictors.
    Args:
        data: pandas DataFrame
        *args: predictor column name(s) followed by the outcome column name (last arg)
        intercept_mu, intercept_sigma: prior mean and std for intercept ~ Normal
        slope_mu, slope_sigma: prior mean and std for slope ~ Normal
        sigma_sigma: prior std for residual noise ~ HalfNormal
        samples: number of posterior draws
        tune: number of tuning steps
        hdi_prob: HDI probability for summaries and plots
        show_trace: plot trace and posterior density per parameter
        show_forest: plot forest (posterior means + HDI)
        show_posterior: plot posterior densities
        show_regression: plot data with posterior regression lines
        n_regression_lines: number of posterior draws to overlay on regression plot
    Returns:
        trace: PyMC InferenceData object
    """
    predictors = list(args[:-1])
    outcome = args[-1]
    y = data[outcome].values

    with pm.Model() as model:
        intercept = pm.Normal("intercept", mu=intercept_mu, sigma=intercept_sigma)
        slopes = []
        mu = intercept
        for pred in predictors:
            s = pm.Normal(f"slope_{pred}", mu=slope_mu, sigma=slope_sigma)
            slopes.append(s)
            mu = mu + s * data[pred].values
        sigma = pm.HalfNormal("sigma", sigma=sigma_sigma)
        likelihood = pm.Normal("y", mu=mu, sigma=sigma, observed=y)
        trace = pm.sample(samples, tune=tune)

    summary = pm.summary(trace, hdi_prob=hdi_prob)
    print(summary)

    # Print regression formula
    posterior = trace.posterior
    intercept_mean = posterior["intercept"].values.flatten().mean()
    formula = f"{outcome} = {intercept_mean:.2f}"
    for pred in predictors:
        slope_mean = posterior[f"slope_{pred}"].values.flatten().mean()
        formula += f" + {slope_mean:.2f}*{pred}"
    print(f"\nRegression formula: {formula}")

    if show_trace:
        az.plot_trace(trace)
        plt.tight_layout()
        plt.show()

    if show_forest:
        az.plot_forest(trace, hdi_prob=hdi_prob)
        plt.show()

    if show_posterior:
        az.plot_posterior(trace, hdi_prob=hdi_prob)
        plt.show()

    if show_regression:
        a_samples = posterior["intercept"].values.flatten()
        slope_samples = {pred: posterior[f"slope_{pred}"].values.flatten() for pred in predictors}
        idx = np.random.choice(len(a_samples), n_regression_lines, replace=False)

        fig, axes = plt.subplots(1, len(predictors), figsize=(6 * len(predictors), 5))
        if len(predictors) == 1:
            axes = [axes]

        for ax, pred in zip(axes, predictors):
            x = data[pred].values
            ax.scatter(x, y, alpha=0.5)
            x_grid = np.linspace(x.min(), x.max(), 100)

            # For each posterior draw, compute the line for this predictor
            # holding other predictors at their mean
            other_contribution = np.zeros(len(a_samples))
            for other_pred in predictors:
                if other_pred != pred:
                    other_contribution += slope_samples[other_pred] * data[other_pred].mean()

            for i in idx:
                y_line = a_samples[i] + other_contribution[i] + slope_samples[pred][i] * x_grid
                ax.plot(x_grid, y_line, alpha=0.05, color="gray")

            # Mean regression line
            mean_other = sum(slope_samples[op].mean() * data[op].mean() for op in predictors if op != pred)
            y_mean = a_samples.mean() + mean_other + slope_samples[pred].mean() * x_grid
            ax.plot(x_grid, y_mean, color="red")
            ax.set_xlabel(pred)
            ax.set_ylabel(outcome)
            ax.set_title(f"{outcome} vs {pred} (others at mean)")

        plt.tight_layout()
        plt.show()

    return trace
In [4]:
kidiq = pd.read_csv('../ros_data/kidiq.csv', skiprows=0)
display(kidiq.head())

fit_and_plot_lm(kidiq, ['mom_hs'], 'kid_score', add_constant=True, show_plot=True, scatter_kws=None, line_kws=None)

fit_and_plot_bayes(kidiq, 'mom_hs', 'kid_score',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=True, show_forest=False,
                       show_posterior=False, show_regression=True,
                       n_regression_lines=100)

fit_and_plot_bayes(kidiq, 'mom_iq', 'kid_score',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=True, show_forest=False,
                       show_posterior=False, show_regression=True,
                       n_regression_lines=100)
kid_score mom_hs mom_iq mom_work mom_age
0 65 1 121.117529 4 27
1 98 1 89.361882 4 25
2 85 1 115.443165 4 27
3 83 1 99.449639 3 25
4 115 1 92.745710 4 27
                            OLS Regression Results                            
==============================================================================
Dep. Variable:              kid_score   R-squared:                       0.056
Model:                            OLS   Adj. R-squared:                  0.054
Method:                 Least Squares   F-statistic:                     25.69
Date:                Tue, 14 Apr 2026   Prob (F-statistic):           5.96e-07
Time:                        07:27:27   Log-Likelihood:                -1911.8
No. Observations:                 434   AIC:                             3828.
Df Residuals:                     432   BIC:                             3836.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
mom_hs        11.7713      2.322      5.069      0.000       7.207      16.336
const         77.5484      2.059     37.670      0.000      73.502      81.595
==============================================================================
Omnibus:                       11.077   Durbin-Watson:                   1.464
Prob(Omnibus):                  0.004   Jarque-Bera (JB):               11.316
Skew:                          -0.373   Prob(JB):                      0.00349
Kurtosis:                       2.738   Cond. No.                         4.11
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: kid_score = 77.55 + 11.77*mom_hs
Residual std dev (σ): 19.85 ± 0.68
MAD of residuals: 19.27
No description has been provided for this image
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_mom_hs, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 1 seconds.
                mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  \
intercept     77.420  2.055    73.247     81.326      0.035    0.025   
slope_mom_hs  11.910  2.304     7.160     16.279      0.039    0.029   
sigma         19.899  0.676    18.629     21.238      0.010    0.009   

              ess_bulk  ess_tail  r_hat  
intercept       3524.0    3957.0    1.0  
slope_mom_hs    3442.0    4055.0    1.0  
sigma           4650.0    4112.0    1.0  

Regression formula: kid_score = 77.42 + 11.91*mom_hs
No description has been provided for this image
No description has been provided for this image
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_mom_iq, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 2 seconds.
                mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  \
intercept     25.503  5.946    13.957     37.328      0.114    0.086   
slope_mom_iq   0.613  0.059     0.494      0.725      0.001    0.001   
sigma         18.325  0.637    17.102     19.572      0.011    0.009   

              ess_bulk  ess_tail  r_hat  
intercept       2721.0    2836.0    1.0  
slope_mom_iq    2712.0    2826.0    1.0  
sigma           3345.0    3162.0    1.0  

Regression formula: kid_score = 25.50 + 0.61*mom_iq
No description has been provided for this image
No description has been provided for this image
Out[4]:
arviz.InferenceData
    • <xarray.Dataset> Size: 208kB
      Dimensions:       (chain: 4, draw: 2000)
      Coordinates:
        * chain         (chain) int64 32B 0 1 2 3
        * draw          (draw) int64 16kB 0 1 2 3 4 5 ... 1995 1996 1997 1998 1999
      Data variables:
          intercept     (chain, draw) float64 64kB 12.19 11.16 15.66 ... 17.75 16.33
          slope_mom_iq  (chain, draw) float64 64kB 0.749 0.7582 ... 0.6936 0.6918
          sigma         (chain, draw) float64 64kB 19.09 18.96 19.01 ... 18.62 18.55
      Attributes:
          created_at:                 2026-04-14T06:27:33.126720+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
          sampling_time:              2.053797960281372
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • intercept
          (chain, draw)
          float64
          12.19 11.16 15.66 ... 17.75 16.33
          array([[12.19086644, 11.1628252 , 15.66250549, ..., 32.26725285,
                  31.6525045 , 24.61766063],
                 [27.46507123, 27.12931563, 25.50404349, ..., 30.17272833,
                  12.05397632, 15.60671643],
                 [28.5937321 , 26.83207938, 27.892273  , ..., 23.32212528,
                  25.78827316, 26.40268733],
                 [30.04292886, 30.13501564, 26.7847597 , ..., 16.42887134,
                  17.75073065, 16.32887373]], shape=(4, 2000))
        • slope_mom_iq
          (chain, draw)
          float64
          0.749 0.7582 ... 0.6936 0.6918
          array([[0.74896467, 0.75820556, 0.71116207, ..., 0.54462895, 0.5424624 ,
                  0.63566031],
                 [0.60501369, 0.58670542, 0.62467417, ..., 0.57467337, 0.75357015,
                  0.70535685],
                 [0.58520452, 0.58765995, 0.59429057, ..., 0.63767166, 0.60979441,
                  0.61243146],
                 [0.56132684, 0.56409925, 0.60406588, ..., 0.69292678, 0.69357608,
                  0.6917883 ]], shape=(4, 2000))
        • sigma
          (chain, draw)
          float64
          19.09 18.96 19.01 ... 18.62 18.55
          array([[19.08899749, 18.95526715, 19.0055044 , ..., 18.70408355,
                  18.68698526, 18.01318038],
                 [17.92931894, 17.89272322, 18.86907288, ..., 17.9251388 ,
                  19.36585531, 18.01649654],
                 [16.80831602, 16.61409823, 16.68232112, ..., 17.74323161,
                  18.47381887, 18.21283113],
                 [17.32870597, 17.35684705, 17.46801465, ..., 17.46061967,
                  18.61528286, 18.55070103]], shape=(4, 2000))
      • created_at :
        2026-04-14T06:27:33.126720+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2
        sampling_time :
        2.053797960281372
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 1MB
      Dimensions:                (chain: 4, draw: 2000)
      Coordinates:
        * chain                  (chain) int64 32B 0 1 2 3
        * draw                   (draw) int64 16kB 0 1 2 3 4 ... 1996 1997 1998 1999
      Data variables: (12/18)
          energy_error           (chain, draw) float64 64kB -0.2203 -0.0232 ... 0.3145
          divergences            (chain, draw) int64 64kB 0 0 0 0 0 0 ... 0 0 0 0 0 0
          diverging              (chain, draw) bool 8kB False False ... False False
          perf_counter_diff      (chain, draw) float64 64kB 0.0002032 ... 0.0001053
          step_size              (chain, draw) float64 64kB 0.2055 0.2055 ... 0.2346
          reached_max_treedepth  (chain, draw) bool 8kB False False ... False False
          ...                     ...
          lp                     (chain, draw) float64 64kB -1.89e+03 ... -1.889e+03
          tree_depth             (chain, draw) int64 64kB 3 3 4 4 5 5 ... 5 2 5 4 4 2
          acceptance_rate        (chain, draw) float64 64kB 0.9769 0.9714 ... 0.4684
          energy                 (chain, draw) float64 64kB 1.891e+03 ... 1.89e+03
          smallest_eigval        (chain, draw) float64 64kB nan nan nan ... nan nan
          step_size_bar          (chain, draw) float64 64kB 0.1565 0.1565 ... 0.1642
      Attributes:
          created_at:                 2026-04-14T06:27:33.197350+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
          sampling_time:              2.053797960281372
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • energy_error
          (chain, draw)
          float64
          -0.2203 -0.0232 ... -0.239 0.3145
          array([[-0.22030655, -0.02319935, -0.03683062, ...,  0.06831862,
                   0.28419242,  0.37126028],
                 [ 0.3924377 , -0.10060744,  0.12138635, ...,  0.18340217,
                  -0.02480499, -0.10676799],
                 [-0.33519234,  0.73831979, -0.78421234, ..., -0.15296437,
                  -0.0223309 ,  0.31426147],
                 [-0.10432362, -0.1216237 ,  0.01734074, ..., -0.05358322,
                  -0.23896928,  0.31449895]], shape=(4, 2000))
        • divergences
          (chain, draw)
          int64
          0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0
          array([[0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0]], shape=(4, 2000))
        • diverging
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 2000))
        • perf_counter_diff
          (chain, draw)
          float64
          0.0002032 0.0002012 ... 0.0001053
          array([[0.00020321, 0.00020125, 0.00040404, ..., 0.00078262, 0.00010054,
                  0.0008    ],
                 [0.000801  , 0.00019958, 0.00057592, ..., 0.00076846, 0.000779  ,
                  0.00059   ],
                 [0.00080383, 0.00058546, 0.00020129, ..., 0.00049671, 0.00053617,
                  0.00039192],
                 [0.00076946, 0.00010325, 0.00077025, ..., 0.00046533, 0.00041675,
                  0.00010533]], shape=(4, 2000))
        • step_size
          (chain, draw)
          float64
          0.2055 0.2055 ... 0.2346 0.2346
          array([[0.20548759, 0.20548759, 0.20548759, ..., 0.20548759, 0.20548759,
                  0.20548759],
                 [0.19330353, 0.19330353, 0.19330353, ..., 0.19330353, 0.19330353,
                  0.19330353],
                 [0.15330645, 0.15330645, 0.15330645, ..., 0.15330645, 0.15330645,
                  0.15330645],
                 [0.23463369, 0.23463369, 0.23463369, ..., 0.23463369, 0.23463369,
                  0.23463369]], shape=(4, 2000))
        • reached_max_treedepth
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 2000))
        • perf_counter_start
          (chain, draw)
          float64
          252.7 252.7 252.7 ... 253.8 253.8
          array([[252.69060829, 252.69085242, 252.69109029, ..., 253.82435896,
                  253.825181  , 253.82532154],
                 [252.72875188, 252.72959392, 252.72983104, ..., 253.95285958,
                  253.95366792, 253.95448575],
                 [252.73352533, 252.73436979, 252.73499337, ..., 253.89047546,
                  253.89101025, 253.89158271],
                 [252.65820383, 252.65901367, 252.65915408, ..., 253.76624717,
                  253.76677179, 253.76722858]], shape=(4, 2000))
        • max_energy_error
          (chain, draw)
          float64
          -0.2437 0.06766 ... -0.3083 1.172
          array([[-0.24372899,  0.06765924,  0.17554921, ...,  1.38458971,
                   0.28419242,  0.69129295],
                 [ 0.59640598, -0.27870111,  0.44778319, ...,  0.50216165,
                  -0.20156449,  0.21406213],
                 [ 1.5073262 ,  6.02544255, -0.95026759, ...,  1.76111136,
                   0.34086078,  0.31426147],
                 [ 1.23414392, -0.13700594, -0.03423739, ..., -0.34491214,
                  -0.30833341,  1.17237385]], shape=(4, 2000))
        • index_in_trajectory
          (chain, draw)
          int64
          5 -4 -8 4 6 -12 ... 2 9 -7 -10 -2
          array([[  5,  -4,  -8, ...,  -6,  -1,  21],
                 [-18,  -2, -14, ...,  -8,  13,  -8],
                 [  3,  -3,  -1, ...,   8, -11,   4],
                 [  6,  -1,   7, ...,  -7, -10,  -2]], shape=(4, 2000))
        • process_time_diff
          (chain, draw)
          float64
          0.000203 0.000202 ... 0.000106
          array([[2.03e-04, 2.02e-04, 4.04e-04, ..., 7.82e-04, 1.00e-04, 8.00e-04],
                 [8.01e-04, 2.00e-04, 5.42e-04, ..., 7.68e-04, 7.78e-04, 5.26e-04],
                 [8.04e-04, 5.85e-04, 2.01e-04, ..., 4.96e-04, 5.36e-04, 3.92e-04],
                 [7.69e-04, 1.03e-04, 7.70e-04, ..., 4.58e-04, 4.18e-04, 1.06e-04]],
                shape=(4, 2000))
        • n_steps
          (chain, draw)
          float64
          7.0 7.0 15.0 15.0 ... 15.0 15.0 3.0
          array([[ 7.,  7., 15., ..., 31.,  3., 31.],
                 [31.,  7., 19., ..., 31., 31., 15.],
                 [31., 23.,  7., ..., 19., 21., 15.],
                 [31.,  3., 31., ..., 15., 15.,  3.]], shape=(4, 2000))
        • largest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]], shape=(4, 2000))
        • lp
          (chain, draw)
          float64
          -1.89e+03 -1.89e+03 ... -1.889e+03
          array([[-1890.0989852 , -1890.21680328, -1888.70956874, ...,
                  -1887.64295263, -1888.15621891, -1888.14653195],
                 [-1887.77956718, -1887.63903947, -1888.05459422, ...,
                  -1887.55648824, -1890.98439838, -1888.33941793],
                 [-1889.98617709, -1891.96726202, -1890.66148097, ...,
                  -1887.19675417, -1886.77291524, -1887.18349426],
                 [-1888.59830209, -1888.24758995, -1887.66938477, ...,
                  -1889.40362   , -1887.85230312, -1888.74086577]], shape=(4, 2000))
        • tree_depth
          (chain, draw)
          int64
          3 3 4 4 5 5 5 4 ... 5 4 5 2 5 4 4 2
          array([[3, 3, 4, ..., 5, 2, 5],
                 [5, 3, 5, ..., 5, 5, 4],
                 [5, 5, 3, ..., 5, 5, 4],
                 [5, 2, 5, ..., 4, 4, 2]], shape=(4, 2000))
        • acceptance_rate
          (chain, draw)
          float64
          0.9769 0.9714 ... 0.9997 0.4684
          array([[0.97694173, 0.9713938 , 0.92107499, ..., 0.54131984, 0.835052  ,
                  0.83851278],
                 [0.75228802, 0.97153536, 0.86559773, ..., 0.78934522, 0.99992935,
                  0.92645465],
                 [0.65964209, 0.1262528 , 1.        , ..., 0.47675497, 0.85833696,
                  0.84626615],
                 [0.63379597, 1.        , 0.99391732, ..., 0.99431749, 0.99971404,
                  0.46835755]], shape=(4, 2000))
        • energy
          (chain, draw)
          float64
          1.891e+03 1.891e+03 ... 1.89e+03
          array([[1890.51005739, 1890.56156839, 1890.61339652, ..., 1890.1369208 ,
                  1888.25748398, 1889.95366065],
                 [1888.63212147, 1888.32349389, 1888.6251555 , ..., 1888.43029405,
                  1892.17223792, 1892.19410442],
                 [1892.03653148, 1893.42593291, 1891.42011713, ..., 1888.26278821,
                  1887.38121279, 1887.46041044],
                 [1891.11939616, 1888.51831643, 1888.90475571, ..., 1889.93332562,
                  1889.44563233, 1890.16910324]], shape=(4, 2000))
        • smallest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]], shape=(4, 2000))
        • step_size_bar
          (chain, draw)
          float64
          0.1565 0.1565 ... 0.1642 0.1642
          array([[0.15654649, 0.15654649, 0.15654649, ..., 0.15654649, 0.15654649,
                  0.15654649],
                 [0.16418295, 0.16418295, 0.16418295, ..., 0.16418295, 0.16418295,
                  0.16418295],
                 [0.1435122 , 0.1435122 , 0.1435122 , ..., 0.1435122 , 0.1435122 ,
                  0.1435122 ],
                 [0.1641509 , 0.1641509 , 0.1641509 , ..., 0.1641509 , 0.1641509 ,
                  0.1641509 ]], shape=(4, 2000))
      • created_at :
        2026-04-14T06:27:33.197350+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2
        sampling_time :
        2.053797960281372
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 7kB
      Dimensions:  (y_dim_0: 434)
      Coordinates:
        * y_dim_0  (y_dim_0) int64 3kB 0 1 2 3 4 5 6 7 ... 427 428 429 430 431 432 433
      Data variables:
          y        (y_dim_0) float64 3kB 65.0 98.0 85.0 83.0 ... 76.0 50.0 88.0 70.0
      Attributes:
          created_at:                 2026-04-14T06:27:33.211818+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
      xarray.Dataset
        • y_dim_0: 434
        • y_dim_0
          (y_dim_0)
          int64
          0 1 2 3 4 5 ... 429 430 431 432 433
          array([  0,   1,   2, ..., 431, 432, 433], shape=(434,))
        • y
          (y_dim_0)
          float64
          65.0 98.0 85.0 ... 50.0 88.0 70.0
          array([ 65.,  98.,  85.,  83., 115.,  98.,  69., 106., 102.,  95.,  91.,
                  58.,  84.,  78., 102., 110., 102.,  99., 105., 101., 102., 115.,
                 100.,  87.,  99.,  96.,  72.,  78.,  77.,  98.,  69., 130., 109.,
                 106.,  92., 100., 107.,  86.,  90., 110., 107., 113.,  65., 102.,
                 103., 111.,  42., 100.,  67.,  92., 100., 110.,  56., 107.,  97.,
                  56.,  95.,  78.,  76.,  86.,  79.,  81.,  79.,  79.,  56.,  52.,
                  63.,  80.,  87.,  88.,  92., 100.,  94., 117., 102., 107.,  99.,
                  73.,  56.,  78.,  94., 110., 109.,  86.,  92.,  91., 123., 102.,
                 105., 114.,  96.,  66., 104., 108.,  84.,  83.,  83.,  92., 109.,
                  95.,  93., 114., 106.,  87.,  65.,  95.,  61.,  73., 112., 113.,
                  49., 105., 122.,  96.,  97.,  94., 117., 136.,  85., 116., 106.,
                  99.,  94.,  89., 119., 112., 104.,  92.,  86.,  69.,  45.,  57.,
                  94., 104.,  89., 144.,  52., 102., 106.,  98.,  97.,  94., 111.,
                 100., 105.,  90.,  98., 121., 106., 121., 102.,  64.,  99.,  81.,
                  69.,  84., 104., 104., 107.,  88.,  67., 103.,  94., 109.,  94.,
                  98., 102., 104., 114.,  87., 102.,  77., 109.,  94.,  93.,  86.,
                  97.,  97.,  88., 103.,  87.,  87.,  90.,  65., 111., 109.,  87.,
                  58.,  87., 113.,  64.,  78.,  97.,  95.,  75.,  91.,  99., 108.,
                  95., 100.,  85.,  97., 108.,  90., 100.,  82.,  94.,  95., 119.,
                  98., 100., 112., 136., 122., 126., 116.,  98.,  94.,  93.,  90.,
                  70., 110., 104.,  83.,  99.,  81., 104., 109., 113.,  95.,  74.,
                  81.,  89.,  93., 102.,  95.,  85.,  97.,  92.,  78., 104., 120.,
                  83., 105.,  68., 104.,  80., 120.,  94.,  81., 101.,  61.,  68.,
                 110.,  89.,  98., 113.,  50.,  57.,  86.,  83., 106., 106., 104.,
                  78.,  99.,  91.,  40.,  42.,  69.,  84.,  58.,  42.,  72.,  80.,
                  58.,  52., 101.,  63.,  73.,  68.,  60.,  69.,  73.,  75.,  20.,
                  56.,  49.,  71.,  46.,  54.,  54.,  44.,  74.,  58.,  46.,  76.,
                  43.,  60.,  58.,  89.,  43.,  94.,  88.,  79.,  87.,  46.,  95.,
                  92.,  42.,  62.,  52., 101.,  97.,  85.,  98.,  94.,  90.,  72.,
                  92.,  75.,  83.,  64., 101.,  82.,  77., 101.,  50.,  90., 103.,
                  96.,  50.,  47.,  73.,  62.,  77.,  64.,  52.,  61.,  86.,  41.,
                  83.,  64.,  83., 116., 100.,  42.,  74.,  76.,  92.,  98.,  96.,
                  67.,  84., 111.,  41.,  68., 107.,  82.,  89.,  83.,  73.,  74.,
                  94.,  58.,  76.,  61.,  38., 100.,  84.,  99.,  86.,  94.,  90.,
                  50., 112.,  58.,  87.,  76.,  68., 110.,  88.,  87.,  54.,  49.,
                  56.,  79.,  82.,  80.,  60., 102.,  87.,  42., 119.,  84.,  86.,
                 113.,  72., 104.,  94.,  78.,  80.,  67., 104.,  96.,  65.,  64.,
                  95.,  56.,  75.,  91., 106.,  76.,  90., 108.,  86.,  85., 104.,
                  87.,  41., 106.,  76., 100.,  89.,  42., 102., 104.,  59.,  93.,
                  94.,  76.,  50.,  88.,  70.])
      • created_at :
        2026-04-14T06:27:33.211818+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2

kid_score = 77.55 + 11.77*mom_hs

Intercept 78 is score for children whose mothers did not graduate from high school. The slope of 11.77 means that children whose mothers graduated from high school are expected to score 11.77 points higher than those whose mothers did not graduate, holding all else constant.

A single continuous predictor¶

kid_score = 25.31 + 0.62*mom_iq

Intercept 25.31 is score for children whose mothers have an IQ of 0. The slope of 0.62 means that for each unit increase in mother's IQ, the child's score is expected to increase by 0.62 points, holding all else constant.

Include both predictors in the model¶

In [5]:
fit_and_plot_bayes(kidiq, 'mom_hs', 'mom_iq', 'kid_score',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=True, show_forest=False,
                       show_posterior=False, show_regression=True,
                       n_regression_lines=100)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_mom_hs, slope_mom_iq, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 3 seconds.
                mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  \
intercept     25.393  5.858    14.039     36.862      0.090    0.074   
slope_mom_hs   5.948  2.242     1.630     10.406      0.030    0.029   
slope_mom_iq   0.567  0.060     0.441      0.679      0.001    0.001   
sigma         18.190  0.615    16.983     19.390      0.008    0.008   

              ess_bulk  ess_tail  r_hat  
intercept       4265.0    4475.0    1.0  
slope_mom_hs    5425.0    4497.0    1.0  
slope_mom_iq    4113.0    4245.0    1.0  
sigma           5425.0    4512.0    1.0  

Regression formula: kid_score = 25.39 + 5.95*mom_hs + 0.57*mom_iq
No description has been provided for this image
No description has been provided for this image
Out[5]:
arviz.InferenceData
    • <xarray.Dataset> Size: 272kB
      Dimensions:       (chain: 4, draw: 2000)
      Coordinates:
        * chain         (chain) int64 32B 0 1 2 3
        * draw          (draw) int64 16kB 0 1 2 3 4 5 ... 1995 1996 1997 1998 1999
      Data variables:
          intercept     (chain, draw) float64 64kB 12.84 19.14 34.0 ... 24.23 16.73
          slope_mom_hs  (chain, draw) float64 64kB 6.562 3.028 9.324 ... 3.807 8.139
          slope_mom_iq  (chain, draw) float64 64kB 0.6747 0.6481 ... 0.5882 0.6416
          sigma         (chain, draw) float64 64kB 16.97 17.8 18.31 ... 18.71 17.9
      Attributes:
          created_at:                 2026-04-14T06:27:36.876245+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
          sampling_time:              2.6884310245513916
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • intercept
          (chain, draw)
          float64
          12.84 19.14 34.0 ... 24.23 16.73
          array([[12.84138487, 19.13816914, 34.00154017, ..., 25.19445788,
                  22.01242024, 23.65264259],
                 [18.88400043, 26.97497511, 15.35722588, ..., 29.45887767,
                  27.11765302, 28.71487173],
                 [26.8925523 , 27.52246881, 23.12574077, ..., 28.27496491,
                  25.64493047, 18.39029086],
                 [21.18446285, 29.62252304, 29.36562109, ..., 33.90464494,
                  24.23148432, 16.72681133]], shape=(4, 2000))
        • slope_mom_hs
          (chain, draw)
          float64
          6.562 3.028 9.324 ... 3.807 8.139
          array([[6.5619548 , 3.02768307, 9.32438692, ..., 3.36829021, 3.51210071,
                  4.10732237],
                 [7.71019915, 2.1590484 , 2.41237629, ..., 6.73498631, 6.567157  ,
                  3.23643717],
                 [3.40744355, 3.63851649, 5.07602058, ..., 7.6756672 , 8.43291534,
                  8.79765789],
                 [6.79527093, 4.61733249, 3.92805883, ..., 8.71136681, 3.80749745,
                  8.13912242]], shape=(4, 2000))
        • slope_mom_iq
          (chain, draw)
          float64
          0.6747 0.6481 ... 0.5882 0.6416
          array([[0.67468214, 0.64807631, 0.45837097, ..., 0.59997315, 0.62315057,
                  0.61105261],
                 [0.61699257, 0.59034981, 0.69464206, ..., 0.51461839, 0.55180044,
                  0.54761468],
                 [0.55280892, 0.55971074, 0.59289719, ..., 0.51297865, 0.54927687,
                  0.62415055],
                 [0.59380851, 0.54465586, 0.53201953, ..., 0.45646983, 0.58819935,
                  0.64160386]], shape=(4, 2000))
        • sigma
          (chain, draw)
          float64
          16.97 17.8 18.31 ... 18.71 17.9
          array([[16.97100188, 17.8035474 , 18.31326829, ..., 17.91817377,
                  18.01033046, 18.06303847],
                 [18.95656438, 18.51946316, 18.70845248, ..., 19.00046596,
                  19.35415191, 18.66107935],
                 [17.91465028, 17.91936753, 18.18113022, ..., 18.6515896 ,
                  17.58265819, 17.91148007],
                 [17.44959785, 18.4703092 , 18.59366751, ..., 17.64873435,
                  18.70875081, 17.90437095]], shape=(4, 2000))
      • created_at :
        2026-04-14T06:27:36.876245+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2
        sampling_time :
        2.6884310245513916
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 1MB
      Dimensions:                (chain: 4, draw: 2000)
      Coordinates:
        * chain                  (chain) int64 32B 0 1 2 3
        * draw                   (draw) int64 16kB 0 1 2 3 4 ... 1996 1997 1998 1999
      Data variables: (12/18)
          energy_error           (chain, draw) float64 64kB -1.487 -0.5732 ... -0.1904
          divergences            (chain, draw) int64 64kB 0 0 0 0 0 0 ... 0 0 0 0 0 0
          diverging              (chain, draw) bool 8kB False False ... False False
          perf_counter_diff      (chain, draw) float64 64kB 0.0008173 ... 0.00085
          step_size              (chain, draw) float64 64kB 0.1181 0.1181 ... 0.1747
          reached_max_treedepth  (chain, draw) bool 8kB False False ... False False
          ...                     ...
          lp                     (chain, draw) float64 64kB -1.893e+03 ... -1.89e+03
          tree_depth             (chain, draw) int64 64kB 5 4 5 5 4 5 ... 1 5 5 5 5 5
          acceptance_rate        (chain, draw) float64 64kB 0.9919 0.9825 ... 0.6016
          energy                 (chain, draw) float64 64kB 1.895e+03 ... 1.892e+03
          smallest_eigval        (chain, draw) float64 64kB nan nan nan ... nan nan
          step_size_bar          (chain, draw) float64 64kB 0.134 0.134 ... 0.1429
      Attributes:
          created_at:                 2026-04-14T06:27:36.913796+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
          sampling_time:              2.6884310245513916
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • energy_error
          (chain, draw)
          float64
          -1.487 -0.5732 ... 0.1181 -0.1904
          array([[-1.48709064, -0.57316528,  0.13727463, ...,  0.13954869,
                  -0.39467292,  0.53806796],
                 [ 0.05602638,  0.25065991, -0.24708205, ..., -0.74046186,
                   0.01531319,  0.06709838],
                 [-0.49822571, -1.35999558, -0.04253631, ...,  0.16476594,
                  -0.51983786,  0.34676626],
                 [-0.08282336,  0.04310474,  0.24722504, ..., -0.03875382,
                   0.11807429, -0.19035591]], shape=(4, 2000))
        • divergences
          (chain, draw)
          int64
          0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0
          array([[0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0]], shape=(4, 2000))
        • diverging
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 2000))
        • perf_counter_diff
          (chain, draw)
          float64
          0.0008173 0.0004104 ... 0.00085
          array([[0.00081729, 0.00041038, 0.00080862, ..., 0.00081687, 0.00081596,
                  0.00080029],
                 [0.00088963, 0.00087813, 0.00093675, ..., 0.00087433, 0.00098592,
                  0.00094663],
                 [0.00054133, 0.00011229, 0.00086554, ..., 0.00085521, 0.00087017,
                  0.00086542],
                 [0.000508  , 0.00104062, 0.00067633, ..., 0.00085508, 0.00085517,
                  0.00084996]], shape=(4, 2000))
        • step_size
          (chain, draw)
          float64
          0.1181 0.1181 ... 0.1747 0.1747
          array([[0.11806407, 0.11806407, 0.11806407, ..., 0.11806407, 0.11806407,
                  0.11806407],
                 [0.11496247, 0.11496247, 0.11496247, ..., 0.11496247, 0.11496247,
                  0.11496247],
                 [0.14567276, 0.14567276, 0.14567276, ..., 0.14567276, 0.14567276,
                  0.14567276],
                 [0.17467977, 0.17467977, 0.17467977, ..., 0.17467977, 0.17467977,
                  0.17467977]], shape=(4, 2000))
        • reached_max_treedepth
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 2000))
        • perf_counter_start
          (chain, draw)
          float64
          256.2 256.2 256.2 ... 257.6 257.6
          array([[256.17750921, 256.17836383, 256.17881567, ..., 257.71748758,
                  257.71833988, 257.71919167],
                 [256.08038171, 256.08133067, 256.08225263, ..., 257.58454196,
                  257.58545775, 257.58652242],
                 [256.15878675, 256.15937862, 256.15953125, ..., 257.52059292,
                  257.52148704, 257.52239608],
                 [256.15305771, 256.15361867, 256.15472721, ..., 257.64665221,
                  257.64754446, 257.64843858]], shape=(4, 2000))
        • max_energy_error
          (chain, draw)
          float64
          -2.177 -0.693 ... 0.1814 1.567
          array([[-2.17676114, -0.69302224,  3.27155529, ...,  1.11582265,
                  -0.43717112,  0.58837927],
                 [ 0.52520142,  0.37237935,  0.43212906, ..., -1.06221136,
                   0.18458685,  0.50992714],
                 [-1.72569596, -1.35999558, -0.09109983, ...,  0.38672486,
                   0.78366829,  0.38498295],
                 [ 0.59076332,  0.56097439,  0.59843012, ..., -0.14744079,
                   0.18135149,  1.56732126]], shape=(4, 2000))
        • index_in_trajectory
          (chain, draw)
          int64
          4 -11 -21 17 12 9 ... 7 20 8 21 -13
          array([[  4, -11, -21, ...,  -8,  -8,  -3],
                 [ 17,  15, -10, ...,   7,   7, -10],
                 [ -4,  -1,   7, ...,  -3, -10,   7],
                 [ -7,   9,  -1, ...,   8,  21, -13]], shape=(4, 2000))
        • process_time_diff
          (chain, draw)
          float64
          0.000817 0.000411 ... 0.000849
          array([[0.000817, 0.000411, 0.000809, ..., 0.000818, 0.000815, 0.0008  ],
                 [0.00089 , 0.000878, 0.000921, ..., 0.000875, 0.000976, 0.000917],
                 [0.000541, 0.000112, 0.000865, ..., 0.000855, 0.00087 , 0.000866],
                 [0.000489, 0.001002, 0.000676, ..., 0.000855, 0.000855, 0.000849]],
                shape=(4, 2000))
        • n_steps
          (chain, draw)
          float64
          31.0 15.0 31.0 ... 31.0 31.0 31.0
          array([[31., 15., 31., ..., 31., 31., 31.],
                 [31., 31., 31., ..., 31., 31., 31.],
                 [19.,  3., 31., ..., 31., 31., 31.],
                 [15., 31., 23., ..., 31., 31., 31.]], shape=(4, 2000))
        • largest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]], shape=(4, 2000))
        • lp
          (chain, draw)
          float64
          -1.893e+03 -1.89e+03 ... -1.89e+03
          array([[-1893.32163592, -1889.62393106, -1890.25303184, ...,
                  -1889.43481043, -1888.82068758, -1889.36175408],
                 [-1889.70430366, -1890.10872516, -1891.01382932, ...,
                  -1889.50226611, -1890.13296433, -1889.65394506],
                 [-1891.36744741, -1888.75669219, -1888.17064108, ...,
                  -1889.73670388, -1889.08685053, -1890.37327516],
                 [-1889.33255236, -1888.99630681, -1889.82735773, ...,
                  -1890.40373464, -1889.19559881, -1889.93480007]], shape=(4, 2000))
        • tree_depth
          (chain, draw)
          int64
          5 4 5 5 4 5 4 5 ... 5 5 1 5 5 5 5 5
          array([[5, 4, 5, ..., 5, 5, 5],
                 [5, 5, 5, ..., 5, 5, 5],
                 [5, 2, 5, ..., 5, 5, 5],
                 [4, 5, 5, ..., 5, 5, 5]], shape=(4, 2000))
        • acceptance_rate
          (chain, draw)
          float64
          0.9919 0.9825 ... 0.9429 0.6016
          array([[0.99189602, 0.98246432, 0.34551825, ..., 0.6884742 , 0.99738227,
                  0.77084458],
                 [0.78321353, 0.85927836, 0.87529411, ..., 0.9936664 , 0.96303328,
                  0.87284319],
                 [0.94899779, 0.97110137, 0.99831462, ..., 0.88028643, 0.84548075,
                  0.85485652],
                 [0.82811122, 0.84925951, 0.8479589 , ..., 0.99644191, 0.94290046,
                  0.60159416]], shape=(4, 2000))
        • energy
          (chain, draw)
          float64
          1.895e+03 1.894e+03 ... 1.892e+03
          array([[1895.16909496, 1893.74169492, 1892.93928722, ..., 1891.63505954,
                  1889.46387476, 1890.08936763],
                 [1890.43595307, 1891.73013547, 1893.2093732 , ..., 1892.96474073,
                  1890.64965027, 1891.42702533],
                 [1893.5032892 , 1890.15197403, 1889.12587945, ..., 1891.46213148,
                  1891.02441577, 1892.10580085],
                 [1889.9279754 , 1890.94196745, 1892.86589993, ..., 1893.24129339,
                  1892.7823979 , 1891.54402775]], shape=(4, 2000))
        • smallest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]], shape=(4, 2000))
        • step_size_bar
          (chain, draw)
          float64
          0.134 0.134 0.134 ... 0.1429 0.1429
          array([[0.13402282, 0.13402282, 0.13402282, ..., 0.13402282, 0.13402282,
                  0.13402282],
                 [0.1431324 , 0.1431324 , 0.1431324 , ..., 0.1431324 , 0.1431324 ,
                  0.1431324 ],
                 [0.16615652, 0.16615652, 0.16615652, ..., 0.16615652, 0.16615652,
                  0.16615652],
                 [0.14290668, 0.14290668, 0.14290668, ..., 0.14290668, 0.14290668,
                  0.14290668]], shape=(4, 2000))
      • created_at :
        2026-04-14T06:27:36.913796+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2
        sampling_time :
        2.6884310245513916
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 7kB
      Dimensions:  (y_dim_0: 434)
      Coordinates:
        * y_dim_0  (y_dim_0) int64 3kB 0 1 2 3 4 5 6 7 ... 427 428 429 430 431 432 433
      Data variables:
          y        (y_dim_0) float64 3kB 65.0 98.0 85.0 83.0 ... 76.0 50.0 88.0 70.0
      Attributes:
          created_at:                 2026-04-14T06:27:36.935458+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
      xarray.Dataset
        • y_dim_0: 434
        • y_dim_0
          (y_dim_0)
          int64
          0 1 2 3 4 5 ... 429 430 431 432 433
          array([  0,   1,   2, ..., 431, 432, 433], shape=(434,))
        • y
          (y_dim_0)
          float64
          65.0 98.0 85.0 ... 50.0 88.0 70.0
          array([ 65.,  98.,  85.,  83., 115.,  98.,  69., 106., 102.,  95.,  91.,
                  58.,  84.,  78., 102., 110., 102.,  99., 105., 101., 102., 115.,
                 100.,  87.,  99.,  96.,  72.,  78.,  77.,  98.,  69., 130., 109.,
                 106.,  92., 100., 107.,  86.,  90., 110., 107., 113.,  65., 102.,
                 103., 111.,  42., 100.,  67.,  92., 100., 110.,  56., 107.,  97.,
                  56.,  95.,  78.,  76.,  86.,  79.,  81.,  79.,  79.,  56.,  52.,
                  63.,  80.,  87.,  88.,  92., 100.,  94., 117., 102., 107.,  99.,
                  73.,  56.,  78.,  94., 110., 109.,  86.,  92.,  91., 123., 102.,
                 105., 114.,  96.,  66., 104., 108.,  84.,  83.,  83.,  92., 109.,
                  95.,  93., 114., 106.,  87.,  65.,  95.,  61.,  73., 112., 113.,
                  49., 105., 122.,  96.,  97.,  94., 117., 136.,  85., 116., 106.,
                  99.,  94.,  89., 119., 112., 104.,  92.,  86.,  69.,  45.,  57.,
                  94., 104.,  89., 144.,  52., 102., 106.,  98.,  97.,  94., 111.,
                 100., 105.,  90.,  98., 121., 106., 121., 102.,  64.,  99.,  81.,
                  69.,  84., 104., 104., 107.,  88.,  67., 103.,  94., 109.,  94.,
                  98., 102., 104., 114.,  87., 102.,  77., 109.,  94.,  93.,  86.,
                  97.,  97.,  88., 103.,  87.,  87.,  90.,  65., 111., 109.,  87.,
                  58.,  87., 113.,  64.,  78.,  97.,  95.,  75.,  91.,  99., 108.,
                  95., 100.,  85.,  97., 108.,  90., 100.,  82.,  94.,  95., 119.,
                  98., 100., 112., 136., 122., 126., 116.,  98.,  94.,  93.,  90.,
                  70., 110., 104.,  83.,  99.,  81., 104., 109., 113.,  95.,  74.,
                  81.,  89.,  93., 102.,  95.,  85.,  97.,  92.,  78., 104., 120.,
                  83., 105.,  68., 104.,  80., 120.,  94.,  81., 101.,  61.,  68.,
                 110.,  89.,  98., 113.,  50.,  57.,  86.,  83., 106., 106., 104.,
                  78.,  99.,  91.,  40.,  42.,  69.,  84.,  58.,  42.,  72.,  80.,
                  58.,  52., 101.,  63.,  73.,  68.,  60.,  69.,  73.,  75.,  20.,
                  56.,  49.,  71.,  46.,  54.,  54.,  44.,  74.,  58.,  46.,  76.,
                  43.,  60.,  58.,  89.,  43.,  94.,  88.,  79.,  87.,  46.,  95.,
                  92.,  42.,  62.,  52., 101.,  97.,  85.,  98.,  94.,  90.,  72.,
                  92.,  75.,  83.,  64., 101.,  82.,  77., 101.,  50.,  90., 103.,
                  96.,  50.,  47.,  73.,  62.,  77.,  64.,  52.,  61.,  86.,  41.,
                  83.,  64.,  83., 116., 100.,  42.,  74.,  76.,  92.,  98.,  96.,
                  67.,  84., 111.,  41.,  68., 107.,  82.,  89.,  83.,  73.,  74.,
                  94.,  58.,  76.,  61.,  38., 100.,  84.,  99.,  86.,  94.,  90.,
                  50., 112.,  58.,  87.,  76.,  68., 110.,  88.,  87.,  54.,  49.,
                  56.,  79.,  82.,  80.,  60., 102.,  87.,  42., 119.,  84.,  86.,
                 113.,  72., 104.,  94.,  78.,  80.,  67., 104.,  96.,  65.,  64.,
                  95.,  56.,  75.,  91., 106.,  76.,  90., 108.,  86.,  85., 104.,
                  87.,  41., 106.,  76., 100.,  89.,  42., 102., 104.,  59.,  93.,
                  94.,  76.,  50.,  88.,  70.])
      • created_at :
        2026-04-14T06:27:36.935458+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2

In [6]:
# plot mother iq vs kid score colored by whether mother graduated high school and with regression lines for each group
sns.lmplot(data=kidiq, x='mom_iq', y='kid_score', hue='mom_hs', ci=None)
plt.title('Kid Score vs Mom IQ by High School Graduation')
plt.show()

# lm for mothers who graduated high school - filter data to only include those with mom_hs == 1
kidiq_graduated = kidiq[kidiq['mom_hs'] == 1]
kidiq_not_graduated = kidiq[kidiq['mom_hs'] == 0]

fit_and_plot_lm(kidiq_graduated, ['mom_iq'], 'kid_score', add_constant=True, show_plot=True, scatter_kws=None, line_kws=None)

fit_and_plot_lm(kidiq_not_graduated, ['mom_iq'], 'kid_score', add_constant=True, show_plot=True, scatter_kws=None, line_kws=None)
No description has been provided for this image
                            OLS Regression Results                            
==============================================================================
Dep. Variable:              kid_score   R-squared:                       0.143
Model:                            OLS   Adj. R-squared:                  0.140
Method:                 Least Squares   F-statistic:                     56.42
Date:                Tue, 14 Apr 2026   Prob (F-statistic):           5.24e-13
Time:                        07:27:37   Log-Likelihood:                -1462.0
No. Observations:                 341   AIC:                             2928.
Df Residuals:                     339   BIC:                             2936.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
mom_iq         0.4846      0.065      7.511      0.000       0.358       0.612
const         39.7862      6.663      5.971      0.000      26.679      52.893
==============================================================================
Omnibus:                        5.765   Durbin-Watson:                   1.612
Prob(Omnibus):                  0.056   Jarque-Bera (JB):                5.908
Skew:                          -0.314   Prob(JB):                       0.0521
Kurtosis:                       2.851   Cond. No.                         720.
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: kid_score = 39.79 + 0.48*mom_iq
Residual std dev (σ): 17.66 ± 0.68
MAD of residuals: 15.55
No description has been provided for this image
                            OLS Regression Results                            
==============================================================================
Dep. Variable:              kid_score   R-squared:                       0.294
Model:                            OLS   Adj. R-squared:                  0.286
Method:                 Least Squares   F-statistic:                     37.87
Date:                Tue, 14 Apr 2026   Prob (F-statistic):           2.00e-08
Time:                        07:27:37   Log-Likelihood:                -405.14
No. Observations:                  93   AIC:                             814.3
Df Residuals:                      91   BIC:                             819.3
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
mom_iq         0.9689      0.157      6.154      0.000       0.656       1.282
const        -11.4820     14.601     -0.786      0.434     -40.485      17.521
==============================================================================
Omnibus:                        2.584   Durbin-Watson:                   1.924
Prob(Omnibus):                  0.275   Jarque-Bera (JB):                2.360
Skew:                          -0.389   Prob(JB):                        0.307
Kurtosis:                       2.950   Cond. No.                         685.
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: kid_score = -11.48 + 0.97*mom_iq
Residual std dev (σ): 19.07 ± 1.41
MAD of residuals: 18.01
No description has been provided for this image
Understanding the fitted model¶

kid_score = 25.43 + 5.96mom_hs + 0.57mom_iq + error

Intercept = when mom didnt complete high school and has an IQ of 0, the expected score is 25.43. Not meaningful as no one has an IQ of 0.

Coefficient for mom_hs = 5.96 means comparing children whose mothers have the same IQ, the children whose mothers graduated from high school are expected to score 5.96 points higher than those whose mothers did not graduate.

Coefficient for mom_iq = 0.57 means comparing children whose mothers have the same high school graduation status, for each unit increase in mother's IQ, the child's score is expected to increase by 0.57 points.

We can also look at the separate regressions for moms who did and did not graduate from high school to see how the relationship between mom_iq and kid_score differs by mom_hs status.

mom_hs = 0: kid_score = -11.48 + 0.97mom_iq mom_hs = 1: kid_score = 39.79 + 0.48mom_iq

10.2 Interpreting regression coefficients¶

It’s not always possible to change one predictor while holding all others constant¶

Regression slopes are comparisons of individuals who differ in one predictor holding all other predictors constant. We can also manipulate predictors to change some or hold others constant - but not always.

Counterfactual and predictive interpretations¶

Two ways to interpret multiple linear regression:

  1. Predictive interpretation: how outcome variable differs on average when comparing two groups that differ by one unit in predictor while being identical in all other predictors.
  2. Counterfactual interpretation: changes within individuals rather than between individuals. Here coefficient is expected change in outcome for a one unit change in predictor within an individual, holding all else constant. For example, changing maternal IQ by 1 point would increase child score by 0.57 points, holding all else constant. This arises in causal inference.

The most careful interpretation is in terms of comparisons. “When comparing two children whose mothers have the same level of education, the child whose mother is x IQ points higher is predicted to have a test score that is 6x higher, on average.” Or,“Comparing two items i and j that differ by anamount x on predictor k but are identical on all other predictors, the predicted difference yi−yj is βk x, on average.”

10.3 Interactions¶

Slope of regression of childs test score on mothers IQ was previously forced to be equal across subgroups defined by mothers high school graduation status. But data suggests slopes differ by subgroup. We can allow for this by including an interaction term between mom_hs and mom_iq in the model. This is a product of the two predictors, which allows the slope of one predictor to differ by levels of the other predictor.

In [7]:
# Add interaction term between mom_hs and mom_iq to allow slopes to differ by subgroup
kidiq['interaction_mom_hs_mom_iq'] = kidiq['mom_hs'] * kidiq['mom_iq']
fit_and_plot_bayes(kidiq, 'mom_hs', 'mom_iq', 'interaction_mom_hs_mom_iq', 'kid_score',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=True, show_forest=False,
                       show_posterior=False, show_regression=True,
                       n_regression_lines=100)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_mom_hs, slope_mom_iq, slope_interaction_mom_hs_mom_iq, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 6 seconds.
                                   mean      sd  hdi_2.5%  hdi_97.5%  \
intercept                        -7.200  13.057   -33.852     17.108   
slope_mom_hs                     46.206  14.645    18.697     76.067   
slope_mom_iq                      0.923   0.141     0.664      1.213   
slope_interaction_mom_hs_mom_iq  -0.431   0.155    -0.730     -0.124   
sigma                            18.034   0.622    16.759     19.201   

                                 mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat  
intercept                            0.320    0.204    1670.0    2402.0    1.0  
slope_mom_hs                         0.356    0.228    1691.0    2173.0    1.0  
slope_mom_iq                         0.003    0.002    1692.0    2331.0    1.0  
slope_interaction_mom_hs_mom_iq      0.004    0.002    1691.0    2223.0    1.0  
sigma                                0.011    0.010    3391.0    3235.0    1.0  

Regression formula: kid_score = -7.20 + 46.21*mom_hs + 0.92*mom_iq + -0.43*interaction_mom_hs_mom_iq
No description has been provided for this image
No description has been provided for this image
Out[7]:
arviz.InferenceData
    • <xarray.Dataset> Size: 336kB
      Dimensions:                          (chain: 4, draw: 2000)
      Coordinates:
        * chain                            (chain) int64 32B 0 1 2 3
        * draw                             (draw) int64 16kB 0 1 2 ... 1997 1998 1999
      Data variables:
          intercept                        (chain, draw) float64 64kB -28.01 ... -6...
          slope_mom_hs                     (chain, draw) float64 64kB 66.23 ... 33.36
          slope_mom_iq                     (chain, draw) float64 64kB 1.16 ... 0.8988
          slope_interaction_mom_hs_mom_iq  (chain, draw) float64 64kB -0.6673 ... -...
          sigma                            (chain, draw) float64 64kB 17.48 ... 18.34
      Attributes:
          created_at:                 2026-04-14T06:27:44.579354+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
          sampling_time:              6.260993003845215
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • intercept
          (chain, draw)
          float64
          -28.01 -33.76 ... -1.585 -6.109
          array([[-28.00593401, -33.76112636, -32.35074354, ..., -28.08350644,
                  -29.50090373, -34.42396576],
                 [-20.60025847, -22.66895804, -22.62604555, ..., -12.01971193,
                  -11.39695048,  -7.17804917],
                 [-20.68920349, -21.11911682,  -0.11679469, ..., -29.08114619,
                  -12.82624724,   6.85569922],
                 [ -5.50619552,  -0.17826862,  -9.65059075, ..., -28.1112805 ,
                   -1.584939  ,  -6.10924139]], shape=(4, 2000))
        • slope_mom_hs
          (chain, draw)
          float64
          66.23 68.75 69.76 ... 45.45 33.36
          array([[66.22745641, 68.74598074, 69.76285915, ..., 71.42365506,
                  61.55304365, 76.16961656],
                 [64.8221493 , 63.45241184, 64.10575322, ..., 48.29688758,
                  59.88629385, 50.07724235],
                 [78.28373021, 72.53796796, 48.62896564, ..., 64.83711981,
                  54.80457305, 34.44320524],
                 [37.89098512, 39.83561351, 37.33350756, ..., 68.93450002,
                  45.44509301, 33.36115997]], shape=(4, 2000))
        • slope_mom_iq
          (chain, draw)
          float64
          1.16 1.203 1.159 ... 0.849 0.8988
          array([[1.16041098, 1.20335424, 1.15850033, ..., 1.14118807, 1.17167373,
                  1.18940597],
                 [1.11631622, 1.10463028, 1.10285524, ..., 0.96726303, 0.97006227,
                  0.93777336],
                 [1.07436588, 1.08354615, 0.84516946, ..., 1.1727998 , 0.96814886,
                  0.77956646],
                 [0.90210665, 0.87005022, 0.95697293, ..., 1.15034347, 0.8489833 ,
                  0.89877329]], shape=(4, 2000))
        • slope_interaction_mom_hs_mom_iq
          (chain, draw)
          float64
          -0.6673 -0.6541 ... -0.3997 -0.293
          array([[-0.66729811, -0.65410742, -0.64225105, ..., -0.68658452,
                  -0.61928643, -0.72117267],
                 [-0.65812772, -0.65302218, -0.64829626, ..., -0.46504232,
                  -0.56920815, -0.47488186],
                 [-0.7494437 , -0.72387   , -0.44836539, ..., -0.64369007,
                  -0.50414788, -0.32316438],
                 [-0.35202719, -0.38443812, -0.36551987, ..., -0.67525309,
                  -0.39968798, -0.29298636]], shape=(4, 2000))
        • sigma
          (chain, draw)
          float64
          17.48 18.05 18.14 ... 17.8 18.34
          array([[17.48126266, 18.04805696, 18.13783702, ..., 17.73384894,
                  18.52680561, 17.80859772],
                 [17.83400267, 17.85607161, 17.85996332, ..., 17.30360759,
                  17.8487205 , 18.17265553],
                 [19.65644467, 19.40717166, 16.77351999, ..., 18.41144741,
                  17.2258794 , 18.4566737 ],
                 [17.6986216 , 18.5944973 , 18.183012  , ..., 18.41787832,
                  17.80461437, 18.33758434]], shape=(4, 2000))
      • created_at :
        2026-04-14T06:27:44.579354+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2
        sampling_time :
        6.260993003845215
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 1MB
      Dimensions:                (chain: 4, draw: 2000)
      Coordinates:
        * chain                  (chain) int64 32B 0 1 2 3
        * draw                   (draw) int64 16kB 0 1 2 3 4 ... 1996 1997 1998 1999
      Data variables: (12/18)
          energy_error           (chain, draw) float64 64kB -0.04069 ... -0.01438
          divergences            (chain, draw) int64 64kB 0 0 0 0 0 0 ... 0 0 0 0 0 0
          diverging              (chain, draw) bool 8kB False False ... False False
          perf_counter_diff      (chain, draw) float64 64kB 0.0002343 ... 0.001953
          step_size              (chain, draw) float64 64kB 0.0486 0.0486 ... 0.05203
          reached_max_treedepth  (chain, draw) bool 8kB False False ... False False
          ...                     ...
          lp                     (chain, draw) float64 64kB -1.891e+03 ... -1.891e+03
          tree_depth             (chain, draw) int64 64kB 3 6 4 5 7 5 ... 7 2 5 6 7 6
          acceptance_rate        (chain, draw) float64 64kB 0.9825 0.1145 ... 0.9998
          energy                 (chain, draw) float64 64kB 1.896e+03 ... 1.892e+03
          smallest_eigval        (chain, draw) float64 64kB nan nan nan ... nan nan
          step_size_bar          (chain, draw) float64 64kB 0.04712 ... 0.04721
      Attributes:
          created_at:                 2026-04-14T06:27:44.595295+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
          sampling_time:              6.260993003845215
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • energy_error
          (chain, draw)
          float64
          -0.04069 1.357 ... 0.02596 -0.01438
          array([[-0.04069156,  1.35689869, -0.91699293, ...,  0.03625224,
                   0.07907489, -0.18745375],
                 [ 0.7569691 , -0.07328483, -0.85871103, ..., -0.18460336,
                  -0.8169415 ,  0.28267363],
                 [-0.27869952, -0.10667723, -0.37113685, ..., -0.01406687,
                  -0.15394453,  0.56326785],
                 [-0.49386785, -0.05987973,  0.1230083 , ..., -0.06284739,
                   0.02595659, -0.0143805 ]], shape=(4, 2000))
        • divergences
          (chain, draw)
          int64
          0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0
          array([[0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0]], shape=(4, 2000))
        • diverging
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 2000))
        • perf_counter_diff
          (chain, draw)
          float64
          0.0002343 0.00182 ... 0.001953
          array([[0.00023429, 0.00181954, 0.00045792, ..., 0.00307796, 0.00187142,
                  0.00184933],
                 [0.00024225, 0.00046829, 0.00011708, ..., 0.00091383, 0.00362788,
                  0.0037465 ],
                 [0.00193779, 0.00049104, 0.00394925, ..., 0.00046742, 0.00371475,
                  0.00374483],
                 [0.00203521, 0.00390021, 0.00148713, ..., 0.00195517, 0.00373483,
                  0.00195263]], shape=(4, 2000))
        • step_size
          (chain, draw)
          float64
          0.0486 0.0486 ... 0.05203 0.05203
          array([[0.04859798, 0.04859798, 0.04859798, ..., 0.04859798, 0.04859798,
                  0.04859798],
                 [0.04444516, 0.04444516, 0.04444516, ..., 0.04444516, 0.04444516,
                  0.04444516],
                 [0.05211841, 0.05211841, 0.05211841, ..., 0.05211841, 0.05211841,
                  0.05211841],
                 [0.05203261, 0.05203261, 0.05203261, ..., 0.05203261, 0.05203261,
                  0.05203261]], shape=(4, 2000))
        • reached_max_treedepth
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 2000))
        • perf_counter_start
          (chain, draw)
          float64
          261.2 261.2 261.2 ... 265.2 265.2
          array([[261.21567479, 261.21595838, 261.21782017, ..., 264.79021733,
                  264.793357  , 264.79527367],
                 [261.34969229, 261.34998754, 261.35049925, ..., 265.41439642,
                  265.41534996, 265.41901967],
                 [261.34289683, 261.34491063, 261.34544825, ..., 264.67591546,
                  264.6764245 , 264.68018067],
                 [261.33668221, 261.33879192, 261.34277304, ..., 265.18022238,
                  265.18224367, 265.18603204]], shape=(4, 2000))
        • max_energy_error
          (chain, draw)
          float64
          -0.1498 6.379 ... 0.04832 -0.02619
          array([[-0.14976304,  6.37902676, -1.02287656, ...,  0.14821857,
                   0.11809895,  1.43272186],
                 [ 0.83439666, -1.13521691, -0.89216116, ..., -1.04125153,
                  -0.88858308,  0.90093997],
                 [ 0.97381404, -0.19490403, -0.42538568, ..., -0.15752417,
                  -0.16569414,  1.01914517],
                 [-0.62893592,  0.3421838 ,  2.34371301, ..., -0.06715713,
                   0.04831563, -0.0261943 ]], shape=(4, 2000))
        • index_in_trajectory
          (chain, draw)
          int64
          6 8 -4 9 -31 12 ... 3 -10 -7 65 -17
          array([[  6,   8,  -4, ...,  41, -42, -12],
                 [  3,  -2,   1, ...,  12,  19,  18],
                 [  7,  -7,  58, ..., -11, -24,  43],
                 [ 15, -20,  17, ...,  -7,  65, -17]], shape=(4, 2000))
        • process_time_diff
          (chain, draw)
          float64
          0.000235 0.00182 ... 0.001946
          array([[0.000235, 0.00182 , 0.000458, ..., 0.003062, 0.001872, 0.001851],
                 [0.000243, 0.000469, 0.000117, ..., 0.000914, 0.003628, 0.003726],
                 [0.001927, 0.000491, 0.003906, ..., 0.000467, 0.003714, 0.003745],
                 [0.001988, 0.003885, 0.001477, ..., 0.001925, 0.003736, 0.001946]],
                shape=(4, 2000))
        • n_steps
          (chain, draw)
          float64
          7.0 63.0 15.0 ... 63.0 127.0 63.0
          array([[  7.,  63.,  15., ..., 103.,  63.,  63.],
                 [  7.,  15.,   3., ...,  31., 127., 127.],
                 [ 63.,  15., 127., ...,  15., 127., 127.],
                 [ 63., 127.,  47., ...,  63., 127.,  63.]], shape=(4, 2000))
        • largest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]], shape=(4, 2000))
        • lp
          (chain, draw)
          float64
          -1.891e+03 ... -1.891e+03
          array([[-1890.80842263, -1892.88543595, -1892.41450329, ...,
                  -1890.41901576, -1891.36274846, -1891.84750693],
                 [-1894.18099695, -1892.90904483, -1890.84448854, ...,
                  -1890.97574651, -1889.77360502, -1889.56978412],
                 [-1896.72664672, -1894.48256238, -1892.05529209, ...,
                  -1890.95470459, -1889.85580594, -1890.74102783],
                 [-1889.43242856, -1889.973774  , -1890.67753457, ...,
                  -1890.35456855, -1889.39219293, -1890.56903123]], shape=(4, 2000))
        • tree_depth
          (chain, draw)
          int64
          3 6 4 5 7 5 5 6 ... 5 6 7 2 5 6 7 6
          array([[3, 6, 4, ..., 7, 6, 6],
                 [3, 4, 2, ..., 5, 7, 7],
                 [6, 4, 7, ..., 4, 7, 7],
                 [6, 7, 6, ..., 6, 7, 6]], shape=(4, 2000))
        • acceptance_rate
          (chain, draw)
          float64
          0.9825 0.1145 ... 0.9777 0.9998
          array([[0.98246661, 0.11453095, 0.99158414, ..., 0.9348102 , 0.95999198,
                  0.62278793],
                 [0.69375989, 0.97965093, 0.98282719, ..., 0.99964678, 0.98884197,
                  0.67142129],
                 [0.87536342, 0.99712431, 0.98909289, ..., 1.        , 0.99672427,
                  0.63732335],
                 [0.9866531 , 0.89299857, 0.42608627, ..., 0.99712177, 0.97770754,
                  0.9997613 ]], shape=(4, 2000))
        • energy
          (chain, draw)
          float64
          1.896e+03 1.899e+03 ... 1.892e+03
          array([[1896.26174181, 1898.90731691, 1894.74523062, ..., 1892.33992702,
                  1892.43457452, 1896.42057954],
                 [1895.28425032, 1895.26613018, 1892.78438806, ..., 1893.62467371,
                  1892.29332966, 1892.63865863],
                 [1897.91969727, 1897.51670024, 1894.357412  , ..., 1891.01206072,
                  1893.11842454, 1892.45911949],
                 [1891.43398743, 1892.3065939 , 1894.27858607, ..., 1893.47176757,
                  1892.5786341 , 1891.51672576]], shape=(4, 2000))
        • smallest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]], shape=(4, 2000))
        • step_size_bar
          (chain, draw)
          float64
          0.04712 0.04712 ... 0.04721 0.04721
          array([[0.04712013, 0.04712013, 0.04712013, ..., 0.04712013, 0.04712013,
                  0.04712013],
                 [0.04367153, 0.04367153, 0.04367153, ..., 0.04367153, 0.04367153,
                  0.04367153],
                 [0.05223002, 0.05223002, 0.05223002, ..., 0.05223002, 0.05223002,
                  0.05223002],
                 [0.04721074, 0.04721074, 0.04721074, ..., 0.04721074, 0.04721074,
                  0.04721074]], shape=(4, 2000))
      • created_at :
        2026-04-14T06:27:44.595295+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2
        sampling_time :
        6.260993003845215
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 7kB
      Dimensions:  (y_dim_0: 434)
      Coordinates:
        * y_dim_0  (y_dim_0) int64 3kB 0 1 2 3 4 5 6 7 ... 427 428 429 430 431 432 433
      Data variables:
          y        (y_dim_0) float64 3kB 65.0 98.0 85.0 83.0 ... 76.0 50.0 88.0 70.0
      Attributes:
          created_at:                 2026-04-14T06:27:44.599156+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
      xarray.Dataset
        • y_dim_0: 434
        • y_dim_0
          (y_dim_0)
          int64
          0 1 2 3 4 5 ... 429 430 431 432 433
          array([  0,   1,   2, ..., 431, 432, 433], shape=(434,))
        • y
          (y_dim_0)
          float64
          65.0 98.0 85.0 ... 50.0 88.0 70.0
          array([ 65.,  98.,  85.,  83., 115.,  98.,  69., 106., 102.,  95.,  91.,
                  58.,  84.,  78., 102., 110., 102.,  99., 105., 101., 102., 115.,
                 100.,  87.,  99.,  96.,  72.,  78.,  77.,  98.,  69., 130., 109.,
                 106.,  92., 100., 107.,  86.,  90., 110., 107., 113.,  65., 102.,
                 103., 111.,  42., 100.,  67.,  92., 100., 110.,  56., 107.,  97.,
                  56.,  95.,  78.,  76.,  86.,  79.,  81.,  79.,  79.,  56.,  52.,
                  63.,  80.,  87.,  88.,  92., 100.,  94., 117., 102., 107.,  99.,
                  73.,  56.,  78.,  94., 110., 109.,  86.,  92.,  91., 123., 102.,
                 105., 114.,  96.,  66., 104., 108.,  84.,  83.,  83.,  92., 109.,
                  95.,  93., 114., 106.,  87.,  65.,  95.,  61.,  73., 112., 113.,
                  49., 105., 122.,  96.,  97.,  94., 117., 136.,  85., 116., 106.,
                  99.,  94.,  89., 119., 112., 104.,  92.,  86.,  69.,  45.,  57.,
                  94., 104.,  89., 144.,  52., 102., 106.,  98.,  97.,  94., 111.,
                 100., 105.,  90.,  98., 121., 106., 121., 102.,  64.,  99.,  81.,
                  69.,  84., 104., 104., 107.,  88.,  67., 103.,  94., 109.,  94.,
                  98., 102., 104., 114.,  87., 102.,  77., 109.,  94.,  93.,  86.,
                  97.,  97.,  88., 103.,  87.,  87.,  90.,  65., 111., 109.,  87.,
                  58.,  87., 113.,  64.,  78.,  97.,  95.,  75.,  91.,  99., 108.,
                  95., 100.,  85.,  97., 108.,  90., 100.,  82.,  94.,  95., 119.,
                  98., 100., 112., 136., 122., 126., 116.,  98.,  94.,  93.,  90.,
                  70., 110., 104.,  83.,  99.,  81., 104., 109., 113.,  95.,  74.,
                  81.,  89.,  93., 102.,  95.,  85.,  97.,  92.,  78., 104., 120.,
                  83., 105.,  68., 104.,  80., 120.,  94.,  81., 101.,  61.,  68.,
                 110.,  89.,  98., 113.,  50.,  57.,  86.,  83., 106., 106., 104.,
                  78.,  99.,  91.,  40.,  42.,  69.,  84.,  58.,  42.,  72.,  80.,
                  58.,  52., 101.,  63.,  73.,  68.,  60.,  69.,  73.,  75.,  20.,
                  56.,  49.,  71.,  46.,  54.,  54.,  44.,  74.,  58.,  46.,  76.,
                  43.,  60.,  58.,  89.,  43.,  94.,  88.,  79.,  87.,  46.,  95.,
                  92.,  42.,  62.,  52., 101.,  97.,  85.,  98.,  94.,  90.,  72.,
                  92.,  75.,  83.,  64., 101.,  82.,  77., 101.,  50.,  90., 103.,
                  96.,  50.,  47.,  73.,  62.,  77.,  64.,  52.,  61.,  86.,  41.,
                  83.,  64.,  83., 116., 100.,  42.,  74.,  76.,  92.,  98.,  96.,
                  67.,  84., 111.,  41.,  68., 107.,  82.,  89.,  83.,  73.,  74.,
                  94.,  58.,  76.,  61.,  38., 100.,  84.,  99.,  86.,  94.,  90.,
                  50., 112.,  58.,  87.,  76.,  68., 110.,  88.,  87.,  54.,  49.,
                  56.,  79.,  82.,  80.,  60., 102.,  87.,  42., 119.,  84.,  86.,
                 113.,  72., 104.,  94.,  78.,  80.,  67., 104.,  96.,  65.,  64.,
                  95.,  56.,  75.,  91., 106.,  76.,  90., 108.,  86.,  85., 104.,
                  87.,  41., 106.,  76., 100.,  89.,  42., 102., 104.,  59.,  93.,
                  94.,  76.,  50.,  88.,  70.])
      • created_at :
        2026-04-14T06:27:44.599156+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2

  1. The intercept = predicted test score for children whose mothers did not complete high school and have an IQ of 0. Not meaningful as no one has an IQ of 0.
  2. The coefficient for mom_hs = difference between test scores for children whose mothers did not complete high school and IQ of 0, and children whose mothers did complete high school and have an IQ of 0. Impluasible as no one has an IQ of 0.
  3. The coefficient for mom_iq = comparison of mean test scores across children whose mothers did not complete high school but differ in IQ by 1 point.
  4. The coefficient for the interaction term = difference in slope for mom_iq between children whose mothers did not complete high school and children whose mothers did complete high school. This is the difference in the expected change in test score for a one unit increase in mom_iq between the two groups defined by mom_hs.

We can also look at separate regressions for moms who did and did not graduate from high school to see how the relationship between mom_iq and kid_score differs by mom_hs status.

kid_score = -7.46 + 46.37mom_hs + 0.93mom_iq + -0.43*interaction_mom_hs_mom_iq

mom_hs = 0: kid_score = -7.46 + 0.93mom_iq mom_hs = 1: kid_score = -7.46 + 46.37 + (0.93 - 0.43)mom_iq = 38.91 + 0.50*mom_iq

Slopes are 1.1 for children whose mothers did not complete high school and 0.5 for children whose mothers did complete high school.

When should we look for interactions?¶

Interactions can be important. Typically look for them when predictors have large coefficients when not interacted.

For example, smoking strongly associated with cancer. Crucial to adjust for other factors, for example radon exposure. Those who smoke and are exposed to radon may have a much higher risk of cancer than those who only smoke or are only exposed to radon. This would be an interaction between smoking and radon exposure.

We can fit models separately for smokers and non-smokers to see if the relationship between radon exposure and cancer risk differs by smoking status. We can also include an interaction term in a single model to formally test for the presence of an interaction between smoking and radon exposure.

Interpreting regression coefficients in the presence of interactions¶

We can more easily interpret models with interactions by centering the predictors. Typically about the mean.

10.4 Indicator variables¶

Dummy variables are used to represent categorical predictors in regression models.

In [8]:
earnings = pd.read_csv('../ros_data/earnings.csv', skiprows=0)
# earnings['earn_k'] = earnings['earn'] / 1000
# earnings['c_height'] = earnings['height'] - 66 # Center height around 66 inches for better interpretability of the intercept

display(earnings.head())

earnings_clean = earnings.dropna(subset=['height', 'weight'])

# Predict weight in pounds from height
# fit_and_plot_lm(earnings_clean, ['height'], 'weight', add_constant=True, show_plot=True, scatter_kws=None, line_kws=None)

earnings_model = fit_and_plot_bayes(earnings_clean, 'height', 'weight',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=True, show_forest=False,
                       show_posterior=False, show_regression=True,
                       n_regression_lines=100)

print(earnings_model)
height weight male earn earnk ethnicity education mother_education father_education walk exercise smokenow tense angry age
0 74 210.0 1 50000.0 50.0 White 16.0 16.0 16.0 3 3 2.0 0.0 0.0 45
1 66 125.0 0 60000.0 60.0 White 16.0 16.0 16.0 6 5 1.0 0.0 0.0 58
2 64 126.0 0 30000.0 30.0 White 16.0 16.0 16.0 8 1 2.0 1.0 1.0 29
3 65 200.0 0 25000.0 25.0 White 17.0 17.0 NaN 8 1 2.0 0.0 0.0 57
4 63 110.0 0 50000.0 50.0 Other 16.0 16.0 16.0 5 6 2.0 0.0 0.0 91
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_height, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 5 seconds.
                 mean      sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  \
intercept    -164.165  11.550  -186.812   -140.853      0.241    0.175   
slope_height    4.813   0.173     4.469      5.157      0.004    0.003   
sigma          28.970   0.490    28.004     29.921      0.008    0.007   

              ess_bulk  ess_tail  r_hat  
intercept       2293.0    2534.0    1.0  
slope_height    2283.0    2585.0    1.0  
sigma           3458.0    3324.0    1.0  

Regression formula: weight = -164.16 + 4.81*height
No description has been provided for this image
No description has been provided for this image
Inference data with groups:
	> posterior
	> sample_stats
	> observed_data
In [9]:
# Posterior prediction for a new observation (PyMC equivalent of R's posterior_predict)
# Predict weight for a person who is 66 inches tall
new_height = 66

posterior = earnings_model.posterior
intercept_samples = posterior["intercept"].values.flatten()  # all posterior draws for intercept (e.g. 4000 samples)
slope_samples = posterior["slope_height"].values.flatten()  # all posterior draws for slope
sigma_samples = posterior["sigma"].values.flatten()  # all posterior draws for residual std dev

# Point prediction for each posterior draw: E[weight | height=66]
mu_samples = intercept_samples + slope_samples * new_height  # linear predictor evaluated at new_height for each draw

# Full posterior predictive: sample a new observation from Normal(mu, sigma) for each draw
# This adds residual noise on top of the mean prediction, capturing both:
#   1. Parameter uncertainty (from the spread of intercept/slope draws)
#   2. Individual-level variation (from sigma)
# This is what makes it a *prediction* interval rather than just a *confidence* interval
pred_samples = np.random.normal(mu_samples, sigma_samples)

print(f"Predicted weight for height={new_height} inches:")
print(f"  Mean: {pred_samples.mean():.1f} lbs")
print(f"  Median: {np.median(pred_samples):.1f} lbs")
print(f"  50% interval: [{np.percentile(pred_samples, 25):.1f}, {np.percentile(pred_samples, 75):.1f}]")
print(f"  95% interval: [{np.percentile(pred_samples, 2.5):.1f}, {np.percentile(pred_samples, 97.5):.1f}]")

fig, ax = plt.subplots(figsize=(8, 4))
ax.hist(pred_samples, bins=50, density=True, alpha=0.7)
ax.axvline(pred_samples.mean(), color='red', linestyle='--', label=f'Mean: {pred_samples.mean():.1f}')
ax.set_xlabel('Predicted weight (lbs)')
ax.set_ylabel('Density')
ax.set_title(f'Posterior predictive distribution for height = {new_height} inches')
ax.legend()
plt.show()
Predicted weight for height=66 inches:
  Mean: 153.5 lbs
  Median: 153.0 lbs
  50% interval: [133.5, 173.0]
  95% interval: [97.5, 211.1]
No description has been provided for this image
In [10]:
earnings['c_height'] = earnings['height'] - 66 # Center height around 66 inches for better interpretability of the intercept

display(earnings.head())

earnings_clean = earnings.dropna(subset=['c_height', 'weight', 'male'])

# Predict weight in pounds from height and male

earnings_model = fit_and_plot_bayes(earnings_clean, 'male', 'c_height', 'weight',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=True, show_forest=False,
                       show_posterior=False, show_regression=True,
                       n_regression_lines=100)
height weight male earn earnk ethnicity education mother_education father_education walk exercise smokenow tense angry age c_height
0 74 210.0 1 50000.0 50.0 White 16.0 16.0 16.0 3 3 2.0 0.0 0.0 45 8
1 66 125.0 0 60000.0 60.0 White 16.0 16.0 16.0 6 5 1.0 0.0 0.0 58 0
2 64 126.0 0 30000.0 30.0 White 16.0 16.0 16.0 8 1 2.0 1.0 1.0 29 -2
3 65 200.0 0 25000.0 25.0 White 17.0 17.0 NaN 8 1 2.0 0.0 0.0 57 -1
4 63 110.0 0 50000.0 50.0 Other 16.0 16.0 16.0 5 6 2.0 0.0 0.0 91 -3
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_male, slope_c_height, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 1 seconds.
                   mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  \
intercept       149.509  0.916   147.630    151.197      0.013    0.010   
slope_male       11.840  1.968     7.871     15.538      0.030    0.023   
slope_c_height    3.892  0.253     3.399      4.381      0.004    0.003   
sigma            28.697  0.475    27.793     29.648      0.006    0.005   

                ess_bulk  ess_tail  r_hat  
intercept         4846.0    5811.0    1.0  
slope_male        4397.0    4569.0    1.0  
slope_c_height    4854.0    5437.0    1.0  
sigma             6527.0    6053.0    1.0  

Regression formula: weight = 149.51 + 11.84*male + 3.89*c_height
No description has been provided for this image
No description has been provided for this image

Coefficient of 12 for man tells us that when comparing a man to woman of the same height, the man will be 12 pounds more on average.

In [11]:
# Predict weight of a 70 inch woman:

# Posterior prediction for a new observation (PyMC equivalent of R's posterior_predict)
# Predict weight for a person who is 66 inches tall
new_height = 70

posterior = earnings_model.posterior
intercept_samples = posterior["intercept"].values.flatten()  # all posterior draws for intercept (e.g. 4000 samples)
slope_height_samples = posterior["slope_c_height"].values.flatten()  # all posterior draws for slope
slope_male_samples = posterior["slope_male"].values.flatten()  # all posterior draws for slope
sigma_samples = posterior["sigma"].values.flatten()  # all posterior draws for residual std dev

# Point prediction for each posterior draw: E[weight | height=70]
mu_samples = intercept_samples + slope_height_samples * (70-66) + slope_male_samples * 0

# Full posterior predictive: sample a new observation from Normal(mu, sigma) for each draw
# This adds residual noise on top of the mean prediction, capturing both:
#   1. Parameter uncertainty (from the spread of intercept/slope draws)
#   2. Individual-level variation (from sigma)
# This is what makes it a *prediction* interval rather than just a *confidence* interval
pred_samples = np.random.normal(mu_samples, sigma_samples)

print(f"Predicted weight for height={new_height} inches:")
print(f"  Mean: {pred_samples.mean():.1f} lbs")
print(f"  Median: {np.median(pred_samples):.1f} lbs")
print(f"  50% interval: [{np.percentile(pred_samples, 25):.1f}, {np.percentile(pred_samples, 75):.1f}]")
print(f"  95% interval: [{np.percentile(pred_samples, 2.5):.1f}, {np.percentile(pred_samples, 97.5):.1f}]")

fig, ax = plt.subplots(figsize=(8, 4))
ax.hist(pred_samples, bins=50, density=True, alpha=0.7)
ax.axvline(pred_samples.mean(), color='red', linestyle='--', label=f'Mean: {pred_samples.mean():.1f}')
ax.set_xlabel('Predicted weight (lbs)')
ax.set_ylabel('Density')
ax.set_title(f'Posterior predictive distribution for height = {new_height} inches')
ax.legend()
plt.show()
Predicted weight for height=70 inches:
  Mean: 164.9 lbs
  Median: 165.1 lbs
  50% interval: [146.5, 183.4]
  95% interval: [107.8, 220.7]
No description has been provided for this image
Using indicator variables for multiple levels of a categorical predictor¶

Add ethnicity

In [12]:
# Create dummy variables for ethnicity
# Drop any existing eth_ columns first to avoid duplicates on re-run
earnings_clean = earnings_clean[[c for c in earnings_clean.columns if not c.startswith('eth_')]]
eth_dummies = pd.get_dummies(earnings_clean['ethnicity'], prefix='eth', dtype=int)
eth_dummies = eth_dummies.drop(columns='eth_Black')  # Black is reference group
earnings_clean = pd.concat([earnings_clean, eth_dummies], axis=1)
print(eth_dummies.columns.tolist())

with pm.Model() as earnings_eth_model:
    intercept = pm.Normal("intercept", mu=0, sigma=50)
    male = pm.Normal("male", mu=0, sigma=50)
    c_height = pm.Normal("c_height", mu=0, sigma=50)
    eth_Hispanic = pm.Normal("eth_Hispanic", mu=0, sigma=50)
    eth_Other = pm.Normal("eth_Other", mu=0, sigma=50)
    eth_White = pm.Normal("eth_White", mu=0, sigma=50)
    sigma = pm.HalfNormal("sigma", sigma=50)

    # intercept = expected weight for a Black female of average height
    # each eth_ coefficient = difference from Black reference group
    mu = (intercept
          + male * earnings_clean['male'].values
          + c_height * earnings_clean['c_height'].values
          + eth_Hispanic * earnings_clean['eth_Hispanic'].values
          + eth_Other * earnings_clean['eth_Other'].values
          + eth_White * earnings_clean['eth_White'].values)

    y = pm.Normal("y", mu=mu, sigma=sigma, observed=earnings_clean['weight'].values)
    trace_eth = pm.sample(2000, tune=1000)

print(pm.summary(trace_eth, hdi_prob=0.95))
az.plot_trace(trace_eth)
plt.tight_layout()
plt.show()
['eth_Hispanic', 'eth_Other', 'eth_White']
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, male, c_height, eth_Hispanic, eth_Other, eth_White, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 3 seconds.
                 mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  \
intercept     153.997  2.215   149.709    158.387      0.035    0.027   
male           12.127  1.989     8.135     15.971      0.028    0.021   
c_height        3.852  0.252     3.387      4.370      0.003    0.003   
eth_Hispanic   -5.772  3.558   -12.955      1.081      0.047    0.038   
eth_Other     -11.832  5.117   -21.836     -1.974      0.065    0.055   
eth_White      -4.880  2.262    -9.218     -0.416      0.034    0.027   
sigma          28.648  0.479    27.688     29.568      0.006    0.005   

              ess_bulk  ess_tail  r_hat  
intercept       4121.0    4687.0    1.0  
male            5005.0    5109.0    1.0  
c_height        5213.0    5139.0    1.0  
eth_Hispanic    5702.0    5752.0    1.0  
eth_Other       6244.0    5932.0    1.0  
eth_White       4323.0    4830.0    1.0  
sigma           7086.0    5873.0    1.0  
No description has been provided for this image

When comparing a black person to hispanic person of same sex and height, the hispanic person will be -5.83 pounds lighter on average.

Dummy variables act like switches that turn on or off the effect of a particular category. The coefficients for the dummy variables represent the difference in the response variable between the category represented by the dummy variable and the reference category, holding all else constant.

Reference is always when all dummy variables are 0. In this case, the reference category is white people

Using an index variable to access a group-level predictor¶

Sometimes we have predictors that are measured at a group level rather than an individual level. For example, we have data on 1000 students from 20 schools, and we want to include a predictor that is the average income of parents in each school. We can create an index variable that assigns a unique number to each school, and then use this index variable to merge the group-level predictor (average income) with the individual-level data on students. In simple terms this means that we allocate the same average income value to all students in the same school, and then include this variable in our regression model to see how it affects student outcomes.

10.5 Formulating paired or blocked designs as a regression problem¶

Regression coefficients can be interpreted as comparisons. Conversely, we can express comparisons as regressions.

Completely randomised experiment¶

n people are randomly assigned to treatment and control groups, with n/2 in each group. The estimate of treatment effect is $\hat{y_treatment} - \hat{y_control}$, which is the difference in mean outcomes between the treatment and control groups. Standard error of this estimate is $\sqrt{\frac{sd^2_{treatment}}{n/2} + \frac{sd^2_{control}}{n/2}}$.

We can also express this as a regression problem by creating a binary predictor variable that indicates treatment assignment (0 for control, 1 for treatment). But why bother? We can use the regression framework to easily extend this to more complex designs, e.g., adjusting for other variables.

Paired design¶

If we have pairs of individuals (e.g., age, sex, etc.) that are similar in some way, we can match them and then randomly assign one member of each pair to treatment and the other to control. This is a paired design.

Classic way is to calculate the difference between pairs and take the mean of these differences as the estimate of treatment effect.

With regression, we can create a binary predictor variable that indicates treatment assignment (0 for control, 1 for treatment) and include a fixed effect for each pair. This allows us to estimate the treatment effect within pairs.

Block design¶

When participants are grouped into blocks before experiment and then randomly assigned to treatment and control groups within each block, this is a block design. For example, we might group students by classroom and then randomly assign treatment and control within each classroom.

10.6 Example: uncertainty in predicting congressional elections¶

Construct model to predict 1988 election from 1986 election, then apply this to predict 1990 from 1988 and then check against actual 1990 results.

Background¶

US has 435 congressional districts. We define the outcome $y_i$ for $i=1,...,435$ to be Democratic partys share of the two-party vote in district $i$.

Variable Description
v86, v88, v90 Democratic share of the two-party vote in each district (1986, 1988, 1990)
v86_adj, v88_adj, v90_adj Adjusted vote share — uncontested races winsorized to avoid 0/1 extremes
inc86, inc88, inc90 Incumbency status: +1 = Dem incumbent, -1 = Rep incumbent, 0 = open seat

Where v86_adj is defined as: where election was uncontested, we impute the value 0.25 for uncontested Republicans and 0.75 for uncontested Democrats, which have been chosen to approximate the proportion of votes received by the Democratic candidate had the election actually been contested.

In [13]:
# read csv
df = pd.read_csv('../ros_data/congress.csv')
display(df.head())
inc86 inc88 inc90 v86 v88 v90 v86_adj v88_adj v90_adj
0 1 1 1 0.745036 0.772443 0.714029 0.745036 0.772443 0.714029
1 1 1 1 0.673845 0.636182 0.597050 0.673845 0.636182 0.597050
2 1 1 0 0.696457 0.664928 0.521043 0.696457 0.664928 0.521043
3 -1 -1 -1 0.464590 0.273834 0.234377 0.464590 0.273834 0.234377
4 -1 -1 0 0.391095 0.263613 0.477439 0.391095 0.263613 0.477439
In [14]:
print(df['inc86'].value_counts())
print(df['inc88'].value_counts())
print(df['inc90'].value_counts())

# histogram of v88, with 10 bins and a KDE curve
sns.histplot(df['v88'], bins=20, kde=True)
plt.title('Distribution of Vote Share in 1988 (v88)')
plt.xlabel('Vote Share in 1988 (v88)')
plt.ylabel('Density')
plt.show()

# scatter plot of v86 vs v88, colored by inc88
sns.scatterplot(data=df, x='v86', y='v88', hue='inc88')
plt.title('Vote Share in 1986 vs 1988, colored by Incumbency in 1988')
plt.xlabel('Vote Share in 1986 (v86)')
plt.ylabel('Vote Share in 1988 (v88)')
plt.legend(title='Incumbent in 1988')
plt.show()

# scatter plot of v86_adj vs v88_adj, colored by inc88
sns.scatterplot(data=df, x='v86_adj', y='v88_adj', hue='inc88')
plt.title('Vote Share in 1986 vs 1988, colored by Incumbency in 1988')
plt.xlabel('Vote Share in 1986 (v86_adj)')
plt.ylabel('Vote Share in 1988 (v88_adj)')
plt.legend(title='Incumbent in 1988')
plt.show()
inc86
 1    231
-1    163
 0     41
Name: count, dtype: int64
inc88
 1    245
-1    160
 0     30
Name: count, dtype: int64
inc90
 1    247
-1    158
 0     30
Name: count, dtype: int64
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [15]:
sns.lmplot(data=df, x='v86_adj', y='v88_adj', hue='inc88', ci=None)
plt.title('Linear Regression of 1988 Vote Share on 1986 Vote Share, colored by Incumbency in 1988')
plt.xlabel('Vote Share in 1986 (v86_adj)')
plt.ylabel('Vote Share in 1988 (v88_adj)')
plt.show()
No description has been provided for this image
In [16]:
# Predict 1988 from 1986

congress_model = fit_and_plot_bayes(df, 'v86_adj', 'inc88', 'v88_adj',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=True, show_forest=False,
                       show_posterior=False, show_regression=True,
                       n_regression_lines=100)

# There are issues with this, suggesting interactions
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_v86_adj, slope_inc88, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 2 seconds.
                mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  \
intercept      0.238  0.017     0.204      0.271      0.000      0.0   
slope_v86_adj  0.522  0.033     0.455      0.583      0.001      0.0   
slope_inc88    0.096  0.007     0.082      0.109      0.000      0.0   
sigma          0.067  0.002     0.063      0.072      0.000      0.0   

               ess_bulk  ess_tail  r_hat  
intercept        2293.0    3138.0    1.0  
slope_v86_adj    2199.0    2853.0    1.0  
slope_inc88      2404.0    3791.0    1.0  
sigma            4504.0    4314.0    1.0  

Regression formula: v88_adj = 0.24 + 0.52*v86_adj + 0.10*inc88
No description has been provided for this image
No description has been provided for this image
Simulation for inferences and predictions of new data points¶
In [17]:
# Extract posterior simulations (equivalent to R's as.matrix(fit88))
posterior = congress_model.posterior

# Get parameter draws (flatten across chains)
sigma = posterior['sigma'].values.flatten()
beta0 = posterior['intercept'].values.flatten()
beta1 = posterior['slope_v86_adj'].values.flatten()
beta2 = posterior['slope_inc88'].values.flatten()
n_sims = len(sigma)

# Generate predicted vote shares for 1990 using 1988 data
# y_tilde_i = beta0 + beta1 * v88_adj_i + beta2 * inc90_i + noise
v88_adj = df['v88_adj'].values
inc90 = df['inc90'].values

y_tilde = np.zeros((n_sims, 435))
for s in range(n_sims):
    mu = beta0[s] + beta1[s] * v88_adj + beta2[s] * inc90
    y_tilde[s, :] = np.random.normal(mu, sigma[s])

# Count predicted Democratic wins per simulation
dem_wins = (y_tilde > 0.5).sum(axis=1)

# Build summary table like Figure 10.7
sim_table = pd.DataFrame({
    'sim': range(1, n_sims + 1),
    'sigma': sigma,
    'beta0': beta0,
    'beta1': beta1,
    'beta2': beta2,
})
# Add a few example district predictions
for j in [0, 1, 434]:
    sim_table[f'y_tilde_{j+1}'] = y_tilde[:, j]
sim_table['Dem_wins'] = dem_wins

# Show first few, last, and summary stats
display(sim_table.head())
display(sim_table.tail(1))
print('\nSummary:')
display(sim_table.describe().loc[['mean', '50%', 'std']].rename(index={'50%': 'median'}))
sim sigma beta0 beta1 beta2 y_tilde_1 y_tilde_2 y_tilde_435 Dem_wins
0 1 0.067848 0.239418 0.515399 0.093285 0.696838 0.696872 0.691548 255
1 2 0.067314 0.217457 0.554205 0.093199 0.654943 0.577956 0.571270 258
2 3 0.068343 0.252716 0.493964 0.097277 0.693813 0.568496 0.584509 262
3 4 0.071545 0.250206 0.503098 0.101749 0.786488 0.743149 0.718841 262
4 5 0.067992 0.231123 0.539080 0.094944 0.725125 0.629334 0.636047 262
sim sigma beta0 beta1 beta2 y_tilde_1 y_tilde_2 y_tilde_435 Dem_wins
7999 8000 0.065646 0.20504 0.586001 0.082499 0.704776 0.656344 0.629696 258
Summary:
sim sigma beta0 beta1 beta2 y_tilde_1 y_tilde_2 y_tilde_435 Dem_wins
mean 4000.50000 0.067366 0.237602 0.521709 0.096226 0.736238 0.665941 0.630627 259.98475
median 4000.50000 0.067279 0.237542 0.521681 0.096297 0.735521 0.666381 0.629967 260.00000
std 2309.54541 0.002281 0.017228 0.032812 0.006860 0.067754 0.067164 0.068355 2.48456
Combining simulation and analytic calculations¶

Previous example: want to know change an election ends in a tie or within one vote.

This means votes would land in a very narrow range around 50% for the Democratic candidate. If a district had 100,000 voters, the window is between 49,999 and 50,001 votes for the Democratic candidate. If we run 1,000 simulations, to count how many land in this window, we almost always get 0.

Different ways to handle this:

  1. Brute force: run simulations, potentially millions.
  2. Widen the window and scale down. How many fall in a wider window and then simulate in that window to get the proportion that land in the narrower window.
  3. Use an analytic calculation - use theory. Predictions from linear regression follow a known distribution (t-distribution). We can use the center and spread to calculate tail probabilities to get the probability of landing in the narrow window around 50%.

Takeaway: when trying to predict a rare event, it can be more efficient to use an analytic calculation rather than brute force simulation.

10.7 Mathematical notation and statistical inference¶

Predictors¶

Predictors are the columns in the X matrix. Predictor as a term can sometimes refer to the column in the X matrix, and sometimes refer to the variable (concept/idea) that the column represents. Sometimes the constant term (intercept) is also considered a predictor, and sometimes it is not.

Regression in vector-matrix notation¶

The outcome for the $i^{th}$ observation is denoted as $y_i$.

The deterministic prediction as $X_i \beta = \beta_{i}X_{i1} + \beta_{2}X_{i2} + ... + \beta_{k}X_{ik}$, indexing people as $i=1,...,n$.

In childs test score example, $y_i$ is the test score for child $i$, k=4 items in vector $X_{i}$ (intercept, mom_hs, mom_iq, interaction_mom_hs_mom_iq). $X_{i1}$ is the intercept (always 1), $X_{i2}$ is mom_hs, $X_{i3}$ is mom_iq, and $X_{i4}$ is the interaction term. $\beta$ is the vector of coefficients corresponding to these predictors. The vector $\beta$ has 4 items: $\beta_1$ is the coefficient for the intercept, $\beta_2$ is the coefficient for mom_hs, $\beta_3$ is the coefficient for mom_iq, and $\beta_4$ is the coefficient for the interaction term.

When we fit a model to data, there is the true underlying model, which we cannot observe, and then there is the model we estimate from the data, which is an approximation of the true model. The distances between the data points and the true model are errors, which we denote as $\epsilon_i$. The distances between the data points and the estimated model are residuals, which we denote as $e_i$.

Errors are assumed to follow a normal distribution with mean 0 and standard deviation $\sigma$, which we write as $normal(0, \sigma)$.

Variables can be called predictors/independent variables and outcomes/dependent variables.

Two ways of writing the model¶

Classical linear regression model:

$$y_i = \beta_0 + \beta_1 x_{i1} + \beta_2 x_{i2} + ... + \beta_k x_{ik} + \epsilon_i$$ for $i=1,...,n$.

Where errors $\epsilon_i$ have independent normal distributions with mean 0 and standard deviation $\sigma$.

An equivalent way to write this is in vector-matrix notation:

$y_{i} = X_i \beta + \epsilon_i$ for $i=1,...,n$.

Where X is an n by k matrix with $i^{th}$ row $X_i$

Multivariate notation:

$y_{i} ~ normal(X_i \beta, \sigma)$ for $i=1,...,n$.

Even more compactly, we can write this as:

$y ~ multivariate\_normal(X \beta, \sigma^2 I)$

Where $y$ is the vector length n, X is an n x K matrix of predictors, \beta is a vector of length k and I is the n x n identity matrix.

Least squares, maximum likelihood, and Bayesian inference¶

Estimation and inference for linear regression with multiple predictors is the same as for one predictor, just with more predictors. We start with least squares estimate, the vector $\hat{\beta}$ that minimizes the sum of squared residuals, $RSS = \sum_{i=1}^n (y_i - X_i \hat{\beta})^2$. Standard linear regression model with accurately measured predictors, independent errors, and normally distributed errors, the least squares estimate is also the maximum likelihood estimate. The standard estimate of the residual standard deviation $\hat{\sigma}$ is the square root of the mean squared residuals, $\hat{\sigma} = \sqrt{\frac{1}{n-k} \sum_{i=1}^n (y_i - X_i \hat{\beta})^2}$, where $k$ is the number of predictors including the intercept.

Nonidentified parameters, collinearity, and the likelihood function¶

Nonidenfiability: if we imagine a mountain of the likelihood function, the peak of the mountain is the maximum likelihood estimate. If there are two or more parameters that can be changed in a way that does not change the value of the likelihood function, then we have a ridge in the likelihood function rather than a single peak. This means that there are multiple combinations of parameters that yield the same maximum likelihood value, and we cannot uniquely identify the parameters from the data - there is no unique best estimate. The parameters along the ridge are not uniquely identifiable and their errors blow up to infinity. This is known as nonidentifiability.

Collinearity: The most common reason is collinearity, which is when one of the predictors is a combination of the others - it adds no new information. For example, if we have (number of boys) + (number of girls) + (total children). But total children = boys + girls. The third predictor is completely redundant. The model can't tell whether the effect "belongs to" boys, girls, or the total.

Near collinearity: when one predictor is almost a combination of the others, we have near collinearity. This can lead to very large standard errors for the coefficients of the predictors involved in the near collinearity, making it difficult to determine which predictor is actually associated with the outcome variable. For example, if we try to estimate height from right foot length and left foot length, these two predictors are highly correlated and almost perfectly collinear. This can lead to large standard errors for the coefficients of right foot length and left foot length, making it difficult to determine which predictor is actually associated with height.

Hypothesis testing: why we do not like t tests and F tests¶

A t-test asks: "Is this one regression coefficient different from 0 or is it 0?" Take the coefficient, divide by its standard error, and compare to a t-distribution and if the value is big enough, we declare the coefficient is statistically significant.

An F-test asks: "Are all of the regression coefficients (except the intercept) equal to 0 or are at least some of them different from 0?"

Authors have a philosophical objection to these tests. They are not interested in testing whether a coefficient is exactly 0, which is almost never the case in real data. Some predictor will almost always have some association with the outcome variable, even if it is very small. So, testing this null hypothesis is not very informative.

Instead, their preferred approach:

  1. Look at standard errors alongside estimates. If standard error is huge compared to the estimate, then we have a lot of uncertainty about the estimate and it is not very informative.
  2. Use Bayesian inference with noisy estimates. This allows us to incorporate prior information and get a more nuanced understanding of the uncertainty around our estimates, rather than just a binary significant/not significant result from a t-test or F-test.
  3. Use cross validation if real question is 'does dropping this predictor hurt predictions?'. Test models against held-out data to see if dropping a predictor leads to worse predictions.

10.8 Weighted regression¶

Least squares (pick line that makes sum of squared residuals minimal) regression is equivalent to maximum likelihood (pick line makes observed data look most probable) when errors are independent (one point being off doesn't affect others) and normally distributed with equal variance (normal curve is same width everywhere). In sum of squared residuals, each term gets same weight.

But we might want to give different weights to different observations. We can with weighted least squares regression, where estimate $\hat{\beta_{wls}}$ minimizes the weighted sum of squared residuals: $\sum_{i=1}^n w_i (y_i - x_i^T \beta)^2$, for some specified weights $w = (w_1, w_2, ..., w_n)$. This is equivalent to maximum likelihood when errors are independent and normally distributed with mean 0 and variance $\sigma^2 / w_i$ for the $i^{th}$ observation. Points with higher weights count more so line is constrained to be closer to those points.

Matrix algebra:

$\hat{\beta = (X^T W X)^{-1} X^T W y}$

Where W is matrix of weights, W = Diag(w).

Three models leading to weighted regression¶

Weighted regression can be derived from three different models:

  1. Using observed data to represent a larger population. A weighted regression is fit to sample data in order to estimate the (unweighted) regression in the population. For example, if we have a sample that is not representative of the population, we can use weights to adjust for this and get estimates that are more reflective of the population.
  2. Duplicate observations. If you have 10 observations that are identical, you can represent this as a single observation with a weight of 10. This is equivalent to having 10 separate observations that are the same.
  3. Unequal variances. When fitting a line to data that is precise (small error) and data that is noisy (large error), we want the line to be closer to the precise data and less influenced by the noisy data. We can achieve this by giving more weight to the precise data and less weight to the noisy data in a weighted regression.

All give same point estimate, but different standard errors and different predictive distributions.

Most usual scenario of adjusting between sample and population. Renormalise the weights so they average 1 and then pass into the regression.

Using a matrix of weights to account for correlated errors¶

Previous examples are for independent errors. If we have correlated errors, we can use a matrix of weights to account for this. For example, if we have repeated measurements on the same individuals, the errors for those measurements are likely to be correlated.

Correlated errors can show up:

  1. Time series data
  2. Spatial data
  3. Clustered data (e.g., students within schools, patients within hospitals)
  4. Repeated measures data (e.g., multiple measurements on the same individuals)

Ignoring correlated errors can lead to underestimated standard errors and inflated type I error rates.

10.9 Fitting the same model to many datasets¶

E.g. earnings vs height. Instead of fitting a single regression model to the entire dataset, we can fit separate regression models for different subgroups (e.g., different years, countrys, etc.). This allows us to see how the relationship between earnings and height differs across these subgroups.

Patterns can jump out: is effect stable, growing, shrinking, etc. across subgroups. This can give us insights into how the relationship between earnings and height may be changing over time or across different contexts.

More sophisticated models are not always better. More things to go wrong, more things to interpret, more things to get wrong. Sometimes simpler models can be more robust and easier to interpret, even if they don't capture all the nuances of the data. It's important to balance model complexity with interpretability and robustness.

  1. Complete pooling: fit a single model to the entire dataset, ignoring subgroup differences. This is the simplest approach, but it may miss important differences between subgroups.
  2. No pooling: fit separate models for each subgroup, ignoring any similarities between subgroups. This allows for maximum flexibility, but it can lead to overfitting and unstable estimates, especially for subgroups with small sample sizes.
  3. Partial pooling: fit a hierarchical model that allows for both subgroup-specific effects and overall effects. This approach can borrow strength across subgroups, leading to more stable estimates, especially for subgroups with small sample sizes, while still allowing for subgroup differences. More complex to interpret.

Secret weapon is the no pooling approach.

Predicting party identification¶

Illustrate the secret weapon.

In [39]:
# read csv
nes = pd.read_csv('../ros_data/nes.txt', sep=' ')
# display(nes.head())
# filter only years 1972 to 2000
nes = nes[(nes['year'] >= 1972) & (nes['year'] <= 2000)]
display(nes.head())
# count unique years
print("Unique years in NES data:", nes['year'].nunique())
# print unique years
print("Unique years in NES data:", nes['year'].unique())
# print unique columns
print("Columns in NES data:", nes.columns.tolist())
# print unique values in ideo
print("Unique values in 'ideo':", nes['ideo'].unique())
# print unique values in real_ideo
print("Unique values in 'real_ideo':", nes['real_ideo'].unique())
# print unique values in partyid7
print("Unique values in 'partyid7':", nes['partyid7'].unique())

# loop through years 2 at a time starting 1972 and ending with 2000

for loop_year in range(1972, 2002, 2):
    print(f"\n--- Year: {loop_year} ---")
    # filter data for this year
    data = nes[nes['year'] == loop_year]
    # drop rows with missing values in real_ideo or partyid7
    data = data.dropna(subset=['real_ideo', 'partyid7'])
    # fit and plot linear model predicting partyid7 from real_ideo
    fit_and_plot_lm(data, ['real_ideo'], 'partyid7', add_constant=True, show_plot=False, scatter_kws=None, line_kws=None)
year resid weight1 weight2 weight3 age gender race educ1 urban ... parent_party white year_new income_new age_new vote.1 age_discrete race_adj dvote rvote
13790 1972 1 1.0 1.0 1.0 75 2 1 2 3.0 ... NaN 1 10 -2 2.947545 NaN 4 1.0 NaN NaN
13791 1972 3 1.0 1.0 1.0 24 1 1 2 3.0 ... -2.0 1 10 0 -2.152455 1.0 1 1.0 NaN NaN
13792 1972 4 1.0 1.0 1.0 21 1 1 2 3.0 ... 2.0 1 10 0 -2.452455 1.0 1 1.0 0.0 1.0
13794 1972 6 1.0 1.0 1.0 30 2 1 2 3.0 ... 2.0 1 10 1 -1.552455 NaN 2 1.0 1.0 0.0
13795 1972 7 1.0 1.0 1.0 73 1 1 2 3.0 ... 2.0 1 10 1 2.747545 1.0 4 1.0 0.0 1.0

5 rows × 70 columns

Unique years in NES data: 15
Unique years in NES data: [1972 1974 1976 1978 1980 1982 1984 1986 1988 1990 1992 1994 1996 1998
 2000]
Columns in NES data: ['year', 'resid', 'weight1', 'weight2', 'weight3', 'age', 'gender', 'race', 'educ1', 'urban', 'region', 'income', 'occup1', 'union', 'religion', 'educ2', 'educ3', 'martial_status', 'occup2', 'icpsr_cty', 'fips_cty', 'partyid7', 'partyid3', 'partyid3_b', 'str_partyid', 'father_party', 'mother_party', 'dlikes', 'rlikes', 'dem_therm', 'rep_therm', 'regis', 'vote', 'regisvote', 'presvote', 'presvote_2party', 'presvote_intent', 'ideo_feel', 'ideo7', 'ideo', 'cd', 'state', 'inter_pre', 'inter_post', 'black', 'female', 'age_sq', 'rep_presvote', 'rep_pres_intent', 'south', 'real_ideo', 'presapprov', 'perfin1', 'perfin2', 'perfin', 'presadm', 'age_10', 'age_sq_10', 'newfathe', 'newmoth', 'parent_party', 'white', 'year_new', 'income_new', 'age_new', 'vote.1', 'age_discrete', 'race_adj', 'dvote', 'rvote']
Unique values in 'ideo': [nan  1.  5.  3.]
Unique values in 'real_ideo': [nan  5.  6.  4.  7.  2.  3.  1.]
Unique values in 'partyid7': [ 1.  5.  6.  7.  2.  3.  4. nan]

--- Year: 1972 ---
                            OLS Regression Results                            
==============================================================================
Dep. Variable:               partyid7   R-squared:                       0.113
Model:                            OLS   Adj. R-squared:                  0.112
Method:                 Least Squares   F-statistic:                     169.2
Date:                Tue, 14 Apr 2026   Prob (F-statistic):           1.69e-36
Time:                        07:43:01   Log-Likelihood:                -2769.0
No. Observations:                1330   AIC:                             5542.
Df Residuals:                    1328   BIC:                             5552.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
real_ideo      0.5394      0.041     13.008      0.000       0.458       0.621
const          1.5578      0.180      8.664      0.000       1.205       1.911
==============================================================================
Omnibus:                      495.446   Durbin-Watson:                   2.027
Prob(Omnibus):                  0.000   Jarque-Bera (JB):               69.067
Skew:                           0.107   Prob(JB):                     1.01e-15
Kurtosis:                       1.904   Cond. No.                         15.4
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: partyid7 = 1.56 + 0.54*real_ideo
Residual std dev (σ): 1.94 ± 0.04
MAD of residuals: 2.28

--- Year: 1974 ---
                            OLS Regression Results                            
==============================================================================
Dep. Variable:               partyid7   R-squared:                       0.118
Model:                            OLS   Adj. R-squared:                  0.117
Method:                 Least Squares   F-statistic:                     142.4
Date:                Tue, 14 Apr 2026   Prob (F-statistic):           6.75e-31
Time:                        07:43:01   Log-Likelihood:                -2196.6
No. Observations:                1065   AIC:                             4397.
Df Residuals:                    1063   BIC:                             4407.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
real_ideo      0.5144      0.043     11.933      0.000       0.430       0.599
const          1.5302      0.186      8.214      0.000       1.165       1.896
==============================================================================
Omnibus:                      265.630   Durbin-Watson:                   1.919
Prob(Omnibus):                  0.000   Jarque-Bera (JB):               48.578
Skew:                           0.099   Prob(JB):                     2.83e-11
Kurtosis:                       1.973   Cond. No.                         14.5
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: partyid7 = 1.53 + 0.51*real_ideo
Residual std dev (σ): 1.91 ± 0.04
MAD of residuals: 2.2

--- Year: 1976 ---
                            OLS Regression Results                            
==============================================================================
Dep. Variable:               partyid7   R-squared:                       0.185
Model:                            OLS   Adj. R-squared:                  0.184
Method:                 Least Squares   F-statistic:                     268.5
Date:                Tue, 14 Apr 2026   Prob (F-statistic):           1.52e-54
Time:                        07:43:01   Log-Likelihood:                -2406.7
No. Observations:                1184   AIC:                             4817.
Df Residuals:                    1182   BIC:                             4827.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
real_ideo      0.6563      0.040     16.387      0.000       0.578       0.735
const          1.0079      0.179      5.624      0.000       0.656       1.360
==============================================================================
Omnibus:                      146.930   Durbin-Watson:                   1.966
Prob(Omnibus):                  0.000   Jarque-Bera (JB):               39.049
Skew:                          -0.026   Prob(JB):                     3.32e-09
Kurtosis:                       2.112   Cond. No.                         15.6
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: partyid7 = 1.01 + 0.66*real_ideo
Residual std dev (σ): 1.85 ± 0.04
MAD of residuals: 2.46

--- Year: 1978 ---
                            OLS Regression Results                            
==============================================================================
Dep. Variable:               partyid7   R-squared:                       0.124
Model:                            OLS   Adj. R-squared:                  0.124
Method:                 Least Squares   F-statistic:                     205.8
Date:                Tue, 14 Apr 2026   Prob (F-statistic):           9.38e-44
Time:                        07:43:01   Log-Likelihood:                -2903.6
No. Observations:                1453   AIC:                             5811.
Df Residuals:                    1451   BIC:                             5822.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
real_ideo      0.4994      0.035     14.347      0.000       0.431       0.568
const          1.5222      0.152      9.997      0.000       1.223       1.821
==============================================================================
Omnibus:                      195.606   Durbin-Watson:                   1.967
Prob(Omnibus):                  0.000   Jarque-Bera (JB):               51.180
Skew:                           0.091   Prob(JB):                     7.70e-12
Kurtosis:                       2.099   Cond. No.                         14.9
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: partyid7 = 1.52 + 0.50*real_ideo
Residual std dev (σ): 1.79 ± 0.03
MAD of residuals: 2.22

--- Year: 1980 ---
                            OLS Regression Results                            
==============================================================================
Dep. Variable:               partyid7   R-squared:                       0.185
Model:                            OLS   Adj. R-squared:                  0.183
Method:                 Least Squares   F-statistic:                     158.3
Date:                Tue, 14 Apr 2026   Prob (F-statistic):           7.21e-33
Time:                        07:43:01   Log-Likelihood:                -1439.7
No. Observations:                 701   AIC:                             2883.
Df Residuals:                     699   BIC:                             2892.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
real_ideo      0.6462      0.051     12.582      0.000       0.545       0.747
const          0.9090      0.237      3.842      0.000       0.444       1.374
==============================================================================
Omnibus:                      108.506   Durbin-Watson:                   1.931
Prob(Omnibus):                  0.000   Jarque-Bera (JB):               26.029
Skew:                          -0.040   Prob(JB):                     2.23e-06
Kurtosis:                       2.059   Cond. No.                         15.9
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: partyid7 = 0.91 + 0.65*real_ideo
Residual std dev (σ): 1.89 ± 0.05
MAD of residuals: 2.01

--- Year: 1982 ---
                            OLS Regression Results                            
==============================================================================
Dep. Variable:               partyid7   R-squared:                       0.190
Model:                            OLS   Adj. R-squared:                  0.189
Method:                 Least Squares   F-statistic:                     190.3
Date:                Tue, 14 Apr 2026   Prob (F-statistic):           4.85e-39
Time:                        07:43:01   Log-Likelihood:                -1670.1
No. Observations:                 814   AIC:                             3344.
Df Residuals:                     812   BIC:                             3354.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
real_ideo      0.6796      0.049     13.794      0.000       0.583       0.776
const          0.6245      0.223      2.800      0.005       0.187       1.062
==============================================================================
Omnibus:                      128.028   Durbin-Watson:                   2.040
Prob(Omnibus):                  0.000   Jarque-Bera (JB):               30.499
Skew:                           0.058   Prob(JB):                     2.38e-07
Kurtosis:                       2.059   Cond. No.                         16.0
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: partyid7 = 0.62 + 0.68*real_ideo
Residual std dev (σ): 1.89 ± 0.05
MAD of residuals: 2.4

--- Year: 1984 ---
                            OLS Regression Results                            
==============================================================================
Dep. Variable:               partyid7   R-squared:                       0.175
Model:                            OLS   Adj. R-squared:                  0.174
Method:                 Least Squares   F-statistic:                     259.0
Date:                Tue, 14 Apr 2026   Prob (F-statistic):           5.33e-53
Time:                        07:43:01   Log-Likelihood:                -2565.0
No. Observations:                1226   AIC:                             5134.
Df Residuals:                    1224   BIC:                             5144.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
real_ideo      0.6611      0.041     16.092      0.000       0.581       0.742
const          1.0822      0.183      5.920      0.000       0.724       1.441
==============================================================================
Omnibus:                       98.065   Durbin-Watson:                   1.914
Prob(Omnibus):                  0.000   Jarque-Bera (JB):               32.705
Skew:                          -0.049   Prob(JB):                     7.91e-08
Kurtosis:                       2.206   Cond. No.                         15.2
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: partyid7 = 1.08 + 0.66*real_ideo
Residual std dev (σ): 1.96 ± 0.04
MAD of residuals: 2.46

--- Year: 1986 ---
                            OLS Regression Results                            
==============================================================================
Dep. Variable:               partyid7   R-squared:                       0.138
Model:                            OLS   Adj. R-squared:                  0.137
Method:                 Least Squares   F-statistic:                     235.0
Date:                Tue, 14 Apr 2026   Prob (F-statistic):           2.46e-49
Time:                        07:43:01   Log-Likelihood:                -3054.5
No. Observations:                1476   AIC:                             6113.
Df Residuals:                    1474   BIC:                             6124.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
real_ideo      0.6085      0.040     15.331      0.000       0.531       0.686
const          1.1709      0.176      6.656      0.000       0.826       1.516
==============================================================================
Omnibus:                      291.207   Durbin-Watson:                   1.885
Prob(Omnibus):                  0.000   Jarque-Bera (JB):               59.949
Skew:                           0.061   Prob(JB):                     9.60e-14
Kurtosis:                       2.020   Cond. No.                         16.4
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: partyid7 = 1.17 + 0.61*real_ideo
Residual std dev (σ): 1.92 ± 0.04
MAD of residuals: 2.38

--- Year: 1988 ---
                            OLS Regression Results                            
==============================================================================
Dep. Variable:               partyid7   R-squared:                       0.189
Model:                            OLS   Adj. R-squared:                  0.189
Method:                 Least Squares   F-statistic:                     259.4
Date:                Tue, 14 Apr 2026   Prob (F-statistic):           1.28e-52
Time:                        07:43:01   Log-Likelihood:                -2328.7
No. Observations:                1113   AIC:                             4661.
Df Residuals:                    1111   BIC:                             4671.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
real_ideo      0.6801      0.042     16.107      0.000       0.597       0.763
const          1.0416      0.194      5.367      0.000       0.661       1.422
==============================================================================
Omnibus:                      115.109   Durbin-Watson:                   1.949
Prob(Omnibus):                  0.000   Jarque-Bera (JB):               39.222
Skew:                          -0.189   Prob(JB):                     3.04e-09
Kurtosis:                       2.162   Cond. No.                         15.8
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: partyid7 = 1.04 + 0.68*real_ideo
Residual std dev (σ): 1.96 ± 0.04
MAD of residuals: 2.43

--- Year: 1990 ---
                            OLS Regression Results                            
==============================================================================
Dep. Variable:               partyid7   R-squared:                       0.123
Model:                            OLS   Adj. R-squared:                  0.122
Method:                 Least Squares   F-statistic:                     165.5
Date:                Tue, 14 Apr 2026   Prob (F-statistic):           1.53e-35
Time:                        07:43:01   Log-Likelihood:                -2477.3
No. Observations:                1180   AIC:                             4959.
Df Residuals:                    1178   BIC:                             4969.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
real_ideo      0.5568      0.043     12.866      0.000       0.472       0.642
const          1.3893      0.189      7.332      0.000       1.018       1.761
==============================================================================
Omnibus:                      381.467   Durbin-Watson:                   1.889
Prob(Omnibus):                  0.000   Jarque-Bera (JB):               57.041
Skew:                           0.042   Prob(JB):                     4.11e-13
Kurtosis:                       1.926   Cond. No.                         15.1
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: partyid7 = 1.39 + 0.56*real_ideo
Residual std dev (σ): 1.98 ± 0.04
MAD of residuals: 2.31

--- Year: 1992 ---
                            OLS Regression Results                            
==============================================================================
Dep. Variable:               partyid7   R-squared:                       0.234
Model:                            OLS   Adj. R-squared:                  0.233
Method:                 Least Squares   F-statistic:                     411.8
Date:                Tue, 14 Apr 2026   Prob (F-statistic):           4.11e-80
Time:                        07:43:01   Log-Likelihood:                -2774.0
No. Observations:                1350   AIC:                             5552.
Df Residuals:                    1348   BIC:                             5562.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
real_ideo      0.7271      0.036     20.294      0.000       0.657       0.797
const          0.7228      0.158      4.561      0.000       0.412       1.034
==============================================================================
Omnibus:                       76.752   Durbin-Watson:                   2.018
Prob(Omnibus):                  0.000   Jarque-Bera (JB):               31.298
Skew:                          -0.106   Prob(JB):                     1.60e-07
Kurtosis:                       2.285   Cond. No.                         14.3
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: partyid7 = 0.72 + 0.73*real_ideo
Residual std dev (σ): 1.89 ± 0.04
MAD of residuals: 2.29

--- Year: 1994 ---
                            OLS Regression Results                            
==============================================================================
Dep. Variable:               partyid7   R-squared:                       0.265
Model:                            OLS   Adj. R-squared:                  0.265
Method:                 Least Squares   F-statistic:                     447.2
Date:                Tue, 14 Apr 2026   Prob (F-statistic):           5.54e-85
Time:                        07:43:01   Log-Likelihood:                -2510.5
No. Observations:                1240   AIC:                             5025.
Df Residuals:                    1238   BIC:                             5035.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
real_ideo      0.8022      0.038     21.146      0.000       0.728       0.877
const          0.5221      0.177      2.955      0.003       0.175       0.869
==============================================================================
Omnibus:                       28.987   Durbin-Watson:                   2.065
Prob(Omnibus):                  0.000   Jarque-Bera (JB):               20.965
Skew:                          -0.213   Prob(JB):                     2.80e-05
Kurtosis:                       2.526   Cond. No.                         16.5
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: partyid7 = 0.52 + 0.80*real_ideo
Residual std dev (σ): 1.83 ± 0.04
MAD of residuals: 2.09

--- Year: 1996 ---
                            OLS Regression Results                            
==============================================================================
Dep. Variable:               partyid7   R-squared:                       0.362
Model:                            OLS   Adj. R-squared:                  0.361
Method:                 Least Squares   F-statistic:                     590.4
Date:                Tue, 14 Apr 2026   Prob (F-statistic):          1.14e-103
Time:                        07:43:01   Log-Likelihood:                -2080.9
No. Observations:                1043   AIC:                             4166.
Df Residuals:                    1041   BIC:                             4176.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
real_ideo      0.9503      0.039     24.299      0.000       0.874       1.027
const         -0.2857      0.179     -1.599      0.110      -0.636       0.065
==============================================================================
Omnibus:                       12.887   Durbin-Watson:                   1.968
Prob(Omnibus):                  0.002   Jarque-Bera (JB):               12.555
Skew:                          -0.239   Prob(JB):                      0.00188
Kurtosis:                       2.753   Cond. No.                         15.4
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: partyid7 = -0.29 + 0.95*real_ideo
Residual std dev (σ): 1.78 ± 0.04
MAD of residuals: 1.7

--- Year: 1998 ---
                            OLS Regression Results                            
==============================================================================
Dep. Variable:               partyid7   R-squared:                       0.207
Model:                            OLS   Adj. R-squared:                  0.206
Method:                 Least Squares   F-statistic:                     244.9
Date:                Tue, 14 Apr 2026   Prob (F-statistic):           3.20e-49
Time:                        07:43:01   Log-Likelihood:                -1911.4
No. Observations:                 939   AIC:                             3827.
Df Residuals:                     937   BIC:                             3837.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
real_ideo      0.7072      0.045     15.651      0.000       0.618       0.796
const          0.7717      0.201      3.845      0.000       0.378       1.166
==============================================================================
Omnibus:                       45.150   Durbin-Watson:                   1.856
Prob(Omnibus):                  0.000   Jarque-Bera (JB):               20.393
Skew:                          -0.123   Prob(JB):                     3.73e-05
Kurtosis:                       2.322   Cond. No.                         15.4
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: partyid7 = 0.77 + 0.71*real_ideo
Residual std dev (σ): 1.85 ± 0.04
MAD of residuals: 2.35

--- Year: 2000 ---
                            OLS Regression Results                            
==============================================================================
Dep. Variable:               partyid7   R-squared:                       0.272
Model:                            OLS   Adj. R-squared:                  0.270
Method:                 Least Squares   F-statistic:                     176.9
Date:                Tue, 14 Apr 2026   Prob (F-statistic):           1.60e-34
Time:                        07:43:01   Log-Likelihood:                -967.69
No. Observations:                 476   AIC:                             1939.
Df Residuals:                     474   BIC:                             1948.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
real_ideo      0.8048      0.061     13.300      0.000       0.686       0.924
const          0.3502      0.274      1.276      0.203      -0.189       0.889
==============================================================================
Omnibus:                       11.681   Durbin-Watson:                   1.918
Prob(Omnibus):                  0.003   Jarque-Bera (JB):                6.809
Skew:                          -0.100   Prob(JB):                       0.0332
Kurtosis:                       2.449   Cond. No.                         15.3
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: partyid7 = 0.35 + 0.80*real_ideo
Residual std dev (σ): 1.85 ± 0.06
MAD of residuals: 2.37

10.10 Bibliographic note¶

10.11 Exercises¶