SAMPLING-BASED INFERENCE FOR LARGE LINEAR MODELS, WITH APPLICATION TO LINEARISED LAPLACE

Abstract

Large-scale linear models are ubiquitous throughout machine learning, with contemporary application as surrogate models for neural network uncertainty quantification; that is, the linearised Laplace method. Alas, the computational cost associated with Bayesian linear models constrains this method's application to small networks, small output spaces and small datasets. We address this limitation by introducing a scalable sample-based Bayesian inference method for conjugate Gaussian multioutput linear models, together with a matching method for hyperparameter (regularisation strength) selection. Furthermore, we use a classic feature normalisation method, the g-prior, to resolve a previously highlighted pathology of the linearised Laplace method. Together, these contributions allow us to perform linearised neural network inference with ResNet-18 on CIFAR100 (11M parameters, 100 output dimensions × 50k datapoints) and with a U-Net on a high-resolution tomographic reconstruction task (2M parameters, 251k output dimensions).

1. INTRODUCTION

The linearised Laplace method, originally introduced by Mackay (1992), has received renewed interest in the context of uncertainty quantification for modern neural networks (NN) (Khan et al., 2019; Immer et al., 2021b; Daxberger et al., 2021a) . The method constructs a surrogate Gaussian linear model for the NN predictions, and uses the error bars of that linear model as estimates of the NN's uncertainty. However, the resulting linear model is very large; the design matrix is sized number of parameters by number of datapoints times number of output classes. Thus, both the primal (weight space) and dual (observation space) formulations of the linear model are intractable. This restricts the method to small network or small data settings. Moreover, the method is sensitive to the choice of regularisation strength for the linear model (Immer et al., 2021a; Antorán et al., 2022c) . Motivated by linearised Laplace, we study inference and hyperparameter selection in large linear models. To scale inference and hyperparameter selection in Gaussian linear regression, we introduce a samplebased Expectation Maximisation (EM) algorithm. It interleaves E-steps, where we infer the model's posterior distribution over parameters given some choice of hyperparameters, and M-steps, where the hyperparameters are improved given the current posterior. Our contributions here are two-fold: 1 We enable posterior sampling for large-scale conjugate Gaussian-linear models with a novel sample-then-optimize objective, which we use to approximate the E-step. 2 We introduce a method for hyperparameter selection that requires only access to posterior samples, and not the full posterior distribution. This forms our M-step. Combined, these allow us to perform inference and hyperparameter selection by solving a series of quadratic optimisation problems using iterative optimisation, and thus avoiding an explicit cubic cost in any of the problem's properties. Our method readily extends to non-conjugate settings, such as classification problems, through the use of the Laplace approximation. In the context of linearised NNs, our approach also differs from previous work in that it avoids instantiating the full NN Jacobian matrix, an operation requiring as many backward passes as output dimensions in the network. We demonstrate the strength of our inference technique in the context of the linearised Laplace procedure for image classification on CIFAR100 (100 classes × 50k datapoints) using an 11M parameter ResNet-18. We also consider a high-resolution (251k pixel) tomographic reconstruction (regression) task with a 2M parameter U-Net. In tackling these, we encounter a pathology in the M-step of the procedure first highlighted by Antorán et al. (2022c) : the standard objective therein is ill-defined when the NN contains normalisation layers. Rather than using the solution proposed in Antorán et al. (2022c) , which introduces more hyperparameters, we show that a standard featurenormalisation method, the g-prior (Zellner, 1986; Minka, 2000) , resolves this pathology. For the tomographic reconstruction task, the regression problem requires a dual-form formulation of our E-step; interestingly, we show that this is equivalent to an optimisation viewpoint on Matheron's rule (Journel & Huijbregts, 1978; Wilson et al., 2020) , a connection we believe to be novel.

2. CONJUGATE GAUSSIAN REGRESSION AND THE EM ALGORITHM

We study Bayesian conjugate Gaussian linear regression with multidimensional outputs, where we observe inputs x 1 , . . . , x n ∈ R d and corresponding outputs y 1 , . . . , y n ∈ R m . We model these as y i = ϕ(x i )θ + η i , where ϕ : R d → R m × R d ′ is a known embedding function. The parameters θ are assumed sampled from N (0, A -1 ) with an unknown precision matrix A ∈ R d ′ ×d ′ , and for each i ≤ n, η i ∼ N (0, B -1 i ) are additive noise vectors with precision matrices B i ∈ R m×m relating the m output dimensions. Our goal is to infer the posterior distribution for the parameters θ given our observations, under the setting of A of the form A = αI for α > 0 most likely to have generated the observed data. For this, we use the iterative procedure of Mackay (1992) , which alternates computing the posterior for θ, denoted Π, for a given choice of A, and updating A, until the pair (A, Π) converge to a locally optimal setting. This corresponds to an EM algorithm (Bishop, 2006) . Henceforth, we will use the following stacked notation: we write Y ∈ R nm for the concatenation of y 1 , . . . , y n ; B ∈ R nm×nm for a block diagonal matrix with blocks B 1 , . . . , B n and Φ = [ϕ(X 1 ) T ; . . . ; ϕ(X n ) T ] T ∈ R nm×d ′ for the embedded design matrix. We write M := Φ T BΦ. Additionally, for a vector v and a PSD matrix G of compatible dimensions, ∥v∥ 2 G = v T Gv. With that, the aforementioned EM algorithm starts with some initial A ∈ R d ′ ×d ′ , and iterates: • (E step) Given A, the posterior for θ, denoted Π, is computed exactly as Π = N ( θ, H -1 ) where H = M + A and θ = H -1 Φ T BY. (2) • (M step) We lower bound the log-probability density of the observed data, i.e. the evidence, for the model with posterior Π and precision A ′ as (derivation in Appendix B.2) log p(Y ; A ′ ) ≥ -1 2 ∥ θ∥ 2 A ′ -1 2 log det(I + A ′-1 M ) + C =: M(A ′ ), for C independent of A ′ . We choose an A that improves this lower bound.

Limited scalability

The above inference and hyperparameter selection procedure for Π and A is futile when both d ′ and nm are large. The E-step requires the inversion of a d ′ × d ′ matrix and the M-step evaluating its log-determinant, both cubic operations in d ′ . These may be rewritten to instead yield a cubic dependence on nm (as in Section 3.3), but under our assumptions, that too is not computationally tractable. Instead, we now pursue a stochastic approximation to this EM-procedure.

3.1. HYPERPARAMETER SELECTION USING POSTERIOR SAMPLES (M-STEP)

For now, assume that we have an efficient method of obtaining samples ζ 1 , . . . , ζ k ∼ Π 0 at each step, where Π 0 is a zero-mean version of the posterior Π, and access to θ, the mean of Π. Evaluating the first order optimality condition for M (see Appendix B.3) yields that the optimal choice of A satisfies ∥ θ∥ 2 A = Tr {H -1 M } =: γ, where the quantity γ is the effective dimension of the regression problem. It can be interpreted as the number of directions in which the weights θ are strongly determined by the data. Setting A = αI for α = γ/∥ θ∥ 2 yields a contraction step converging towards the optimum of M (Mackay, 1992). Computing γ directly requires the inversion of H, a cubic operation. We instead rewrite γ as an expectation with respect to Π 0 using Hutchinson (1990) 's trick, and approximate it using samples as γ = Tr {H -1 M } = Tr {H -1 2 M H -1 2 } = E[ζ T 1 M ζ 1 ] ≈ 1 k k j=1 ζ T j Φ T BΦζ j =: γ. ( ) We then select α = γ/∥ θ∥ 2 . We have thus avoided the explicit cubic cost of computing the logdeterminant in the expression for M (given in (3)) or inverting H. Due to the block structure of B, γ may be computed in order n vector-matrix products.

3.2. SAMPLING FROM THE LINEAR MODEL'S POSTERIOR USING SGD (E-STEP)

Now we turn to sampling from Π 0 = N (0, H -1 ). It is known (also, shown in Appendix C.1) that for E ∈ R nm the concatenation of ε 1 , . . . , ε n with ε i ∼ N (0, B -1 i ) and θ 0 ∼ N (0, A -1 ), the minimiser of the following loss is a random variable ζ with distribution Π 0 : L(z) = 1 2 ∥Φz -E∥ 2 B + 1 2 ∥z -θ 0 ∥ 2 A . This is called the "sample-then-optimise" method (Papandreou & Yuille, 2010) . We may thus obtain a posterior sample by optimising this quadratic loss for a given sample pair (E, θ 0 ). Examining L: • The first term is data dependent. It corresponds to the scaled squared error in fitting E as a linear combination of Φ. Its gradient requires stochastic approximation for large datasets. • The second term, a regulariser centred at θ 0 , does not depend on the data. Its gradient can thus be computed exactly at every optimisation step. Predicting E, i.e. random noise, from features Φ is hard. Due to this, the variance of a mini-batch estimate of the gradient of ∥Φz -E∥ 2 B may be large. Instead, for E and θ 0 defined as before, we propose the following alternative loss, equal to L up to an additive constant independent of z: L ′ (z) = 1 2 ∥Φz∥ 2 B + 1 2 ∥z -θ n ∥ 2 A with θ n = θ 0 + A -1 Φ T BE. The mini-batch gradients of L ′ and L are equal in expectation (see Appendix C.1). However, in L ′ , the randomness from the noise samples E and the prior sample θ 0 both feature within the regularisation term-the gradient of which can be computed exactly-rather than in the data-dependent term. This may lead to a lower minibatch gradient variance. To see this, consider the variance of the single-datapoint stochastic gradient estimators for both objectives' data dependent terms. At z ∈ R d ′ , for datapoint index j ∼ Unif({1, . . . , n}), these are ĝ = nϕ(X j ) T (ϕ(X j )z -ε j ) and ĝ′ = nϕ(X j ) T ϕ(X j )z for L and L ′ , respectively. Direct calculation, presented in Appendix C.2, shows that 1 n [Varĝ -Varĝ ′ ] = Var(Φ T BE) -2Cov(Φ T BΦz, Φ T BE) =: ∆. ( ) Note that both Varĝ and Varĝ ′ are d ′ × d ′ matrices. We impose an order on these by considering their traces: we prefer the new gradient estimator ĝ′ if the sum of its per-dimension variances is lower than that of ĝ; that is if Tr ∆ > 0. We analyse two key settings: • At initialisation, taking z = θ 0 (or any other initialisation independent of E), Tr ∆ = Tr {Φ T BE[EE T ]BΦ} -Tr {Φ T BΦE[θ 0 E T ]BΦ} = Tr {M } > 0, where we used that E[EE T ] = B -1 and since E is zero mean and independent of θ 0 , we have E[θ 0 E T ] = Eθ 0 EE T = 0. Thus, the new objective L ′ is always preferred at initialisation. (11) This is satisfied if α is large relative to the eigenvalues of M (see Appendix C.4), that is, when the parameters are not strongly determined by the data relative to the prior. When L ′ is preferred both at initialisation and at convergence, we expect it to have lower variance for most minibatches throughout training. Even if the proposed objective L ′ is not preferred at convergence, it may still be preferred for most of the optimisation, before the noise is fit well enough.

3.3. DUAL FORM OF E-STEP: MATHERON'S RULE AS OPTIMISATION

The minimiser of both L and L ′ is ζ = H -1 (Aθ 0 + Φ T BE). The dual (kernelised) form of this is ζ = θ 0 + A -1 Φ T (B -1 + ΦA -1 Φ T ) -1 (ε -Φθ 0 ), which is known in the literature as Matheron's rule (Wilson et al., 2020, our Appendix D) . Evaluating (12) requires solving a mn-dimensional quadratic optimisation problem, which may be preferable to the primal problem for mn < d ′ ; however, this dual form cannot be minibatched over observations. For small n, we can solve this optimisation problem using iterative full-batch quadratic optimisation algorithms (e.g. conjugate gradients), significantly accelerating our sample-based EM iteration.

4. NN UNCERTAINTY QUANTIFICATION AS LINEAR MODEL INFERENCE

Consider the problem of m-output prediction. Suppose that we have trained a neural network of the form f : R d ′ × R d → R m , obtaining weights w ∈ R d ′ , using a loss of the form L(f (w, •)) = n i=1 ℓ(y i , f (w, x i )) + R(w) where ℓ is a data fit term (a negative log-likelihood) and R is a regulariser. We now show how to use linearised Laplace to quantify uncertainty in the network predictions f ( w, •). We then present the g-prior, a feature normalisation that resolves a certain pathology in the linearised Laplace method when the network f contains normalisation layers.

4.1. THE LINEARISED LAPLACE METHOD

The linearised Laplace method consists of two consecutive approximations, the latter of which is necessary only if ℓ is non-quadratic (that is, if the likelihood is non-Gaussian): 1 We take a first-order Taylor expansion of f around w, yielding the surrogate model h(θ, x) = f ( w, x) + ϕ(x)(θ -w) for ϕ(x) = ∇ w f ( w, x). This is an affine model in the features ϕ(x) given by the network Jacobian at x.

2

We approximate the loss of the linear model L(h(θ, •)) with a quadratic, and treat it as a negative log-density for the parameters θ, yielding a Gaussian posterior of the form N ( θ, (∇ 2 θ L) -1 (h( θ, •))) where θ ∈ argmin θ L(h(θ, •)). ( ) Direct calculation shows that (∇ 2 θ L)(h( θ, •)) = (A + Φ T BΦ) = H, for ∇ 2 w R( w) = A and B a block diagonal matrix with blocks B i = ∇ 2 ŷi ℓ(y i , ŷi ) evaluated at ŷi = h( θ, x i ). We have thus recovered a conjugate Gaussian multi-output linear model. We treat A as a learnable parameter thereafterfoot_1 In practice, we depart from the above procedure in two ways: • We use the neural network output f ( w, •) as the predictive mean, rather than the surrogate model mean h( θ, •). Nonetheless, we still need to compute θ to use within the M-step of the EM procedure. To do this, we minimise L(h(θ, •)) over θ ∈ R d ′ . • We compute the loss curvature B i at predictions ŷi = f ( w, x i ) in place of h( θ, x i ), since the latter would change each time the regulariser A is updated, requiring expensive re-evaluation. Both of these departures are recommended within the literature (Antorán et al., 2022c) .

4.2. ON THE COMPUTATIONAL ADVANTAGE OF SAMPLE-BASED PREDICTIONS

The linearised Laplace predictive posterior at an input x is N (f ( w, x), ϕ(x)H -1 ϕ(x) T ). Even given H -1 , evaluating this naïvely requires instantiating ϕ(x), at a cost of m vector-Jacobian products (i.e. backward passes). This is prohibitive for large m. However, expectations of any function r : R m → R under the predictive posterior can be approximated using only samples from Π 0 as E[r] ≈ 1 k k j=1 r(ψ j ) for ψ j = f ( w, x) + ϕ(x)ζ j with ζ 1 , . . . , ζ k ∼ Π 0 , requiring only k Jacobian-vector products. In practice, we find k much smaller than m suffices.

4.3. FEATURE EMBEDDING NORMALISATION: THE DIAGONAL G-PRIOR

Due to symmetries and indeterminacies in neural networks, the embedding function ϕ(•) = ∇ w f ( w, •) used in the linearised Laplace method yields features with arbitrary scales across the d ′ weight dimensions. Consequently, the dimensions of the embeddings may have an (arbitrarily) unequal weight under an isotropic prior; that is, considering ϕ(x)θ 0 for θ 0 ∼ N (0, α -1 I). There are two natural solutions: either normalise the features by their (empirical) second moment, resulting in the normalised embedding function ϕ ′ given by ϕ ′ (x) = ϕ(x) diag(s) for s ∈ R d ′ given by s i = [Φ T BΦ] -1/2 ii , or likewise scale the prior, setting A = α diag((s -2 i ) d ′ i=1 ). The latter formulation is a diagonal version of what is known in the literature as the g-prior (Zellner, 1986) or scale-invariant prior (Minka, 2000) . The g-prior may, in general, improve the conditioning of the linear system. Furthermore, when the linearised network contains normalisation layers, such as batchnorm (that is, most modern networks), the g-prior is essential. Antorán et al. (2022c) show that normalisation layers lead to indeterminacies in NN Jacobians, that in turn lead to an ill-defined model evidence objective. They propose learning separate regularisation parameters for each normalised layer of the network. While fixing the pathology, this increases the complexity of model evidence optimisation. As we show in Appendix E, the g-prior cancels these indeterminacies, allowing for the use of a single regularisation parameter.

5. DEMONSTRATION: SAMPLE-BASED LINEARISED LAPLACE INFERENCE

We demonstrate our linear model inference and hyperparameter selection approach on the problem of estimating the uncertainty in NN predictions with the linearised Laplace method. First, in Section 5.1, we perform an ablation analysis on the different components of our algorithm using small LeNet-style CNNs trained on MNIST. In this setting, full-covariance Laplace inference (that is, exact linear model inference) is tractable, allowing us to evaluate the quality of our approximations. We then Algorithm 1: Sampling-based linearised Laplace hyperparameter learning and inference Inputs: (2020) . Top: prior function samples present large std-dev. (left). When these samples are optimised (middle shows a 2D slice of weight space), the resulting predictive errorbars are larger than the marginal target variance (right). Bottom: after EM, the std-dev. of prior functions roughly matches that of the targets (left), the overlap between prior and posterior is maximised, leading to shorter sample trajectories (center), and the predictive errorbars are qualitatively more appealing (right). initial α > 0; k, k ′ ∈ N, demonstrate our method at scale on CIFAR100 classification with a ResNet18 (Section 5.2) and the dual (kernelised) formulation of our method on tomographic image reconstruction using a U-Net (Section 5.3). We look at both marginal and joint uncertainty calibration and at computational cost. For all experiments, our method avoids storing covariance matrices H -1 , computing their logdeterminants, or instantiating Jacobian matrices ϕ(x), all of which have hindered previous linearised Laplace implementations. We interact with NN Jacobians only through Jacobian-vector and vector-Jacobian products, which have the same asymptotic computational and memory costs as a NN forward-pass (Novak et al., 2022) . Unless otherwise specified, we use the diagonal g-prior and a scalar regularisation parameter. Algorithm 1 summarises our method, Figure 1 shows an illustrative example, and full algorithmic detail is in Appendix F. An implementation of our method in JAX can be found here. Additional experimental results are provided in Appendix H and Appendix I.

5.1. ABLATION STUDY: LENET ON MNIST

We first evaluate our approach on MNIST m=10 class image classification, where exact linearised Laplace inference is tractable. The training set consists of n=60k observations and we employ 3 LeNet-style CNNs of increasing size: "LeNetSmall" (d ′ =14634), "LeNet" (d ′ =29226) and "LeNet-Big" (d ′ =46024). The latter is the largest model for which we can store the covariance matrix on an A100 GPU. We draw samples and estimate posterior modes using SGD with Nesterov momentum (full details in Appendix G). We use 5 seeds for each experiment, and report the mean and std. error. Comparing sampling objectives We first compare the proposed objective L ′ with the one standard in the literature L, using LeNet. The results are shown in Figure 2 . We draw exact samples: (ζ ⋆ j ) j≤k through matrix inversion and assess sample fidelity in terms of normalised squared distance to these exact samples ∥ζ-ζ ⋆ ∥ 2 2/∥ζ ⋆ ∥ 2 2 . All runs share a prior precision of α ≈ 5.5 obtained with EM iteration. The effective dimension is γ ≈ 1300. Noting that the g-prior feature normalisation results in Tr M = d ′ , we can see that condition (11) is not satisfied (2 × 5.5 × 1300 < 29k). Despite this, the proposed objective converges to more accurate samples even when using a 16-times smaller batch size (left plot). The right side plots relate sample error to categorical symmetrised-KL (sym. KL) and logit Wasserstein-2 (W2) distance between the sampled and exact lin. Laplace predictive distributions on the test-set. Both objectives' prediction errors stop decreasing below a sample error of ≈0.5 but, nevertheless, the proposed loss L ′ reaches lower a prediction error.

Fidelity of sampling inference

We compare our method's uncertainty using 64 samples against approximate methods based on the NN weight point-estimate (MAP), a diagonal covariance, and a KFAC estimate of the covariance (Martens & Grosse, 2015; Ritter et al., 2018) implemented with the Laplace library, in terms of similarity to the full-covariance lin. Laplace predictive posterior. As standard, we compute categorical predictive distributions with the probit approximation (Daxberger et al., 2021a) . All methods use the same layerwise prior precision obtained with 5 steps of fullcovariance EM iteration. For all three LeNet sizes, the sampled approximation presents the lowest categorical sym. KL and logit W2 distance to the exact lin. Laplace pred. posterior (Figure 3 , LHS). The fidelity of competing approximations degrades with model size but that of sampling increases. Accuracy of sampling hyperparameter selection We compare our sampling EM iteration with 16 samples to full-covariance EM on LeNet without the g-prior. Figure 3 , right, shows that for a single precision hyperparameter, both approaches converge in about 3 steps to the same value. In this setting, the diagonal covariance approximation diverges, and KFAC converges to a biased solution. We also consider learning layer-wise prior precisions by extending the M-step update from Section 3 to diagonal but non-isotropic prior precision matrices (see Appendix B.4). Here, neither the full covariance nor sampling methods converge within 15 EM steps. The precisions for all but the final layer grow, revealing a pathology of this prior parametrisation: only the final layer's Jacobian, i.e. the final layer activations, are needed to accurately predict the targets; other features are pruned away.

5.2. RESNET18 ON CIFAR100

We showcase the stability and performance of our approach by applying it to CIFAR100 m=100-way image classification. The training set consists of n=50k observations, and we employ a ResNet-18 model with d ′ ≈ 11M parameters. We perform optimisation using SGD with Nesterov momentum and a linear learning rate decay schedule. Unless specified otherwise, we run 8 steps of EM with 6 samples to select α. We then optimise 64 samples to be used for prediction. We run each experiment with 5 different seeds reporting mean and std. error. Full experimental details are in Appendix G. κ MAP Ensemble (5) KFAC Sampling marginal LL 1 -1.40 ± 0.00 -0.90 ± 0.00 -1.12 ± 0.01 -1.07 ± 0.01 joint LL 2 -13.97 ± 0.01 -6.86 ± 0.01 -4.92 ± 0.04 -5.14 ± 0.04 3 -27.89 ± 0.03 -14.17 ± 0.03 -10.83 ± 0.12 -10.77 ± 0.09 4 -41.83 ± 0.03 -22.29 ± 0.04 -19.02 ± 0.22 -18.04 ± 0.18 5 -55.89 ± 0.02 -31.07 ± 0.09 -29.40 ± 0.40 -26.75 ± 0.26 Table 1 : Comparison of methods' marginal and joint prediction performance for ResNet18 on CIFAR100.

Stability and cost of sampling algorithm

1.4 × 10 7 > 1.1 × 10 7 . Thus, ( 11) is satisfied and L ′ is preferred. We use 50 epochs of optimisation for the posterior mode and 20 for sampling. When using 2 samples, the cost of one EM step with our method is 45 minutes on an A100 GPU; for the KFAC approximation, this takes 20 minutes. Evaluating performance in the face of distribution shift We employ the standard benchmark for evaluating methods' test Log-Likelihood (LL) on the increasingly corrupted data sets of Hendrycks & Gimpel (2017) ; Ovadia et al. (2019) . We compare the predictions made with our approach to those from 5-element deep ensembles, arguably the strongest baseline for uncertainty quantification in deep learning (Lakshminarayanan et al., 2017; Ashukha et al., 2020) , with point-estimated predictions (MAP), and with a KFAC approximation of the lin. Laplace covariance (Ritter et al., 2018) . For the latter, constructing full Jacobian matrices for every test point is computationally intractable, so we use 64 samples for prediction, as suggested in Section 4.2. The KFAC covariance structure leads to fast log-determinant computation, allowing us to learn layer-wise prior precisions (following Immer et al., 2021a) for this baseline using 10 steps of non-sampled EM. For both lin. Laplace methods, we use the standard probit approximation to the categorical predictive (Daxberger et al., 2021b) . Figure 5 shows that for in-distribution inputs, ensembles performs best and KFAC overestimates uncertainty, degrading LL even relative to point-estimated MAP predictions. Conversely, our method improves LL. For sufficiently corrupted data, our approach outperforms ensembles, also edging out KFAC, which fares well here due to its consistent overestimation of uncertainty. Joint predictions Joint predictions are essential for sequential decision making, but are often ignored in the context of NN uncertainty quantification (Janz et al., 2019) . To address this, we replicate the "dyadic sampling" experiment proposed by Osband et al. (2022) . We group our test-set into sets of κ data points and then uniformly re-sample the points in each set until sets contain τ points. We then evaluate the LL of each set jointly. Since each set only contains κ distinct points, a predictor that models self-covariances perfectly should obtain an LL value at least as large as its marginal LL for all values of κ. We use τ =10(κ -1) and repeat the experiment for 10 test-set shuffles. Our setup remains the same as above but we use Monte Carlo marginalisation instead of probit, since the latter discards covariance information. Table 1 shows that ensembles make calibrated predictions marginally but their joint predictions are poor, an observation also made by Osband et al. (2021) . Our approach is competitive for all κ, performing best in the challenging large κ cases.

5.3. TOMOGRAPHIC RECONSTRUCTION

To demonstrate our approach in dual (kernelised) form, we replicate the setting of Barbano et al. (2022a; b) and Antorán et al. (2022b) , where linearised Laplace is used to estimate uncertainty for a tomographic reconstruction outputted by a U-Net autoencoder. We provide an overview of the problem in Appendix G.3, but refer to Antorán et al. (2022b) for full detail. Whereas the authors use a single EM step to learn hyperparameters, we use our sample-based variant and run 5 steps. Unless otherwise specified, we use 16 samples for stochastic EM, and 1024 for prediction. We test on the real-measured µCT dataset of 251k pixel scans of a single walnut released by Der Sarkissian et al. (2019b) . We train d ′ =2.97M parameter U-Nets on the m=7680 dimensional observation used by Antorán et al. (2022b) and a twice as large setting m=15360. Here, the U-Net's input is clamped to a constant (n = 1), and its parameters are optimised to output the reconstructed image. As a result, we do not need mini-batching and can draw samples using Matheron's rule (12). We solve the linear system contained therein using conjugate gradient (CG) iteration implemented with GPyTorch (Gardner et al., 2018) . We accelerate CG with a randomised SVD preconditioner of rank 400 (alg. 5.6 in Halko et al., 2011) . See Appendix G for full experimental details. Stability and cost of sampling algorithm Figure 6 shows that sample-based EM iteration converges within 4 steps using as few as 2 samples. Table 2 shows the time taken to perform 5 sample-based EM steps for both the m=7680 and m=15360 settings; avoiding explicit estimation of the covariance log-determinant provides us with a two order of magnitude speedup relative to Antorán et al. (2022c) for hyperparameter learning. By avoiding covariance inversion, we obtain an order of magnitude speedup for prediction. Furthermore, while scaling to double the observations m=15360 is intractable with the previous method, our sampling method requires only a 25% increase in computation time. Predictive performance Figure 7 shows, qualitatively, that the marginal standard deviation assigned to each pixel by our method aligns with the pixelwise error in the U-Net reconstruction in a finegrained manner. By contrast, MC dropout (MCDO), the most common baseline for NN uncertainty estimation in tomographic reconstruction (Laves et al., 2020; Tölle et al., 2021) , spreads uncertainty more uniformly across large sections of the image. Table 2 shows that the pixelwise LL obtained with our method exceeds that obtained by Antorán et al. (2022b) , potentially due to us optimising the prior precision to convergence while the previous work could only afford a single EM step. The rightmost plot in Figure 6 displays joint test LL, evaluated on patches of neighbouring pixels. Our method performs best. MCDO's predictions are poor marginally. They improve when considering covariances, although remaining worse than lin. Laplace. 

6. CONCLUSION

Our work introduced a sample-based approximation to inference and hyperparameter selection in Gaussian linear multi-output models. This allowed us to scale the linearised Laplace method to ResNet-18 on CIFAR100, where it was computationally intractable with existing methods. We also demonstrated the strength of the approach on a high resolution tomographic reconstruction task, where it decreases the cost of hyperparameter selection by two orders of magnitude. The uncertainty estimates obtained through our method are well-calibrated not just marginally, but also jointly across predictions. Thus, our work may be of interest in the fields of active and reinforcement learning, where joint predictions are of importance, and computation of posterior samples is often needed.

REPRODUCIBILITY STATEMENT

In order to aid the reproduction of our results, we provide a high-level overview of our procedure in algorithm 1 and the fully detailed algorithms we use in our two major experiments in Appendix F. Appendix G provides full experimental details for all datasets and models used in our experiments. Our code is available in a repository at this link.

A RELATED WORK

Bayesian Gaussian linear models This work builds on the rich literature of Bayesian linear regression (Gull, 1989; Bishop, 2006; Rasmussen & Williams, 2006) . Specifically, we present a stochastic approximation the iterative algorithm for hyperparameter selection introduced by (Mackay, 1992) and extended by Tipping (2001) ; Tipping & Faul (2003) ; Wipf & Nagarajan (2007) ; Antorán et al. (2022c) . Analytical tractability makes linear models ubiquitous in machine learning, with applications in genomics (Runcie et al., 2021) , reinforcement learning (Ash et al., 2022) , and pandemic modelling (Nicholson et al., 2022) 2020), although in these setting it does not draw exact posterior samples. In this work, we show sample-then-optimise to be the primal form of Matheron's rule (Journel & Huijbregts, 1978; Hoffman & Ribak, 1991) , a method for updating jointly Gaussian samples into conditional samples, which was recently repopularised by Wilson et al. (2020) . introduce subnetwork and finite differences approaches, respectively, for faster inference with the linearised model. This line of work is also closely related to the neural tangent kernel (Jacot et al., 2018; Lee et al., 2019; Novak et al., 2020) in which NNs are linearised at initialisation. The g-prior, originally introduced by Zellner (1986) , consists of a centred Gaussian with covariance matching the inverse of the Fisher information matrix. Resultantly, the g-prior ensures inferences are independent of the units of measurement of the covariates (Minka, 2000) . Since then, it has extensively used in the context of model selection for generalised linear models (Liang et al., 2008; Bové & Held, 2011; Baragatti & Pommeret, 2012) . In the large-scale setting, we overcome the computational intractability of the Fisher by diagonalising the g-prior while preserving its scale-invariance property.

B MODEL EVIDENCE LOWER BOUND AND THE EFFECTIVE DIMENSION B.1 EQUIVALENT FORMULATIONS OF EFFECTIVE DIMENSION

We begin by relating two standard forms of effective dimension, which we use throughout. Starting with the form standard in the kernel-based literature (that without an explicit d ′ dependence), γ = Tr {(A + M ) -1 M } = Tr {(I + A -1 M ) -1 A -1 M } = Tr {I -(I + A -1 M ) -1 } (18) = d ′ -Tr {A(A + M ) -1 }, we have arrived at the form used within the finite-dimensional linear modelling literature (Mackay, 1992; Wipf & Nagarajan, 2007; Maddox et al., 2020) .

B.2 DERIVATION OF M AS A LOWER BOUND ON THE MODEL EVIDENCE

Let p θ be the Lebesgue density of N (Φθ, B -1 ), P = N (0, A ′-1 ) and Q = N ( θ, + A ′ ) -1 ). Then, log p(Y ; A ′ ) = log p θ (Y )dP = log p θ (Y ) dP dQ dQ ≥ log p θ (Y ) dP dQ dQ (20) = log p θ (Y )dQ -D(Q||P ). ( ) where D denotes the KL-divergence. Starting with the first term, log p θ (Y )dQ = 1 2 -n log 2π + log detB -(Y -Φθ) T B(Y -Φθ)dQ (22) = 1 2 [-n log 2π + log detB] - 1 2 (Y -Φθ) T B(Y -Φθ)dQ, and expanding the quadratic form, (Y -Φθ) T B(Y -Φθ)dQ = Y T BY -2Y T BΦ θdQ + θ T Φ T BΦθdQ (24) = Y T BY -2Y T BΦ θ + θ T M θdQ. ( ) To handle the final integral, consider that γ = Tr {M (M + A ′ ) -1 } (26) = Tr {M (θ -θ)(θ -θ) T dQ} (27) = -Tr {M θ θT } + Tr {M θθ T dQ} (28) = -θT M θ + θ T M θdQ, and thus log p θ (Y )dQ = 1 2 log detB -n log 2π -(Y -Φ θ) T B(Y -Φ θ) -γ (30) = log pθ(Y ) - 1 2 γ. The KL between two multivariate Gaussians is a standard result, yielding D(Q||P ) = 1 2 -log detA ′ + log det(M + A ′ ) -d ′ + θT A ′ θ + Tr {A ′ (M + A ′ ) -1 } (32) = 1 2 -log detA ′ + log det(M + A ′ ) + θT A ′ θ -γ , where we used that γ = d ′ -Tr {A ′ (M + A ′ ) -1 }. Putting together ( 31) and ( 33), we obtain log p(Y ; A ′ ) ≥ log pθ(Y ) - 1 2 log det(A ′-1 M + I) - 1 2 ∥ θ∥ 2 A ′ = M(A ′ ), which is the stated result up to taking C = log pθ(Y ).

B.3 FIRST ORDER OPTIMALITY CONDITION FOR M

Consider the derivative of M. We have, ∇M(A) = - 1 2 ∇∥ θ∥ 2 A + ∇ log det(A + M ) -∇ log detA , where we expanded log det(I + A -1 M ) = log det(A + M )log detA. Taking the respective derivatives and setting equal to zero at A, this leads to the condition θ θT = (I (I + A -1 M ) -1 )A -1 . Post-multiplying by A and applying the push-through identity, we obtain θ θT A = M (A + M ) -1 . ( ) For the above to hold, it is necessary that the traces of both sides are equal. Thus, ∥ θ∥ 2 A = Tr { θ θT A} = Tr {M (A + M ) -1 } = γ, which is the stated first order optimality condition, up to a cyclic permutation.

B.4 M-STEP FOR FEATURE-WISE REGULARISATION STRENGTHS

We can leverage the primal form expression for the effective dimension given in Appendix B.1 to extend the above first order optimality condition to the feature-wise regulariser setting. Consider a sub-vector of our weight vector contiguous between the ith and jth weights written as θi:j . Note that we only choose contiguous weights for notational convenience but it is not necessary to do so in general. The first order condition from Appendix B.3 is satisfied if for any i, j with i < j we have α∥ θi:j ∥ 2 = j -i - j k=i [A] kk [(A + M ) -1 ] kk := γ i:j . We assume [A] kk = α for all i ≤ k < j. Thus, we may update the regulariser for each separate weight sub-vector as α = γ i:j /∥ θi:j ∥ 2 .

C ANALYSIS OF LOSSES AND LOSS GRADIENT ESTIMATOR VARIANCES C.1 ON LOSS MINIMA

The losses L and L ′ are strictly convex, thus to confirm they have the same unique minimum, it suffices to consider the respective first order optimality conditions, ∇L(ζ) = 0 and ∇L ′ (ζ ′ ) = 0. We have, ∇L(ζ) = Φ T B(Φζ -E) + A(ζ -θ 0 ), and ∇L ′ (ζ ′ ) = Φ T BΦζ ′ + A(ζ ′ -A -1 Φ T BE -θ 0 ) (41) = Φ T B(Φζ ′ -E) + A(ζ ′ -θ 0 ) (42) Thus ζ = ζ ′ almost surely. Moreover, L ′ (z) = L(z) + C for all z, for C a constant independent of z. To determine the distribution of ζ, note that it is a linear transformation of zero-mean Gaussian random variables, and thus itself a zero-mean Gaussian random variable. Rearranging the first order optimality condition, we find that ζ = H -1 (Φ T BE + Aθ 0 ). Thus E[ζζ T ] = H -1 E[(Φ T BE + Aθ 0 )(Φ T BE + Aθ 0 ) T ]H -1 (44) = H -1 Φ T BE[EE T ]BΦ + AE[θ 0 θ 0 ]A + 2Φ T BE[E(θ 0 ) T ]A H -1 (45) = H -1 (Φ T BΦ + A)H -1 = H -1 HH -1 = H -1 . ( ) And so ζ ∼ N (0, H -1 ) = Π 0 as claimed.

C.4 ANALYSING CONDITION AT CONVERGENCE

To gain some intuition for the condition at convergence, denote by λ 1 , . . . , λ d ′ the eigenvalues of M (with multiplicity). We can use these to restate the condition as 2αγ = 2α d ′ j=1 λ j λ j + α > d ′ j=1 λ j = Tr {M }. ( ) This formulation of effective dimension gives an interpretation of a soft count of the number of dimensions for which λ j is larger than α; in that sense, λ j measures how well determined the corresponding dimension of the weight vector θ is by the observed data. From here, note that 2αλ j λ j + α > min{λ j , α}, and thus it is sufficient for Tr ∆ > 0 to hold at convergence that α > λ j for all j (but, of course, not necessary), yielding the intuition that L ′ is preferred when the problem is heavily regularised.

D DUAL FORM OF THE SAMPLE-THEN-OPTIMISE LOSS: MATHERON'S RULE

Both losses L and L ′ result in a random variable ζ ∼ Π 0 given by ζ = H -1 (Φ T BE + Aθ 0 ). Recalling that H = A + Φ T BΦ and using the push-through identity, we can express ζ equivalently as ζ = H -1 ((H -Φ T BΦ)θ 0 + Φ T BE) (70) = θ 0 + H -1 Φ T B(E -Φθ 0 ) (71) = θ 0 + A -1 (I + Φ T BΦA -1 ) -1 Φ T B(E -Φθ 0 ) (72) = θ 0 + A -1 Φ T B(I + ΦA -1 Φ T B) -1 (E -Φθ 0 ) (73) = θ 0 + A -1 Φ T (B -1 + ΦA -1 Φ T ) -1 (E -Φθ 0 ) (74) Now taking a sample of the posterior Gaussian process evaluated at input x to be G = ϕ(x)ζ and the corresponding sample of the prior process to be G 0 = ϕ(x)θ 0 , premultiplying the above expression by ϕ(x) we obtain G = G 0 + ϕ(x)A -1 Φ T (B -1 + ΦA -1 Φ T ) -1 (E -Φθ 0 ) which is Matheron's rule.

E RESOLVING FEATURE SCALE INDETERMINACIES IN THE NN JACOBIAN

WITH THE G-PRIOR Antorán et al. (2022c) show that for NNs with normalisation layers, the Jacobian features ϕ(•) = ∇ w f ( w, •) corresponding to each NN layer are scaled arbitrarily. To illustrate this, we divide the NN linearisation point into the concatenation of two weight vectors w = [ w0 , w1 ]. We assume the layer containing w0 is followed by a normalisation layer, but not that containing w1 , which leads to the invariance f ([k w0 , w1 ], •) = f ([ w0 , w1 ], •) (76) for all k > 0. While f is invariant to this scaling, the Jacobian feature embeddings ϕ(•) = ∇ w f ( w, •) are not. We separate the embeddings as ϕ(x) = [ϕ 0 (•), ϕ 1 (•)] = [∇ w0 f ( w, •), ∇ w1 f ( w, •)]. (77) Antorán et al. (2022c) show that, given a reference pair ([ w0 , w1 ], [ϕ 0 (x), ϕ 1 (x)]), and for w0 normalised, scaling w0 by k results in the pair ([k w0 , w1 ], [k -1 ϕ 0 (x), ϕ 1 (x)]). Thus, using a single prior precision parameter, the regularisation strength applied to the weights multiplying ϕ 0 (x) relative to those multiplying ϕ 1 (x) will increase with k. The value of k, the scale of the linearisation point, depends on exogenous factors such as learning rate or batch size-and importantly is independent of the data, since it does not affect the output. One way to resolve this is to assign the weights w0 and w1 separate regularisation parameters and learn these using the EM procedure outlined in Section 2. However, instead, consider using the g-prior normalised features ϕ ′ introduced in Section 4.3, and specifically, the scaling vector corresponding to normalised and non-normalised components s = [s 0 , s 1 ]. For a reference pair ([ w0 , w1 ], [s 0 , s 1 ]) and for w0 normalised, the k-scaled pair is ([k w0 , w1 ] and [diag(k -1 Φ T 0 BΦ 0 k -1 ) ⊙-1 2 , diag(Φ T 1 BΦ 1 ) ⊙-1 2 ] = [ks 0 , s 1 ] where ⊙ denotes an elementwise power. Since the k-scaled features are [k -1 ϕ 0 (•), ϕ 1 (•)], when applying the g-prior normalisation we recover a feature vector independent of k. This resolves the aforementioned pathology.

F A PRACTICAL IMPLEMENTATION OF SAMPLE-BASED INFERENCE AND HYPERPARAMETER LEARNING FOR LINEARISED NEURAL NETWORKS

Algorithm 1 provides a high level overview of the procedure used for our experiments. This appendix expands on this, providing fully detailed algorithms for both sampled linearised Laplace applied to classification networks and the kernelised version of the method that we use for tomographic image reconstruction. Image classification Algorithm 2 provides an algorithm for linearised Laplace inference using the stochastic EM iteration presented in Section 3 for hyperparameter selection and the g-prior normalisation described in Section 4.3. Therein, µ denotes the softmax function. The curvature of the cross entropy loss at x i , denoted B i , is given by B i = diag(p i )p i p T i for p i = µ(f ( w, x i )) our neural network's predictive probabilities. The notation ⊙ refers to the elementwise product and to the elementwise power when used in an exponent. We refer to the Cholesky factorisation of a positive definite matrix as its 1 /2th power. In order to limit computational cost, we sample the stochastic regularisation terms (θ n j ), per (7), only once at the start. Not resampling these at each E step results in the optima of the sampling objective being close for successive iterations. This comes at the cost of a small bias in our estimator which we find to be negligible in practise. We separate (θ n j ) into a sum consisting of a prior sample from (θ 0 j ) and a data dependent term, denoted (θ ′ j ). The former scales with α -1 /2 while the latter with α -1 so this allows us to update each term in closed form each time α changes. We initialise our samples at (θ 0 j ) at the first EM iteration. We warm-start the posterior mode θ at the previous solution between iterations, initialising it to zero for the first iteration. We estimate the g-prior scaling vector s by noting that it relates to θ ′ 1 as s = α -1 (E[θ ′ 1 ⊙ θ ′ 1 ]) ⊙-1 /2 . We optimise both our samples ζ and posterior mean θ using stochastic gradient descent with a Nesterov momentum parameter of 0.9. We find that Polyak averaging is very effective at reducing gradient variance when optimising the sampling objective (per Dieuleveut et al., 2017) . However, it has two limitations 1) it slows down optimisation, increasing the number of steps needed 2) it doubles the memory requirement needed to store posterior samples. This decreases the number of samples that can be optimised in parallel on a single hardware accelerator. Instead we employ a linear learning rate decay schedule, which we find to work nearly as well while not increasing computational burden. The regularised classification loss L is non-quadratic and thus Polyak averaging is no longer optimal (Bach, 2014) . Thus here we also employ a linear learning rate decay schedule. For optimising both the sampling and classification loss objectives we find that gradient clipping helps prevent oscillations at the start of training and as a result speeds up convergence. The key hyperparameters of our algorithm are the number of samples to draw for the EM iteration, the number of EM steps to run, and SGD hyperparameters: learning rate, linear decay rate, number of steps, batch-size and gradient clipping. Empirically, we find that at most 5 EM steps are necessary for hyperparameter convergence and that as little as 3 samples can be used for the algorithm without degrading performance. Choosing SGD hyperparameters is more complicated. However, we are aided by the fact that lower loss values correspond to more precise posterior mean and sample estimates. Algorithm 2: Sampling-based linearised Laplace inference for image classification Inputs: Linearised network h, linearisation point w, observations x 1 , . . . , x n , negative log-likelihood function ℓ, initial precision α > 0, number of samples k Function B(i):  p i ← µ(h( w, x i )) return diag(p i ) -p i p T i for j = 1, . . . , k do θ 0 j ∼ N (0, α -1 I) θ ′ j ← α -1 n i=1 ϕ(x i ) T ε j where ε j ∼ N (0, B(i)) ζ j ← θ 0 j θ ← 0 s ← α -1 1 k k j=1 θ ′ ⊙2 j ⊙-1/2 while α has not converged do for j = 1, . . . , k do ζ j ← SGD z ∥Φ(s ⊙ z)∥ 2 B + α∥z -θ 0 j -(s ⊙ θ ′ j )∥ 2 2 , init=ζ j θ ← SGD θ n i=1 ℓ(y i , h((s ⊙ θ), x i ) + α∥θ∥ 2 2 , init= θ γ ← 1 k k j=1 n i=1 ∥(ζ j ⊙ s) T ϕ(x i ) T B(i) 1 2 ∥ 2 2 α ′ ← γ/∥ θ∥ 2 2 for j = 1, . . . , k do θ 0 j ← α α ′ θ 0 j θ ′ j ← α α ′ θ ′ j α ← α ′ Output: return U Φ(α -1 diag(s ⊙2 ))Φ T U T v + B -1 v s ← ( i<m (U i Φ) ⊙2 ) -1 /2 while α has not converged do P ← Compute-preconditioner(Kvp) for j = 1, . . . , k do ζ 0 j ← U Φ(s ⊙ θ 0 j ) + E j where E j ∼ N (0, B -1 ) and θ 0 j ∼ N (0, A -1 ) c j ← CG Kvp, ζ 0 j , precond.=P ζ j ← ζ 0 j -U Φ(α -1 diag(s ⊙2 ))Φ T U T c j δ ← U (Φ w -f ( w)) c ← CG (Kvp, Y +δ, precond.=P ) θ ← s ⊙ α -1 Φ T U T c γ ← 1 k k j=1 ∥U Φ(s ⊙ ζ j )∥ 2 2 α ′ ← γ/∥ θ∥ 2 α ← α ′ Output: Optimised precision α As a result, we can tune these parameters on the train data, no validation set is required. The specific hyperparameter values used in our experiments are provided in Appendix G. A final thing to note is that due to the presence of normalisation layers and a dense final layer, for our classification networks, the constant-in-θ terms cancel in the linearised model and we are left with h(θ, x) = ϕ(x)θ (Antorán et al., 2022c) . In our algorithm, this fact is only relevant to the computation of the posterior mode θ as the optima of L(h(θ, •)). Tomographic reconstruction Algorithm 3 is the kernelised version of algorithm 2 that we use for tomographic reconstruction. This problem is described in detail in Appendix G.3. Distinctly from the image classification setting, tomographic reconstruction is a regression problem for which we use a Gaussian likelihood with fixed noise precision B = I. The linear model's loss function L is thus quadratic and the Laplace approximation is not needed. Both the sample loss and the linear model's loss can be optimised in closed form by solving a linear system of equations given by the observation covariance, i.e. the kernel matrix, U Φ(α -1 diag(s ⊙2 ))Φ T U T + B -1 where the linear operator U represents the discrete Radon transform and combines with the U-Net Jacobian to build the feature embedding U Φ. We solve against the kernel matrix using the preconditioned conjugate gradient (CG) method described by Gardner et al. (2018) . As a preconditioner, we compute a 400-dimensional randomised eigendecomposition (alg. 5.6 in Halko et al., 2011) preconditioner, which we invert using the Woodbury identity. We find the preconditioner to provide important speedups to CG convergence and we re-estimate it after every hyperparameter update. Both computing the preconditioner and running preconditioned CG optimisation only interact with the kernel matrix by computing its products with vectors. Our algorithm defines our kernel vector product Kvp routine explicitly, as it is central to our implementation. We find that the GPyTorch CG implementation does not benefit from warm-starting the solution vector. Consequently, we re-draw prior and noise samples (θ 0 , E) at every E-step. Similarly to image classification, the key hyperparameters are the number of samples to draw for the EM iteration, the number of EM steps to run, and CG optimisation hyperparameters. Again, the number of samples can be kept low (e.g. 2) and we find around 5 steps to suffice for convergence of the prior precision α. The key conjugate gradients hyperparameters are the tolerance at which to stop optimisation and the maximum number of optimisation steps if the tolerance is not reached. We provide our choices in Appendix G.3 but note that our use of a large preconditioner results in CG always hitting the desired low error tolerance within 10 steps and never stopping due to reaching the maximum number of steps. In turn, this makes our kernelised EM algorithm notably faster than its primal form SGD-based counterpart. A particularity of this setting is that the U-Net does not have a dense final layer. As a result, the constant-in-θ terms in the linearised function h do not cancel (see Section 4.1), leading to the appearance of the target offset term δ when solving for the posterior mean.

G EXPERIMENTAL DETAILS

In this appendix we provide experimental details and hyperparameter settings omitted from the main text.

G.1 MNIST EXPERIMENTS

MNIST m=10 way classification experiments were performed using the LeNet-style CNN architectures of increasing size employed by Antorán et al. (2022c) : "LeNetSmall" (d ′ =14634), "LeNet" (d ′ =29226) and "LeNetBig" (d ′ =46024). We note that these models contain batch normalisation layers. Each model was trained with using SGD with momentum of 0.9 for 90 epochs with a learning rate drop of a factor of 10 every 30 epochs. The MNIST dataset was downloaded from PyTorch torchvision. We employ standard per-channel mean and std-dev standardisation preprocessing and two pixel shift and crop data augmentation. For posterior mode optimisation and sampling, we do not perform data augmentation as to avoid cold posterior effects (Izmailov et al., 2021) . The details of our SGD approaches to convex optimisation for obtaining posterior modes and samples are as follows • Posterior mode optimisation: The linearised NN weights are trained using SGD with a Nesterov momentum coefficient of 0.9, and batch size 1000 for 40 epochs. We clip gradients to a maximum norm of 1. We use an initial learning rate of 1e -2 when using standard isotropic or layerwise Gaussian priors, and 1 for the g-prior. We employ a linear decay schedule that reduces the lr by a factor of 330 over the first 75% of the training procedure and holds it constant afterwards. • Sampling: We optimise 32 samples in parallel using SGD with Nesterov momentum (=0.9) and a batch size of 1000 for 20 epochs. For standard Gaussian priors (isotropic and layerwise), we use a learning rate of 2e-1, whereas for the g-prior, we find a higher learning rate of 200 to work best. Hyperparameter optimisation: We tuned the learning rate, decay schedule and gradient clipping strength using a rough grid search over multiple orders of magnitude. We chose the settings that reached the lowest loss values. These can be evaluated with just the train set. We chose the largest batch size that could accommodate optimising 32 samples in parallel on a single hardware accelerator. We note that posterior mode and sample optimisation converge in less than half of the total epochs we use for their optimisation. The numbers of epochs chosen were set to be large enough to ensure convergence and not tuned. A decrease in computational cost can likely be achieved by stopping sample optimisation earlier. Baseline methods. For the comparison of learning a single precision hyperparameter and layerwise hyperparameters in Figure 3 , we extend the M-step update to as α l = γ l/∥ θl ∥ 2 2 where l indexes each layer's attributes, as done in (Mackay, 1992; Tipping, 2001) . For the MAP, diagonal covariance and KFAC covariance baselines, we use the same pre-trained models when possible (i.e. not for the ensembles or dropout baselines). Since all baselines share the same linearisation point, they also share the same mean predictions. Differences in performance among baselines are thus only due to differences in uncertainty estimation. The diagonal approximation to the covariance is constructed by first computing the diagonal of the Hessian M and the inverting it. For the KFAC covariance approximation, we exploit the equivalency between the Generalised Gauss Newton matrix (i.e. the Hessian of the linear model h) and the Fisher information matrix for exponential family likelihoods (i.e. the categorical). This allows us to formulate the Hessian as an expectation of likelihood gradients, which in turn we approximate using a single sample per training observation, as in (Daxberger et al., 2021a) . For completeness, we also state the probit approximation for sampled predictive posteriors over logits. For input x and samples ζ i , . . . , ζ k , the predictive probability for class i ∈ ∥1, . . . , m∥ is given by softmax   f ( w, x) ⊙ (1 + π 2k j<k (ϕ(x)ζ j ) ⊙2 ) ⊙-0.5   i . G.2 CIFAR100 CLASSIFICATION CIFAR100 m=100 way classification experiments were performed using ResNet18 models (d ′ ≈ 11M ) with specific architecture details matching the PyTorch torchvision implementation. We train these models using SGD with momentum of 0.9 for 300 epochs. The starting lr is 0.1 and we reduce it by a factor of 10 every 100 epochs. The CIFAR100 dataset was also downloaded using torchvision and our data preprocessing and augmentation also follow the default implementation from this library. For posterior mode optimisation and sampling, we do not perform data augmentation. The SGD details used to solve the convex optimisation problems required for obtaining posterior modes and drawing samples are as follows • Posterior mode optimisation: The linearised NN weights are trained using SGD with Nesterov momentum (=0.9) and a batch size of 2000 for 40 epochs. We employ a linear decay learning rate schedule with an initial learning rate of 1e-1. It is decreased by a factor of 330 over the first 75% of training, and then held constant. We also employ gradient clipping with maximum norm= 0.1. • Sampling: We optimise 6 samples in parallel using SGD with Nesterov momentum (=0.9) and a batch size 100 for 20 epochs. All other details match those of posterior mode optimisation. Upon convergence of the EM algorithm, we draw 64 further samples using the optimal prior precision by following the optimisation procedure described above. We initialise these samples at prior samples drawn with the optimised prior precision. Hyperparameter optimisation: We tuned the learning rate, decay rate and gradient clipping strength using a rough grid search over orders of magnitude. We also chose the largest batch size that for which we could simultaneously optimise 6 samples in parallel on a single hardware accelerator. Similarly to the MNIST experiments, we did not optimise the number of optimisation epochs and instead chose large values that would ensure convergence. It is likely that our EM iteration can be sped up by decreasing the duration of the convex optimisation routines. Details for baselines and hyperparameters not mentioned explicitly in this subsection match those given for MNIST in the previous subsection. G.2.1 EFFICIENT κ-ADIC SAMPLING Osband et al. (2022; 2021) introduced dyadic test input sampling (κ = 2) as a practical way of evaluating joint predictions in discriminative tasks. This method samples κ = 2 random anchor points from the test dataset, and then randomly resamples them to create a batch of size τ = 10. Test log-likelihood is evaluated jointly for each batch as log exp   i≤τ ℓ(y i , f (θ, x i ))   dΠ, for f the model being evaluated and Π its posterior distribution over model parameters. This quantity can be estimated with posterior samples ζ 1 , . . . , ζ k ∼ Π as log 1 k j≤k exp   i≤τ ℓ(y i , f (ζ j , x i ))   . We extend this evaluation approach to larger κ and τ values without increasing computational cost. We randomly sample κ integers {b 1 , . . . , b κ } such that they sum to τ , i.e κ i k i = τ . The joint log-likelihood over the batch of size τ with κ unique datapoints can then be estimated as log 1 k j≤k exp   l≤κ b l ℓ(y l , f (ζ j , x l ))   . where the inner sum is over the κ distinct elements in the batch instead of the "total batch size" τ . This is equivalent to the formulation proposed in Osband et al. (2022) for dyadic sampling, when κ = 2 and τ = 10. We note that it is not possible to achieve augmented dyadic sampling, as described in (Osband et al., 2021) , with this approach. However the authors mention that there is not a significant difference in the relative performance of methods when using augmented dyadic sampling compared to regular dyadic sampling. We introduce a final step however, which is to repeat the computation for multiple shuffles (10) of the test dataset. This eliminates variance in our results from the choice of the κ observations which get grouped together in each batch.

G.3 TOMOGRAPHIC RECONSTRUCTION

Problem setup Tomographic reconstruction consists in solving a linear inverse problem in imaging where we observe a set of measurements y ∈ R m , which we assume to be generated as y = U x * + η for U ∈ R m×d the discrete Radon transform, x * ∈ R d the image to reconstruct and η ∼ N (0, I) random noise. We have m ≪ d, making the problem underconstrained. Our specific tomographic reconstruction task closely follows the one from Barbano et al. (2021) . We perform a sparse-view reconstruction of an image of a slice of a walnut from a sub-sampled set of measurements. Specifically, from the full measurement set of (Der Sarkissian et al., 2019a) , which containing scans at 1200 equidistant angles over [0, 360 • ), we choose our measurement set by subsampling angles by a factor of either 10x or 20x, leading to measurements of size m = 15360 or m = 7680. As in Barbano et al. (2021) ; Antorán et al. (2022b) ; Barbano et al. (2023) , we reduce the original 3D scan geometry to the 2D slice of interest by selecting the relevant subset of measurement pixels. We assemble the Radon operator U as a sparse matrix taking in an image of resolution (501px) 2 and outputting a measurement tensor coherent with the described geometry. Methods To provide a reconstruction, we use the Deep Image prior (Ulyanov et al., 2020) which trains the parameters w ∈ R d ′ of a fully convolutional U-Net autoencoder f : R d ′ → R d , where the input is fixed and thus absent from our notation, until a satisfactory reconstruction f ( w) is obtained. The U-Net network architecture is the one proposed in (Baguer et al., 2020) . The optimisation of the U-Net parameters follows Barbano et al. (2021) , although we note that faster optimisation strategies exist (Barbano et al., 2023) . To estimate the uncertainty in this reconstruction, we linearise the U-Net around w, as described in Section 4.1. This leaves us with a model affine in the parameters and with design matrix U Φ ∈ R m×d ′ . We may now proceed with linear model inference. While (Antorán et al., 2022b) use the traditional EM iteration described in Section 2, we use the sample-based one from Section 3. For evaluation, we use the non-sparse reconstruction (using 1200 angles) provided by (Der Sarkissian et al., 2019a) as the ground truth image x * . To evaluate joint log-likelihoods we estimate the predictive covariance matrix for patches of neighbouring pixels using samples. Covariance matrix estimates from samples are known to be unreliable. We use the stabilised formulation of (Maddox et al., 2019)  : Σ = 1 2k k j=1 x2 j + xj xT j for (x j ) k j=1 samples from the predictive posterior over a patch. We then construct predictive distributions over pixels as N (f ( w), Σ). Hyperparameters We employ a low CG tolerance of 1e -3 and a maximum number of iterations of 150, which is never reached in practise as the error tolerance level is always hit in less steps.

H CALIBRATION OF PREDICTIVE DISTRIBUTIONS

This appendix evaluates the calibration of the predictive distributions provided by the methods under consideration in our CIFAR100 classification experiment (Section 5.2) and U-net image reconstruction experiment (Section 5.3). For classification, we separate our predicted probabilities into 10 equal width bins between 0 and 1. For each bin, we plot the proportion of targets that coincide with the class for which the predicted probability falls into the bin. This is shown on the right hand side of Figure 8 . Consistent with the results from the main text, KFAC overestimates uncertainty at all confidence levels whereas MAP underestimates it. Both sample-based linearised Laplace and ensembling show significantly improved calibration. While ensembles show a small amount of uncertainty overestimation consistently, our method underestimates uncertainty for low predicted probabilities and overestimates it for large predicted probabilities. For image reconstruction regression, we first compute normalised residuals by subtracting our predictions from the targets and dividing by the predictive standard deviation. Our predictive distribution for these normalised residuals is the centered unit variance Gaussian. We consider posterior credible intervals centered at 0 and of increasing width and plot the proportion of test points that fall within them in the left side plot of Figure 8 . We find dropout inference to underestimate the magnitude of the residuals across all credible interval widths. Linearised inference with a single EM step, as in (Antorán et al., 2022b) , consistently overestimates uncertainty. Our approach, which performs 5 steps of EM, overestimates uncertainty, but to a much smaller degree, presenting the best overall calibration. The latter two approaches consist of the same model but with different regularisation strength. The difference between the two reveals the paramount importance of tuning the prior precision hyperparameter well.

I ADDITIONAL EXPERIMENTS

This appendix contains additional experiments and baselines that supplement the experimental results provided in the main text.

I.1 COMPARING PRIMAL AND DUAL EFFECTIVE DIMENSIONALITY ESTIMATORS

Our main-text experiments employ the kernelised effective dimension estimator introduced in (5). A different unbiased estimator may be obtained in primal form following the derivation provided in Appendix B.1. Figure 9 compares both estimators when applied to the 1d toy problem used to generate Figure 1 from the main text. In particular, we use a 2 hidden layer MLP with layernorm after every hidden layer and the "Matern" dataset of Antorán et al. (2020) . We use 8 samples from the exact linearised Laplace posterior to compute effective dimension estimates and repeat this procedure 1000 times to characterise the behaviour of each estimator. As a reference, we also compute the exact effective dimension using eigendecomposition. Both estimators present distributions centered at the true effective dimension value. However, the prediction space (kernelised) estimator presents a much lower variance of 9.16 as opposed to 654.19 from the weight space estimator. Additionally, the weight space estimator distribution places a substatial amount of probability mass on negative effective dimension values. From the form of ( 19), we see that this is due to our 8-sample estimator overestimating posterior variance. On the other hand, the kernelised estimator in (5) can only produce positive values. q) , var: 654.19 pred space estimator g ( f ) , var: 9.16 19) and the kernelised (prediction space) estimator (5). Both distributions are roughly centered at the true effective dimension but the kernelised estimator presents much lower variance.

I.2 EVALUATING APPROACHES TO UPDATING HYPERPARAMETERS IN THE M-STEP

This section empirically motivates the fixed-point iteration M-step introduced by Mackay (1992), and described in Section 3, by comparing it with alternative approaches to updating hyperparameters. In particular, we compare Mackay's update with the standard Laplace M-step evidence, denoted M and given in (3) and Appendix B.2, and a Gaussian ELBO with optimised mean and covariance. The latter two objectives differ in that the regulariser appears inside of the log-determinant term in M, while the ELBO's covariance does not change with the regulariser while performing the M-step. Both of these objectives differ from the Mackay update in that they provide an objective which requires gradient-based optimisation in the M-step. Instead, the Mackay update has a closed-form. The plot on the left of Figure 10 compares the exact linearised Laplace evidence for a 2 hidden layer MLP with layernorm trained on the toy dataset of Antorán et al. (2020) with the bound M (3) and the decoupled ELBO. The initial regulariser is set to α = 500. The ELBO is only tight for regulariser values very close to initialisation, resulting in very small M steps. M is tangent to the evidence at the same point as the ELBO but presents a much better approximation as we move away from α = 500. The optimum of M is much closer to the optimum of the evidence. The Mackay update does not use a lower bound but instead provides an updated value for α which is even closer to the optimum of the evidence. The right hand side plot shows the change in the regularisation parameter across successive M-steps using the update methods under consideration. The Mackay M-step converges to the optima of the evidence in 2 steps. Using M as an objective results in convergence after 5 steps. On the other hand, the ELBO update requires around 100 steps. Figure 11 further illustrates hyperparameter learning in the 1d toy setting by showing the successive lower bounds obtained by each of the approaches under consideration at each M-step. Interestingly, the Mackay update produces regulariser updates that almost exactly maximise M. Figure 12 compares the evidence lower bounds of the form of M, given in (3), when using different covariance matrix approximations in the MNIST classification setup presented in Section 5.1. In particular, we consider the full-covariance Laplace evidence, which we note does not match the exact model evidence due to the non-quadratic classification loss, the KFAC approximation to the covariance (labelled KFAC GGN), a single-sample KFAC Fisher estimate of the covariance, the KFAC empirical Fisher matrix (Immer et al., 2021a) , and a diagonal Laplace covariance. We also include a 16 sample estimate of the ELBO described above. In all cases, we initialise the regulariser at an optima found by applying the EM algorithm while using the full covariance M in the M-step. In this way, we may use the deviation of different objectives' optima from the optima of M as estimates of the bias in their corresponding approximations. Figure 12 shows the KFAC and KFAC-Fisher approximations result in a systematic overestimation of the evidence optima which grows with model size. This issue is even more pronounced for the diagonal covariance approximation. On the other hand, we find the empirical Fisher to provide an accurate approximation. A similar finding is reported by (Immer et al., 2021a) . This is surprising, given that the empirical Fisher is known to provide a heavily biased estimate of loss curvature and thus perform poorly for optimisation tasks (Kunstner et al., 2019) . The sample-based ELBO shows close to no bias when using 16 samples. This result agrees well with our experiments from Section 5.2, where the sample-based EM algorithm behaves well even when using very few samples. Figure 12 : Full covariance linearised Laplace evidence M together with approximations to this curve that rely on different covariance matrix approximations. We consider convolutional networks of increasing size (left to right) trained on the MNIST dataset.

I.3 CIFAR100 CLASSIFICATION

Additional Baselines . In the main text, we report the test log-likelihood obtained by our method as well as that of a point-estimated NN (MAP), an ensemble of 5 of point-estimate NNs (Ensemble 5), and linearised Laplace with a KFAC-approximated posterior covariance matrix (KFAC). Here, we report further comparisons with other baselines standard-in-literature: a diagonal approximation of the Laplace covariance matrix (diag), a Laplace approximation over a selected subset of the full NN weight space (subnetwork*) (Daxberger et al., 2021b) , and a Laplace approximation over only the last-layer weights of the NN with a KFAC covariance matrix approximation (KFAC-LL*) (Eschenhagen et al., 2021) . Note that the last layer contains 51200 weights and thus its full Laplace covariance matrix is too large to invert on our A100 GPUs. We distinguish the last two methods with a star (*) to denote that they require cross-validation with a held-out set to tune the regularisation strength hyperparameter. In particular, we use 50 held-out points from the test set, and evaluate these methods on the remaining 9950 points. This gives these methods a slight advantage over the approaches to uncertainty quantification considered in the main text. Similarly to the main text, we estimate the KFAC and sampling posterior distributions with 64 samples. We use exact marginalisation, computing full Jacobians, for the diagonal covariance approximation. on the CIFAR100 dataset. We note that KFAC-LL and subnetwork inference require a held-out validation set to tune hyperparameters and thus we mark them with a star (*).

Robustness to distribution shift

We provide the test log-likelihood results obtained with all methods under data corruption of increasing intensity in Figure 13 . Following the standard setup in the literature (Daxberger et al., 2021b; a; Eschenhagen et al., 2021) , we employ the multi-class probit approximation to map Gaussian posteriors over NN outputs to class probabilities for all methods except ensembles and MAP. However, when combined with our sampling approach, we find the probit approximation to overestimate uncertainty in-distribution. We illustrate this by plotting an additional curve for sample-based inference with Monte-Carlo marginalisation of the Gaussian distribution over NN outputs. This approach provides stronger in-distribution performance which comes very close to that of ensembles, subnetwork inference (*), and KFAC-LL (*). The strong performance of the latter two methods reveals the relevance of selecting a good regularisation strength parameter to uncertainty quantification with Laplace-style methods. In the out-of distribution setting, the probit's increased uncertainty results in larger log-likelihood scores than Monte Carlo Marginalisation. KFAC-LL performs very competitively both with ensembles in-distribution and with our approach in the out-of-distribution setting. Joint LL. We report marginal and joint test log-likelihood for the KFAC-LL and subnetwork inference baselines in Table 3 . We use the same κ-adic sampling setup as in Section 5.2, marginalising the Gaussian posterior over network outputs with Monte Carlo for all methods. KFAC-LL is once again quite competitive with our approach in terms of both marginal and joint LL. Predictive uncertainty vs number of samples. In the main text, we report our method's predictive performance when drawing 64 samples. In Figure 14 , we plot the degradation in test log-likelihood for the standard and progressively corrupted CIFAR100 test sets when decreasing the number of samples used for prediction. We provide results for both Monte Carlo and probit marginalisation. Our results, show two main trends: 1. Monte Carlo marginalisation provides better results in-distribution for all numbers of samples. This is coherent with our above observation that the probit approximation results in uncertainty overestimation. On the other hand, Monte Carlo marginalisation is unbiased. 2. The probit approximation benefits less from increased numbers of samples. This is expected, since MCDO systematically underestimates uncertainty for the pixels where the reconstruction error is largest. Interestingly, our method shows to be slightly worsely calibrated in the more data-rich setting, as the reconstruction error decreases faster than the predictive standard deviation. Published as a conference paper at ICLR 2023 -UNet and MCDO-UNet. We also include the histograms of both methods' predictive standard deviations across pixels.



EVIDENCE MAXIMISATION USING STOCHASTIC APPROXIMATIONWe now present our main contribution, a stochastic approximation(Nielsen, 2000) to the iterative algorithm presented in the previous section. Our M-step requires only access to samples from Π. We introduce a method to approximate posterior samples through stochastic optimisation for the E-step. The EM procedure from Section is for the conjugate Gaussian-linear model, where it carries guarantees on non-decreasing model evidence, and thus convergence to a local optimum. These guarantees do not hold for non-conjugate likelihood functions, e.g., the softmax-categorical, where the Laplace approximation is necessary.



Figure1: Illustration of our procedure for a fully connected NN on the toy dataset ofAntorán et al.  (2020). Top: prior function samples present large std-dev. (left). When these samples are optimised (middle shows a 2D slice of weight space), the resulting predictive errorbars are larger than the marginal target variance (right). Bottom: after EM, the std-dev. of prior functions roughly matches that of the targets (left), the overlap between prior and posterior is maximised, leading to shorter sample trajectories (center), and the predictive errorbars are qualitatively more appealing (right).

Figure 2: Left: optimisation traces for new and existing sampling losses averaged across 16 samples and 5 seeds. Right: for batch-size=1000, traces of sample-error vs distance to exact linear predictions.

Figure 4: Left: prior precision optimisation traces for ResNet18 on CIFAR100 varying n. samples. Middle: same for the eff. dim. Right: average sample norm and posterior mean norm throughout successive EM steps' SGD runs while varying n. samples. Note that traces almost perfectly overlap.

Figure 5: Performance under distribution shift for ResNet18 and CIFAR100.

Figure 6: Left 3 plots: traces of prior precision, eff. dim., and marginal test LL vs EM steps for the tomographic reconstruction task with m = 7680. Right: joint test LL for varying image patch sizes.

Figure 7: Original 501×501 pixel walnut image and reconstruction error for a m=7680 dimensional observation, along with pixel-wise std-dev obtained with sampling lin. Laplace and MCDO.

Linearised neural networks Introduced byMackay (1992), this approximation yields closedform errorbars for Laplace posteriors.Lawrence (2000) andRitter et al. (2018) found the Laplace approximation to underperform without the linearisation step.Khan et al. (2019) andImmer et al.  (2021b)  re-popularised the linearisation step by showing that it improves the quality of uncertainty estimates.Kristiadi et al. (2020) show that the Laplace approximation is sufficient to resolve certain pathologies of point-estimated NNs' predictions.Immer et al. (2021a)  andAntorán et al. (2022a;c)   explore the linear model's evidence for model selection.Immer et al. (2022) shows the objective can even be used to learn invariances in deep models.Daxberger et al. (2021b) andMaddox et al. (2021)

Optimised precision α and weight samples ζ 1 , . . . , ζ k Algorithm 3: Kernelised sampling-based linearised NN inference for CT reconstruction Inputs: Linearised network h, linearisation point w, measurements Y , discrete Radon transform U , U-Net Jacobian Φ, initial precision α > 0, number of samples k, noise precision B Function Kvp(v, α, U Φ, s, B -1 ):

Figure8: Left: empirical coverage of test targets for posterior credible intervals of increasing width for our U-net tomographic reconstruction experiment (Section 5.3). Right: confidence vs accuracy plot (also known as a reliability diagram) for our CIFAR100 classification experiment (Section 5.2).

Figure 9: Histogram, with bin heights normalised to represent density estimates, of the effective dimension estimates produced by the primal form (weight space) estimator (19) and the kernelised (prediction space) estimator (5). Both distributions are roughly centered at the true effective dimension but the kernelised estimator presents much lower variance.

Figure10: Left: exact linear model evidence for a linearised 2 hidden layer MLP with layer normalisation together with the lower bound presented in (3), M, and an ELBO where the Gaussian posterior covariance is decoupled from the regulariser. All curves use an initial regulariser of α = 500 and have a marker placed at their optima. Right: values of the regularisation strength α obtained successive EM iterations while using the different update strategies under consideration for the M step. Note that when we assume access to the exact evidence function, the regulariser converges in a single step and no EM iteration is necessary.

Figure11: Exact linear model evidence for a linearised 2 hidden layer MLP with layer normalisation together with the lower bound presented in (3), M (left and middle plots), and an ELBO, where the Gaussian posterior covariance is decoupled from the regulariser (right hand side plot), at different EM steps. We update the regularisation strength with Mackay's fixed point iteration for the left side plot. Note that M curves are shown in this plot. We maximise M in the middle plot and we maximise the ELBO in the right hand side plot. All curves use an initial regulariser of α = 5 and we place a vertical dashed line at each step's update.

Figure Performance under distribution shift for additional inference baselines applied to ResNet18 on the CIFAR100 dataset. We note that KFAC-LL and subnetwork inference require a held-out validation set to tune hyperparameters and thus we mark them with a star (*).

Figure 15: Sample-based EM iteration convergence for tomographic reconstruction given m = 15360. The prior precision α (left), the effective dimension γ (middle left), and the marginal test loglikelihood (LL) (middle right) are plotted as a function of the EM step. The plot on the right shows the joint test LL across image patches of increasing size for sampling inference and an MC dropout baseline (MCDO).

Figure 16: Original 501×501 pixel walnut image and reconstruction error for a m=15360 dimensional observation, along with pixel-wise std-dev obtained with sampling lin. Laplace and MCDO.

Figure17: Histogram of the absolute pixelwise error computed between the reconstructed walnut image given m = 7680 observations and the ground-truth for both lin.-UNet and MCDO-UNet. We also include the histograms of both methods' predictive standard deviations across pixels.

At convergence, that is, when z = ζ, a more involved calculation presented in Appendix C.3 shows that L ′ is preferred if 2αγ > Tr M.

Tomographic reconstruction: test LL and wall-clock times (A100 GPU) for both data sizes.

, among others. Alas, linear models are held back by a cost of inference cubic in the number of parameters when expressed in primal form, or cubic in the number of observations for the dual (i.e. kernelised or Gaussian Process) form. Additionally, for non-Gaussian likelihoods, e.g. in classification, inference is no longer closed form. The most common approximations used in these settings are Laplace's method (Mackay, 1992) and variational inference(Hensman et al., 2013).Khan et al. (2019) andAdam et al. (2021) show that every Gaussian approximation corresponds to the true posterior of a surrogate regression problem with the same features, a fact which we use in this work to apply sample-then-optimise to Laplace posteriors.Sample-then-optimise Papandreou & Yuille (2010); de G.Matthews et al. (2017) phrase sampling from a conjugate Gaussian-linear model as solving a perturbed quadratic optimisation problem. This method has been applied for uncertainty estimation in non-linearised NNs byOsband et al. (2018;  2021), and Pearce et al. (

ACKNOWLEDGEMENTS

The authors would like to thank Alex Terenin and Marine Schimel for helpful discussions. JA acknowledges support from Microsoft Research, through its PhD Scholarship Programme, and from the EPSRC. SP acknowledges support from the Harding Distinguished Postgraduate Scholars Programme Leverage Scheme. JMHL acknowledges support from a Turing AI Fellowship under grant EP/V023756/1. This work has been performed using resources provided by the Cambridge Tier-2 system operated by the University of Cambridge Research Computing Service (http://www.hpc.cam.ac.uk) funded by an EPSRC Tier-2 capital grant. This work was also supported with Cloud TPUs from Google's TPU Research Cloud (TRC).

annex

Published as a conference paper at ICLR 2023 C.2 LOSS GRADIENT VARIANCE CONDITION Taking j ∼ Unif({1, . . . , n}), the gradient estimators for the data-dependent terms of L and L ′ are ĝ = n∇∥ϕ(x j )zε j ∥ 2 Bj = j ) T B j (ϕ(x j )zε j ) (47) and ĝ′ = n∇∥ϕ(x j )z∥ 2 Bj = nϕ(x j ) T B j ϕ(x j )z ,respectively. Their variances are related as Var(ĝ) = Var(nϕ(x j ) T B j (ϕ(x j )zε j ))= Var(nϕ(x j ) T B j ϕ(x j )z) + Var(nϕ(x j ) T B j ε j )Evaluating the variance and covariance, we haveand thus Varĝ -

C.3 CONDITION AT CONVERGENCE

Now consider Tr ∆ for z = ζ ∼ Π 0 , the optimum of both L and L ′ . From the first order optimality condition,Proceeding to rearrange the condition at z = ζ,where we substituted in the definition of ζ, then used that E and θ 0 are independent, and that E[θ 0 ] = 0, and finally thatWriting M = Φ T BΦ and recalling that H = (M + A), we havewhere we have used thatfor the fourth equality. Now consider the isotropic prior case A = αI and recall the effective dimension is written as γ = Tr {M (M + A) -1 }. The above implies Tr ∆ > 0 if and only if 2αγ > Tr Φ T BΦ. The regularisation strength converges faster in this more data-rich setting than in the 60 angle setting considered in the main text (Figure 6 ), with convergence occurring after 1 EM step instead of 2. We see a slightly larger sensitivity to the number of samples in this larger setting. However, the difference in test LL obtained after running stochastic EM with 2 and 256 samples remains smaller than 0.01 nats.

Further analysis of uncertainty calibration

The rightmost plot in Figure 15 compares the joint test log-likelihood obtained by our method and MC dropout on image patches of increasing size when the observation dimension is set to m = 15360. Similarly to the results shown in the main text, our method performs better across all patch sizes. Qualitatively, Figure 16 shows the sample-based approach to yield uncertainty estimates with a much larger dynamic range; some pixel regions are assigned large errorbars, while others are assigned small errorbars. MCDO produces less fine-grained outputs and assigns relatively small errorbars to whole sections of the image.Finally, Figure 17 and Figure 18 , compare the reconstruction error and uncertainty histograms for both uncertainty quantification methods under consideration for both the m = 7680 and m = 15360 settings. In both plots, sample-based linearised Laplace inference slightly overestimates uncertainty.

