SEMI-PARAMETRIC INDUCING POINT NETWORKS AND NEURAL PROCESSES

Abstract

We introduce semi-parametric inducing point networks (SPIN), a general-purpose architecture that can query the training set at inference time in a compute-efficient manner. Semi-parametric architectures are typically more compact than parametric models, but their computational complexity is often quadratic. In contrast, SPIN attains linear complexity via a cross-attention mechanism between datapoints inspired by inducing point methods. Querying large training sets can be particularly useful in meta-learning, as it unlocks additional training signal, but often exceeds the scaling limits of existing models. We use SPIN as the basis of the Inducing Point Neural Process, a probabilistic model which supports large contexts in metalearning and achieves high accuracy where existing models fail. In our experiments, SPIN reduces memory requirements, improves accuracy across a range of metalearning tasks, and improves state-of-the-art performance on an important practical problem, genotype imputation.

1. INTRODUCTION

Recent advances in deep learning have been driven by large-scale parametric models (Krizhevsky et al., 2012; Peters et al., 2018; Devlin et al., 2019; Brown et al., 2020; Ramesh et al., 2022) . Modern parametric models rely on large numbers of weights to capture the signal contained in the training set and to facilitate generalization (Frankle & Carbin, 2018; Kaplan et al., 2020) ; as a result, they require non-trivial computational resources (Hoffmann et al., 2022) , have limited interpretability (Belinkov, 2022) , and impose a significant carbon footprint (Bender et al., 2021) . This paper focuses on an alternative semi-parametric approach, in which we have access to the training set D train = {x (i) , y (i) } n i=1 at inference time and learn a parametric mapping y = f θ (x | D train ) conditioned on this dataset. Semi-parametric models can query the training set D train and can therefore express rich and interpretable mappings with a compact f θ . Examples of the semi-parametric framework include retrieval-augmented language models (Grave et al., 2016; Guu et al., 2020; Rae et al., 2021) and non-parametric transformers (Wiseman & Stratos, 2019; Kossen et al., 2021) . However, existing approaches are often specialized to specific tasks (e.g., language modeling (Grave et al., 2016; Guu et al., 2020; Rae et al., 2021) or sequence generation (Graves et al., 2014) ), and their computation scales superlinearly in the size of the training set (Kossen et al., 2021) . Here, we introduce semi-parametric inducing point networks (SPIN), a general-purpose architecture whose computational complexity at training time scales linearly in the size of the training set D train and in the dimensionality of x and that is constant in D train at inference time. Our architecture is inspired by inducing point approximations (Snelson & Ghahramani, 2005; Titsias, 2009; Wilson & Nickisch, 2015; Evans & Nair, 2018; Lee et al., 2018) and relies on a cross-attention mechanism between datapoints (Kossen et al., 2021) . An important application of SPIN is in meta-learning, where conditioning on large training sets provides the model additional signal and improves accuracy, but poses challenges for methods that scale superlinearly with D train . We use SPIN as the basis of the Inducing Point Neural Process (IPNP), a scalable probabilistic model that supports accurate meta-learning with large context sizes that cause existing methods to fail. We evaluate SPIN and IPNP on a range of supervised and meta-learning benchmarks and demonstrate the efficacy of SPIN on a real-world task in genomics-genotype imputation (Li et al., 2009) . In meta-learning experiments, IPNP supports querying larger training sets, which yields high accuracy in settings where existing methods run out of memory. In the genomics setting, SPIN outperforms highly engineered stateof-the-art software packages widely used within commercial genomics pipelines (Browning et al., 2018b) , indicating that our technique has the potential to impact real-world systems. Contributions In summary, we introduce SPIN, a semi-parametric neural architecture inspired by inducing point methods that is the first to achieve the following characteristics: 1. Linear time and space complexity in the size and the dimension of the data during training. 2. The ability to learn a compact encoding of the training set for downstream applications. As a result, at inference time, computational complexity does not depend on training set size. We use SPIN as the basis of the IPNP, a probabilistic model that enables performing meta-learning with context sizes that are larger than what existing methods support and that achieves high accuracy on important real-world tasks such as genotype imputation.

2. BACKGROUND

Parametric and Semi-Parametric Machine Learning Most supervised methods in deep learning are parametric. Formally, given a training set D train = {x (i) , y (i) } n i=1 with features x ∈ X and labels y ∈ Y, we seek to learn a fixed number of parameters θ ∈ Θ of a mapping y = f θ (x) using supervised learning. In contrast, non-parametric approaches learn a mapping y = f θ (x | D train ) that can query the training set D train at inference time; when the mapping f θ has parameters, the approach is called semi-parametric. Many deep learning algorithms-including memory-augmented architectures (Graves et al., 2014; Santoro et al., 2016) , retrieval-based language models (Grave et al., 2016; Guu et al., 2020; Rae et al., 2021) , and non-parametric transformers (Kossen et al., 2021) -are instances of this approach, but they are often specialized to specific tasks, and their computation scales superlinearly in n. This paper develops scalable and domain-agnostic semi-parametric methods. Meta-Learning and Neural Processes An important application of semi-parametric methods is in meta-learning, where we train a model to achieve high performance on new tasks using only a small amount of data from these tasks. Formally, consider a collection of D datasets (or a metadataset) {D (d) } D d=1 , each defining a task. Each A Motivating Application: Genotype Imputation A specific motivating example for developing efficient semi-parametric methods is the problem of genotype imputation. Consider the problem of determining the genomic sequence y ∈ {A, T, C, G} k of an individual; rather than directly measuring y, it is common to use an inexpensive microarray device to measure a small subset of genomic positions x ∈ {A, T, C, G} p , where p ≪ k. Genotype imputation is the task of determining y from x via statistical methods and a dataset D train = {x (i) , y (i) } n i=1 of fully-sequenced individuals (Li et al., 2009) . Imputation is part of most standard genome analysis workflows. It is also a natural candidate for semi-parametric approaches (Li & Stephens, 2003) : a query genome y can normally be represented as a combination of sequences y (i) from D train because of the biological principle of D (d) = (D (d) c , D (d) t ) contains a set of context points D (d) c = {x (di) c , y (di) c } m i=1 and target points D (d) t = {x (di) t , y



t } n i=1 . Meta-learning seeks to produce a model f (x; D c ) that outputs accurate predictions for y on D t and on pairs (D c , D t ) not seen at training time. Neural Process (NP) architectures perform uncertainty aware meta-learning by mapping context sets to representations r c (D c ), which can be combined with target inputs to provide a distribution on target labels y t ∼ p(y|x t , r c (D c )), where p is a probabilistic model. Recent successes in NPs have been driven by attention-based architectures (Kim et al., 2018; Nguyen & Grover, 2022), whose complexity scales super-linearly with context size D c -our method yields linear complexity. In concurrent work, Feng et al. (2023) propose a linear time method, using cross attention to reduce the size of context datasets.

