BAYESADAPTER: BEING BAYESIAN, INEXPENSIVELY AND ROBUSTLY, VIA BAYESIAN FINE-TUNING

Abstract

Despite their theoretical appealingness, Bayesian neural networks (BNNs) are falling far behind in terms of adoption in real-world applications compared with deterministic NNs, mainly due to their limited scalability in training and low fidelity in uncertainty estimates. In this work, we develop a new framework, named BayesAdapter, to address these issues and bring Bayesian deep learning to the masses. The core notion of BayesAdapter is to adapt pre-trained deterministic NNs to be BNNs via Bayesian fine-tuning. We implement Bayesian fine-tuning with a plug-and-play instantiation of stochastic variational inference, and propose exemplar reparameterization to reduce gradient variance and stabilize the finetuning. Together, they enable training BNNs as if one were training deterministic NNs with minimal added overheads. During Bayesian fine-tuning, we further propose an uncertainty regularization to supervise and calibrate the uncertainty quantification of learned BNNs at low cost. To empirically evaluate BayesAdapter, we conduct extensive experiments on a diverse set of challenging benchmarks, and observe satisfactory training efficiency, competitive predictive performance, and calibrated and faithful uncertainty estimates.

1. INTRODUCTION

Much effort has been devoted to developing flexible and efficient Bayesian deep models to make accurate, robust, and well-calibrated decisions (MacKay, 1992; Neal, 1995; Graves, 2011; Blundell et al., 2015) , with Bayesian neural networks (BNNs) as popular examples. The principled uncertainty quantification inside BNNs is critical for realistic decision-making, well evaluated in scenarios ranging from model-based reinforcement learning (Depeweg et al., 2016) and active learning (Hernández-Lobato & Adams, 2015) , to healthcare (Leibig et al., 2017) and autonomous driving (Kendall & Gal, 2017) . BNNs are also known to be capable of resisting over-fitting. However, there are fundamental obstacles posed in front of ML practitioners when trying to push the limit of BNNs to larger datasets and deeper architectures: (i) The scalability of the existing BNNs is generally restrictive owing to the essential difficulties of learning a complex, non-degenerate distribution over parameters in a high-dimensional and over-parameterized space (Liu & Wang, 2016; Louizos & Welling, 2017; Sun et al., 2019) . (ii) The Bayes posteriors learned from scratch are often systematically worse than their point-estimate counterparts in terms of predictive performance when "cold posterior" strategies are not applied (Wenzel et al., 2020) . (iii) It is shown that the BNNs have the possibility to assign low (epistemic) uncertainty for realistic out-of-distribution (OOD) data (e.g., adversarial examples), rendering their uncertainty estimates unreliable in safety-critical scenarios (Grosse et al., 2018) . To solve these problems, we present a scalable workflow, named BayesAdapter, to learn more reliable BNNs. In a holistic view, we unfold the learning of a BNN into two steps: deterministic pre-training of the deep neural network (DNN) counterpart of the BNN followed by several-round Bayesian fine-tuning. This enables us to learn a principled BNN with slightly more efforts than training a regular DNN, and provides us with the opportunities to embrace qualified off-the-shelf pre-trained DNNs (e.g., those on PyTorch Hub). The converged parameters of the deterministic model serve as a strong start point for Bayesian fine-tuning, allowing us to bypass extensive local h 1 = x@w 1 ; w 1 ∼ 𝒩(μ 1 , ψ 1 ) h 2 = h 1 @w 2 ; w 2 ∼ 𝒩(μ 2 , ψ 2 ) h 3 = h 2 @w 3 ; w 3 ∼ 𝒩(μ 3 , ψ 3 ) Training samples OOD samples Log-likelihood Uncertainty reg. optimum suffered by a direct learning of BNNfoot_0 . To render the fine-tuning in the style of training normal NNs, we resort to stochastic variational inference (VI) to update the approximate posterior. h 1 = x@w 1 * h 2 = h 1 @w 2 * h 3 = h 2 @w 3 * We develop optimizers with built-in weight decay for the parameters of the variational distribution to absorb the regularization effects from the prior, and develop exemplar reparametrization to reduce the gradient variance. Moreover, to make the uncertainty estimation of the learned models reliable, we propose to additionally, explicitly regularize the model to behave uncertainly on representative foreseeable OOD data during fine-tuning. This regularization takes the form of a margin loss, and is readily applicable to most of the existing BNNs. Figure 1 depicts the whole framework of BayesAdapter. Extensive empirical studies validate the efficiency and effectiveness of our workflow. In summary, our contributions are as follows: 1. We propose BayesAdapter, to quickly and cheaply adapt a pre-trained DNN to be Bayesian without compromising performance when facing new tasks. 2. We provide an easy-to-use instantiation of stochastic VI, which allows learning a BNN as if training a deterministic NN and frees the users from tedious details of BNN. 3. We augment the fine-tuning with a generally applicable uncertainty regularization term to rectify the predictive uncertainty according to a collection of OOD data. 4. Extensive studies validate that BayesAdapter is scalable; the delivered BNN models are high-quality; and the acquired uncertainty quantification is calibrated and transferable.

2. BAYESADAPTER

In this section, we first motivate BayesAdapter by drawing a connection between maximum a posteriori (MAP) and Bayesian inference. We then describe the proposed procedure Bayesian fine-tuning, and a practical and robust implementation of stochastic VI to realize it. Figure 1 illustrates the overall workflow of BayesAdapter.

2.1. FROM DNNS TO BNNS

Let D = {(x i , y i )} n i=1 be a given training set, where x i ∈ R d and y i ∈ Y denote the input data and label, respectively. A DNN model can be fit via MAP as following: max w 1 n i [log p(yi|xi; w)] + 1 n log p(w). We use w ∈ R p to denote the high-dimensional model parameters, and p(y|x; w) as the predictive distribution associated with the model. The prior term p(w), when taking the form of an isotropic Gaussian, reduces to the common L2 weight decay regularizer in optimization. Despite the wide adoption, DNNs are known to be prone to over-fitting, generating over-confident predictions, and are unable to convey valuable information on the trustworthiness of their predictions. Naturally, Bayesian neural networks (BNNs) come into the picture to address these limitations.



Here the BNN mainly refers to mean-field variational BNNs, and the results in Sec 4.1 testify this point.



Figure1: The workflow of BayesAdapter. We assume a three-layer model for simplicity. We at first pretrain a DNN counterpart of the target BNN via maximum a posteriori (MAP) estimation, then transform it to be a BNN by replacing the point-estimate parameters with a diagonal Gaussian centered at them, from which the parameter samples are drawn for computation. After that, we build separate optimizers with built-in weight decay for the Gaussian mean and variance, and perform fine-tuning to fit the data under uncertainty regularization based on autodiff libraries.

