SEMI-PARAMETRIC INDUCING POINT NETWORKS AND NEURAL PROCESSES

Abstract

We introduce semi-parametric inducing point networks (SPIN), a general-purpose architecture that can query the training set at inference time in a compute-efficient manner. Semi-parametric architectures are typically more compact than parametric models, but their computational complexity is often quadratic. In contrast, SPIN attains linear complexity via a cross-attention mechanism between datapoints inspired by inducing point methods. Querying large training sets can be particularly useful in meta-learning, as it unlocks additional training signal, but often exceeds the scaling limits of existing models. We use SPIN as the basis of the Inducing Point Neural Process, a probabilistic model which supports large contexts in metalearning and achieves high accuracy where existing models fail. In our experiments, SPIN reduces memory requirements, improves accuracy across a range of metalearning tasks, and improves state-of-the-art performance on an important practical problem, genotype imputation.

1. INTRODUCTION

Recent advances in deep learning have been driven by large-scale parametric models (Krizhevsky et al., 2012; Peters et al., 2018; Devlin et al., 2019; Brown et al., 2020; Ramesh et al., 2022) . Modern parametric models rely on large numbers of weights to capture the signal contained in the training set and to facilitate generalization (Frankle & Carbin, 2018; Kaplan et al., 2020) ; as a result, they require non-trivial computational resources (Hoffmann et al., 2022) , have limited interpretability (Belinkov, 2022) , and impose a significant carbon footprint (Bender et al., 2021) . This paper focuses on an alternative semi-parametric approach, in which we have access to the training set D train = {x (i) , y (i) } n i=1 at inference time and learn a parametric mapping y = f θ (x | D train ) conditioned on this dataset. Semi-parametric models can query the training set D train and can therefore express rich and interpretable mappings with a compact f θ . Examples of the semi-parametric framework include retrieval-augmented language models (Grave et al., 2016; Guu et al., 2020; Rae et al., 2021) and non-parametric transformers (Wiseman & Stratos, 2019; Kossen et al., 2021) . However, existing approaches are often specialized to specific tasks (e.g., language modeling (Grave et al., 2016; Guu et al., 2020; Rae et al., 2021) or sequence generation (Graves et al., 2014) ), and their computation scales superlinearly in the size of the training set (Kossen et al., 2021) . Here, we introduce semi-parametric inducing point networks (SPIN), a general-purpose architecture whose computational complexity at training time scales linearly in the size of the training set D train and in the dimensionality of x and that is constant in D train at inference time. Our architecture is inspired by inducing point approximations (Snelson & Ghahramani, 2005; Titsias, 2009; Wilson & Nickisch, 2015; Evans & Nair, 2018; Lee et al., 2018) and relies on a cross-attention mechanism between datapoints (Kossen et al., 2021) . An important application of SPIN is in meta-learning, where conditioning on large training sets provides the model additional signal and improves accuracy, * Correspondence to Richa Rastogi and Volodymyr Kuleshov 1

