Is the LKJ(1) prior uniform? “Yes”

This post is about how the LKJ(1) prior is indeed "uniform", but perhaps not in the way people expect it to be. To be clear – This result is not a new one, nor is this post novel, because this is a well-known result. However, I was long confused by it, and I see others confused by it as well.

The LKJ prior is a spherical prior on correlation matrices. I emphasize matrices because it is not a marginal prior on individual elements. It is commonly used in Stan, in part due to performance reasons, and in part because specifying priors on correlations separately from variances is much easier. This latter point is especially useful when estimating multivariate scale models, wherein variances can vary, but we may want correlations to remain constant or vary according to a different model.

The LKJ prior takes one parameter, $\eta \in \mathbb{R}^+$. As $\eta$ increases, more mass is placed over identity matrices (i.e., a diagonal of one; little to no correlation). When $\eta = 1$, the LKJ prior is a uniform prior over correlation matrices. Because people (myself included) are, in general, poor at visualizing and intuitively understanding multidimensional space beyond three dimensions — One may think that the elements of the correlation matrices should therefore be uniformly distributed.

However, when you sample from the LKJ(1) prior with $K > 2$ variables (i.e., more than one correlation, a $K>2 \times K>2$ matrix), and look at the marginal distributions of the elements, they are not distributed uniformly. What gives?

The answer lies in the constraints of the correlation matrix. Correlation matrices must be symmetric, and positive semi-definite (PSD). That PSD constraint alters where probability mass can exist, and where it will accumulate marginally. When you only have $K=2$ variables, this is not an issue. But when you have $K>2$ variables, then one correlation’s value constrains what other correlations can be, if the psd constraint is to be met. I.e., you cannot create a $K\times K$ correlation matrix, and fill the off-diagonals with uniform(-1,1) values and expect it to be PSD. As K increases, the chances of creating a correlation matrix from uniformly distributed elements that is also PSD rapidly increases near-zero.

The LKJ(1) prior can therefore be thought of as a "uniform prior on permissible correlation matrices"; that is, it is uniform subject to the PSD constraint.

To build that intuition, I randomly generated 50,000, $K = 4$, "correlation matrices", wherein each unique off-diagonal element is sampled from a uniform(-1,1) distribution. For each matrix, I checked whether the PSD constraint was met. I then resampled each invalid matrix until the PSD constraint was met[^1].

I then plotted the marginal distribution of $\rho_{1,2}$. For comparison, I also sampled from LKJ(1) using rethinking::rlkjcorr. The plot below shows the two overlaid. In red, I show the $\rho_{1,2}$ from the resampling method. In blue, I show the $\rho_{1,2}$ from the LKJ(1) distribution.

Marginal distribution

The two are effectively identical (within sampling error). As K increases, the PSD constraint will cause the marginal distributions to be increasingly concentrated over zero. It is hard to say that it is "regularizing more" as dimensionality increases, but rather that in large matrices, the space in which the PSD constraint is met, is considerably smaller than when dimensionality is low.

You may want to know what the marginal priors for each correlation are, given a uniform LKJ and K variables. A useful result is provided, as always, by Ben Goodrich. He stated:

In general, when there are K variables, then the marginal distribution of a single correlation is Beta on the (-1,1) interval with both shape parameters equal to K / 2.

We can examine this nifty result. I took the PSD-enforced samples of $\rho_{1,2}$, and transformed them to be in the [0, 1] domain ($\tilde\rho_{1,2}^s = .5\rho_{1,2}^s + .5$, for each $s \in S$ samples). I then plotted the $\text{beta}(x| 4/2, 4/2)$ density on top of the normalized histogram. This is the result:

Marginal Transformed with Beta overlay

Indeed, the marginal distribution matches perfectly with the $Beta(\frac{K}{2}, \frac{K}{2})$ distribution.

To recap: I generated the elements of a correlation matrix from a uniform(-1, 1) distribution. I then enforced an PSD constraint by resampling the matrix until PSD was met. The marginal distributions of the correlations are then equivalent to those from the LKJ(1) distribution. In other words: The LKJ(1) distribution is indeed a "uniform" prior over matrices, subject to the constraint that the matrix must be positive semi-definite. This constraint disallows the marginals from being uniformly distributed.

The LKJ prior is not the only possible prior for correlations. The Wishart can be used, provided that you divide out the diagonals. If you want to better control the marginal priors of the elements, a better option may be the matrix-F distribution — You can get effectively uniform marginal priors for the correlations using this, but it is not standard (not too hard to implement though).

R Code

rlkj_naive <- function(N, k, enforce_spd = TRUE, as_samples = TRUE) {
    n_cors <- k*(k - 1)/2
    out <- list()
    gen_unif_matrix <- function(n_cors, k) {
        cors <- runif(n_cors, -1, 1)
        mat <- diag(1, k, k)
        mat[lower.tri(mat)] <- cors
        mat <- as.matrix(Matrix::forceSymmetric(mat, uplo = "L"))
        return(mat)
    }
    for(n in 1:N) {
        out[[n]] <- gen_unif_matrix(n_cors, k)
    }

    if(enforce_spd) {
        is_not_spd <- !sapply(out, matrixcalc::is.positive.semi.definite)
        if(any(is_not_spd)) {
            mean(is_not_spd)
            warning("PSD not met; rejection sampling. ", mean(is_not_spd), " of the mats are not PSD.")
        }
        while(any(is_not_spd)) {
            n_not_spd <- sum(is_not_spd)
            out[is_not_spd] <- replicate(n_not_spd, gen_unif_matrix(n_cors, k), simplify = FALSE)
            is_not_spd <- !sapply(out, matrixcalc::is.positive.semi.definite)
        }
    }
    
    if(as_samples) {
        out <- t(sapply(out, function(x){x[lower.tri(x)]}))
    }
    return(out)
}

cors <- rlkj_naive(50000, 4)
cors_better <- rethinking::rlkjcorr(50000, 4, 1)

png("psdResampling_vs_lkj.png")
hist(cors[,1], col = rgb(1,0,0,.5), xlab = expression(rho[12]), freq = FALSE, main = "Marginal")
hist(cors_better[,2,1], add = TRUE, col = rgb(0,0,1,.5), freq = FALSE)
dev.off()

png("corTrans_vs_beta.png")
cor12_trans <- (cors[,1] / 2) + .5
hist(cor12_trans, freq = FALSE, xlab = expression(tilde(rho)[12]), main = "Marginal (transformed)")
curve(dbeta(x, 4/2, 4/2), 0, 1, add = TRUE)
dev.off()

Matrix F distribution (For those who are curious)

$E(V) = \nu/(\delta – 2)*B$

$\nu, \delta$ control the shape. B scales it. For correlations, you can sample from matrix F, and divide out the diagonal to produce a correlation matrix. If $\nu = K$, where K is the cov dimension, and $\delta = 3$ (the minimum value to have a defined mean), then you have a fairly uniform prior on the correlations.

  real matrix_f_lpdf(matrix cov, real nu, real delta){
    int k = cols(cov);
    return(lmgamma(k, (nu + delta + k - 1)/2) - (lmgamma(k, nu/2) + lmgamma(k, (delta + k - 1)/2)) + log_determinant(cov)*((nu -k - 1)/2) - (nu + delta + k - 1)/2 * log_determinant(cov + diag_matrix(rep_vector(1, k))));
  }
  
  real matrix_f_fast_lpdf(matrix cov, real nu, real delta){
    int k = cols(cov);
    real log_det_cov = 2*sum(log(diagonal(cholesky_decompose(cov))));
    real I_Sig_log_det = 2*sum(log(diagonal(cholesky_decompose(diag_matrix(rep_vector(1,k)) + cov))));
    return(log_det_cov*(nu - k - 1)/2 - I_Sig_log_det*(nu + delta + k - 1)/2);
    // return(log_determinant(cov)*(nu - k - 1)/2 + log_determinant(cov + diag_matrix(rep_vector(1,k)))*-(nu + delta + k - 1)/2);
  }
  
  real f_lpdf(real v, real nu, real delta){
    return((nu/2 - 1)*log(v) - (nu + delta)/2*log(1 + v));
  }
  
  real matrix_f_fast_cholesky_lpdf(matrix L, real nu, real delta){
    int k = cols(L);
    vector[k] L_diag = diagonal(L);
    real log_det_cov = 2*sum(log(L_diag));
    real log_jac = k*log(2);
    real I_Sig_log_det;
    for(i in 1:k){
      log_jac += (k - i + 1)*log(L_diag[i]);
    }
    // I_Sig_log_det = log_determinant(diag_matrix(rep_vector(1,k)) + multiply_lower_tri_self_transpose(L));
    I_Sig_log_det = 2*sum(log(diagonal(cholesky_decompose(diag_matrix(rep_vector(1,k)) + multiply_lower_tri_self_transpose(L)))));
    return(log_jac + log_det_cov*((nu - k - 1)/2) - ((nu + delta + k - 1)/2)*I_Sig_log_det);
  }
  real matrix_f_B_fast_cholesky_lpdf(matrix L, real nu, real delta,real B){
    int k = cols(L);
    vector[k] L_diag = diagonal(L);
    real log_det_cov = 2*sum(log(L_diag));
    real log_jac = k*log(2);
    matrix[k,k] BmatInv = diag_matrix(rep_vector(1/B,k));
    real I_Sig_log_det;
    for(i in 1:k){
      log_jac += (k - i + 1)*log(L_diag[i]);
    }
    I_Sig_log_det = 2*sum(log(diagonal(cholesky_decompose(diag_matrix(rep_vector(1,k)) + multiply_lower_tri_self_transpose(L)*BmatInv))));
    return(log_jac + log_det_cov*((nu - k - 1)/2) - ((nu + delta + k - 1)/2)*I_Sig_log_det);
  }

Footnotes

[^1]: This took a long time. It is highly improbable to find an PSD matrix, even when the matrix is a measly 4×4, when constructing a symmetric matrix with uniformly distributed elements. For reference, of the 50,000 initial matrices, 81.58% of them were not PSD.

Leave a Reply