META-LEARNING BAYESIAN NEURAL NETWORK PRIORS BASED ON PAC-BAYESIAN THEORY

Abstract

Bayesian deep learning is a promising approach towards improved uncertainty quantification and sample efficiency. Due to their complex parameter space, choosing informative priors for Bayesian Neural Networks (BNNs) is challenging. Thus, often a naive, zero-centered Gaussian is used, resulting both in bad generalization and poor uncertainty estimates when training data is scarce. In contrast, meta-learning aims to extract such prior knowledge from a set of related learning tasks. We propose a principled and scalable algorithm for meta-learning BNN priors based on PAC-Bayesian bounds. Whereas previous approaches require optimizing the prior and multiple variational posteriors in an interdependent manner, our method does not rely on difficult nested optimization problems, and moreover, it is agnostic to the variational inference method in use. Our experiments show that the proposed method is not only computationally more efficient but also yields better predictions and uncertainty estimates when compared to previous meta-learning methods and BNNs with standard priors.

1. INTRODUCTION

Bayesian Neural Networks (BNNs) offer a probabilistic interpretation of deep learning by inferring distributions over the model's weights (Neal, 1996) . With the potential of combining the scalability and performance of neural networks (NNs) with a framework for uncertainty quantification, BNNs have lately received increased attention (Blundell et al., 2015; Gal & Ghahramani, 2016) . In particular, their ability to express epistemic uncertainty makes them highly relevant for applications such as active learning (Hernández-Lobato & Adams, 2015) and reinforcement learning (Riquelme et al., 2018) . However, BNNs face two major issues: 1) the intractability of posterior inference and 2) the difficulty of choosing good Bayesian priors. While the former has been addressed in an extensive body of literature on variational inference (e.g. Blundell et al., 2015; Blei et al., 2016; Mishkin et al., 2018; Liu & Wang, 2016) , the latter has only received limited attention (Vladimirova et al., 2019; Ghosh & Doshi-Velez, 2017) . Choosing an informative prior for BNNs is particularly difficult due to the high-dimensional and hardly interpretable parameter space of NNs. Due to the lack of good alternatives, often a zero-centered, isotropic Gaussian is used, reflecting (almost) no a priori knowledge about the problem at hand. This does not only lead to poor generalization when data is scarce, but also renders the Bayesian uncertainty estimates poorly calibrated (Kuleshov et al., 2018) . Meta-learning (Schmidhuber, 1987; Thrun & Pratt, 1998) acquires inductive bias in a data-driven way, thus, constituting an alternative route for addressing this issue. In particular, meta-learners attempt to extract shared (prior) knowledge from a set of related learning tasks (i.e., datasets), aiming to learn in the face of a new, related task. Our work develops a principled and scalable algorithm for metalearning BNN priors. We build on the PAC-Bayesian framework (McAllester, 1999) , a methodology from statistical learning theory for deriving generalization bounds. Previous PAC-Bayesian bounds for meta-learners (Pentina & Lampert, 2014; Amit & Meir, 2018 ) require solving a difficult optimization problem, involving the optimization of the prior as well as multiple variational posteriors in a nested manner. Aiming to overcome this issue, we present a PAC-Bayesian bound that does not rely on nested optimization and, unlike (Rothfuss et al., 2020) , can be tractably optimized for BNNs. This makes the resulting meta-learner, referred to as PACOH-NN, not only much more computationally efficient and scalable than previous approaches for meta-learning BNN priors (Amit & Meir, 2018) , but also agnostic to the choice of approximate posterior inference method which allows us to combine it freely with recent advances in MCMC (e.g. Chen et al., 2014) or variational inference (e.g. Wang et al., 2019) . Our experiments demonstrate that the computational advantages of PACOH-NN do not result in degraded predictive performance. In fact, across several regression and classification environments, PACOH-NN achieves a comparable or better predictive accuracy than several popular meta-learning approaches, while improving the quality of the uncertainty estimates. Finally, we showcase how metalearned PACOH-NN priors can be used in a real-world bandit task concerning the development of vaccines, suggesting that many other challenging real-world problems may benefit from our approach.

2. RELATED WORK

Bayesian Neural Networks. The majority of research on BNNs focuses on approximating the intractable posterior distribution (Graves, 2011; Blundell et al., 2015; Liu & Wang, 2016; Wang et al., 2019) . In particular, we employ the approximate inference method of Liu & Wang (2016) . Another crucial question is how to select a good BNN prior (Vladimirova et al., 2019) . While the majority of work (e.g. Louizos & Welling, 2016; Huang et al., 2020) employs a simple zero-centered, isotropic Gaussian, Ghosh & Doshi-Velez (2017) and Pearce et al. ( 2020) have proposed other prior distributions for BNNs. In contrast, we go the alternative route of choosing priors in a data-driven way. Meta-learning. A range of popular methods in meta-learning attempt to learn the "learning program" in form of a recurrent model (Hochreiter et al., 2001; Andrychowicz et al., 2016; Chen et al., 2017) , learn an embedding space shared across tasks (Snell et al., 2017; Vinyals et al., 2016) or the initialization of a NN such that it can be quickly adapted to new tasks (Finn et al., 2017; Nichol et al., 2018; Rothfuss et al., 2019b) . A group of recent methods also uses probabilistic modeling to also allow for uncertainty quantification (Kim et al., 2018; Finn et al., 2018; Garnelo et al., 2018) . Although the mentioned approaches are able to learn complex inference patterns, they rely on settings where metatraining tasks are abundant and fall short of performance guarantees. In contrast, we provide a formal assessment of the generalization properties of our algorithm. Moreover, PACOH-NN allows for principled uncertainty quantification, including separate treatment of epistemic and aleatoric uncertainty. This makes it particularly useful for sequential decision algorithms (Lattimore & Szepesvari, 2020). PAC-Bayesian theory. Previous work presents generalization bounds for randomized predictors, assuming a prior to be given exogenously (McAllester, 1999; Catoni, 2007; Germain et al., 2016; Alquier et al., 2016) . More recent work explores data-dependent priors (Parrado-Hernandez et al., 2012; Dziugaite & Roy, 2016) or extends previous bounds to the scenario where priors are meta-learned (Pentina & Lampert, 2014; Amit & Meir, 2018) . However, these meta-generalization bounds are hard to minimize as they leave both the hyper-posterior and posterior unspecified, which leads to nested optimization problems. Our work builds on the results of Rothfuss et al. ( 2020) who introduce the methodology to derive the closed-form solution of the PAC-Bayesian meta-learning problem. However, unlike ours, their approach suffers from (asymptotically) non-vanishing terms in the bounds and relies on a closedform solution of the marginal log-likelihood. By contributing a numerically stable score estimator for the generalized marginal log-likelihood, we are able to overcome such limitations, making PAC-Bayesian meta-learning both tractable and scalable for a much larger array of models, including BNNs. For brevity, we also write z j := (x j , y j ) ∈ Z. Let h θ : X → Y be a function parametrized by a NN with weights θ ∈ Θ. Using the NN mapping, we define a conditional distribution p(y|x, θ). For regression, we set p(y|x, θ) = N (y|h θ (x), σ 2 ), where σ 2 is the observation noise variance. For classification, we choose p(y|x, θ) = Categorical(softmax(h θ (x))). For Bayesian inference, one presumes a prior distribution p(θ) over the model parameters θ which is combined with the training data S into a posterior distribution p(θ|X, Y) ∝ p(θ)p(Y|X, θ). For unseen test data points x * , we form the predictive distribution as p(y * |x * , X, Y) = p(y * |x * , θ)p(θ|X, Y)dθ. The Bayesian framework presumes partial knowledge of the data-generating process in form of a prior distribution. However, due to the practical difficulties in choosing an appropriate BNN prior, the prior is typically strongly misspecified (Syring & Martin, 2018) . As a result, modulating the influence of the prior relative to the likelihood during inference typically improves the empirical performance of BNNs and is thus a common practice (Wenzel et al., 2020) . Using such a "tempered" posterior p τ (θ|X, Y) ∝ p(θ)p(Y|X, θ) τ with τ > 0 is also referred to as generalized Bayesian learning (Guedj, 2019).



BACKGROUND: THE PAC-BAYESIAN FRAMEWORK Bayesian Neural Networks. Consider a supervised learning task with data S = {(x j , y j )} m j=1 drawn from unknown distribution D. In that, X = {x j } m j=1 ∈ X m denotes training inputs and Y = {y j } m j=1 ∈ Y m the targets.

