Neural Networks in Stan: Or how I was utterly surprised that it worked at all.

Feed-forward neural networks are a staple in machine learning. The basic feed-forward NN (which I’ll just call a NN from here on out) is a relatively simple idea.

The tale is as old as (statistical) time: You have a set of "features" (covariates) and you want to predict an outcome. This outcome might be continuous, or it might be categorical (for classification tasks). However, you may not care whatsoever about interpreting parameters to the model, you just want good predictions (or, "inferences", depending on your background). Stated alternatively, you want to estimate a function that takes some input and spits out a relatively accurate output.

The set of predictive models is enormous. Given my Bayesian background, I tend to prefer principled approaches, such as the Gaussian process models, or Dirichlet process models, etc, for function-finding. Look no further than my omegad package for evidence of that. There, I use a GP to model the relationship between exogenous variables, latent variables, and a latent measurement error term to get non-parametric, non-linear predictions of reliability! In a forthcoming package called mires, I use Dirichlet processes to estimate some terrible posterior density functions from MCMC output for Bayes factor computation.

But I digress, you are all here for neural nets in Stan.

NN in a nutshell

A neural network takes features, and projects them into a latent space. It does this via a multivariate linear model. Then those linear variates are fed through an activation function, which maps the linear variate to a constrained space. Some examples are the inverse logit function, tanh, and rectified linear units (ReLU).

You then take those activation outputs, and project them again into a second latent space (a second hidden layer). You again run these through activation functions. Rinse and repeat through the hidden layers until you arrive at the output stage.

In the output stage, you project the last hidden layer’s values into your prediction, again using a (possibly multivariate) linear model. If you are doing classification, then this last stage involves a multinomial logistic regression. If you are doing regression, this this last stage is a linear model prediction.

To make it more concrete, let’s look at formulas. I am going to describe it as a multivariate logistic model; however, in the machine learning world, they often work with all these things transposed. I am also going to fit a rectangular NN, such that all hidden layers have the same number of neurons (or nodes, or linear outputs).

Let $H_{it}$ be the $t$-th hidden layer $X_i$ is a vector of inputs. $Y_i$ is the output. $W_t$ is the linear coefficient matrix for the $t$-th hidden layer; or "weights". $b_t$ is the row vector of intercepts, or "biases", for the $t$-th hidden layer. $n_H$ is the number of neurons per hidden layer.

Step 1: Project $X_i$ into latent space, and call activation function $f$ on each element: $$ H_{i1} = f(X_i W_1 + b_1) $$ $H_{i1}$ now contains the $n_H$ activations. I use the inverse logit function for this, so that each element in $H_{i1}$ is squashed to be between 0 and 1.

Step 2: Project the neural outputs onto a second layer: $$ H_{i2} = f(H_{i1} W_2 + b_2) $$

Step 3 through T: Continue this process until the final layer. $$ H_{it} = f(H_{i, t – 1}W_t + b_t) $$

Step T+1: Generate output predictions.

How this step is implemented depends on the task. For classification, you are predicting multinomial logits (usually with reference to a base class). $$ \begin{align*} \hat y_\text{logit} &= H_T W_\text{out} + b_\text{out} \\ \hat y_\text{prob} &= \text{softmax}(\hat y_\text{logit}) \\ y &\sim Categorical(\pi = \hat y_\text{prob}) \tag{If using Bayes} \end{align*} $$ For regression, it is just a linear model: $$ \begin{align*} \hat y &= H_T W_\text{out} + b_\text{out} \\ y &\sim \mathcal{N}(\hat y, \sigma) \tag{If using Bayes} \end{align*} $$

That’s it. That’s the model. Then for predictions, you just feed in new $X_i$ values and carry through the matrix multiplication and activations. This is, quite literally, a multivariate logistic regression model in the hidden layers, then a multinomial regression for classification, or a linear model for regression.

There are all sorts of tweaks you can include if you choose. Perhaps you don’t want a normal likelihood (assuming you’re using a probabilistic loss at all). Perhaps you want to model both the $\mu$ and $\sigma$ parameters from a NN (given my work in location scale models, this is an interesting idea to me). You could embed this into larger models to jointly include random effects and latent variables into it. And this doesn’t even begin to get into, e.g., dropout (i.e., Bayesian spike and slab) or regularization.

Bayesian models need priors (alternatively stated, your model deserves priors). I was lazy, I just threw standard normal priors on all weights and biases. In effect, this is similar to L2 regularization. For sampling, I also assume biases within each layer are ordered, because neural nets are notoriously overparameterized and underidentified.

And this is why I was surprised it worked at all in Stan. Bayesian NNs are extremely hard, because the posterior is intractably aliased. All hidden layers contain multiple neurons, and it does not much matter which neuron does what. In essence, it’s a label-switching problem as seen in mixture modeling, but taken to a new extreme.

When estimating NNs, the MCMC chains should either oscillate between identical solutions within each chain; or each chain should find a different mode. I am sure that my implementation would simply not sample well for even moderately large NNs. Stan is just not optimized for this problem; it would take until the heat death of the universe to fit a truly large, deep NN, and you’d have long exhausted the memory needed for it. But in my small regime of say, two layers each with 50 nodes, it sampled fine! More specifically, the posterior predictive distribution sampled fine. However, sometimes the NN parameters even sampled fine, much to my utter surprise.

The classification model

functions {
  vector[] nn_predict(matrix x, matrix d_t_h, matrix[] h_t_h, matrix h_t_d, row_vector[] hidden_bias, row_vector y_bias) {
    int N = rows(x);
    int n_H = cols(d_t_h);
    int H = size(hidden_bias);
    int num_labels = cols(y_bias) + 1;
    matrix[N, n_H] hidden_layers[H];
    vector[num_labels] output_layer_logit[N];
    vector[N] ones = rep_vector(1., N);

    hidden_layers[1] = inv_logit(x * d_t_h + ones * hidden_bias[1]);
    for(h in 2:H) {
      hidden_layers[h] = inv_logit(hidden_layers[h-1] * h_t_h[h - 1] + ones * hidden_bias[h]);
    }
    for(n in 1:N) {
      output_layer_logit[n, 1] = 0.0;
      output_layer_logit[n, 2:num_labels] = (hidden_layers[H, n] * h_t_d + y_bias)';
    }
    return(output_layer_logit);
  }
}

data {
  int N; // Number of training samples
  int P; // Number of predictors (features)
  matrix[N, P] x; // Feature data
  int labels[N]; // Outcome labels
  int H; // Number of hidden layers
  int n_H; // Number of nodes per layer (All get the same)

  int N_test; // Number of test samples
  matrix[N_test, P] x_test; // Test predictors
}

transformed data {
  int num_labels = max(labels); // How many labels are there
}

parameters {
  matrix[P, n_H] data_to_hidden_weights; // Data -> Hidden 1
  matrix[n_H, n_H] hidden_to_hidden_weights[H - 1]; // Hidden[t] -> Hidden[t+1]
  matrix[n_H, num_labels - 1] hidden_to_data_weights; // Hidden[T] -> Labels. Base class gets 0.
  // ordered[n_H] hidden_bias[H]; // Use ordered if using NUTS
  row_vector[n_H] hidden_bias[H]; // Hidden layer biases
  row_vector[num_labels - 1] labels_bias; // Labels biases. Base class gets 0.
}

transformed parameters {
  vector[num_labels] output_layer_logit[N]; // Predicted output layer logits

  output_layer_logit = nn_predict(x,
                                  data_to_hidden_weights,
                                  hidden_to_hidden_weights,
                                  hidden_to_data_weights,
                                  hidden_bias,
                                  labels_bias);

}

model {
  // Priors
  to_vector(data_to_hidden_weights) ~ std_normal();

  for(h in 1:(H-1)) {
    to_vector(hidden_to_hidden_weights[h]) ~ std_normal();
  }

  to_vector(hidden_to_data_weights) ~ std_normal();

  for(h in 1:H) {
    to_vector(hidden_bias[h]) ~ std_normal();
  }
  labels_bias ~ std_normal();

  for(n in 1:N) { // Likelihood
    labels[n] ~ categorical_logit(output_layer_logit[n]);
  }
}

generated quantities {
  vector[num_labels] output_layer_logit_test[N_test] = nn_predict(x_test,
							   data_to_hidden_weights,
							   hidden_to_hidden_weights,
							   hidden_to_data_weights,
							   hidden_bias,
							   labels_bias);
  matrix[N_test, num_labels] output_test;
  for(n in 1:N_test) {
    output_test[n] = softmax(output_layer_logit_test[n])';
  }
}

The regression model

functions {
  vector nn_predict(matrix x, matrix d_t_h, matrix[] h_t_h, vector h_t_d, row_vector[] hidden_bias, real y_bias) {
    int N = rows(x);
    int n_H = cols(d_t_h);
    int H = size(hidden_bias);
    matrix[N, n_H] hidden_layers[H];
    vector[N] output_layer;
    vector[N] ones = rep_vector(1., N);

    hidden_layers[1] = inv_logit(x * d_t_h + ones * hidden_bias[1]);
    for(h in 2:H) {
      hidden_layers[h] = inv_logit(hidden_layers[h-1] * h_t_h[h - 1] + ones * hidden_bias[h]);
    }
    output_layer = hidden_layers[H] * h_t_d + y_bias;
    return(output_layer);
  }		
}

data {
  int N; // Number of training samples
  int P; // Number of predictors (features)
  matrix[N, P] x; // Feature data
  vector[N] y; // Outcome
  int H; // Number of hidden layers
  int n_H; // Number of nodes per layer (All get the same)

  int N_test; // Number of test samples
  matrix[N_test, P] x_test; // Test predictors
}

transformed data {
}

parameters {
  matrix[P, n_H] data_to_hidden_weights; // Data -> Hidden 1
  matrix[n_H, n_H] hidden_to_hidden_weights[H - 1]; // Hidden[t] -> Hidden[t+1]
  vector[n_H] hidden_to_data_weights;
  // ordered[n_H] hidden_bias[H]; // Use ordered if using NUTS
  row_vector[n_H] hidden_bias[H]; // Hidden layer biases
  real y_bias; // Bias. 
  real<lower=0> sigma;
}

transformed parameters {
  vector[N] output_layer;

  output_layer = nn_predict(x,
                            data_to_hidden_weights,
                            hidden_to_hidden_weights,
                            hidden_to_data_weights,
                            hidden_bias,
                            y_bias);

}

model {
  // Priors
  to_vector(data_to_hidden_weights) ~ std_normal();

  for(h in 1:(H-1)) {
    to_vector(hidden_to_hidden_weights[h]) ~ std_normal();
  }

  to_vector(hidden_to_data_weights) ~ std_normal();

  for(h in 1:H) {
    to_vector(hidden_bias[h]) ~ std_normal();
  }
  y_bias ~ std_normal();

  sigma ~ std_normal();


  y ~ normal(output_layer, sigma);
}

generated quantities {
  vector[N_test] output_test = nn_predict(x_test,
				          data_to_hidden_weights,
				          hidden_to_hidden_weights,
				          hidden_to_data_weights,
				          hidden_bias,
				          y_bias);
  vector[N_test] output_test_rng;
  for(n in 1:N_test) {
    output_test_rng[n] = normal_rng(output_test[n], sigma);
  }
}

Fitting the categorical model

As is typical, you need to compile the stan model. I also include a (quickly hacked together, messy) function to fit the model and give some sane output (Stan’s optimizing output is… less than stellar).

library(rstan)
library(magrittr)
sm <- stan_model("nn_cat.stan")

fit_nn_cat <- function(x_train, y_train, x_test, y_test, H, n_H, method = "optimizing", ...) {
    stan_data <- list(
        N = nrow(x_train),
        P = ncol(x_train),
        x = x_train,
        labels = y_train,
        H = H,
        n_H = n_H,
        N_test = length(y_test)
    )
    if(method == "optimizing") {
        optOut <- optimizing(sm, data = stan_data)
        test_char <- paste0("output_test[",1:length(y_test), ",",rep(1:max(y_train), each = length(y_test)),"]") 
        y_test_pred <- matrix(optOut$par[test_char], stan_data$N_test, max(y_train))
        y_test_cat <- apply(y_test_pred, 1, which.max)
        out <- list(y_test_pred = y_test_pred,
                    y_test_cat = y_test_cat,
                    conf = table(y_test_cat, y_test),
                    fit = optOut)
        return(out)
    } else if(method == "sampling") {
        out <- sampling(sm, data = stan_data, pars = "output_test", ...)
        return(out)
    } 
}

Iris

Now let’s get the Iris data, split into a training and testing sample, and fit.

data(iris)
x <- iris[,1:4]
y <- as.numeric(as.factor((iris[,"Species"])))

N_test <- 50
test_indices <- sample(1:nrow(x), N_test)
x_train <- x[-test_indices,]
y_train <- y[-test_indices]
x_test <- x[test_indices,]
y_test <- y[test_indices]

I just used two hidden layers, each with 50 neurons. Fewer neurons seems to be fine too. Adding more layers seems to break prediction considerably in optimization, and sampling takes ages due to the rather insane increase of parameters that new layers bring.

fit_opt <- fit_nn_cat(x_train, y_train, x_test, y_test, 2, 50, method = "optimizing")
fit_opt$conf

Let’s see how we did:

> fit_opt$conf
          y_test
y_test_cat  1  2  3
         1 16  0  0
         2  0 13  0
         3  0  0 21

Nice! Our held-out sample was classified perfectly.

Let’s try with NUTS/MCMC:

fit_nuts <- fit_nn_cat(x_train, y_train, x_test, y_test, 2, 50, method = "sampling", cores = 4, iter = 1000)

And… let’s see how we did:

cat_nuts <- summary(fit_nuts)$summary[,"mean"] %>%
                            matrix(N_test, 3, byrow = TRUE) %>%
                            apply(1, which.max)
table(cat_nuts, y_test)

That terribly named variable name contains the classifications of each test output, based on the posterior expectation (not the mode) of their predicted class probabilities.

After about 30 seconds, we get:

>         y_test
cat_nuts  1  2  3
       1 16  0  0
       2  0 13  0
       3  0  0 21

Which is great!

The Iris dataset is fairly easy to predict, because the labels are well-separated in the feature space. Let’s give another classic a go.

MNIST

The MNIST dataset contains handwritten digits from 0 to 9. The dataset I import has all the pixel data, flattened into rows, for each image. It then has the true label as the last column. Altogether, this dataset has 70,000 images, and 784 features (one for each pixel in a 28×28 px image).

Because Stan is, and I cannot emphasize this enough, not built for neural nets, I cannot estimate a large NN, nor can I feed it a lot of data. When you have ‘bigger’ data and lots of matrix multiplication, you really need GPU acceleration and multithreaded computations. Stan now has some GPU acceleration, and some support for multithreading; rstan (2.21) does not currently have GPU acceleration, and its support for multithreading depends on map_rect, which would be awful to implement for this model.

Therefore, I’m only going to optimize, not sample, and with only 5000 training, 1000 test samples. Moreover, I’ll only use 2 hidden layers, each with 30 neurons.

library(snedata) # https://github.com/jlmelville/snedata

mnist <- download_mnist()

x <- mnist[,c(-ncol(mnist))]
y <- as.numeric(mnist[,"Label"])

N_train <- 5000
N_test <- nrow(x) - N_train
N_test <- 1000

x_train <- head(x, N_train)
y_train <- head(y, N_train)
x_test <- tail(x, N_test)
y_test <- tail(y, N_test)

fit <- fit_nn_cat(x_train, y_train, x_test, y_test, 2, 30)
fit$conf

After waiting for a bit, we get:

# Classification matrix for MNIST
          y_test
y_test_cat   1   2   3   4   5   6   7   8   9  10
        1   97   0   5   0   0   1   1   0   0   0
        2    0 115   0   0   0   0   0   0   2   0
        3    0   2  85   2   1   3   1   5   0   0
        4    0   0   2  94   0   4   0   0   1   0
        5    1   0   0   0  91   1   2   0   1  10
        6    2   0   1   1   0  72   2   0   5   0
        7    1   0   0   0   0   4  95   0   4   0
        8    0   0   1   1   0   0   0 109   0   5
        9    1   2   4   3   0   0   1   0  81   0
        10   0   0   1   1   0   0   0   1   0  75

This is an accuracy of .914. For only using around 7% of the data, and a relatively small NN, this is decent.

Summary

Stan was used to estimate relatively small neural nets on the Iris and MNIST datasets. Despite Stan not being built for it, it can technically do it. I am not saying you should do it, but as a heavy Stan user, and a moderate neural net user, it was nice to see that Stan didn’t collapse in on itself like a dying star when trying to fit one.

In practice, if you really wanted to use Stan for optimizing NNs, I would suggest doing so at the C++ level, and implementing a gradient descent approach with Stan’s autodiff. The optimizing algorithm used in Stan is a Quasi-Newton method – L-BFGS. A great optimizer, no doubt, but not well-suited for extremely high-dimensional models with irregular loss surfaces. Gradient descent approaches, by contrast, are fairly cheap, and probably easier to parallelize. Primarily though, you really need GPU acceleration to make quick work of NNs; here’s to looking forward to more GPU compute from the Stan team.

For sampling neural nets, I wouldn’t recommend it. The Stan team wouldn’t recommend it either. It’s just an intractable posterior. On the Iris dataset, I was lucky to have acceptable $\hat R$ values, even for the NN parameters. The ordered bias constraint (and priors) should not be enough to identify the NN parameters. Nevertheless, I did shockingly obtain convergence across the whole set of parameters.

Fitting a NN regression

Let’s first generate data from a somewhat ugly function. We’ll only work with a univariate function for now, because visualization is simple.

N <- 1000
x <- rnorm(1000)
sigma <- .5

y_normal_f <- function(x) {2 * x^2 - 2*x + 3*cos(3*x) - sqrt(abs(x) / 10) + .2 * (x - 3)^2 - 10}
y <- y_normal_f(x) + rnorm(N,0, sigma)

When plotted, this data is:

No doubt that a GP could make quick work of that. But today is neural net day, so let’s fit a NN regression.

Let’s split our data again into train and test sets. Then we’ll optimize a 2-layer, 10-neuron NN regression model.

x <- matrix(x, N, 1)

N_test <- 200 
test_indices <- sample(1:nrow(x), N_test)
x_train <- x[-test_indices,, drop = FALSE]
y_train <- y[-test_indices]
x_test <- x[test_indices,, drop = FALSE]
y_test <- y[test_indices]

fit_opt <- fit_nn_reg(x_train, y_train, x_test, y_test, 2, 10, method = "optimize")

The mean squared error was .275, and the predictions look spot on.

This plot contains the true function (in purple), and the test points. The blue line represents the (modal) predicted values, and the dashed lines are just the prediction $\pm 1.96 \sigma$. Therefore, the dashed lines are modal prediction intervals.

plot(x_test, y_test, xlab = "X", ylab = "Y")
curve(y_normal_f, col = "purple", add = TRUE)
lines(sort(x_test), fit_opt$y_test_pred[order(x_test)], col = "blue")
lines(sort(x_test), fit_opt$y_test_pred[order(x_test)] - fit_opt$sigma*1.96, lty = "dashed")
lines(sort(x_test), fit_opt$y_test_pred[order(x_test)] + fit_opt$sigma*1.96, lty = "dashed")
legend("topright", legend = c("True", "NN"), lty = "solid", col = c("purple", "blue"))

Optimizing the NN therefore did a decent job, and only took around 4 seconds. It also recovered the $\sigma$ parameter well; the true value was .5, the estimated value was .51.

Now let’s use NUTS/MCMC to sample the posterior:

fit_nuts <- fit_nn_reg(x_train, y_train, x_test, y_test, 2, 10, method = "sampling", cores = 4, iter = 300)

Of no surprise, sampling takes longer at 7 minutes, for only a total of 600 posterior samples.

Let’s plot the predictions now. Note that we do not have modal estimates, but expectations. And by the magic of probabilistic inference, we have distributions across parameters and predictions. Therefore, the uncertainty expressed in the predictive plot includes uncertainty in the neural net parameters and in the realizations. The dotted line represents the 95% predictive interval — This includes uncertainty in the NN parameters, the $\sigma$ parameter, and the realizations. The dashed line represents the 95% fitted interval — This only includes uncertainty in the function itself (i.e., the NN parameters).

Again — NUTS did well here. There is not much uncertainty in the expected value, which is why the predictive intervals look so similar to those of the optimized solution.

The $\hat R$ values for the predicted quantities are all acceptable (i.e., $\approx 1.0$). This is not unexpected from unidentified models. In mixtures or other unidentifiable models, the aliased or otherwise unidentified parameters can have terrible convergence, but functions of those parameters can be identified. E.g., a model with $a = b + c$ may be unidentified because the values of b and c can be interchanged without affecting a: 10 = 5 + 5, 4 + 6, -10 + 20, etc. Without prior information, b and c are therefore unidentified. However, the MCMC samples of a can nevertheless be convergent.

That is the case here. Even if the NN parameters are not identified, the resulting predictions remain the same (in a sense, this is why it is unidentified). Therefore, the model converged with respect to the NN predictions.

Summary

I wrote out a quick Bayesian neural network model for both regression and classification. Despite Stan not being built for such models, it performed well! Optimization was fairly quick, and NUTS converged with respect to the predictions.

Would I use Stan for NNs?

No. No I wouldn’t. This was a fun, toy exercise for an afternoon, just to see if it could be done. I’ve had good luck with mxnet and JAX for actually estimating NNs; and these frameworks are general enough that you can still add priors (penalties) on the parameters if you want to (look at the "loss" function as an unnormalized negative log posterior density for why).

What would I use in Stan instead?

Gaussian processes, where data permit. Approximate Gaussian processes (as I did in omegad) are fairly quick. GAMs would also work well in these examples. Stan really shines in generative modeling contexts, to build custom models for processes of interest; it lets you build things like multivariate mixed effects location scale models with latent variables, interactions, non-linearities, etc. It lets you estimate models with structures and assumptions that you simply can’t do otherwise.

But Stan is really not meant for big data, nor for models with inherently bad posteriors like neural nets. But as a testament to the robustness of Stan, it still estimated a neural net, despite not being built for it. And that’s cool.

Edit

I just realized I forgot to include the fit_reg_nn function:

sm_reg <- stan_model("nn_reg.stan")

fit_nn_reg <- function(x_train, y_train, x_test, y_test, H, n_H, method = "optimize", ...) {
    stan_data <- list(
        N = nrow(x_train),
        P = ncol(x_train),
        x = x_train,
        y = y_train,
        H = H,
        n_H = n_H,
        N_test = length(y_test),
        y_test = y_test
    )
    if(method == "optimize") {
        optOut <- optimizing(sm_reg, data = stan_data)
        test_char <- paste0("output_test[", 1:stan_data$N_test, "]")
        y_test_pred <- optOut$par[test_char]
        mse <- mean((y_test_pred - y_test)^2)
        out <- list(y_test_pred = y_test_pred,
                    sigma = optOut$par["sigma"],
                    mse  = mse,
                    fit = optOut)
        return(out)
    } else {
        if(method == "sampling") {
            out <- sampling(sm_reg, data = stan_data, ...)
        } else if (method == "vb") {
            out <- vb(sm_reg, data = stan_data, pars = c("output_test", "sigma", "output_test_rng"), ...)
        }
        y_test_pred <- summary(out, pars = "output_test")$summary
        sigma <- summary(out, pars = "sigma")$summary
        out <- list(y_test_pred = y_test_pred,
                    sigma = sigma,
                    fit = out)
        return(out)
    }
}

Leave a Reply