EXPRESSIVE YET TRACTABLE BAYESIAN DEEP LEARNING VIA SUBNETWORK INFERENCE

Abstract

The Bayesian paradigm has the potential to solve some of the core issues in modern deep learning, such as poor calibration, data inefficiency, and catastrophic forgetting. However, scaling Bayesian inference to the high-dimensional parameter spaces of deep neural networks requires restrictive approximations. In this paper, we propose performing inference over only a small subset of the model parameters while keeping all others as point estimates. This enables us to use expressive posterior approximations that would otherwise be intractable for the full model. In particular, we develop a practical and scalable Bayesian deep learning method that first trains a point estimate, and then infers a full covariance Gaussian posterior approximation over a subnetwork. We propose a subnetwork selection procedure which aims to maximally preserve posterior uncertainty. We empirically demonstrate the effectiveness of our approach compared to point-estimated networks and methods that use less expressive posterior approximations over the full network. D d=1 q(w d ) where w d denotes the d-th weight in the D-dimensional neural network weight vector W ∈ R D (the concatenation and flattening of all layers' weight matrices). Clearly, this is a very wishful assumption; In practise, it suffers from severe pathologies (Foong et al., 2019a; b) . In this work, we question the implicit assumption that a good posterior approximation needs to include all BNN parameters. Instead, we aim to perform inference only over a small subset of the weights. This approach is well-motivated for two reasons: 1. Overparameterization: Maddox et al. ( 2020) have shown that, in the neighborhood of local optima, there are many directions that leave the NN's predictions unchanged. Moreover, NNs can be heavily pruned without sacrificing test-set accuracy (Frankle & Carbin, 2019) . Thus, the majority of a NN's predictive power might be isolated to a small subnetwork.

1. INTRODUCTION

Deep neural networks (DNNs) still suffer from critical shortcomings that make them unfit for important applications. For instance, DNNs tend to be poorly calibrated and overconfident in their predictions, especially when there is a shift in the train and test distributions (Nguyen et al., 2015; Guo et al., 2017) . To reliably inform decision making, DNNs must be able to robustly quantify the uncertainty in their predictions, which is particularly important in safety-critical areas such as healthcare or autonomous driving (Amodei et al., 2016; Filos et al., 2019a; Fridman et al., 2019) . Bayesian modeling (Ghahramani, 2015; Gal, 2016) presents a principled way to capture predictive uncertainty via the posterior distribution over model parameters. Unfortunately, due to their nonlinearities, exact posterior inference is intractable in DNNs. Despite recent successes in the field of Bayesian deep learning (Blundell et al., 2015; Gal & Ghahramani, 2016; Osawa et al., 2019; Maddox et al., 2019; Dusenberry et al., 2020) , existing methods are only made scalable to modern DNNs with large numbers of parameters by invoking unrealistic assumptions. This severely limits the expressiveness of the inferred posterior and thus deteriorates the quality of the induced uncertainty estimates (Ovadia et al., 2019; Fort et al., 2019; Foong et al., 2019a; Ashukha et al., 2020a) . Due to the heavy overparameterization of DNNs, their accuracy is well-preserved by a small subnetwork (Cheng et al., 2017) . Additionally, recent work by Izmailov et al. (2019) has shown how performing inference over a low dimensional subspace of the weights can result in accurate uncertainty quantification. These observations prompt the following question for a DNN's uncertainty: Can a full DNN's model uncertainty be well-preserved by a small subnetwork's model uncertainty? We answer this question in the affirmative. We show both theoretically and empirically that the full network posterior can be well represented by a subnetwork's posterior. As a result, we can use more expensive but faithful posterior approximations over just that subnetwork. We show that this achieves better uncertainty quantification than if we use cheaper, but more crude, posterior approximations over the full network. The contributions of this paper are as follows: 1. We propose a new Bayesian deep learning approach that performs Bayesian inference over only a small subset of the model weights and keeps all other weights deterministic. This allows us to use expressive posterior approximations that are typically intractable in DNNs. 2. As a concrete instantiation of this framework, we develop a practical and scalable Bayesian deep learning method that uses the linearized Laplace approximation to infer a fullcovariance Gaussian posterior over a subnetwork within a point-estimated neural network. 3. We formally characterize the discrepancy between the posterior distributions over a subnetwork and the full network (in terms of their Wasserstein distance) in the linearized model, and derive a theoretically motivated strategy to select a subnetwork that minimizes this discrepancy under certain assumptions. 4. We empirically show, on various benchmarks, that our method compares favourably against point-estimated networks and other Bayesian deep learning methods, experimentally confirming that expressive subnetwork inference is superior to crude inference over full networks.

2. SUBNETWORK POSTERIOR APPROXIMATION

Bayesian neural networks (BNNs) aim to capture model uncertainty, i.e., uncertainty about the choice of weights W which arises due to multiple plausible explanations of the training data {y, X}. Here, y is the dependent variable (e.g. classification label) and X is the feature matrix. A prior distribution p(W) is specified over the BNN's weights. We wish to infer their full posterior distribution p(W|y, X) ∝ p(y|X, W) p(W) . (1) To make predictions, we then estimate the posterior predictive distribution that averages the network's predictions across all possible settings of the weights, weighted by their posterior probability, i.e.

p(y

* |X * , y, X) = W p(y * |X * , W)p(W|y, X)dW . Unfortunately, due to the size of modern deep neural networks, it is not only intractable to infer the exact posterior distribution p(W|y, X) in Eq. ( 1), but it is even computationally challenging to properly approximate it. As a consequence, crude posterior approximations such as complete factorization are commonly employed (Blundell et al., 2015; Hernández-Lobato & Adams, 2015; Kingma et al., 2015; Khan et al., 2018; Osawa et al., 2019) , i.e. p(W|y, X) ≈ 2. Inference over submodels: Previous workfoot_0 has provided evidence that inference can be effective even when not done on the full parameter space. Izmailov et al. (2019) performed inference over a low-dimensional projection of the weights. Neural-linear models, which give a Bayesian treatment to only the last layer of a DNN, have shown to be competitive with full-network approaches (Riquelme et al., 2018; Ober & Rasmussen, 2019) . We thus combine these ideas, making the following two-step approximation of the posterior in Eq. (1): p(W|y, X) ≈ p(W S |y, X) r δ(w r -w * r ) ≈ q(W S ) r δ(w r -w * r ) . The first approximation decomposes the full neural network posterior p(W|y, X) into a posterior p(W S |y, X) over the subnetwork W S and delta functions δ(w rw * r ) over all remaining weights {w r } r , keeping them at fixed values w * r ∈ R. This can be viewed as pruning the variances of the weights {w r } r to zero, which is in contrast to ordinary weight pruning methods that set the weights {w r } r themselves to zero. The second approximation is a result of posterior inference over the subnetwork still being intractable. In turn, we introduce the approximate distribution q(W S ). Yet, as the subnetwork is much smaller than the full network, we can afford to make q(W S ) expressive and able to capture rich dependencies across the weights within the subnetwork.

3. SUBNETWORK INFERENCE VIA LAPLACE APPROXIMATION

To obtain a method that is as practical as possible, we propose to use inference techniques that can estimate a posterior distribution post-hoc from a point-estimated network. The Laplace approximation (MacKay, 1992) is well-suited to this task as it derives the approximate posterior from the local optimization landscape. Other inference procedures, such as SWAG (Maddox et al., 2019) , could also be used. Nevertheless, we focus on Laplace due to it being a well-studied, fundamental technique. Step #1: Point Estimation. The first step of the procedure is to train a neural network to obtain a point estimate of the weights, denoted W M AP . This estimate should respect the Bayesian model given in Eq. ( 1), and therefore we optimize the maximum a-posteriori (MAP) objective: W M AP = arg max W [log p(y|X, W ) + log p(W )] . This can be done using standard stochastic gradient-based optimization methods commonly-used in modern deep learning (Goodfellow et al., 2016) . This step is illustrated in Fig. 1 (a). Step #2: Subnetwork Selection. The second step is to identify a small subnetwork W S . Ideally, we would like to identify the subnetwork whose posterior is 'closest' to the full-network posterior. We formalize this argument in Section 4 and describe a principled strategy that, under certain conditions, minimizes the 2-Wasserstein distance between the sub-and full-network posteriors. All other weights not belonging to that subnetwork are then assigned fixed values: the MAP estimates obtained in Step #1. See Fig. 1 (b ) for an illustration of this step. Step #3: Bayesian Inference. Given the subnetwork point estimate W S M AP , we use the Laplace approximation to infer a full-covariance Gaussian posterior distribution over the subnetwork W S : p(W S |y, X) ≈ q(W S ) = N W S ; W S M AP , H -1 (5) where the posterior covariance matrix H -1 ∈ R D×D corresponds to the inverse of the average Hessian of the negative log posterior, i.e. H = N E -∂ 2 log p(y|X, W )/∂W 2 + λI. Here, the expectation is w.r.t. the data generating distribution and λ is the precision of a zero-mean factorized Gaussian prior p(W) = N (W; 0, λ -1 I). In practice, we approximate the Hessian H with the generalized Gauss-Newton (GGN) matrix H (Schraudolph, 2002), i.e. H = N n=1 J n H n J n + λI, with J n = ∂f (x n , W ) ∂W and H n = ∂ 2 L(y n , f (x n , W )) ∂ 2 f (x n , W ) where J n ∈ R O×D is the Jacobian of the neural network features f (x n , W ) ∈ R O w.r.t. the weights W , and H n ∈ R O×O is the Hessian of the loss L(y n , f (x n , W )) w.r.t. the features f (x n , W ).The GGN H has clear practical advantages over the Hessian H; see Martens & Sutskever (2011) and Martens (2016) . Using the Laplace approximation with the GGN Hessian can be viewed as an implicit local linearization of the underlying neural network f (x, W ) at its MAP estimate W M AP , f M AP lin (x, W ) = f (x, W M AP ) + J W M AP (x)(W -W M AP ) where Immer et al., 2020) . Note that the model in Eq. ( 7) is linear in W , as only the term J W M AP (x)W depends linearly on W , while the other terms are constant w.r.t. W and can thus be subsumed into an additive bias term (Khan et al., 2019) . J W M AP (x) = ∂f (x, W M AP )/∂W M AP ∈ R O×D ( The GGN approximation thus locally turns the underlying probabilistic model from a Bayesian neural network into a (generalized) linear model, with basis function expansion J W M AP (x) of covariate x (Immer et al., 2020) . Put differently, linearized Laplace in the neural network f (x, W ) is equivalent to ordinary Laplace in the linear model f M AP lin (x, W ) in Eq. ( 7), as the GGN H corresponding to f (x, W ) in Eq. ( 6) is equivalent to the Hessian H corresponding to f M AP lin (x, W ) in Eq. ( 7) (Khan et al., 2019) . This is a useful property that will allow us to derive a principled subnetwork selection strategy in Section 4. This step is illustrated in Fig. 1 (c ). We emphasize that this whole procedure (i.e. Steps #1-#3) is a perfectly valid mixed inference strategy, performing full Laplace inference over the selected subnetwork and MAP inference over all remaining weights. Step #4: Prediction. Given the linearized Laplace approximation over the subnetwork W S in Eqs. ( 5) and ( 6), i.e. q(W S ) = N (W S ; W S M AP , H -1 ), we can then compute the posterior predictive distribution. While, traditionally, one would compute the predictive distribution using the original Bayesian neural network likelihood, i.e. p(y|X, W) = p(y|f (x, W )), Immer et al. (2020) recently suggested that, since inference was (implicitly) done in the GGN-linearized model, it is more principled to instead predict using the linearized likelihood Eq. ( 7), i.e. p(y|X, W) = p(y|f M AP lin (x, W )). This provides a formal justification for the empirical superiority of this approach observed previously (Lawrence, 2001; Foong et al., 2019b) . We thus obtain the linearized predictive distribution p(y * |X * , y, X) ≈ W p(y * |f M AP lin (X * , W )) N (W S ; W S M AP , H -1 ) r δ(w r -w * r ) dW . (8) There are two ways to compute Eq. ( 8): Firstly, via a Monte Carlo approximation p(y * |X * , y, X) 1 M M m=1 p(y * |f M AP lin (X * , W m ) ) by sampling W m from N (W S M AP , H -1 ) and r δ(w rw * r ), the latter of which is trivial. Secondly, due to linearity of p(y * |f M AP lin (X * , W )), there are closedform expressions which are exact for Gaussian likelihoods (i.e. regression) and approximate for categorical ones (i.e. classification) (Bishop, 2006; Gibbs, 1998) . This step is illustrated in Fig. 1 (d) .

4. PRINCIPLED SUBNETWORK SELECTION FOR LINEAR(IZED) MODELS

We next analyze the subnetwork inference procedure described in Section 3 for the case of a generalized linear model (GLM) (Nelder & Baker, 1972) , which models the expected response y n given the basis function expansion of the covariates φ n = φ(x n ) as E[y n |φ n ] = g -1 (w T φ n ). Here, w ∈ R D is the vector of model parameters (which subsumes a scalar bias β 0 for notational convenience) and g -1 (•) denotes a link function such that g -1 : R → µ y|φ . In particular, we consider a Bayesian GLM, by specifying a prior distribution p(w) over model parameters and aiming to infer the posterior distribution p(w|y, Φ) ∝ p(y|Φ, w)p(w), where Φ = [φ 1 , ... φ N ] T . 1. Point Estimation. Obtain the MAP estimate, w M AP = arg max w log p(y|Φ, w) + log p(w). For commonly-used link functions (e.g. the identity in case of a Gaussian likelihood for regression, or the sigmoid/softmax function in case of a categorical likelihood for classification) and commonly-used priors (e.g. a Gaussian), the log-posterior ∝ log p(y|Φ, w) + log p(w) is concave. This allows for simple gradient-based MAP optimisation. It also makes a full-covariance Gaussian, estimated via Laplace, a faithful approximation to the true, uni-modal posterior, i.e. p(w|y, Φ) ≈ p(w|y, Φ) = N (w; w M AP , H -1 ) ( ) where H is the Hessian defined in Section 3. Note that for the GLM we consider, the Hessian H is equivalent to the GGN H defined in Eq. ( 6), meaning that an ordinary Laplace approximation is equivalent to a linearized Laplace approximation (Martens, 2016) . For the case of an identity link function (i.e. a Gaussian likelihood with noise variance σfoot_1 0 ) and a Gaussian prior w ∼ N (0, Λ -1 0 ), the MAP estimate even has a closed-form expression, w M AP = (Φ T Φ + σ 2 0 Λ 0 ) -1 Φ T y. Here, the Laplace approximation in Eq. ( 10) exactly corresponds to the true posterior, i.e. p(w|y, Φ) = p(w|y, Φ). We will thus refer to the posterior p(w|y, Φ) in Eq. ( 10) as the full posterior. 2. Subnetwork Selection. Select a subset of S model weights via a method of choice, yielding a binary vector m ∈ R D where m d = 1 if the d-th weight is part of the subset, and m d = 0 otherwise. For convenience, we define the binary mask matrix M S = mm ∈ R D×D which contains 1s in the rows/columns corresponding to the S subnetwork weights 2 , and 0s otherwise. 3. Bayesian Inference. Compute the posterior over the subnetwork via a Laplace approximation: p S (w|y, Φ) = N (w; w M AP , M S H -1 ) . (11) Firstly, note that the mean of the subnetwork posterior in Eq. ( 11) is the MAP estimate w M AP and thus equal to the mean of the full posterior p(w|y, Φ) in Eq. ( 10). Secondly, note that the covariance matrix of the subnetwork posterior in Eq. ( 11) is the element-wise product M S H -1 , which masks the (co-)variances of all weights not belonging to the subnetwork to zero, effectively making them deterministic. More precisely, the subnetwork covariance matrix, M S H -1 , is a D × D matrix that is equal to the full posterior covariance matrix H -1 in the rows/columns of the S weights in the subnetwork, and zero in the rows/columns of all other D -S weights. We consider what we term the posterior gap-the Wasserstein distancefoot_2 (in particular the squared 2-Wasserstein distance) between the posterior distribution over the full network and the posterior distribution over the subnetwork. The proofs for all results below will be presented in Appendix A. Proposition 1 (Posterior Gap). For a subnetwork of size S < D, the Wasserstein gap between the full posterior p(w|y, Φ) in Eq. ( 10) and the subnetwork posterior p S (w|y, Φ) in Eq. ( 11) is: W [ p(w|y, Φ) || p S (w|y, Φ)] = D d=1 (1 + m dd ) σ 2 d -trace(2(H -1 (M S H -1 )) 1/2 ) . (12) The optimal subnetwork should then minimize the posterior gap in Eq. ( 12). However, for full covariance matrices H -1 and a large number of weights D, this will generally be infeasible as Eq. ( 12) depends on all entries of the D × D-matrix H -1 , which is intractable to compute/store. To derive a practical subnetwork selection strategy, we assume the covariance matrix to be diagonal. Corollary 1.1 (Optimality of Maximum Variance Subnetwork Selection under Decorrelation). For a generalized linear model with posterior covariance matrix H -1 = diag(σ 2 1 , . . . , σ 2 D ), the optimal subnetwork under the Wasserstein gap is comprised of the S weights with the largest variances σ 2 d . Finally, since a GGN-linearized neural network, as in Eq. ( 7), corresponds to a GLM with basis expansion φ n = J W M AP (x n ) = ∂f (x n , W M AP )/∂W M AP (see Step #3 in Section 3), Corollary 1.1 implies that under decorrelation, the optimal subnetwork comprises of the weights with the largest variances. In practice, even just computing the diagonal of the covariance matrix is challenging, so we use a diagonal Laplace approximation which instead computes the inverse of the diagonal of the GGN (see e.g. Ritter et al. (2018) ). Finally, note that we only have to make the decorrelation assumption for the purposes of subnetwork selection -when doing posterior inference over the selected subnetwork, we estimate a full covariance matrix for maximal expressiveness, as described in Step #3 in Section 3. In our experiments in Section 5, we empirically show that making the decorrelation assumption for subnetwork selection but then using a full-covariance Gaussian for inference performs significantly better than directly making the decorrelation assumption for inference (e.g. mean-field variational inference, diagonal Laplace).

5. EMPIRICAL ANALYSIS

We empirically assess the effectiveness of subnetwork inference compared to point-estimated NNs and methods that do less expressive inference over the full network. We consider three tasks: 1) small-scale toy regression, 2) medium-scale tabular regression, and 3) large-scale image classification. Under review as a conference paper at ICLR 2021 -2 0 2 Full Cov (2600) Wass 50% (1300) Wass 3% (78) Wass 1% ( 26) MAP ( 0) -2 0 2 -2 0 2 Diag (2600) -2 0 2 Rand 50% (1300) -2 0 2 Rand 3% ( 78) -2 0 2 Rand 1% ( 26) -2 0 2 Final layer (50) Figure 2 : Predictive distributions (mean ± std) for 1D regression. The numbers in brackets denote the number of parameters over which inference was done (out of 2600 in total). Wasserstein-based subnetwork inference maintains richer predictive uncertainties at smaller parameter counts.

5.1. HOW DOES SUBNETWORK INFERENCE RETAIN POSTERIOR PREDICTIVE UNCERTAINTY?

We first assess how the predictive distribution of a full-covariance Gaussian posterior over a selected subnetwork qualitatively compares to that obtained from 1) a full-covariance Gaussian over the full network (Full Cov), 2) a factorised Gaussian posterior over the full network (Diag), 3) a fullcovariance Gaussian over only the (Final layer) of the network (Kristiadi et al., 2020) , and 4) a point estimate (MAP). For subnetwork inference, we consider both Wasserstein (Wass) (as described in Section 4) and uniform random subnetwork selection (Rand) to obtain subnetworks that comprise of only 50%, 3% and 1% of the model parameters. Note that while for this toy example, we could in principle use the full covariance matrix for the purpose of subnetwork selection, we still just use its diagonal (as described in Section 4) for consistency. Our NN consists of 2 ReLU hidden layers with 50 hidden units each. We employ a homoscedastic Gaussian likelihood function where the noise variance is optimised with maximum likelihood. We use GGN Laplace inference over network weights (not biases) in combination with the linearized predictive distribution in Eq. ( 8). Thus, all approaches considered share their predictive mean, allowing us to better compare their uncertainty estimates. All approaches share a single prior precision of λ = 3. We use a synthetic 1D regression task with two separated clusters of inputs (Antorán et al., 2020) , allowing us to probe for 'in-between' uncertainty (Foong et al., 2019b) . Results are shown in Fig. 2 . Subnetwork inference preserves more of the uncertainty of full network inference than diagonal Gaussian or final layer inference while doing inference over fewer weights. By capturing weight correlations, subnetwork inference retains uncertainty in between clusters of data. This is true for both random and Wasserstein subnetwork selection. However, the latter preserves more uncertainty with smaller subnetworks. Finally, the strong superiority to diagonal Laplace shows that making a diagonal assumption for subnetwork selection but then using a full-covariance Gaussian for inference (as we do) performs much signficantly better than making a diagonal assumption for the inferered posterior directly. These results suggest that expressive inference over a carefully selected subnetwork retains more predictive uncertainty than crude approximations over the full network.

5.2. SUBNETWORK INFERENCE IN LARGE MODELS VS FULL INFERENCE OVER SMALL MODELS

Secondly, we study the following natural question: "Why should one use subnetwork inference in a large model when one can just perform full network inference over a smaller model?" We explore this by considering 4 fully connected models of increasing size. These have numbers of hidden layers h d ={1, 2} and hidden layer widths w d ={50, 100}. For a dataset with input dimension i d , the number of weights is given by D=(i d +1)w d +(h d -1)w 2 d . Our 2 hidden layer, 100 hidden unit models have a weight count of the order 10 4 . Full covariance inference in these models borders the limit of computational tractability on commercial hardware. We first obtain a MAP estimate of each model's weights and our homoscedastic likelihood function's noise variance. We then perform full network GGN Laplace inference for each model. We also use our proposed Wassertein rule to prune every network's weight variances such that the number of variances that remain matches the size of every smaller network under consideration. In all cases, we employ the linearized predictive in Eq. ( 7). Consequently, networks with the same number of weights make the same mean predictions. Increasing the number of weight variances considered will thus only increase predictive uncertainty. We employ 3 tabular datasets of increasing size (input dimensionality, n. points): wine (11, 1439), kin8nm (8, 7373) and protein (9, 41157). We consider their standard train-test splits (Hernández-Lobato & Adams, 2015) and their gap variants (Foong et al., 2019b) , designed to test for out-ofdistribution uncertainty. Details are provided in Appendix C.4. For each split, we set aside 15% of the train data as a validation set. We use these for early stopping when finding MAP estimates and for selecting the weights' prior precision. We keep other hyperparameters fixed across all models and datasets. Results are in Fig. 3 . We present mean test log-likelihood (LL) values, as these take into account both accuracy and uncertainty. Larger models tend to perform better when doing MAP inference, with wine-gap and protein-gap being exceptions. We also find larger models improve over their respective MAP LLs more than small ones when performing approximate inference over the same numbers of weights. We conjecture this is due to an abundance of degenerate directions (weights) in the weight posterior of all models (Maddox et al., 2020) . Full network inference in small models captures information about both useful and non-useful weights. In larger models, our subnetwork selection strategy allows us to dedicate a larger proportion of our resources to modelling informative weight variances and covariances. In 3 out of 6 datasets, we find abrupt increases in LL as we increase the number of weights over which we perform inference, followed by a plateau. Such plateaus might be explained by all of the most informative weight variances having already been accounted for. These results suggest that, given the same amount of compute, larger models benefit more from subnetwork inference than small ones.

5.3. SCALING TO IMAGE CLASSIFICATION WITH DISTRIBUTION SHIFT

We now assess the robustness of large convolutional neural networks with subnetwork inference to distribution shift on image classification tasks compared to the following baselines: point-estimated networks (MAP), Bayesian deep learning methods that do less expressive inference over the full network: MC Dropout (Gal & Ghahramani, 2016) , diagonal Laplace, VOGN (Osawa et al., 2019) (all of which assume factorisation of the weight posterior), and SWAG (Maddox et al., 2019) (which assumes a diagonal plus low-rank posterior). We also benchmark deep ensembles (Lakshminarayanan et al., 2017) . The latter is considered state-of-the-art for uncertainty quantification in deep learning (Ovadia et al., 2019; Ashukha et al., 2020a) . We use ensembles of 5 DNNs, as suggested by (Ovadia et al., 2019) , and 16 samples for MC Dropout, diagonal Laplace and SWAG. We use a Dropout probability of 0.1 and a prior precision of λ = 40, 000 for diagonal Laplace, found via grid search. We apply all approaches to ResNet-18 (He et al., 2016) , which is composed of an input convolutional block, 8 residual blocks and a linear layer, for a total of 11,168,000 parameters. For subnetwork inference, we compute the linearized predictive distribution in Eq. ( 8) via the closed-form approximation for integrals between Gaussians and multi-class cross-entropy likelihoods described in (Gibbs, 1998) . We use Wasserstein subnetwork selection to retain only 0.38% of the weights, yielding a subnetwork with only 42,438 weights. This is the largest subnetwork for which we can tractably compute a full covariance matrix. Its size is 42, 438 2 × 4 Bytes ≈ 7.2 GB. We use a prior precision of λ = 500, found via grid search. Finally, to assess to importance of principled subnetwork selection, we also consider the baseline where we select the subnetwork uniformly at random (called Ours (Rand)). We perform the following two experiments, with results in Fig. 4 . See Appendix B for additional results. Rotated MNIST: Following Ovadia et al. ( 2019); Antorán et al. (2020) , we train all methods on MNIST and evaluate their predictive distributions on increasingly rotated digits. While all methods perform well on the original MNIST test set, their accuracy degrades quickly for rotations larger than 30 degrees. In terms of LL, ensembles perform best out of our baselines. Subnetwork inference obtains significantly larger LL values than almost all baselines, including ensembles. The only exception is VOGN, which achieves slightly better performance. It was also observed in Ovadia et al. (2019) that mean-field variational inference (which VOGN also is an instance of) is very strong on MNIST, but its performance deteriorates on larger datasets. Subnetwork inference makes accurate predictions in-distribution while assigning higher uncertainty than the baselines to out-of-distribution points. Corrupted CIFAR: Again following Ovadia et al. (2019) ; Antorán et al. (2020) , we train on CIFAR10 and evaluate on data subject to 16 different corruptions with 5 levels of intensity each (Hendrycks & Dietterich, 2019) . Our approach matches a MAP estimated network in terms of predictive error as local linearization makes their predictions the same. Ensembles and SWAG are the most accurate. Even so, subnetwork inference differentiates itself by being the least overconfident, outperforming all baselines in terms of log-likelihood at all corruption levels. Here, VOGN performs rather badly; while this might appear in stark contrast to its strong performance on the MNIST benchmark, the behaviour that mean-field VI performs well on MNIST but not on larger datasets was also observed in Ovadia et al. (2019) . On both benchmarks, we furthermore find that randomly selecting the subnetwork performs substantially worse than using our more principled subnetwork selection strategy. This highlights the importance of the way subnetworks are selected. These results suggest that subnetwork inference results in better uncertainty calibration and robustness to distribution shift than other popular uncertainty quantification approaches.

6. RELATED WORK

Bayesian Deep Learning. There have significant efforts to characterise the posterior distribution over NN weights p(W |D). Hamiltonian Monte Carlo (Neal, 1995) remains the golden standard for approximate inference in BNNs to this day. Although asymptotically unbiased, sampling based approaches are difficult to scale to the large datasets (Betancourt, 2015) . As a result, approaches which find the best surrogate posterior among an approximating family (most often Gaussians) have gained popularity. The first of these was the Laplace approximation, introduced by MacKay (1992), who also proposed approximating the predictive posterior with that of the linearised model (Khan et al., 2019; Immer et al., 2020) . The popularisation of larger NN models has made surrogate distributions that capture correlations between weights computationally intractable. Thus, most modern methods make use of the mean field assumption (Blundell et al., 2015; Hernández-Lobato & Adams, 2015; Gal & Ghahramani, 2016; Mishkin et al., 2018; Osawa et al., 2019) . This comes at the cost of limited expressivity (Foong et al., 2019a) and empirical under-performance (Ovadia et al., 2019; Antorán et al., 2020) of uncertainty estimates. Our proposed approach recovers predictive posterior expressivity while maintaining tractability by lowering the dimensionality of the weight space considered. This allows us to scale up approximations that do consider weight correlations (MacKay, 1992; Louizos & Welling, 2016; Maddox et al., 2019; Ritter et al., 2018) . Neural Network Linearization. In the limit of infinite width, NNs converge to Gaussian process (GP) behaviour (Neal, 1995; Matthews, 2017; Garriga-Alonso et al., 2018) . Recently, these results have been extended to finite width BNNs when the surrogate posterior is Gaussian (Khan et al., 2019) . We draw upon these results to formulate a subnetwork selection strategy for BNNs. Neural linear methods perform inference over only the last layer of a NN, while keeping all other layers fixed (Snoek et al., 2015; Riquelme et al., 2018; Ovadia et al., 2019; Ober & Rasmussen, 2019; Pinsler et al., 2019; Kristiadi et al., 2020) . These represent a different generalised linear model in which the basis functions are defined by the l-1 first layers of a NN. They can also be viewed as a special case of subnetwork inference, in which the subnetwork is simply defined to be the last NN layer. Inference over Subspaces. The subfield of NN pruning aims to increase the computational efficiency of NNs by identifying the smallest subset of weights which are required to make accurate predictions. Approaches trade-off computational cost with compression efficiency, ranging from those that require multiple training runs (Frankle & Carbin, 2019) to those that prune before training (Wang et al., 2020) . Our work differs in that it retains all NN weights but aims to find a small subset over which to perform probabilistic reasoning. More closely related work to ours is that of Izmailov et al. (2019) , who propose to perform inference over a low-dimensional subspace of weights; e.g. one constructed from the principal components of the SGD trajectory. Moreover, several recent approaches use low-rank parameterizations of approximate posteriors in the context of variational inference (Rossi et al., 2019; Swiatkowski et al., 2020; Dusenberry et al., 2020) . This could also be viewed as doing inference over an implicit subspace of weight space. In contrast, we propose a technique to find subsets of weights which are relevant to predictive uncertainty, i.e., we identify axis aligned subspaces. Finally, there have been recent works studying neural network sparsity / pruning from a Bayesian perspective (Ghosh & Doshi-Velez, 2017; Polson & Ročková, 2018; Cui et al., 2020; Louizos et al., 2017; Molchanov et al., 2017; Gomez et al., 2019; Lee et al., 2018) . While these seem conceptually related at first glance, their goal is fundamentally different to ours: While those methods aim to perform model selection / sparsification by either explicitly or implicitly pruning unnecessary weights, our goal is to make inference more tractable. More precisely, while those sparse Bayesian deep learning methods prune individual weights, we instead just prune the variances of certain weights, which, importantly retains the full predictive power of the full network to retain high predictive accuracy.

7. CONCLUSION

In this paper, we develop a practical and scalable method for expressive yet tractable probabilistic inference in deep neural networks. We approximate the posterior over a subset of the weights while keeping all other weights deterministic. Computational cost is decoupled from network size, allowing us to scale expressive approximations, such as full-covariance Gaussian distributions, to real-world sized NNs. Our approach can be applied post-hoc to any pre-trained model, making it particularly attractive for practical use. Our empirical analysis suggests that subnetwork inference 1) is more expressive and retains more uncertainty than crude approximations over the full network, 2) allows us to employ larger NNs, which fit a broader range of functions, without sacrificing the quality of our uncertainty estimates, and 3) is competitive with state-of-the-art uncertainty quantification methods, like deep ensembles (Lakshminarayanan et al., 2017) , on real-world scale problems.

A PROOFS FOR THE THEORETICAL RESULTS

We now provide the proofs for the results in Section 4. A.1 PROOF OF PROPOSITION 1 Proof. Note that the posterior distributions p(w|y, Φ) and p S (w|y, Φ) are both Gaussian. We thus consider the squared 2-Wasserstein distance between two Gaussian distributions N (µ 1 , Σ 1 ) and N (µ 2 , Σ 2 ), which has the following closed-form expression (Givens et al., 1984) 4 : W [N (µ 1 , Σ 1 ) || N (µ 2 , Σ 2 )] = µ 1 -µ 2 2 2 + trace Σ 1 + Σ 2 -2 (Σ 1 Σ 2 ) 1/2 . ( ) Plugging in µ 1 = µ 2 = w M AP , Σ 1 = H -1 and Σ 2 = M S H -1 , we obtain W [ p(w|y, X) || p S (w|y, X)] = W N (w M AP , H -1 ) || N (w M AP , M S H -1 ) = ( ( ( ( ( ( ( ( ( w M AP -w M AP 2 2 + trace H -1 + (M S H -1 ) -2 H -1 (M S H -1 ) 1/2 = trace (1 + M S ) H -1 -trace 2 H -1 M S H -1 1/2 = D d=1 (1 + m dd ) σ 2 d -trace 2 H -1 M S H -1 1/2 A.2 PROOF OF COROLLARY 1.1 Proof. For H -1 = diag(σ 2 1 , . . . , σ 2 D ) , the Wasserstein posterior gap in Eq. ( 12) simplifies to W [ p(w|y, Φ) || p S (w|y, Φ)] = D d=1 (1 + m dd ) σ 2 d -2m dd σ 2 d . The optimal subnetwork selection strategy amounts to choosing the binary vector m = [m dd ] D d=1 with D d=1 m d = S (i.e., we select S out of D parameters) s.t. the posterior gap in Eq. ( 14) is minimized. Observing that the contribution of the d-th parameter to the posterior gap is (1 + 1)σ 2 d -1 × 2σ 2 d = 0 if it is selected (i.e. if m dd = 1), and (1 + 0)σ 2 d -0 × 2σ 2 d = σ 2 d if it is not selected (i. e. if m dd = 0), we see that the optimal subnetwork comprises of the S weights with the largest variances σ 2 d . Colors denote different levels of rotation (left) and corruption (right). It can clearly be observed that the performance of deep ensembles saturates after around 15 ensemble members, meaning that adding more members yields strongly diminishing returns. This is in agreement with recent works (Antorán et al., 2020; Ashukha et al., 2020a; Lobacheva et al., 2020) . Our method significantly outperforms even very large deep ensembles, especially for high degrees of rotation/corruption. We simulate a realistic OOD rejection scenario (Filos et al., 2019b) by jointly evaluating our models on an in-distribution and an OOD test set. We allow our methods to reject increasing proportions of the data based on predictive entropy before classifying the rest. All predictions on OOD samples are treated as incorrect. Following (Nalisnick et al., 2019) , we use CIFAR10 vs SVHN and MNIST vs FashionMNIST as in-and out-of-distribution datasets, respectively. Note that the SVHN test set is randomly sub-sampled down to a size of 10,000. (39,190 / 23,466,560) of the parameters of the full network. We see that subnetwork inference still results in an improvement in the calibration of predictive uncertainty. As expected, however, for ResNet-50 the improvement over MAP is smaller than for ResNet-18 where we were able to choose a subnetwork containing 0.38% of the parameters. 



See Section 6 for a more thorough discussion of related work. For consistency, we will keep referring to the S selected linear model weights as a "subnetwork". We use the Wasserstein distance instead of the more common Kullback-Leibler divergence because the Wasserstein is well-defined for degenerate distributions and is an actual distance metric (i.e. symmetric). This also holds for our case of a degenerate Gaussian with singular covariance matrix(Givens et al., 1984).



Figure 1: Schematic illustration of our proposed approach. (a) We train a neural network using standard techniques to obtain a point estimate of the weights. (b) We identify a small subset of the weights. (c) We estimate a posterior distribution over the selected subnetwork via Bayesian inference techniques. (d) We make predictions using the full network of mixed Bayesian/deterministic weights.

Figure 3: Mean test log-likelihood values obtained on UCI datasets across all splits. Different markers indicate models with different numbers of weights. The horizontal axis indicates the number of weights over which full covariance inference is performed. 0 corresponds to MAP parameter estimation, and the rightmost setting for each marker corresponds to full network inference.

Figure 4: Results on the rotated MNIST (left) and the corrupted CIFAR (right) benchmarks of Ovadia et al. (2019), showing the mean ± std of the error (top) and log-likelihood (bottom) across three different seeds. Subnetwork inference retains better uncertainty calibration and robustness to distribution shift than point estimated networks and other Bayesian deep learning approaches.

Figure 5: Rotated MNIST (left) and Corrupted CIFAR10 (right) results for deep ensembles(Lakshminarayanan et al., 2017) with large numbers of ensemble members (i.e. up to 55). Horizontal axis denotes number of ensemble members, and vertical axis denotes performance in terms of loglikelihood. Straight horizontal lines correspond to the performance of our method, as a reference. Colors denote different levels of rotation (left) and corruption (right). It can clearly be observed that the performance of deep ensembles saturates after around 15 ensemble members, meaning that adding more members yields strongly diminishing returns. This is in agreement with recent works(Antorán et al., 2020; Ashukha et al., 2020a;Lobacheva et al., 2020). Our method significantly outperforms even very large deep ensembles, especially for high degrees of rotation/corruption.

Figure6: Rejection-classification plots. We simulate a realistic OOD rejection scenario(Filos et al.,  2019b)  by jointly evaluating our models on an in-distribution and an OOD test set. We allow our methods to reject increasing proportions of the data based on predictive entropy before classifying the rest. All predictions on OOD samples are treated as incorrect. Following(Nalisnick et al., 2019), we use CIFAR10 vs SVHN and MNIST vs FashionMNIST as in-and out-of-distribution datasets, respectively. Note that the SVHN test set is randomly sub-sampled down to a size of 10,000.

Figure 7: Full MNIST rotation and CIFAR10 corruption results, for ResNet-18, reporting predictive error, log-likelihood (LL), expected calibration error (ECE) and brier score, respectively (from top to bottom).

Figure8: MNIST rotation results for ResNet-50, reporting predictive error, log-likelihood (LL), expected calibration error (ECE) and brier score. We choose a subnetwork containing only 0.167%(39,190 / 23,466,560)  of the parameters of the full network. We see that subnetwork inference still results in an improvement in the calibration of predictive uncertainty. As expected, however, for ResNet-50 the improvement over MAP is smaller than for ResNet-18 where we were able to choose a subnetwork containing 0.38% of the parameters.

AUC-ROC scores for out-of-distribution detection, using CIFAR10 vs SVHN and MNIST vs FashionMNIST as in-(source) and out-of-distribution (target) datasets, respectively(Nalisnick et al., 2019).

MNIST -15 • rotation.

MNIST -30 • rotation.

MNIST -45 • rotation.

MNIST -60 • rotation.

MNIST -75 • rotation.

MNIST -90 • rotation.

MNIST -105 • rotation.

MNIST -120 • rotation.

MNIST -135 • rotation.

MNIST -150 • rotation.

MNIST -165 • rotation.

MNIST -180 • rotation.

CIFAR10 -no corruption.

CIFAR10 -level 1 corruption.

CIFAR10 -level 2 corruption.

CIFAR10 -level 3 corruption.

CIFAR10 -level 4 corruption.

CIFAR10 -level 5 corruption.

C EXPERIMENTAL SETUP C.1 TOY EXPERIMENTS

We train a single, 2 hidden layer network, with 50 hidden ReLU units per layer using MAP inference until convergence. Specifically, we use SGD with a learning rate of 1 × 10 -3 , momentum of 0.9 and weight decay of 1 × 10 -4 . We use a batch size of 512. The objective we optimise is the Gaussian log-likelihood of our data, where the mean is outputted by the network and the the variance is a hyperparameter learnt jointly with NN parameters by SGD. This variance parameters is shared among all datapoints. Once the network is trained, we perform post-hoc inference on it using different approaches. Since all of these involve the linearized approximation, the mean prediction is the same for all methods. Only their uncertainty estimates vary.Note that while for this toy example, we could in principle use the full covariance matrix for the purpose of subnetwork selection, we still just use its diagonal (as described in Section 4) for consistency. We use GGN Laplace inference over network weights (not biases) in combination with the linearized predictive distribution in Eq. ( 8). Thus, all approaches considered share their predictive mean, allowing us to better compare their uncertainty estimates.All approaches share a single prior precision of λ = 3. We chose to select the prior precision such that the full covariance approach (optimistic baseline) presents reasonable results. We use the same value for all other methods. We first tried a precision of 1 and found the full covariance approach to produce excessively large errorbars (covering the whole plot). A value of 3 produces more reasonable results.Final layer inference is performed by computing the full Laplace covariance matrix and discarding all entries except those corresponding to the final layer of the NN. Results for random sub-network selection are obtained with a single sample from a scaled uniform distribution over weight choice.

C.2 UCI EXPERIMENTS

In this experiment, our fully connected NNs have numbers of hidden layers h d ={1, 2} and hidden layer widths w d ={50, 100}. For a dataset with input dimension i d , the number of weights is given by D=(i d +1)w d +(h d -1)w 2 d . Our 2 hidden layer, 100 hidden unit models have a weight count of the order 10 4 . The non-linearity used is ReLU.We first obtain a MAP estimate of each model's weights. Specifically, we use SGD with a learning rate of 1 × 10 -3 , momentum of 0.9 and weight decay of 1 × 10 -4 . We use a batch size of 512. The objective we optimise is the Gaussian log-likelihood of our data, where the mean is outputted by the network and the the variance is a hyperparameter learnt jointly with NN parameters by SGD.For each dataset split, we set aside 15% of the train data as a validation set. We use these for early stopping training. Training runs for a maximum of 2000 epochs but early stops with a patience of 500 if validation performance does not increase. For the larger Protein dataset, these values are 500 and 125. The weight settings which provide best validation performance are kept.We then perform full network GGN Laplace inference for each model. We also use our proposed Wassertein rule together with the diagonal Hessian assumption to prune every network's weight variances such that the number of variances that remain matches the size of every smaller network under consideration. The prior precision used for these steps is chosen such that the resulting predictor's logliklihood performance on the validation set is maximised. Specifically, we employ a grid search over the values: λ : [0.0001, 0.001, 0.1, 0.5, 1, 2, 5, 10, 100, 1000]. In all cases, we employ the linearized predictive in Eq. ( 7). Consequently, networks with the same number of weights make the same mean predictions. Increasing the number of weight variances considered will thus only increase predictive uncertainty.

C.3 IMAGE EXPERIMENTS

The results shown in Section 5.3 and Appendix B are obtained by training ResNet-18 (and ResNet-50) models using SGD with momentum. For each experiment repetition, we train 7 different models: The first is for: 'MAP', 'Ours', 'Ours (Rand)', 'SWAG', 'Diag-Laplace' and as the first element of 'Ensemble'. We train 4 additional 'Ensemble' elements, 1 network with 'Dropout', and, finally 1 network for 'VOGN'. The methods 'Ours', 'Ours (Rand)', 'SWAG', and 'Diag-Laplace' are applied post training.For all methods except 'VOGN' we use the following training procedure. The (initial) learning rate, momentum, and weight decay are 0.1, 0.9, and 1 × 10 -4 , respectively. For 'MAP' we use 4 Nvidia P100 GPUs with a total batch size of 2048. For the calculation of the Jacobian in the subnetwork selection phase we use a single P100 GPU with a batch size of 4. For the calculation of the hessian we use a single P100 GPU with a batch size of 2. We train on 1 Nvidia P100 GPU with a batch size of 256 for all other methods. Each dataset is trained for a different number of epochs, shown in Table 21 . We decay the learning rate by a factor of 10 at scheduled epochs, also shown in Table 21 . Otherwise, all methods and datasets share hyperparameters. These hyperparameter settings are the defaults provided by PyTorch for training on ImageNet. We found them to perform well across the board. We report results obtained at the final training epoch. We do not use a separate validation set to determine the best epoch as we found ResNet-18 and ResNet-50 to not overfit with the chosen schedules. For 'Dropout', we add dropout to the standard ResNet-50 model (He et al., 2016) in between the 2 and 3 convolutions in the bottleneck blocks. This approach follows Zagoruyko & Komodakis (2016) and Ashukha et al. (2020b) who add dropout in-between the two convolutions of a WideResNet-50's basic block. Following Antorán et al. (2020) , we choose a dropout probability of 0.1, as they found it to perform better than the value of 0.3 suggested by Ashukha et al. (2020b) . We use 16 MC samples for predictions. 'Ensemble' uses 5 elements for prediction. Ensemble elements differ from each other in their initialisation, which is sampled from the He initialisation distribution (He et al., 2015) . We do not use adversarial training as, inline with Ashukha et al. (2020b), we do not find it to improve results. For 'VOGN' we use the same procedure and hyper-parameters as used by Osawa et al. (2019) in their CIFAR10 experiments, with the exception that we use a learning rate of 1 × 10 -3 as we we found a value of 1 × 10 -4 not to result in convergence.We train on a single Nvidia P100 GPU with a batch size of 256. See the authors' GitHub for more details: github.com/team-approx-bayes/dl-with-bayes/blob/master/ distributed/classification/configs/cifar10/resnet18_vogn_bs256_ 8gpu.json.We modify the standard ResNet-50 and ResNet-18 architectures such that the first 7 × 7 convolution is replaced with a 3 × 3 convolution. Additionally, we remove the first max-pooling layer. Following Goyal et al. (2017) , we zero-initialise the last batch normalisation layer in residual blocks so that they act as identity functions at the start of training.At test time, we tune the prior precision used for 'Ours', 'Diag-Laplace' and 'SWAG' approximation on a validation set for each approach individually, as in Ritter et al. (2018) ; Kristiadi et al. (2020) . We use a grid search from 1 × 10 -4 to 1 × 10 4 in logarithmic steps, and then a second, finer-grained grid search between the two best performing values (again with logarithmic steps).

C.4 DATASETS

The 1d toy dataset used in Section 5.1 was taken from Antorán et al. (2020) . We obtained it from the authors' github repo: https://github.com/cambridge-mlg/DUN. (Xiao et al., 2017) 70,000 (60,000 & 10,000) 784 (28 × 28) 10 2 CIFAR10 (Krizhevsky & Hinton, 2009) 60,000 (50,000 & 10,000) 3072 (32 × 32 × 3) 10 2 SVHN (Netzer et al., 2011) 99,289 (73,257 & 26,032) 3072 (32 × 32 × 3) 10 2

