A Gentle Stan vs. INLA Comparison

· 2018/02/02 · 11 minute read

Not long ago, I came across a nice blogpost by Kahtryn Morrison called A gentle INLA tutorial. The blog was nice and helped me better appreciate INLA. But as a fan of the Stan probabilistic language, I felt that comparing INLA to JAGS is not really that relevant, as Stan should - at least in theory - be way faster and better than JAGS. Here, I ran a comparison of INLA to Stan on the second example called “Poisson GLM with an iid random effect”.

The TLDR is: For this model, Stan scales considerably better than JAGS, but still cannot scale to very large model. Also, for this model Stan and INLA give almost the same results. It seems that Stan becomes useful only when your model cannot be coded in INLA.

Pleas let me know (via an issue on GitHub) should you find any error or anything else that should be included in this post. Also, if you run the experiment on a different machine and/or with different seed, let me know the results.

Here are the original numbers from Kathryn’s blog:

N kathryn_rjags kathryn_rinla
100 30.394 0.383
500 142.532 1.243
5000 1714.468 5.768
25000 8610.32 30.077
100000 got bored after 6 hours 166.819

Full source of this post is available at this blog’s Github repo. Keep in mind that installing RStan is unfortunately not as straightforward as running install.packages. Please consult https://github.com/stan-dev/rstan/wiki/RStan-Getting-Started if you don’t have RStan already installed.

The model

The model we are interested in is a simple GLM with partial pooling of a random effect:

y_i ~ poisson(mu_i)
log(mu_i) ~ alpha + beta * x_i + nu_i
nu_i ~ normal(0, tau_nu)

The comparison

Let’s setup our libraries.

library(rstan)
library(brms)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
library(INLA)
library(tidyverse)
set.seed(6619414)

The results are stored in files within the repository to let me rebuild the site with blogdown easily. Delete cache directory to force a complete rerun.

cache_dir = "_stan_vs_inla_cache/"
if(!dir.exists(cache_dir)){
  dir.create(cache_dir)
}

Let’s start by simulating data

#The sizes of datasets to work with
N_values = c(100, 500, 5000, 25000)
data = list()
for(N in N_values) {
  x = rnorm(N, mean=5,sd=1) 
  nu = rnorm(N,0,0.1)
  mu = exp(1 + 0.5*x + nu) 
  y = rpois(N,mu) 
  
  
  data[[N]] = list(
    N = N,
    x = x,
    y = y
  )  
}

Measuring Stan

Here is the model code in Stan (it is good practice to put it into a file, but I wanted to make this post self-contained). It is almost 1-1 rewrite of the original JAGS code, with few important changes:

  • JAGS parametrizes normal distribution via precision, Stan via sd. The model recomputes precision to sd.
  • I added the ability to explicitly set parameters of the prior distributions as data which is useful later in this post
  • With multilevel models, Stan works waaaaaay better with so-called non-centered parametrization. This means that instead of having nu ~ N(0, nu_sigma), mu = alpha + beta * x + nu we have nu_normalized ~ N(0,1), mu = alpha + beta * x + nu_normalized * nu_sigma. This gives exactly the same inferences, but results in a geometry that Stan can explore efficiently.

There are also packages to let you specify common models (including this one) without writing Stan code, using syntax similar to R-INLA - checkout rstanarm and brms. The latter is more flexible, while the former is easier to install, as it does not depend on rstan and can be installed simply with install.packages.

Note also that Stan developers would suggest against Gamma(0.01,0.01) prior on precision in favor of normal or Cauchy distribution on sd, see https://github.com/stan-dev/stan/wiki/Prior-Choice-Recommendations.

model_code = "
  data {
    int N;
    vector[N] x;
    int y[N];
  
    //Allowing to parametrize the priors (useful later)
    real alpha_prior_mean;
    real beta_prior_mean;
    real<lower=0> alpha_beta_prior_precision;
    real<lower=0> tau_nu_prior_shape;
    real<lower=0> tau_nu_prior_rate; 
  }

  transformed data {
    //Stan parametrizes normal with sd not precision
    real alpha_beta_prior_sigma = sqrt(1 / alpha_beta_prior_precision);
  }

  parameters {
    real alpha;
    real beta;
    vector[N] nu_normalized;
    real<lower=0> tau_nu;
  }

  model {
    real nu_sigma = sqrt(1 / tau_nu);
    vector[N] nu = nu_normalized * nu_sigma;

    //taking advantage of Stan's implicit vectorization here
    nu_normalized ~ normal(0,1);
    //The built-in poisson_log(x) === poisson(exp(x))
    y ~ poisson_log(alpha + beta*x + nu); 

    alpha  ~ normal(alpha_prior_mean, alpha_beta_prior_sigma);
    beta  ~ normal(beta_prior_mean, alpha_beta_prior_sigma); 
    tau_nu ~ gamma(tau_nu_prior_shape,tau_nu_prior_rate);
  }

//Uncomment this to have the model generate mu values as well
//Currently commented out as storing the samples of mu consumes 
//a lot of memory for the big models
/*  
  generated quantities {
    vector[N] mu = exp(alpha + beta*x + nu_normalized * nu_sigma);
  }
*/
"

model = stan_model(model_code = model_code)

Below is the code to make the actual measurements. Some caveats:

  • The compilation of the Stan model is not counted (you can avoid it with rstanarm and need to do it only once otherwise)
  • There is some overhead in transferring the posterior samples from Stan to R. This overhead is non-negligible for the larger models, but you can get rid of it by storing the samples in a file and reading them separately. The overhead is not measured here.
  • Stan took > 16 hours to converge for the largest data size (1e5) and then I had issues fitting the posterior samples into memory on my computer. Notably, R-Inla also crashed on my computer for this size. The largest size is thus excluded here, but I have to conclude that if you get bored after 6 hours, Stan is not practical for such a big model.
  • I was not able to get rjags running in a reasonable amount of time, so I did not rerun the JAGS version of the model.
stan_times_file = paste0(cache_dir, "stan_times.csv")
stan_summary_file = paste0(cache_dir, "stan_summary.csv")
run_stan = TRUE
if(file.exists(stan_times_file) && file.exists(stan_summary_file)) {
  stan_times = read.csv(stan_times_file)
  stan_summary = read.csv(stan_summary_file) 
  if(setequal(stan_times$N, N_values) && setequal(stan_summary$N, N_values)) {
    run_stan = FALSE
  }
} 

if(run_stan) {
  stan_times_values = numeric(length(N_values))
  stan_summary_list = list()
  step = 1
  for(N in N_values) {
    data_stan = data[[N]]
    data_stan$alpha_prior_mean = 0
    data_stan$beta_prior_mean = 0
    data_stan$alpha_beta_prior_precision = 0.001
    data_stan$tau_nu_prior_shape = 0.01
    data_stan$tau_nu_prior_rate = 0.01
    
    
    fit = sampling(model, data = data_stan);
    stan_summary_list[[step]] = 
      as.data.frame(
        rstan::summary(fit, pars = c("alpha","beta","tau_nu"))$summary
      ) %>% rownames_to_column("parameter")
    stan_summary_list[[step]]$N = N
    
    all_times = get_elapsed_time(fit)
    stan_times_values[step] = max(all_times[,"warmup"] + all_times[,"sample"])
    
    step = step + 1
  }
  stan_times = data.frame(N = N_values, stan_time = stan_times_values)
  stan_summary = do.call(rbind, stan_summary_list)
  
  write.csv(stan_times, stan_times_file,row.names = FALSE)
  write.csv(stan_summary, stan_summary_file,row.names = FALSE)
}

Measuring INLA

inla_times_file = paste0(cache_dir,"inla_times.csv")
inla_summary_file = paste0(cache_dir,"inla_summary.csv")
run_inla = TRUE
if(file.exists(inla_times_file) && file.exists(inla_summary_file)) {
  inla_times = read.csv(inla_times_file)
  inla_summary = read.csv(inla_summary_file) 
  if(setequal(inla_times$N, N_values) && setequal(inla_summary$N, N_values)) {
    run_inla = FALSE
  }
} 

if(run_inla) {
  inla_times_values = numeric(length(N_values))
  inla_summary_list = list()
  step = 1
  for(N in N_values) {
    nu = 1:N 
    fit_inla = inla(y ~ x + f(nu,model="iid"), family = c("poisson"), 
               data = data[[N]], control.predictor=list(link=1)) 
    
    inla_times_values[step] = fit_inla$cpu.used["Total"]
    inla_summary_list[[step]] = 
      rbind(fit_inla$summary.fixed %>% select(-kld),
            fit_inla$summary.hyperpar) %>% 
      rownames_to_column("parameter")
    inla_summary_list[[step]]$N = N
    
    step = step + 1
  }
  inla_times = data.frame(N = N_values, inla_time = inla_times_values)
  inla_summary = do.call(rbind, inla_summary_list)
  
  write.csv(inla_times, inla_times_file,row.names = FALSE)
  write.csv(inla_summary, inla_summary_file,row.names = FALSE)
}

Checking inferences

Here we see side-by-side comparisons of the inferences and they seem pretty comparable between Stan and Inla:

for(N_to_show in N_values) {
  print(kable(stan_summary %>% filter(N == N_to_show) %>% 
                select(c("parameter","mean","sd")), 
              caption = paste0("Stan results for N = ", N_to_show)))
  print(kable(inla_summary %>% filter(N == N_to_show) %>% 
                select(c("parameter","mean","sd")), 
              caption = paste0("INLA results for N = ", N_to_show)))
}
Table 1: Stan results for N = 100
parameter mean sd
alpha 1.013559 0.0989778
beta 0.495539 0.0176988
tau_nu 162.001608 82.7700473
Table 1: INLA results for N = 100
parameter mean sd
(Intercept) 1.009037e+00 9.15248e-02
x 4.971302e-01 1.61486e-02
Precision for nu 1.819654e+04 1.71676e+04
Table 1: Stan results for N = 500
parameter mean sd
alpha 1.0046284 0.0555134
beta 0.4977522 0.0102697
tau_nu 71.6301530 13.8264812
Table 1: INLA results for N = 500
parameter mean sd
(Intercept) 1.0053202 0.0538456
x 0.4977124 0.0099593
Precision for nu 77.3311793 16.0255430
Table 1: Stan results for N = 5000
parameter mean sd
alpha 1.009930 0.0159586
beta 0.496859 0.0029250
tau_nu 101.548580 7.4655716
Table 1: INLA results for N = 5000
parameter mean sd
(Intercept) 1.0099282 0.0155388
x 0.4968718 0.0028618
Precision for nu 103.1508773 7.6811740
Table 1: Stan results for N = 25000
parameter mean sd
alpha 0.9874707 0.0066864
beta 0.5019566 0.0012195
tau_nu 104.3599424 3.5391938
Table 1: INLA results for N = 25000
parameter mean sd
(Intercept) 0.9876218 0.0067978
x 0.5019341 0.0012452
Precision for nu 104.8948949 3.4415929

Summary of the timing

You can see that Stan keeps reasonable runtimes for longer time than JAGS in the original blog post, but INLA is still way faster. Also Kathryn got probably very lucky with her seed for N = 25 000, as her INLA run completed very quickly. With my (few) tests, INLA always took at least several minutes for N = 25 000. It may mean that Kathryn’s JAGS time is also too short.

my_results = merge.data.frame(inla_times, stan_times, by = "N")
kable(merge.data.frame(my_results, kathryn_results, by = "N"))
N inla_time stan_time kathryn_rjags kathryn_rinla
100 1.061742 1.885 30.394 0.383
500 1.401597 11.120 142.532 1.243
5000 10.608704 388.514 1714.468 5.768
25000 611.505543 5807.670 8610.32 30.077

You could obviously do multiple runs to reduce uncertainty etc., but this post has already taken too much time of mine, so this will be left to others.

Testing quality of the results

I also had a hunch that maybe INLA is less precise than Stan, but that turned out to be based on an error. Thus, without much commentary, I put here my code to test this. Basically, I modify the random data generator to actually draw from priors (those priors are quite constrained to provide similar values of alpha, beta nad tau_nu as in the original). I than give both algorithms the knowledge of these priors. I compute both difference between true parameters and a point estimate (mean) and quantiles of the posterior distribution where the true parameter is found. If the algorithms give the best possible estimates, the distribution of such quantiles should be uniform over (0,1). Turns out INLA and Stan give almost exactly the same results for almost all runs and the differences in quality are (for this particular model) negligible.

test_precision = function(N) {
  rejects <- 0
  repeat {
    #Set the priors so that they generate similar parameters as in the example above
    
    alpha_beta_prior_precision = 5
    prior_sigma = sqrt(1/alpha_beta_prior_precision)
    alpha_prior_mean = 1
    beta_prior_mean = 0.5
    alpha = rnorm(1, alpha_prior_mean, prior_sigma)
    beta = rnorm(1, beta_prior_mean, prior_sigma)
    
    tau_nu_prior_shape = 2
    tau_nu_prior_rate = 0.01
    tau_nu = rgamma(1,tau_nu_prior_shape,tau_nu_prior_rate)
    sigma_nu = sqrt(1 / tau_nu)
    
    x = rnorm(N, mean=5,sd=1) 
    
    
    nu =  rnorm(N,0,sigma_nu)
    linear = alpha + beta*x + nu
    
    #Rejection sampling to avoid NAs and ill-posed problems
    if(max(linear) < 15) {
      mu = exp(linear) 
      y = rpois(N,mu) 
      if(mean(y == 0) < 0.7) {
        break;
      }
    } 
    rejects = rejects + 1
  }
  
  #cat(rejects, "rejects\n")
  
  
  data = list(
    N = N,
    x = x,
    y = y
  )
  #cat("A:",alpha,"B:", beta, "T:", tau_nu,"\n")
  #print(linear)
  #print(data)
  
  #=============== Fit INLA
  nu = 1:N 
  fit_inla = inla(y ~ x + f(nu,model="iid",
                  hyper=list(theta=list(prior="loggamma",
                                        param=c(tau_nu_prior_shape,tau_nu_prior_rate)))), 
                  family = c("poisson"), 
                  control.fixed = list(mean = beta_prior_mean, 
                                       mean.intercept = alpha_prior_mean,
                                       prec = alpha_beta_prior_precision,
                                       prec.intercept = alpha_beta_prior_precision
                                       ),
             data = data, control.predictor=list(link=1)
             ) 
  
  time_inla = fit_inla$cpu.used["Total"]
  
  alpha_mean_diff_inla = fit_inla$summary.fixed["(Intercept)","mean"] - alpha
  beta_mean_diff_inla = fit_inla$summary.fixed["x","mean"] - beta
  tau_nu_mean_diff_inla = fit_inla$summary.hyperpar[,"mean"] - tau_nu
  
  alpha_q_inla = inla.pmarginal(alpha, fit_inla$marginals.fixed$`(Intercept)`)
  beta_q_inla = inla.pmarginal(beta, fit_inla$marginals.fixed$x)
  tau_nu_q_inla = inla.pmarginal(tau_nu, fit_inla$marginals.hyperpar$`Precision for nu`)

  
    
  #================ Fit Stan
  data_stan = data
  data_stan$alpha_prior_mean = alpha_prior_mean
  data_stan$beta_prior_mean = beta_prior_mean
  data_stan$alpha_beta_prior_precision = alpha_beta_prior_precision
  data_stan$tau_nu_prior_shape = tau_nu_prior_shape
  data_stan$tau_nu_prior_rate = tau_nu_prior_rate
  
  fit = sampling(model, data = data_stan, control = list(adapt_delta = 0.95)); 
  all_times = get_elapsed_time(fit)
  max_total_time_stan = max(all_times[,"warmup"] + all_times[,"sample"])

  samples = rstan::extract(fit, pars = c("alpha","beta","tau_nu"))
  alpha_mean_diff_stan = mean(samples$alpha) - alpha
  beta_mean_diff_stan = mean(samples$beta) - beta
  tau_nu_mean_diff_stan = mean(samples$tau_nu) - tau_nu
  
  alpha_q_stan = ecdf(samples$alpha)(alpha)
  beta_q_stan = ecdf(samples$beta)(beta)
  tau_nu_q_stan = ecdf(samples$tau_nu)(tau_nu)
  
  return(data.frame(time_rstan = max_total_time_stan,
                    time_rinla = time_inla,
                    alpha_mean_diff_stan = alpha_mean_diff_stan,
                    beta_mean_diff_stan = beta_mean_diff_stan,
                    tau_nu_mean_diff_stan = tau_nu_mean_diff_stan,
                    alpha_q_stan = alpha_q_stan,
                    beta_q_stan = beta_q_stan,
                    tau_nu_q_stan = tau_nu_q_stan,
                    alpha_mean_diff_inla = alpha_mean_diff_inla,
                    beta_mean_diff_inla = beta_mean_diff_inla,
                    tau_nu_mean_diff_inla = tau_nu_mean_diff_inla,
                    alpha_q_inla= alpha_q_inla,
                    beta_q_inla = beta_q_inla,
                    tau_nu_q_inla = tau_nu_q_inla
                    ))
}

Actually running the comparison. On some occasions, Stan does not converge, my best guess is that the data are somehow pathological, but I didn’t investigate thoroughly. You see that results for Stan and Inla are very similar both as point estimates and the distribution of posterior quantiles. The accuracy of the INLA approximation is also AFAIK going to improve with more data.

library(skimr) #Uses skimr to summarize results easily
precision_results_file = paste0(cache_dir,"precision_results.csv")
if(file.exists(precision_results_file)) {
  results_precision_df = read.csv(precision_results_file)
} else {
  results_precision = list()
  for(i in 1:100) {
    results_precision[[i]] = test_precision(50)
  }
  
  results_precision_df = do.call(rbind, results_precision)
  write.csv(results_precision_df,precision_results_file,row.names = FALSE)
}

#Remove uninteresting skim statistics
skim_with(numeric = list(missing = NULL, complete = NULL, n = NULL))

skimmed = results_precision_df %>% select(-X) %>% skim() 
#Now a hack to display skim histograms properly in the output:
skimmed_better = skimmed %>% rowwise() %>%  mutate(formatted = 
     if_else(stat == "hist", 
         utf8ToInt(formatted) %>% as.character() %>% paste0("&#", . ,";", collapse = ""), 
         formatted))  
mostattributes(skimmed_better) = attributes(skimmed)

skimmed_better %>% kable(escape = FALSE)

Skim summary statistics
n obs: 100
n variables: 14

Variable type: numeric

variable mean sd p0 p25 p50 p75 p100 hist
alpha_mean_diff_inla -0.0021 0.2 -0.85 -0.094 0.0023 0.095 0.53 ▁▁▁▂▇▇▁▁
alpha_mean_diff_stan -0.0033 0.2 -0.84 -0.097 -0.00012 0.093 0.52 ▁▁▁▂▇▇▁▂
alpha_q_inla 0.5 0.29 0.00084 0.25 0.5 0.73 0.99 ▅▇▇▆▇▆▆▇
alpha_q_stan 0.5 0.28 0.001 0.26 0.5 0.73 0.99 ▅▇▇▆▇▆▆▇
beta_mean_diff_inla -0.00088 0.04 -0.12 -0.016 -0.001 0.014 0.17 ▁▁▃▇▂▁▁▁
beta_mean_diff_stan -0.001 0.04 -0.12 -0.016 -5e-04 0.014 0.16 ▁▁▂▇▂▁▁▁
beta_q_inla 0.51 0.28 0.0068 0.26 0.52 0.75 1 ▆▆▅▆▇▅▆▆
beta_q_stan 0.51 0.28 0.0065 0.27 0.51 0.75 1 ▆▆▅▇▆▅▆▆
tau_nu_mean_diff_inla 4.45 90.17 -338.58 -26.74 4.49 53.38 193 ▁▁▁▂▅▇▃▂
tau_nu_mean_diff_stan 5.21 90 -339.89 -24.62 4.29 54.48 191.94 ▁▁▁▂▅▇▃▂
tau_nu_q_inla 0.53 0.26 0.023 0.32 0.52 0.74 0.99 ▃▅▆▆▇▆▅▅
tau_nu_q_stan 0.53 0.26 0.021 0.32 0.53 0.75 0.99 ▃▅▅▆▇▃▅▅
time_rinla 0.97 0.093 0.86 0.91 0.93 0.98 1.32 ▇▇▂▁▁▁▁▁
time_rstan 1.79 1.4 0.55 0.89 1.45 2.09 10.04 ▇▂▁▁▁▁▁▁

All content is licensed under the BSD 2-clause license. Source files for the blog are available at https://github.com/martinmodrak/blog.

This site uses google analytics to find out how many people visit it and some related information. You can opt out of google analytics by installing a browser add-on.