SAMPLING-BASED INFERENCE FOR LARGE LINEAR MODELS, WITH APPLICATION TO LINEARISED LAPLACE

Abstract

Large-scale linear models are ubiquitous throughout machine learning, with contemporary application as surrogate models for neural network uncertainty quantification; that is, the linearised Laplace method. Alas, the computational cost associated with Bayesian linear models constrains this method's application to small networks, small output spaces and small datasets. We address this limitation by introducing a scalable sample-based Bayesian inference method for conjugate Gaussian multioutput linear models, together with a matching method for hyperparameter (regularisation strength) selection. Furthermore, we use a classic feature normalisation method, the g-prior, to resolve a previously highlighted pathology of the linearised Laplace method. Together, these contributions allow us to perform linearised neural network inference with ResNet-18 on CIFAR100 (11M parameters, 100 output dimensions × 50k datapoints) and with a U-Net on a high-resolution tomographic reconstruction task (2M parameters, 251k output dimensions).

1. INTRODUCTION

The linearised Laplace method, originally introduced by Mackay (1992), has received renewed interest in the context of uncertainty quantification for modern neural networks (NN) (Khan et al., 2019; Immer et al., 2021b; Daxberger et al., 2021a) . The method constructs a surrogate Gaussian linear model for the NN predictions, and uses the error bars of that linear model as estimates of the NN's uncertainty. However, the resulting linear model is very large; the design matrix is sized number of parameters by number of datapoints times number of output classes. Thus, both the primal (weight space) and dual (observation space) formulations of the linear model are intractable. This restricts the method to small network or small data settings. Moreover, the method is sensitive to the choice of regularisation strength for the linear model (Immer et al., 2021a; Antorán et al., 2022c) . Motivated by linearised Laplace, we study inference and hyperparameter selection in large linear models. To scale inference and hyperparameter selection in Gaussian linear regression, we introduce a samplebased Expectation Maximisation (EM) algorithm. It interleaves E-steps, where we infer the model's posterior distribution over parameters given some choice of hyperparameters, and M-steps, where the hyperparameters are improved given the current posterior. Our contributions here are two-fold: 1 We enable posterior sampling for large-scale conjugate Gaussian-linear models with a novel sample-then-optimize objective, which we use to approximate the E-step. 2 We introduce a method for hyperparameter selection that requires only access to posterior samples, and not the full posterior distribution. This forms our M-step. Combined, these allow us to perform inference and hyperparameter selection by solving a series of quadratic optimisation problems using iterative optimisation, and thus avoiding an explicit cubic cost in any of the problem's properties. Our method readily extends to non-conjugate settings, such as classification problems, through the use of the Laplace approximation. In the context of linearised NNs, our approach also differs from previous work in that it avoids instantiating the full NN Jacobian matrix, an operation requiring as many backward passes as output dimensions in the network. We demonstrate the strength of our inference technique in the context of the linearised Laplace procedure for image classification on CIFAR100 (100 classes × 50k datapoints) using an 11M parameter ResNet-18. We also consider a high-resolution (251k pixel) tomographic reconstruction (regression) task with a 2M parameter U-Net. In tackling these, we encounter a pathology in the M-step of the procedure first highlighted by Antorán et al. (2022c): the standard objective therein is ill-defined when the NN contains normalisation layers. Rather than using the solution proposed in Antorán et al. (2022c) , which introduces more hyperparameters, we show that a standard featurenormalisation method, the g-prior (Zellner, 1986; Minka, 2000) , resolves this pathology. For the tomographic reconstruction task, the regression problem requires a dual-form formulation of our E-step; interestingly, we show that this is equivalent to an optimisation viewpoint on Matheron's rule (Journel & Huijbregts, 1978; Wilson et al., 2020) , a connection we believe to be novel.

2. CONJUGATE GAUSSIAN REGRESSION AND THE EM ALGORITHM

We study Bayesian conjugate Gaussian linear regression with multidimensional outputs, where we observe inputs x 1 , . . . , x n ∈ R d and corresponding outputs y 1 , . . . , y n ∈ R m . We model these as y i = ϕ(x i )θ + η i , where ϕ : R d → R m × R d ′ is a known embedding function. The parameters θ are assumed sampled from N (0, A -1 ) with an unknown precision matrix A ∈ R d ′ ×d ′ , and for each i ≤ n, η i ∼ N (0, B -1 i ) are additive noise vectors with precision matrices B i ∈ R m×m relating the m output dimensions. Our goal is to infer the posterior distribution for the parameters θ given our observations, under the setting of A of the form A = αI for α > 0 most likely to have generated the observed data. For this, we use the iterative procedure of Mackay (1992), which alternates computing the posterior for θ, denoted Π, for a given choice of A, and updating A, until the pair (A, Π) converge to a locally optimal setting. This corresponds to an EM algorithm (Bishop, 2006) . Henceforth, we will use the following stacked notation: we write Y ∈ R nm for the concatenation of y 1 , . . . , y n ; B ∈ R nm×nm for a block diagonal matrix with blocks B 1 , . . . , B n and Φ = [ϕ(X 1 ) T ; . . . ; ϕ(X n ) T ] T ∈ R nm×d ′ for the embedded design matrix. We write M := Φ T BΦ. Additionally, for a vector v and a PSD matrix G of compatible dimensions, ∥v∥ 2 G = v T Gv. With that, the aforementioned EM algorithm starts with some initial A ∈ R d ′ ×d ′ , and iterates: • (E step) Given A, the posterior for θ, denoted Π, is computed exactly as Π = N ( θ, H -1 ) where H = M + A and θ = H -1 Φ T BY. (2) • (M step) We lower bound the log-probability density of the observed data, i.e. the evidence, for the model with posterior Π and precision A ′ as (derivation in Appendix B.2) log p(Y ; A ′ ) ≥ -1 2 ∥ θ∥ 2 A ′ -1 2 log det(I + A ′-1 M ) + C =: M(A ′ ), for C independent of A ′ . We choose an A that improves this lower bound.

Limited scalability

The above inference and hyperparameter selection procedure for Π and A is futile when both d ′ and nm are large. The E-step requires the inversion of a d ′ × d ′ matrix and the M-step evaluating its log-determinant, both cubic operations in d ′ . These may be rewritten to instead yield a cubic dependence on nm (as in Section 3.3), but under our assumptions, that too is not computationally tractable. Instead, we now pursue a stochastic approximation to this EM-procedure.

3. EVIDENCE MAXIMISATION USING STOCHASTIC APPROXIMATION

We now present our main contribution, a stochastic approximation (Nielsen, 2000) to the iterative algorithm presented in the previous section. Our M-step requires only access to samples from Π. We introduce a method to approximate posterior samples through stochastic optimisation for the E-step.



* Equal contribution. Correspondence to ja666@cam.ac.uk and sp2058@cam.ac.uk .

