MARS: META-LEARNING AS SCORE MATCHING IN THE FUNCTION SPACE

Abstract

Meta-learning aims to extract useful inductive biases from a set of related datasets. In Bayesian meta-learning, this is typically achieved by constructing a prior distribution over neural network parameters. However, specifying families of computationally viable prior distributions over the high-dimensional neural network parameters is difficult. As a result, existing approaches resort to meta-learning restrictive diagonal Gaussian priors, severely limiting their expressiveness and performance. To circumvent these issues, we approach meta-learning through the lens of functional Bayesian neural network inference, which views the prior as a stochastic process and performs inference in the function space. Specifically, we view the meta-training tasks as samples from the data-generating process and formalize meta-learning as empirically estimating the law of this stochastic process. Our approach can seamlessly acquire and represent complex prior knowledge by meta-learning the score function of the data-generating process marginals instead of parameter space priors. In a comprehensive benchmark, we demonstrate that our method achieves state-of-the-art performance in terms of predictive accuracy and substantial improvements in the quality of uncertainty estimates.

1. INTRODUCTION

Using data from related tasks is of key importance for sample efficiency. Meta-learning attempts to extract prior knowledge (i.e., inductive bias) about the unknown data generation process from these related tasks and embed it into the learner so that it generalizes better to new learning tasks (Thrun & Pratt, 1998; Vanschoren, 2018) . Many meta-learning approaches try to amortize or re-learn the entire inference process (e.g., Santoro et al., 2016; Mishra et al., 2018; Garnelo et al., 2018) or significant parts of it (e.g., Finn et al., 2017; Yoon et al., 2018) . As a result, they require large amounts of meta-training data and are prone to meta-overfitting (Qin et al., 2018; Rothfuss et al., 2021a) . The Bayesian framework provides a sound and statistically optimal method for inference by combining prior knowledge about the data-generating process with new empirical evidence in the form of a dataset. In this work, we adopt the Bayesian framework for inference at the task level and only focus on meta-learning informative Bayesian priors. Previous approaches (Amit & Meir, 2018; Rothfuss et al., 2021a) meta-learn Bayesian Neural Network (BNN) prior distributions from a set of related datasets; by meta-learning the prior distribution and applying regularization at the meta-level, they facilitate positive transfer from only a handful of meta-training tasks. However, BNNs lack a parametric family of (meta-)learnable priors over the high-dimensional space of neural network (NN) parameters that is both computationally viable and, simultaneously, flexible enough to account for the over-parametrization of NNs. In practice, both approaches use a Gaussian family of priors with a diagonal covariance matrix, which is too restrictive to accurately match the complex probabilistic structure of the data-generating process. To address these shortcomings, we take a new approach to formulating the meta-learning problem and represent prior knowledge in a novel way. We build on recent advances in functional approximate inference for BNNs that perform Bayesian inference in the function space rather than in the parameter space of neural networks (Wang et al., 2018; Sun et al., 2019) . When viewing the BNN prior and posterior as stochastic processes, the perfect Bayesian prior is the (true) data-generating stochastic process itself. Hence, we view the meta-training datasets as samples from the meta-datagenerating process and interpret meta-learning as empirically estimating the law of this stochastic process. More specifically, we meta-learn the score function of its marginal distributions, which can then directly be used as a source of prior knowledge when performing approximate functional BNN inference on a new target task. This ultimately allows us to use flexible neural network models for learning the score and overcome the issues of meta-learning BNN priors in the parameter space. In our experiments, we demonstrate that our proposed approach, called Meta-learning via Attentionbased Regularised Score estimation (MARS), consistently outperforms previous meta-learners in predictive accuracy and yields significant improvements in the quality of uncertainty estimates. Notably, MARS enables positive transfer from only a handful of tasks while maintaining reliable uncertainty estimates. This promises fruitful future applications to domains like molecular biology or medicine, where meta-training data is scarce and reasoning about epistemic uncertainty is crucial.

2. RELATED WORK

Meta-Learning. Common approaches in meta-learning amortize the entire inference process (Santoro et al., 2016; Mishra et al., 2018; Ravi & Beatson, 2018; Garnelo et al., 2018) , learn a good neural network initialization (Finn et al., 2017; Rothfuss et al., 2019; Nichol et al., 2018; Kim et al., 2018) or a shared embedding space (Baxter, 2000; Vinyals et al., 2016; Snell et al., 2017) . Although these approaches can meta-learn complex inference patterns, they require a large amount of metatraining data and often perform poorly in settings where data is scarce. Another line of work uses a hierarchical Bayesian approach to meta-learn priors over the NN parameters (Pentina & Lampert, 2014; Amit & Meir, 2018; Rothfuss et al., 2021a) . Such methods perform much better on small data. However, they suffer from the lack of expressive families of priors for the high-dimensional and complex parameter space of NNs, making too restrictive assumptions to represent complex inductive biases. Our approach overcomes these issues by viewing the problem in the function space and directly learning the score, which can easily be represented by a NN instead of a prior distribution. Also related to our stochastic process approach are methods that meta-learn Gaussian Process (GP) priors (Fortuin et al., 2019; Rothfuss et al., 2021b; 2022) . However, the GP assumption is quite limiting, while MARS can, in principle, match the marginals of any data-generating process. Score estimation. We use score estimation as a central element of our meta-learning method. In particular, we use a parametric approach to score matching and employ an extended version of the score matching objective of Hyvärinen & Dayan (2005) . For high-dimensional problems, Song et al. (2020); Pang et al. (2020) propose randomly sliced variations of the score matching loss. Alternatively, there is a body of work on nonparametric score estimation (Canu & Smola, 2006; Liu et al., 2016; Shi et al., 2018; Engl et al., 1996; Zhou et al., 2020) . Among those, the Spectral Stein Gradient Estimator (Shi et al., 2018) has been used for estimating the stochastic process marginals for functional BNN inference in a setting where the stochastic prior is user-defined and allows for generating arbitrarily many samples (Sun et al., 2019) . Such estimators make it much harder to add explicit dependence on the measurement sets and prevent meta-overfitting via regularization, making them less suited to our problem setting. Posterior inference for BNNs is difficult due to the high-dimensional parameter space Θ and the over-parameterized nature of the NN mapping h θ (x).



. Consider a regression task with data D = X D , y D that consists of m i.i.d. noisy function evaluations y j = f (x j ) + ϵ j of an unknown function f : X → Y. Here, X D = {x j } m j=1 ∈ X m denotes training inputs and y D = {y j } m j=1 ∈ Y m the corresponding noisy function values. Let h θ : X → Y be a function parametrized by a NN with weights θ ∈ Θ. For regression, where Y ⊆ R, we can use h θ to define a conditional distribution over the noisy observations p(y|x, θ) = N (y|h θ (x), σ 2 ), with flexible mean represented by h θ and observation noise variance σ 2 . Given a prior distribution p(θ) over the model parameters θ, Bayes' theorem yields a posterior distribution p(θ|X D , y D ) ∝ p(y D |X D , θ)p(θ), where p(y D |X D , θ) = m j=1 p(y j |x j , θ). For an unseen test point x * , we compute the predictive distribution, which is defined as p(y * |x * , X D , y D ) = p(y * |x * , θ)p(θ|X D , y D )dθ and obtained by marginalizing out the posterior over parameters θ. BNN inference in the function space.

