LATENT BOTTLENECKED ATTENTIVE NEURAL PRO-CESSES

Abstract

Neural Processes (NPs) are popular methods in meta-learning that can estimate predictive uncertainty on target datapoints by conditioning on a context dataset. Previous state-of-the-art method Transformer Neural Processes (TNPs) achieve strong performance but require quadratic computation with respect to the number of context datapoints, significantly limiting its scalability. Conversely, existing sub-quadratic NP variants perform significantly worse than that of TNPs. Tackling this issue, we propose Latent Bottlenecked Attentive Neural Processes (LBANPs), a new computationally efficient sub-quadratic NP variant, that has a querying computational complexity independent of the number of context datapoints. The model encodes the context dataset into a constant number of latent vectors on which self-attention is performed. When making predictions, the model retrieves higher-order information from the context dataset via multiple cross-attention mechanisms on the latent vectors. We empirically show that LBANPs achieve results competitive with the state-of-the-art on meta-regression, image completion, and contextual multi-armed bandits. We demonstrate that LBANPs can trade-off the computational cost and performance according to the number of latent vectors. Finally, we show LBANPs can scale beyond existing attention-based NP variants to larger dataset settings.

1. INTRODUCTION

Meta-learning aims to learn a model that can adapt quickly and computationally efficiently to new tasks. Neural Processes (NPs) are a popular method in meta-learning that models a conditional distribution of the prediction of a target datapoint given a set of labelled (context) datapoints, providing uncertainty estimates. NP variants (Garnelo et al., 2018a; Gordon et al., 2019; Kim et al., 2019) adapt via a conditioning step in which they compute embeddings representative of the context dataset. NPs can be divided into two categories: (1) computationally efficient (sub-quadratic complexity) but poor performance and (2) computationally expensive (quadratic complexity) but good performance. Early NP variants were especially computationally efficient, requiring only linear computation in the number of context datapoints but suffered from underfitting and as a result, overall poor performance. In contrast, recent state-of-the-art methods have proposed to use self-attention mechanisms such as transformers. However, these state-of-the-art methods are computationally expensive in that they require quadratic computation in the number of context datapoints. The quadratic computation makes the method inapplicable in settings with large number of datapoints and in low-resource settings. ConvCNPs (Gordon et al., 2019) partly address this problem by proposing to use convolutional neural networks to encode the context dataset instead of a self-attention mechanism, but this (1) requires the data to satisfy a grid-like structure, limiting the method to low-dimensional settings, and (2) the recent attention-based method Transformer Neural Processes (TNPs) have been shown to greatly outperform ConvCNPs. Inspired by recent developments in efficient attention mechanisms, (1) we propose Latent Bottlenecked Attentive Neural Processes (LBANP), a computationally efficient NP variant that has a querying computational complexity independent of the number of context datapoints. Furthermore, the model requires only linear computation overhead for conditioning on the context dataset. The model encodes the context dataset into a fixed number of latent vectors on which self-attention is performed. When making predictions, the model retrieves higher-order information from the context dataset via multiple cross-attention mechanisms on the latent vectors. N 2 + N M N 2 N M NP (Garnelo et al., 2018b) N + M N M ANP (Kim et al., 2019) N 2 + N M N 2 N M BNP (Lee et al., 2020) (N + M )K KN KM BANP (Lee et al., 2020) (N 2 + N M )K KN 2 KN M TNP-D (Nguyen & Grover, 2022) (N + M ) 2 - (N + M ) 2 EQTNP (Ours) N 2 + N M N 2 N M LBANP (Ours) (N + M + L)L N L + L 2 M L (2) We empirically show that LBANPs achieve results competitive with the state-of-the-art on meta-regression, image completion, and contextual multi-armed bandits. ( 3) We demonstrate that LBANPs can trade-off the computational cost and performance according to the number of latent vectors. (4) In addition, we show that LBANPs can scale to larger dataset settings where existing attention-based NP variants fail to run because of their expensive computational requirements. (5) Lastly, we show that similarly to TNPs, we can propose different variants of LBANPs for different settings.

2.1. META-LEARNING FOR PREDICTIVE UNCERTAINTY ESTIMATION

In meta-learning, models are trained on a distribution of tasks Ω(T ). During each meta-training iteration, a batch of B tasks T = {T i } B i=1 is sampled from a task distribution Ω(T ). A task T i is a tuple (X , Y, L, q), where X is the input space, Y is the output space, L is the task-specific loss function, and q(x, y) is a distribution over data points. During each meta-training iteration, for each T i ∈ T, we sample from q Ti : a context dataset D context i = {(x, y) i,j } N j=1 and a target dataset D target i = {(x, y) i,j } M j=1 , where N and M are the fixed number of context and target datapoints respectively. The context data is used to perform an update to the model f such that the type of update differs depending on the model. In Neural Processes, these updates refer to its conditioning step where embeddings of the dataset are computed. Afterwards, the update is evaluated on the target data and the update rule is adjusted. In this setting, we use a neural network to model the probabilistic predictive distribution p θ (y|x, D context ) where θ are the parameters of the NN.

2.2. NEURAL PROCESSES

Neural Processes (NPs) are a class of models that define an infinite family of conditional distributions where one can condition on arbitrary number of context datapoints (labelled datapoints) and make predictions for an arbitrary number of target datapoints (unlabelled datapoints), while preserving invariance in the ordering of the contexts. Several NP models have proposed to model it as follows: p(y|x, D context ) := p(y|x, r C ) (1)



Computational Complexity in Big-O notation of the model with respect to the number of context datapoints (N ) and the number of target datapoints per batch (M ). L and K are prespecified hyperparameters. L is the number of latent vectors. K is the number of bootstrapping samples for BNP and BANP. It is important that the cost of performing the query step is low.

