EXPRESSIVE YET TRACTABLE BAYESIAN DEEP LEARNING VIA SUBNETWORK INFERENCE

Abstract

The Bayesian paradigm has the potential to solve some of the core issues in modern deep learning, such as poor calibration, data inefficiency, and catastrophic forgetting. However, scaling Bayesian inference to the high-dimensional parameter spaces of deep neural networks requires restrictive approximations. In this paper, we propose performing inference over only a small subset of the model parameters while keeping all others as point estimates. This enables us to use expressive posterior approximations that would otherwise be intractable for the full model. In particular, we develop a practical and scalable Bayesian deep learning method that first trains a point estimate, and then infers a full covariance Gaussian posterior approximation over a subnetwork. We propose a subnetwork selection procedure which aims to maximally preserve posterior uncertainty. We empirically demonstrate the effectiveness of our approach compared to point-estimated networks and methods that use less expressive posterior approximations over the full network.

1. INTRODUCTION

Deep neural networks (DNNs) still suffer from critical shortcomings that make them unfit for important applications. For instance, DNNs tend to be poorly calibrated and overconfident in their predictions, especially when there is a shift in the train and test distributions (Nguyen et al., 2015; Guo et al., 2017) . To reliably inform decision making, DNNs must be able to robustly quantify the uncertainty in their predictions, which is particularly important in safety-critical areas such as healthcare or autonomous driving (Amodei et al., 2016; Filos et al., 2019a; Fridman et al., 2019) . Bayesian modeling (Ghahramani, 2015; Gal, 2016) presents a principled way to capture predictive uncertainty via the posterior distribution over model parameters. Unfortunately, due to their nonlinearities, exact posterior inference is intractable in DNNs. Despite recent successes in the field of Bayesian deep learning (Blundell et al., 2015; Gal & Ghahramani, 2016; Osawa et al., 2019; Maddox et al., 2019; Dusenberry et al., 2020) , existing methods are only made scalable to modern DNNs with large numbers of parameters by invoking unrealistic assumptions. This severely limits the expressiveness of the inferred posterior and thus deteriorates the quality of the induced uncertainty estimates (Ovadia et al., 2019; Fort et al., 2019; Foong et al., 2019a; Ashukha et al., 2020a) . Due to the heavy overparameterization of DNNs, their accuracy is well-preserved by a small subnetwork (Cheng et al., 2017) . Additionally, recent work by Izmailov et al. (2019) has shown how performing inference over a low dimensional subspace of the weights can result in accurate uncertainty quantification. These observations prompt the following question for a DNN's uncertainty: Can a full DNN's model uncertainty be well-preserved by a small subnetwork's model uncertainty? We answer this question in the affirmative. We show both theoretically and empirically that the full network posterior can be well represented by a subnetwork's posterior. As a result, we can use more expensive but faithful posterior approximations over just that subnetwork. We show that this achieves better uncertainty quantification than if we use cheaper, but more crude, posterior approximations over the full network. The contributions of this paper are as follows: 1. We propose a new Bayesian deep learning approach that performs Bayesian inference over only a small subset of the model weights and keeps all other weights deterministic. This allows us to use expressive posterior approximations that are typically intractable in DNNs. 2. As a concrete instantiation of this framework, we develop a practical and scalable Bayesian deep learning method that uses the linearized Laplace approximation to infer a fullcovariance Gaussian posterior over a subnetwork within a point-estimated neural network. 3. We formally characterize the discrepancy between the posterior distributions over a subnetwork and the full network (in terms of their Wasserstein distance) in the linearized model, and derive a theoretically motivated strategy to select a subnetwork that minimizes this discrepancy under certain assumptions. 4. We empirically show, on various benchmarks, that our method compares favourably against point-estimated networks and other Bayesian deep learning methods, experimentally confirming that expressive subnetwork inference is superior to crude inference over full networks.

2. SUBNETWORK POSTERIOR APPROXIMATION

Bayesian neural networks (BNNs) aim to capture model uncertainty, i.e., uncertainty about the choice of weights W which arises due to multiple plausible explanations of the training data {y, X}. Here, y is the dependent variable (e.g. classification label) and X is the feature matrix. A prior distribution p(W) is specified over the BNN's weights. We wish to infer their full posterior distribution p(W|y, X) ∝ p(y|X, W) p(W) . To make predictions, we then estimate the posterior predictive distribution that averages the network's predictions across all possible settings of the weights, weighted by their posterior probability, i.e. p(y * |X * , y, X) = W p(y * |X * , W)p(W|y, X)dW . (2) Unfortunately, due to the size of modern deep neural networks, it is not only intractable to infer the exact posterior distribution p(W|y, X) in Eq. ( 1), but it is even computationally challenging to properly approximate it. As a consequence, crude posterior approximations such as complete factorization are commonly employed (Blundell et al., 2015; Hernández-Lobato & Adams, 2015; Kingma et al., 2015; Khan et al., 2018; Osawa et al., 2019) , i.e. p(W|y, X) ≈ D d=1 q(w d ) where w d denotes the d-th weight in the D-dimensional neural network weight vector W ∈ R D (the concatenation and flattening of all layers' weight matrices). Clearly, this is a very wishful assumption; In practise, it suffers from severe pathologies (Foong et al., 2019a; b) . In this work, we question the implicit assumption that a good posterior approximation needs to include all BNN parameters. Instead, we aim to perform inference only over a small subset of the weights. This approach is well-motivated for two reasons: 1. Overparameterization: Maddox et al. ( 2020) have shown that, in the neighborhood of local optima, there are many directions that leave the NN's predictions unchanged. Moreover, NNs can be heavily pruned without sacrificing test-set accuracy (Frankle & Carbin, 2019) . Thus, the majority of a NN's predictive power might be isolated to a small subnetwork.



Figure 1: Schematic illustration of our proposed approach. (a) We train a neural network using standard techniques to obtain a point estimate of the weights. (b) We identify a small subset of the weights. (c) We estimate a posterior distribution over the selected subnetwork via Bayesian inference techniques. (d) We make predictions using the full network of mixed Bayesian/deterministic weights.

