META-LEARNING BAYESIAN NEURAL NETWORK PRIORS BASED ON PAC-BAYESIAN THEORY

Abstract

Bayesian deep learning is a promising approach towards improved uncertainty quantification and sample efficiency. Due to their complex parameter space, choosing informative priors for Bayesian Neural Networks (BNNs) is challenging. Thus, often a naive, zero-centered Gaussian is used, resulting both in bad generalization and poor uncertainty estimates when training data is scarce. In contrast, meta-learning aims to extract such prior knowledge from a set of related learning tasks. We propose a principled and scalable algorithm for meta-learning BNN priors based on PAC-Bayesian bounds. Whereas previous approaches require optimizing the prior and multiple variational posteriors in an interdependent manner, our method does not rely on difficult nested optimization problems, and moreover, it is agnostic to the variational inference method in use. Our experiments show that the proposed method is not only computationally more efficient but also yields better predictions and uncertainty estimates when compared to previous meta-learning methods and BNNs with standard priors.

1. INTRODUCTION

Bayesian Neural Networks (BNNs) offer a probabilistic interpretation of deep learning by inferring distributions over the model's weights (Neal, 1996) . With the potential of combining the scalability and performance of neural networks (NNs) with a framework for uncertainty quantification, BNNs have lately received increased attention (Blundell et al., 2015; Gal & Ghahramani, 2016) . In particular, their ability to express epistemic uncertainty makes them highly relevant for applications such as active learning (Hernández-Lobato & Adams, 2015) and reinforcement learning (Riquelme et al., 2018) . However, BNNs face two major issues: 1) the intractability of posterior inference and 2) the difficulty of choosing good Bayesian priors. While the former has been addressed in an extensive body of literature on variational inference (e.g. Blundell et al., 2015; Blei et al., 2016; Mishkin et al., 2018; Liu & Wang, 2016) , the latter has only received limited attention (Vladimirova et al., 2019; Ghosh & Doshi-Velez, 2017) . Choosing an informative prior for BNNs is particularly difficult due to the high-dimensional and hardly interpretable parameter space of NNs. Due to the lack of good alternatives, often a zero-centered, isotropic Gaussian is used, reflecting (almost) no a priori knowledge about the problem at hand. This does not only lead to poor generalization when data is scarce, but also renders the Bayesian uncertainty estimates poorly calibrated (Kuleshov et al., 2018) . Meta-learning (Schmidhuber, 1987; Thrun & Pratt, 1998) acquires inductive bias in a data-driven way, thus, constituting an alternative route for addressing this issue. In particular, meta-learners attempt to extract shared (prior) knowledge from a set of related learning tasks (i.e., datasets), aiming to learn in the face of a new, related task. Our work develops a principled and scalable algorithm for metalearning BNN priors. We build on the PAC-Bayesian framework (McAllester, 1999) , a methodology from statistical learning theory for deriving generalization bounds. Previous PAC-Bayesian bounds for meta-learners (Pentina & Lampert, 2014; Amit & Meir, 2018 ) require solving a difficult optimization problem, involving the optimization of the prior as well as multiple variational posteriors in a nested manner. Aiming to overcome this issue, we present a PAC-Bayesian bound that does not rely on nested optimization and, unlike (Rothfuss et al., 2020) , can be tractably optimized for BNNs. This makes the resulting meta-learner, referred to as PACOH-NN, not only much more computationally efficient and scalable than previous approaches for meta-learning BNN priors (Amit & Meir, 2018) , but also agnostic to the choice of approximate posterior inference method which allows us to combine it freely with recent advances in MCMC (e.g. Chen et al., 2014) or variational inference (e.g. Wang et al., 2019) . 1

