Cross-validation — a fourth way to compute a Bayes factor

· 2024/03/23 · 8 minute read

In this post we’ll explore a particular link between Bayes factors and cross-validation I was introduced to via Fong & Holmes 2020. I’ll then argue why this is a reason to not trust Bayes factors too much. This is a followup to Three ways to compute a Bayes factor, though I will repeat all the important bits here.

Note on notation: I tried to be consistent and use plain symbols (\(y_1, z, ...\)) for variables, bold symbols (\(\mathbf{y}\)) for vectors and matrices, \(P(A)\) for the probability of event \(A\) and \(p(y)\) for the density of random variable.

Example models

To make things specific, we will use very simple models as examples (those are the same as in the Three ways post). Our first model, \(\mathcal{M}_1\) assumes that the \(K\) data points are independent draws from a standard normal distribution, i.e.:

\[ \mathcal{M}_1 : \mathbf{y} = \{y_1, ... , y_K\} \\ y_i \sim N(0,1) \]

Our second model, \(\mathcal{M}_2\) assumes that the mean of the normal distribution is a free parameter with a normal prior, i.e.:

\[ \mathcal{M}_2: \mathbf{y} = \{y_1, ... , y_K\} \\ y_i \sim N(\alpha, 1) \\ \alpha \sim N(0,2) \]

Now, lets take a simple vector of values to evaluate the models against:

y <- c(0.5,0.7, -0.4, 0.1)

Bayes factor and evidence

One way to define Bayes factor is as the ratio of evidence i.e.:

\[ BF_{12} = \frac{P(\mathbf{y} | \mathcal{M}_1)}{P(\mathbf{y} | \mathcal{M}_2)} \]

Where “evidence” is exactly the prior density of the data after integrating out all the parameters.

Our models are simple enough that we can evaluate the evidence analytically — the Three ways post has the math.

For the given dataset we thus obtain \(\log P(\mathbf{y} | \mathcal{M}_1) \simeq -4.131\), \(\log P(\mathbf{y} | \mathcal{M}_2) \simeq -5.452\) and \(BF_{12} \simeq 3.748\).

Cross-validation and evidence

One of the main results of Fong & Holmes 2020 is that evidence is related to cross-validation (Proposition 2 in the paper)1. For this to hold, we need to score cross validation using the log posterior predictive density

\[ s_\mathcal{M}(\tilde{y} \mid \mathbf{y}) = \log \int f_\mathcal{M}(\tilde{y}, \theta) \,{\rm d} p(\theta \mid \mathbf{y}, \mathcal{M}) = \log \int_\Theta f_\mathcal{M}(\tilde{y}, \theta) p(\theta \mid \mathbf{y}, \mathcal{M}) \,{\rm d} \theta \]

where \(\mathbf{\theta} \in \Theta\) is the vector of all parameters of the model and \(f_\mathcal{M}(y, \theta)\) is the likelihood of model \(\mathcal{M}\) evaluated for data point \(y\) and parameters \(\theta\). Note that this is the same score as used e.g. in the loo package for Bayesian cross-validation.

We can then define an exhaustive leave-\(J\)-out cross-validation of model \(\mathcal{M}\) with data \(\mathbf{y} = (y_1, ... , y_K)\) as the average of the log predictive densities over all possible held-out dataset of size \(J\):

\[ S^{\rm CV}_\mathcal{M} (\mathbf{y} ; J) = \frac{1}{{K \choose J}} \sum_{t=1}^{{K \choose J }} \frac{1}{J} \sum_{j=1}^{J} s\bigl(\tilde{y}_{j}^{(t)} \;\big|\; y^{(t)}_{1:K-J}\bigr) \]

where \(y^{(t)}_i\) is the \(i\)-the element of \(t\)-th combination of \(J\) elements out of \(K\) and \(\tilde{y}^{(t)}_i\) is the \(i\)-the element of the complement of this combination. Finally, we express the logarithm of evidence as the sum of the cross-validation scores over all possible held-out dataset sizes:

\[ \log P(\mathbf{y} | \mathcal{M}) = \sum_{J=1}^{K} S^{\rm CV}_\mathcal{M} (\mathbf{y} ; J) \]

Note that the formula above holds regardless of the specific way we choose to partition \(\bf{y}\) into individual “data points”. At one extreme, we can treat all the data as a single indivisible element — we then have \(K = 1\) and recover the formula for evidence as the prior predictive probability. We can partition by individual numerical values, but we can also partition by e.g. patients etc.

In all cases, we take the joint likelihood \(f_\mathcal{M}(\tilde{y}, \theta)\) to compute the expected log predictive density for each element of the partition. But for each cross-validation fold we then take the average of those densities. So a finer partition will do “more averaging” and treat small subsets of data as independent, while a coarser partition will consider the joint dependencies in each element of the partition, and then do “less averaging”.

Finally, the above formula is in most cases ridiculously impractical for actual computation and is therefore primarily of theoretical interest.

Computing the example

Let’s explore how the formula works in code and us start with the \(\mathcal{M_2}\) model (intercept) as that’s more interesting. We will closely follow the formulae. Note that the posterior density \(p(\alpha | \mathbf{y})\) is available analytically and is normal (see the wiki page for derivation).

Since the posterior is normal and the observation model is normal, the posterior predictive density is also normal. The posterior predictive mean is exactly the posterior mean and posterior predictive variance is equal to the sum of observational and posterior variances.

Putting it all together we compute \(\frac{1}{J} \sum_{j=1}^{J} s\bigl(\tilde{y}_{j}^{(t)} \;\big|\; y^{(t)}_{1:K-J}\bigr)\) in the cv_score_m2_single function:

cv_score_m2_single <- function(observed, held_out) {
  prior_mean <- 0
  prior_sd <- 2
  obs_sd <- 1
  K <- length(observed)
  if(K > 0) {
    prior_precision <- prior_sd ^ -2
    obs_precision <- obs_sd ^ -2
    obs_mean <- mean(observed)
    post_precision <- prior_precision + K * obs_precision
    post_sd <- sqrt(1/post_precision)
    post_mean <- (K * obs_precision * obs_mean + prior_precision * prior_mean) / 
      post_precision
  } else {
    post_mean <- prior_mean
    post_sd <- prior_sd
  }
  posterior_pred_sd <- sqrt(post_sd^2 + obs_sd^2)
  log_score <- sum(dnorm(
    held_out, mean = post_mean, sd = posterior_pred_sd, log = TRUE))
  return(log_score / length(held_out))
} 

Now cv_score_m2 loops over all possible combinations of size \(J\) and log_evidence_m2_cv adds it all together:

cv_score_m2 <- function(y, J) {
  K <- length(y)
  combinations <- combn(1:K, J)
  res_unscaled <- 0
  for(t in 1:ncol(combinations)) {
    held_out <- y[combinations[,t]]
    observed <- y[setdiff(1:K, combinations[,t])]
    res_unscaled <- res_unscaled + cv_score_m2_single(observed, held_out)
  }
  return(res_unscaled / ncol(combinations))
}

log_evidence_m2_cv <- function(y) {
  res <- 0
  for(p in 1:length(y)) {
    res <- res + cv_score_m2(y, p)
  }
  return(res)
}

We obtain a result that is identical to the direct computation of evidence:

log_evidence_m2_cv(y)
## [1] -5.452067

For the \(\mathcal{M_1}\) (null) model, we can avoid all this looping because the density of the held-out data does not depend on the data seen so far, so we have

\[ s_1(\tilde{y} \mid \mathbf{y}) = \mathtt{normal\_lpdf}(\tilde{y} | 0, 1) \]

where \(\mathtt{normal\_lpdf}\) is the log of the density function of a normal distribution. Since the cross-validation is exhaustive, then each \(y\) value is held-out the same number of times and since \(S^{\rm{CV}}_\mathcal{M}\) is an average, we have

\[ S^{\rm CV}_\mathcal{M} (\mathbf{y} ; J) = \frac{1}{J}\sum_{i = 1}^K \mathtt{normal\_lpdf}(y_i | 0, 1) \] and the evidence thus is:

\[ \log P(\mathbf{y} | \mathcal{M}) = \sum_{i = 1}^n \mathtt{normal\_lpdf}(y_i | 0, 1) \]

which happens to be exactly the same as the log prior predictive density and matches our expectations:

sum(dnorm(y, mean = 0, sd = 1, log = TRUE))
## [1] -4.130754

Since we obtained the correct values for evidence, we also obtain the correct value for the Bayes factor.

And interestingly, log Bayes factor is the difference of log-evidence, so it is in this sense an analog to the difference in log predictive density as reported for cross validation by e.g. the loo package.

What does it mean?

Some people claim that this connection is a justification for using Bayes factors. Some even claim that if you accept cross-validation as valid you must accept Bayes factors as valid. I am personally not very convinced — as already mentioned by Fong & Holmes 2020 the cross-validation scheme we see here is pretty weird. Why would I want to include “leave-all-data out” or “leave-almost-all-data-out” in my cross-validation?

I also agree with Aki Vehtari’s cross-validation FAQ (which is great overall), that the cross-validation scheme you use should be chosen with an eye toward the predictive task you want to handle. If you have a hierarchical model and you expect to never see new groups (e.g. groups are states), leaving out a single observation can make sense. If on the other hand predicting for new groups is essential (e.g. groups are patients), leaving out whole groups is much more reasonable. There’s no such flexibility in Bayes factors.

You say you don’t care about predictions? Well, I subscribe to the view that everything is prediction — i.e. every inference task can be reframed as an equivalent prediction task. Do you want to select a model? Do you want to know the difference between groups? You are implicitly making predictions about future datasets. So I’d suggest you find the prediction task corresponding to your inference goals. Performing well in this task will lead to good inferences and this performance can be often well approximated with cross-validation.

There are also practical considerations: as discussed in the Three ways post, Bayes factors are hard to compute, depend heavily on the choice priors and are hard to interpret. To be fair, cross-validation can be shown to have some issues as the size of the dataset grows to infinity: you need to increase the proportion of held-out data as the data size increases to avoid those (see Grona & Wagenmakers 2019 and the response in Vehtari et al. 2019 for more on this). But I don’t work with datasets that are effectively infinite…

This does not mean that I believe Bayes factor are never useful. I can still imagine scenarios where they may have some merit — if you have two well-specified substantive models with tightly constrained priors (e.g. from previous measurements), you have tons of data and you check that you can compute Bayes factors accurately, then they might provide value. I just think very few people are dealing with such situations.


  1. There are also other ways to relate cross-validation to Bayes factors as discussed e.g. in section 6.1.6 of Bernardo & Smith (1994), but those are not the focus of this post.↩︎


All content is licensed under the BSD 2-clause license. Source files for the blog are available at https://github.com/martinmodrak/blog.