SEMI-PARAMETRIC INDUCING POINT NETWORKS AND NEURAL PROCESSES

Abstract

We introduce semi-parametric inducing point networks (SPIN), a general-purpose architecture that can query the training set at inference time in a compute-efficient manner. Semi-parametric architectures are typically more compact than parametric models, but their computational complexity is often quadratic. In contrast, SPIN attains linear complexity via a cross-attention mechanism between datapoints inspired by inducing point methods. Querying large training sets can be particularly useful in meta-learning, as it unlocks additional training signal, but often exceeds the scaling limits of existing models. We use SPIN as the basis of the Inducing Point Neural Process, a probabilistic model which supports large contexts in metalearning and achieves high accuracy where existing models fail. In our experiments, SPIN reduces memory requirements, improves accuracy across a range of metalearning tasks, and improves state-of-the-art performance on an important practical problem, genotype imputation.

1. INTRODUCTION

Recent advances in deep learning have been driven by large-scale parametric models (Krizhevsky et al., 2012; Peters et al., 2018; Devlin et al., 2019; Brown et al., 2020; Ramesh et al., 2022) . Modern parametric models rely on large numbers of weights to capture the signal contained in the training set and to facilitate generalization (Frankle & Carbin, 2018; Kaplan et al., 2020) ; as a result, they require non-trivial computational resources (Hoffmann et al., 2022) , have limited interpretability (Belinkov, 2022) , and impose a significant carbon footprint (Bender et al., 2021) . This paper focuses on an alternative semi-parametric approach, in which we have access to the training set D train = {x (i) , y (i) } n i=1 at inference time and learn a parametric mapping y = f θ (x | D train ) conditioned on this dataset. Semi-parametric models can query the training set D train and can therefore express rich and interpretable mappings with a compact f θ . Examples of the semi-parametric framework include retrieval-augmented language models (Grave et al., 2016; Guu et al., 2020; Rae et al., 2021) and non-parametric transformers (Wiseman & Stratos, 2019; Kossen et al., 2021) . However, existing approaches are often specialized to specific tasks (e.g., language modeling (Grave et al., 2016; Guu et al., 2020; Rae et al., 2021) or sequence generation (Graves et al., 2014) ), and their computation scales superlinearly in the size of the training set (Kossen et al., 2021) . Here, we introduce semi-parametric inducing point networks (SPIN), a general-purpose architecture whose computational complexity at training time scales linearly in the size of the training set D train and in the dimensionality of x and that is constant in D train at inference time. Our architecture is inspired by inducing point approximations (Snelson & Ghahramani, 2005; Titsias, 2009; Wilson & Nickisch, 2015; Evans & Nair, 2018; Lee et al., 2018) and relies on a cross-attention mechanism between datapoints (Kossen et al., 2021) . An important application of SPIN is in meta-learning, where conditioning on large training sets provides the model additional signal and improves accuracy, but poses challenges for methods that scale superlinearly with D train . We use SPIN as the basis of the Inducing Point Neural Process (IPNP), a scalable probabilistic model that supports accurate meta-learning with large context sizes that cause existing methods to fail. We evaluate SPIN and IPNP on a range of supervised and meta-learning benchmarks and demonstrate the efficacy of SPIN on a real-world task in genomics-genotype imputation (Li et al., 2009) . In meta-learning experiments, IPNP supports querying larger training sets, which yields high accuracy in settings where existing methods run out of memory. In the genomics setting, SPIN outperforms highly engineered stateof-the-art software packages widely used within commercial genomics pipelines (Browning et al., 2018b) , indicating that our technique has the potential to impact real-world systems. Contributions In summary, we introduce SPIN, a semi-parametric neural architecture inspired by inducing point methods that is the first to achieve the following characteristics: 1. Linear time and space complexity in the size and the dimension of the data during training. 2. The ability to learn a compact encoding of the training set for downstream applications. As a result, at inference time, computational complexity does not depend on training set size. We use SPIN as the basis of the IPNP, a probabilistic model that enables performing meta-learning with context sizes that are larger than what existing methods support and that achieves high accuracy on important real-world tasks such as genotype imputation.

2. BACKGROUND

Parametric and Semi-Parametric Machine Learning Most supervised methods in deep learning are parametric. Formally, given a training set D train = {x (i) , y (i) } n i=1 with features x ∈ X and labels y ∈ Y, we seek to learn a fixed number of parameters θ ∈ Θ of a mapping y = f θ (x) using supervised learning. In contrast, non-parametric approaches learn a mapping y = f θ (x | D train ) that can query the training set D train at inference time; when the mapping f θ has parameters, the approach is called semi-parametric. Many deep learning algorithms-including memory-augmented architectures (Graves et al., 2014; Santoro et al., 2016) , retrieval-based language models (Grave et al., 2016; Guu et al., 2020; Rae et al., 2021) , and non-parametric transformers (Kossen et al., 2021) -are instances of this approach, but they are often specialized to specific tasks, and their computation scales superlinearly in n. This paper develops scalable and domain-agnostic semi-parametric methods. Meta-Learning and Neural Processes An important application of semi-parametric methods is in meta-learning, where we train a model to achieve high performance on new tasks using only a small amount of data from these tasks. Formally, consider a collection of D datasets (or a metadataset) {D (d) } D d=1 , each defining a task. Each D (d) = (D (d) c , D (d) t ) contains a set of context points D (d) c = {x (di) c , y (di) c } m i=1 and target points D (d) t = {x (di) t , y } n i=1 . Meta-learning seeks to produce a model f (x; D c ) that outputs accurate predictions for y on D t and on pairs (D c , D t ) not seen at training time. Neural Process (NP) architectures perform uncertainty aware meta-learning by mapping context sets to representations r c (D c ), which can be combined with target inputs to provide a distribution on target labels y t ∼ p(y|x t , r c (D c )), where p is a probabilistic model. Recent successes in NPs have been driven by attention-based architectures (Kim et al., 2018; Nguyen & Grover, 2022) , whose complexity scales super-linearly with context size D c -our method yields linear complexity. In concurrent work, Feng et al. (2023) propose a linear time method, using cross attention to reduce the size of context datasets. A Motivating Application: Genotype Imputation A specific motivating example for developing efficient semi-parametric methods is the problem of genotype imputation. Consider the problem of determining the genomic sequence y ∈ {A, T, C, G} k of an individual; rather than directly measuring y, it is common to use an inexpensive microarray device to measure a small subset of genomic positions x ∈ {A, T, C, G} p , where p ≪ k. Genotype imputation is the task of determining y from x via statistical methods and a dataset D train = {x (i) , y (i) } n i=1 of fully-sequenced individuals (Li et al., 2009) . Imputation is part of most standard genome analysis workflows. It is also a natural candidate for semi-parametric approaches (Li & Stephens, 2003) : a query genome y can normally be represented as a combination of sequences y (i) from D train because of the biological principle of recombination (Kendrew, 2009) , as shown in Figure 1 . Additionally, the problem is a poor fit for parametric models: k can be as high as 10 9 and there is little correlation across non-proximal parts of y. Thus, we need an unwieldy number of parametric models (one per subset of y), whereas a single semi-parametric model can run imputation across the genome. Figure 1 : Genotype recombination Attention Mechanisms Our approach for designing semi-parametric models relies on modern attention mechanisms (Vaswani et al., 2017) , specifically dot-product attention Att(Q, K, V), which combines a query matrix Q ∈ R dq×eq with key and value matrices K ∈ R dv×eq , V ∈ R dv×ev as Att(Q, K, V) = softmax(QK ⊤ / √ e q )V To attend to different aspects of the keys and values, multi-head attention (MHA) extends this mechanism via e h attention heads: MHA(Q, K, V) = concat(O 1 , ...O e h )W O O j = Att(QW Q j , KW K j , VW V j ) Each attention head projects Q, K, V into a lower-dimensional space using learnable projection matrices W Q j , W K j ∈ R eq×e qh , W V j ∈ R ev,e vh and mixes the outputs of the heads using W O ∈ R e h e vh ×eo . As is commonly done, we assume that e vh = e v /e h , e qh = e q /e h , and e o = e q . Given two matrices X, H ∈ R d×e , a multi-head attention block (MAB) wraps MHA together with layer normalization and a fully connected layerfoot_0 : MAB(X, H) = O + FF(LayerNorm(O)) O = X + MHA(LayerNorm(X) , H, H) Attention in semi-parametric models normally scales quadratically in the dataset size (Kossen et al., 2021) ; our work is inspired by efficient attention architectures (Lee et al., 2018; Jaegle et al., 2021b) and develops scalable semi-parametric models with linear computational complexity.

3.1. SEMI-PARAMETRIC LEARNING BASED ON NEURAL INDUCING POINTS

A key challenge posed by semi-parametric methods-one affecting both classical kernel methods (Hearst et al., 1998) as well as recent attention-based approaches (Kossen et al., 2021) -is the O(n 2 ) computational cost per gradient update at training time, due to pairwise comparisons between training set points . Our work introduces methods that reduce this cost to O(hn)-where h ≪ n is a hyper-parameter-without sacrificing performance. Neural Inducing Points Our approach is based on inducing points, a technique popular in approximate kernel methods (Wilson & Nickisch, 2015; Lee et al., 2018) . A set of inducing points H = {h (j) } h j=1 can be thought of as a "virtual" set of training instances that can replace the training set D train . Intuitively, when D train is large, many datapoints are redundant-for example, groups of similar x (i) can be replaced with a single inducing point h (j) with little loss of information. The key challenge in developing inducing point methods is finding a good set H. While classical approaches rely on optimization techniques (Wilson & Nickisch, 2015) , we use an attention mechanism to produce H. Each inducing point h (j) ∈ H attends over the training set D train to select its relevant "neighbors" and updates itself based on them. We implement attention between H and D train in O(hn) time. Dataset Encoding Note that once we have a good set of inducing points H, it becomes possible to discard D train and use H instead for all future predictions. The parametric part of the model makes predictions based on H only. This feature is an important capability of our architecture; computational complexity is now independent of D and we envision this feature being useful in applications where sharing D train is not possible (e.g., for computational or privacy reasons).

3.2. SEMI-PARAMETRIC INDUCING POINT NETWORKS

Next, we describe semi-parametric inducing point networks (SPIN), a domain-agnostic architecture with linear-time complexity.

Notation and Data Embedding

The SPIN model relies on a training set D train = {x (i) , y (i) } n i=1 with input features x (i) ∈ X and labels y (i) ∈ Y where X , Y ∈ V, which is the input and output vocabularyfoot_1 . We embed each dimension (each attribute) of x and y into an e-dimensional embedding and represent D train as a tensor of embeddings D = Embed(D train ), D ∈ R n×d×e , where d = p + k is obtained from concatenating the sequence of embeddings for each x (i) and y (i) . The set D train is used to learn inducing points H = {h (j) } h j=1 ; similarly, we represent H via a tensor H ∈ R h×f ×e of h ≤ n inducing points, each being a sequence of f ≤ d embeddings of size e. To make predictions and measure loss on a set of b examples D query = {x (i) , y (i) } b i=1 , we use the same embedding procedure to obtain a tensor of input embeddings X query ∈ R b×d×e by embedding {x (i) , 0} b i=1 , in which the labels have been masked with zeros. We also use a tensor Y gold ∈ R b×d to store the ground truth labels and inputs (the objective function we use requires the model to make predictions on masked input elements as well, see below for details). 

3.2.1. ARCHITECTURE OF THE ENCODER AND PREDICTOR

Each layer of the encoder consists of three sublayers denoted as XABA, XABD, ABLA. An encoder layer takes as input H A , H D and feeds its outputs H ′ A , H ′ D (defined below) into the next layer. The initial inputs H A , H D of the first encoder layer are randomly initialized learnable parameters. H ′ A = XABA(H A , D) H ′ D = XABD(H D , H ′ A ) H A = ABLA(H ′ A ) Cross-Attention Between Attributes (XABA) An XABA layer captures the dependencies among attributes via cross-attention between the sequence of latent encodings in H and the sequence of datapoint features in D. XABA(H A , D) = MAB(H A , D) This updates the features of each datapoint in H A to be a combination of the features of the corresponding datapoints in D. In effect, this reduces the dimensionality of the datapoints (from n × d × e to n × f × e). The time complexity of this layer is O(ndf e), where f is the dimensionality of the reduced tensor.

Cross-Attention Between Datapoints (XABD)

The XABD layer is the key module that takes into account the entire training set to generate inducing points. A and uses its selected datapoints to update its own representation. The computational complexity of this operation is O(nhf e), which is linear in training set size n.

First, it reshapes ("unfolds") its input tensors H

′ A ∈ R n×f ×e and H D ∈ R h×f ×e into ones of dimensions (1 × n × f e) and (1 × h × f e) respectively.

It then performs cross-attention between

Self-Attention Between Latent Attributes (ABLA) The third type of layer further captures dependencies among attributes by computing regular self-attention across attributes: ABLA(H ′ A ) = MAB(H ′ A , H ′ A ) This enables the inducing points to refine their internal representations. The dataset encoder consists of a sequence of the above layers, see Figure 3 . The ABLA layers are optional based on validation performance. The input H D to the first layer is part of the learned model parameters; the initial H A is a linear projection of D. The output of the encoder is the output H ′ D of the final layer. Predictor Architecture The predictor is a parametric model that maps an input tensor X query to an output tensor of logits Y query . The predictor can use any parametric model. We propose an architecture based on a simple cross-attention operation followed by a linear projection to the vocabulary size, as shown in Figure 3 : Predict(X query , H) = FF(MAB(unfold(X query ), unfold(H)))

3.3. INDUCING POINT NEURAL PROCESSES

An important application of SPIN is in meta-learning, where conditioning on larger training sets provides more information to the model, and therefore has potential to improve predictive accuracy. However, existing methods scale superlinearly with D train , and may not effectively leverage large contexts. We use SPIN as the basis of the Inducing Point Neural Process (IPNP), a scalable probabilistic model that supports fast and accurate meta-learning on large context sizes. An IPNP defines a probabilistic model p(y|x, r(x, D c )) of a target variable y conditioned on an input x and a context dataset D c . This context is represented via a fixed-dimensional context vector r(x, D c ), and we use the SPIN architecture to parameterize r as a function of D c . Specifically, we define r c = Encoder(Embed(D c )), where Encoder is the SPIN encoder, producing a tensor of inducing points. Then, we compute r(x, r c ) = MAB(x, r c ) via cross-attention. The model p(y|x, r) is a distribution with parameters ϕ(x, r), e.g., a Normal distribution with ϕ = (µ, Σ) or a Bernoulli with ϕ ∈ [0, 1]. We parameterize the mapping ϕ(x, r) with a fully-connected neural network. We further extend IPNPs to incorporate a latent variable z that is drawn from a Gaussian p(z|D c ) parameterized by ϕ z = m(Encoder(Embed(D c ))), where m represents mean pooling across datapoints. This latent variable can be thought of as capturing global uncertainty (Garnelo et al., 2018) . This results in a distribution p(y, z|x, D c ) = p(y|z, x, D c )p(z|D c ), where p(y|z, x, D c ) is parameterized by ϕ(z, x, r c ), with ϕ itself being a fully connected neural network. See Appendix A.6 for more detailed architectural breakdowns. Following terminology in the NP literature, we refer to our model as a conditional IPNP (CIPNP) when there is no latent variable z present.

3.4. OBJECTIVE FUNCTION

SPIN We train SPIN models using a supervised learning loss L labels (e.g., ℓ 2 loss for regression, cross-entropy for classification). We also randomly mask attributes and add an additional loss term L attributes that asks the model to reconstruct the missing attributes, yielding the following objective: L SPIN = (1 -λ)L labels + λL attributes Following Kossen et al. (2021) , we start with a weight λ of 0.5 and anneal it to lean towards zero. We detail the loss terms and construction of mask matrices over labels and attrbutes in Appendix A.2. Following prior works (Devlin et al., 2019; Ghazvininejad et al., 2019; Kossen et al., 2021) , we use random token level masking. Additionally, we propose chunk masking, similar to the span masking introduced in (Joshi et al., 2019) , where a fraction ρ of the samples selected have the mask matrix for labels M (i) = 1, and we show the effectiveness of chunk masking in Table 5 . IPNP Following the NP literature, IPNPs are trained on a meta-dataset {D (d) } D d=1 of context and training points D (d) = (D (d) c , D (d) t ) to maximize the log likelihood of the target labels under the learned parametric distribution L IPNP = -1 |D| D d=1 n i=1 log p(y (di) t | D (d) c , x (di) t ). For latent variable NPs, the objective is a variational lower bound; see Appendix A.6 for more details.

4. EXPERIMENTS

Semi-parametric models-including Neural Processes for meta-learning-benefit from large context sets D c , as they provide additional training signal. However, existing methods scale superlinearly with D c and quickly run out of memory. In our experiments section, we show that SPIN and IPNP outperform state-of-the-art models by scaling to large D c that existing methods do not support.

4.1. UCI DATASETS

We present experimental results for 10 standard UCI benchmarks, namely Yacht, Concrete, Boston-Housing, Protein (regression datasets), Kick, Income, Breast Cancer, Forrest Cover, Poker-Hand and Higgs Boson (classification datasets). We compare SPIN with Transformer baselines such as NPT (Kossen et al., 2021) and Set Transformers (Set-TF) (Lee et al., 2018) . We also evaluate against Gradient Boosting (GBT) Friedman (2001) , Multi Layer Perceptron (MLP) (Hinton, 1989; Glorot & Bengio, 2010) , and K-Nearest Neighbours (KNN) (Altman, 1992) . Following Kossen et al. (2021) , we measure the average ranking of the methods and standardize across all UCI tasks. To show the memory efficiency of our approach, we also report GPU memory usage peaks and as a fraction of GPU memory used by NPT for different splits of the test dataset in Table 1 . Results SPIN achieves the best average ranking on 10 UCI datasets and uses half the GPU memory compared to NPT. We provide detailed results on each of the datasets and hyperparameter details in Appendix A.4. Importantly, SPIN achieves high performance by supporting larger context sets-we illustrate this in Table 2 , where we compare SPIN and NPT on the Poker Hand dataset (70/20/10 split) using various context sizes. SPIN and NPT achieve 80-82% accuracy with small contexts, but the performance of SPIN approaches 99% as context size is increased, whereas NPT quickly runs out of GPU memory and fails to reach comparable performance.

4.2. NEURAL PROCESSES FOR META-LEARNING

Experimental Setup Following previous work (Kim et al., 2018; Nguyen & Grover, 2022) , we perform a Gaussian process meta-learning experiment, for which we create a collection of datasets (D (d) ) D d=1 , where each D (d) contains random points (x (di) ) m i=1 , where x (di) ∈ R, and target points y (di) = f (d) (x (di) ) obtained from a function f (d) [1024, 2048] . We train several different NP models for 100,000 steps and evaluate their log-likelihood on 3,000 hold out batches, with B, m, n taking the same values as at training time. We evaluate conditional (CIPNP) and latent variable (IPNP) variations of our model (using h = 1 2 • min_ctx inducing points) and compare them to other attention-based NPs: Conditional ANPs (CANP) (Kim et al., 2018) , Bootstrap ANPs (BANP) (Lee et al., 2020) , and latent variable ANPs (ANP) (Kim et al., 2018) . 

Results

The IPNP models attain higher performance than all baselines at most context sizes (Figure 4 ). Interestingly, IPNPs generalize better-recall that IPNPs are more compact models with fewer parameters, hence are less likely to overfit. We also found that increased context size led to improved performance for all models; however, baseline NPs required excessive resources, and BANPs ran out of memory entirely. In contrast, IPNPs scaled to large context sizes using up to 50% less resources.

4.3. GENOTYPE IMPUTATION

Genotype imputation is the task of inferring the sequence y of an entire genome via statistical methods from a small subset of positions x-usually obtained from an inexpensive DNA microarray device (Li et al., 2009) -and a dataset D train = {x (i) , y (i) } n i=1 of fullysequenced individuals (Li et al., 2009) . Imputation is part of most standard workflows in genomics (Lou et al., 2021) and involves mature imputation software (Browning et al., 2018a; Rubinacci et al., 2020) that benefits from over a decade of engineering (Li & Stephens, 2003) . These systems are fully non-parametric and match genomes in D train to x, y; their scalability to modern datasets of up to millions of individuals is a known problem in the field (Maarala et al., 2020) . Improved imputation holds the potential to reduce sequencing costs and improve workflows in medicine and agriculture. Experiment Setup We compare against one of the state-of-the-art packages, Beagle (Browning et al., 2018a) , on the 1000 Genomes dataset (Clarke et al., 2016) , following the methodology described in Rubinacci et al. (2020) . We use 5008 complete sequences y that we divide into train/val/test splits of 0.86/0.12/0.02, respectively, following Browning et al. (2018b) . We construct inputs x by masking positions that do not appear on the Illumina Omni2.5 array (Wrayner) . Our experiments in Table 3 focus on five sections of the genome for chromosome 20. We pre-process the input into sequences of K-mers for all methods (see Appendix A.3). The performance of this task is measured via the Pearson correlation coefficient R 2 between the imputed SNPs and their true value at each position. We compare against NPTs, Set Transformers, and classical machine learning methods. NPT-16, SPIN-16 and Set Transformer-16 refer to models using an embedding dimension 16, a model depth of 4, and one attention head. NPT-64, SPIN-64 and Set Transformer-64 refer to models using an embedding dimension of 64, a model depth of 4, and 4 attention heads. SPIN uses 10 inducing points for datapoints (h=10, f =10). A batch size of 256 is used for Transformer methods, and we train using the lookahead Lamb optimizer (Zhang et al., 2019) . t . We further create three independent versions of this experiment-denoted Full, 50%, and 25%-in which the segments defining (D (d) c , D (d) t ) contain 400, 200, and 100 SNPs respectively. We fit an NPT and a CIPNP model parameterized by SPIN-64 architecture and apply chunk-level masking method instead of token-level masking. Results Table 4 shows that both the CIPNP (SPIN-64) and the NPT-64 model support the meta-learning approach to genotype imputation and achieve high performance, with CIPNP being more accurate. We provide performance for each region within the datasets in Appendix A.3, Table 8 . However, the NPT model cannot handle full-length genomic segments and runs out of memory on the full experiment. This again highlights the ability of SPIN to scale and thus solve problems that existing models cannot. 5 shows the effect of chunk style masking over token level masking for SPIN in order to learn the imputation algorithm. As the genomes are created by copying over chunks due to the biological priniciple of recombination, we find that chunk style masking of labels at train time provides significant improvements over random token level masking for the meta learning genotype imputation task. Ablation Analysis To evaluate the effectiveness of each module, we perform ablation analysis by gradually removing components from SPIN. We remove components one at a time and compare the performance with default SPIN configuration. In Table 6 , we observe that for the genomics dataset (SNPs 424600-424700) and UCI Boston Housing (BH) dataset, both XABD and XABA are crucial components. We discuss ablation with a synthetic experiment setup in Appendix A.7

5. RELATED WORK

Non-Parametric and Semi-Parametric Methods Non-parametric methods include approaches based on kernels (Davis et al., 2011) , such as Gaussian processes (Rasmussen, 2003) and support vector machines (Hearst et al., 1998) . These methods feature quadratic complexity (Bach, 2013) , which motivates a long line of approximate methods based on random projections (Achlioptas et al., 2001) , Fourier analysis (Rahimi & Recht, 2007) , and inducing point methods (Wilson et al., 2015) . Inducing points have been widely applied in kernel machines (Nguyen et al., 2020) , Gaussian processes classification (Izmailov & Kropotov, 2016) , regression (Cao et al., 2013) , semi-supervised learning (Delalleau et al., 2005) , and more (Hensman et al., 2015; Tolstikhin et al., 2021) . Deep Semi-Parametric Models Deep Gaussian Processes (Damianou & Lawrence, 2013) , Deep Kernel Learning (Wilson et al., 2016) , and Neural Processes (Garnelo et al., 2018) build upon classical methods. Deep GPs rely on sophisticated variational inference methods (Wang et al., 2016) , making them challenging to implement. Retrieval augmented transformers (Bonetta et al., 2021) use attention to query external datasets in specific domains such as language modeling (Grave et al., 2016) , question answering (Yang et al., 2018) , and reinforcement learning (Goyal et al., 2022) and in a way that is similar to earlier memory-augmented models (Graves et al., 2014) . 

Attention Mechanisms

The quadratic cost of self-attention (Vaswani et al., 2017 ) can be reduced using efficient architectures such as sparse attention (Beltagy et al., 2020) , Set Transformers (Lee et al., 2018 ), the Performer (Choromanski et al., 2020) , the Nystromer (Xiong et al., 2021) , Long Ranger (Grigsby et al., 2021) , Big Bird (Zaheer et al., 2020) , Shared Workspace (Goyal et al., 2021) , the Perceiver (Jaegle et al., 2021b; a) , and others (Katharopoulos et al., 2020; Wang et al., 2020) . Our work most closely resembles the Set Transformer (Lee et al., 2018) and Perceiver (Jaegle et al., 2021b; a) mechanisms-we extend these mechanisms to cross-attention between datapoints and use them to attend to datapoints, similar to Non-Parametric Transformers (Kossen et al., 2021) . Set Transformers Lee et al. (2018) introduce inducing point attention (ISA) blocks, which replace self-attention with a more efficient cross-attention mechanism that maps a set of d tokens to a new set of d tokens. In contrast, SPIN cross-attention compresses sets of size d into smaller sets of size h < d. Each ISA block also uses a different set of inducing points, whereas SPIN layers iteratively update the same set of inducing points, resulting in a smaller memory footprint. Finally, while Set Transformers perform cross-attention over features, SPIN performs cross-attention between datapoints.

6. CONCLUSION

In this paper, we introduce a domain-agnostic general-purpose architecture, the semi-parametric inducing point network (SPIN) and use it as the basis for Induced Point Neural Process (IPNPs). Unlike previous semi-parametric approaches whose computational cost grows quadratically with the size of the dataset, our approach scales linearly in the size and dimensionality of the data by leveraging a cross attention mechanism between datapoints and induced latents. This allows our method to scale to large datasets and enables meta learning with large contexts. We present empirical results on 10 UCI datasets, a Gaussian process meta learning task, and a real-world important task in genomics, genotype imputation, and show that our method can achieve competitive, if not better, performance relative to state-of-the-art methods at a fraction of the computational cost.

APPENDIX: SEMI-PARAMETRIC INDUCING POINT NETWORKS AND NEURAL PROCESSES A EXPERIMENTAL DETAILS

A.1 COMPUTE RESOURCES We use 24GB NVIDIA GeForce RTX 3090, Tesla V100-SXM2-16GB and NVIDIA RTX A6000-48GB GPUs for experiments in this paper. A result is reported as OOM if it did not fit in the 24GB GPU memory. We do not use multi-GPU training or other memory-saving techniques such as gradient checkpointing, pruning, mixed precision training, etc. but note that these are orthogonal to our approach and can be used to further reduce the computational complexity.

A.2 TRAINING OBJECTIVE

We define a binary mask matrix for a given sample i as M (i) = [m (i) 1 , m (i) 2 , • • • m (i) l ] , where l = k for labels and l = p for attributes. Then the loss over labels and attributes for each sample i is given by, L labels,(i) (y (i) pred , y (i) true , M labels,(i) ) = k j=1 m (i) j L(y (i) pred,j , y (i) true,j ) L attributes,(i) (x (i) pred , x (i) true , M attributes,(i) ) = p j=1 m (i) j L(x (i) pred,j , x (i) true,j ) where L(y (i) pred,j , y (i) true,j ) = - C c=1 y (i) true,j,c log(softmax(y (i) pred,j,c )) Cross Entropy Loss for C-way Classification and L(y (i) pred,j , y (i) true,j ) = (y (i) true,j -y (i) pred,j ) 2 for MSE Loss. L(x (i) pred,j , x (i) true,j ) for attributes that are reconstructed is computed analogously For chunk masking, a fraction ρ of the samples selected have the mask matrix for labels M (i) = 1 M (i) = 1, with probability ρ 0, otherwise

A.3 GENOMIC SEQUENCE IMPUTATION

Imputation is performed on single-nucleotide polymorphisms (SNPs) with a corresponding marker panel specifying the microarray. We randomly sample five sections of the genome for chromosome 20 for conducting experiments. Each section is selected with 100 SNPs to be predicted and 150 closest SNPs are obtained. For compact encoding of SNPs, we form K-mers, which are commonly used in various genomics applications (Compeau et al., 2011) , where K is a hyper-parameter that controls the granularity of tokenization (how many nucleotides are treated as a single token). This now becomes a 2 K -way classification task. We set K to 5 for all the genomics experiments, so that there are 20 (100/5) target SNPs to be imputed and 30 (150/5) attributes per sampled section. We report pearson R 2 for each of the five sections in Table 7 with error bars per window for five different seeds. For computational load, we report peak GPU memory usage in GB where applicable, an average of train time per epoch in seconds, and parameter count per method. Table 8 provides Pearson R 2 for each of the 10 regions using a single model, thus learning the Genotype imputation algorithm. In Table 9 , we analyze the effect of increasing reference haplotypes during training on pearson R 2 computed by NPT and SPIN. The reference haplotypes in the train dataset are gradually increased from a small fraction of 1% to 100% available. Pearson R 2 is reported cumulatively for 10 randomly selected regions with window size=300. We observe that the performance for both SPIN and NPT improves with increasing reference dataset. However, NPT cannot be used beyond a certain set of reference samples due to its GPU memory footprint, while SPIN yields improved performance. Hyperparameters In Table 10 , we provide the range of hyper-parameters that were grid searched for different methods. Beagle is a specialized software using dynamic programming and does not require any hyper-parameters from the user. In Table 11 , we report results for 10 cross-validation (CV) splits for Yacht and Concrete datasets, 5 CV splits for Boston-Housing datasets, and 1 CV split for Protein dataset. Number of splits were chosen according to computational requirements. Below we provide details about each dataset. • Yacht dataset consists of 308 instances, 1 continuous, and 5 categorical features. • Boston Housing dataset consists of 506 instances, 11 continuous, and 2 categorical features. • Concrete consists of 1030 instances, and 9 continuous features. • Protein consists of 45,730 instances, and 9 continuous features.  L IPNP,ELBO = - 1 |D| D d=1 n i=1 log p θ (y (di) t | z, D (d) c , x (di) t ) + KL(q(z | D (d) t , D (d) c ) ∥ p(z | D (d) c ) where KL is the Kullback-Leibler divergence, q(z | D Hyperparameters Learning rates for all experiments were set to 5e -4 with a Cosine Annealing learning rate scheduler applied. Model parameters were optimized using the ADAM optimizer (Kingma & Ba, 2014) . Architectures In Table 16 and 17 , we detail the architecture for the conditional NPs (CANP, BANP, CIPNP) and latent variable NPs (ANP, IPNP) used in Section 4.2, respectively. Note that although the conditional NPs do not have a latent path, in order to make them comparable in terms of number of parameters we add another deterministic encoding Pooling Encoder to these models, as described in Lee et al. (2020) . In these tables, we remark where XABD is used as opposed to regular self attention between context data points. We use the following shorthand notation below: X c are context features for a batch of datasets stacked into a tensor of size B × m × 1 × 1, and X t is defined similarly. D denotes the full dataset, features and labels for both context and target, stacked into a tensor of size B × m + n × 2 × 1. Note, although the equations described in Section 3.3 and in Table 16 and 17 use tensors of order 4, in practice we use tensors of order 3 and permute the dimensions of the tensor in order to ensure that attention is performed along the correct dimension (i.e., data points). Finally, in Figure 5 , we provide a more detailed diagram of the (conditional) IPNP architecture, which excludes the additional Pooling Encoder.  Q ∈ R B×n×128×1 X c ∈ R B×m×1×1 MLP (1 hidden layer, ReLU activation) K ∈ R B×m×128 D ∈ R B×m×2×1 (1) MLP (1 hidden layer, ReLU activation), V = r c ∈ R B×m×128×1 concat(X t , r, r ′ ) (1) FC ϕ ∈ R B×n×2 ∈ R B×n×257×1 (2) MLP (2 hidden layer, ReLU activation) ϕ chunk (splits input into 2 tensors of equal size)  µ, Σ ∈ R B×n×1 µ, Σ Sampler Y t ∈ R B×n×1 ∼ N (µ, Σ 2 ) ϕ ′ z ∈ R B×256 (2) MAB(D, D) (Self attn between data points) for ANP MAB(H D , D) (XABD) for IPNP (3) Mean pooling (on context points) (4) MLP (1 hidden layer, ReLU activation) ϕ z chunk (splits input into 2 tensors of equal size) µ z , Σ z ∈ R B×128×1 µ z , Σ z (1) Sampler (2) Repeat n times z ∈ R B×n×128×1 ∼ N (µ z , Σ 2 z ) Decoder concat(X t , r, z) (1) FC ϕ ∈ R B×n×2 ∈ R B×n×257×1 (2) MLP (2 hidden layer, ReLU activation) ϕ chunk (splits input into 2 tensors of equal size) µ, Σ ∈ R B×n×1 µ, Σ Sampler Y t ∈ R B×n×1 ∼ N (µ, Σ 2 ) Qualitative Uncertainty Estimation Results In Figure 6 , we show baseline and inducing point NP models trained with context sizes ∈ [64, 128] and display the output of these models on new datasets with varying numbers of context points (4, 8, 16). We observe that the CIPNP and IPNP models better capture uncertainty in regions where context points have not been observed.

Quantitative Calibration Results

To provide more quantitative results of how well our NP models capture uncertainty relative to baselines, we take models trained with context sizes ∈ [64, 128] and evaluate them on 1,000 evaluation batches each with number of targets points ranging from 4 to 64. We repeat this experiment three times with varying numbers of context points (4, 8, 16) available for 7 , we see that in lower context regimes, CIPNP and IPNP models are better calibrated than the other baselines. As context size increases, the calibration of all the models deteriorates. This is further reflected in Table 18 , where we display model calibration scores. Letting CI be confidence intervals ranging from 0 to 1.0 by intervals of 0.1, p CI be the fraction of target labels that fall within confidence interval CI, and n be the number of confidence intervals, this calibration score is equal to: 1 n 1 CI=0 (p CI -CI) 2 This score measures deviation of each model's calibration plot from the 45 • line. Future work will explore the mechanisms that enable inducing point models to better capture uncertainty.

A.7 ABLATION ANALYSIS WITH SYNTHETIC EXPERIMENT

We formulate a synthetic experiment where the model can only learn via XABD layers. First, we initialize a random binary matrix with 50% probability of 1's, number of rows=5000, and number of columns=50. We set the last 20 columns to be target labels. Next, we copy a section of the dataset and divide it into three equal and disjoint parts for train query, val query, and test query. Since there is no correlation between the features and the target, the only way for the model to learn is via XABD layers (for a small dataset the model can also memorize the entire training dataset). This is similar to the synthetic experiment in NPT (Kossen et al., 2021) , except that there is no relation between the features and target in our setup. We find that both default SPIN and SPIN with XABD only component achieves 100% binary classification accuracy, whereas SPIN with XABA only component achieves 70.01% classification accuracy, indicating the effectiveness of XABD component.

A.8 QUALITATIVE ANALYSIS FOR CROSS-ATTENTION

In order to understand what type of inducing points are learnt by the latent H D , we formulate a toy synthetic dataset as shown in Figure 8 . We start with two centroids consisting of binary strings with 120 bits and add bernoulli noise with p = 0.1. We create the labels as another set of 4 bit binary strings and apply bernoulli noise with p = 0.1. In this way we create a dataset with datapoints belonging to two separate clusters. Figure 8 (a) shows the projection of this dataset with two principal We duplicate a part of this data to form the query samples so that they can be looked up from the latent via the cross attention mechanism. Figure 8 (b) shows the schematic of dataset with masked values for query labels. We use a SPIN model with a single XABD module, two induced latents (h=2) and input embedding dimension of 32 and inspect the cross-attention mechanism of the decoder. In Figure 8 (c), we plot the decoder cross attention map between test query and the induced latent and observe that the grouped query datapoints attend to the two latents consistent with the data generating clusters. A.9 IMAGE CLASSIFICATION EXPERIMENTS We conducted additional experiments comparing SPIN and NPT on two image classification datasets. Following NPT, we compare the results for image classification task using a linear patch encoder for MNIST and CIFAR10 dataset (which is the reason for the lower accuracies compared to using CNN-based encoders). Table 19 shows that for the linear patch encoder, SPIN and NPT both perform similarly in terms of accuracy, but SPIN uses far fewer parameters. We conducted sensitivity analysis of SPIN's performance with respect to h and f and found that SPIN is fairly robust to the choice of these hyper-parameters, as evidenced by the low standard deviations in Table 20 . This reflects redundancy in data and why attending to the entire dataset is inefficient. 

B COMPLEXITY ANALYSIS

We provide time complexity for one gradient step, with n l as the number of layers, batch size b equal to training dataset size n during training, and one query sample during inference for transformer methods in Table 21 . There are two operations that contribute heavily to the time complexity. First is computation of Q.K T , second is the four times expansion in the feedforward layers. For NPT, the time complexity is given by maximum of ABD, ABA, and four times expansion in feedforward layers during ABD, that is max(n l n 2 de, n l nd 2 e, 4n l nd 2 e 2 ) during training and inference. Set Transformer consists of ISAB blocks that perform one cross-attention between latent and dataset to project into smaller space and a cross attention between dataset and latent to project back into input space for each layer. This results in complexity that is max(2n l ndf e, 2n l nhf e, 8n l nd 2 e 2 ) during training and inference. For SPIN, the time complexity is given by maximum of XABD, XABA, ABLA, four times expansion in feedforward layers and one cross-attention for Predictor module. This can be formulated as max(n l nhf e, n l ndf e, n l nf 2 e, 4n l nf 2 e, nhde, 4nd 2 e 2 ). At inference, SPIN only uses the Predictor module, with the resultant complexity as max(hde, 4d 2 e 2 ). We C CODE AND DATA AVAILABILITY UCI and Genomic Task Code The experimental results for UCI and genomic task can be reproduced from here. Neural Processes Code The experimental results for the Neural Processes task can be reproduced from here.

Data for Genomics Experiment

The vcf file containing genotypes can be downloaded from 1000Genomes chromosome 20 vcf file. Additionally, the microarray used for genomics experiment can be downloaded from HumanOmni2.5 microarray. Beagle software, used as baseline, can be obtained from Beagle 5.1.

UCI Datasets

All UCI datasets can be obtained from UCI Data Repository.



We use the pre-norm parameterization for residual connections and omit details such as dropout, seeNguyen & Salazar (2019) for the full parameterization. Here we consider the case where both input and output are discrete, but our approach easily generalizes to continuous input and output spaces. NPT reports a mean of 1.27 on this task that we could not reproduce. However, we emphasize that for UCI experiments, all the model parameters are kept same for all the transformer methods. The complexity at test time for SPIN is max(nhde, 4nd 2 e 2 ) when either using the optional cross-attention in the predictor module or when the encoder is enabled at test time, such as in the multiple windows genomic experiment.



Figure 2: SPIN Model Structure

Figure 3: SPIN Architecture. Each layer of the encoder consists of sublayers XABA, ABLA and XABD, and the predictor consists of a cross attention layer. We omit feedforward layers for simplicity. the two unfolded tensors. The output of cross-attention has dimension (1 × h × f e); it is reshaped ("folded") into an output tensor of size (h × f × e). XABD(H D , H ′ A ) = fold(MAB(unfold(H D ), unfold(H ′ A )) This layer produces inducing points. Each inducing point in H D attends to dimensionality-reduced datapoints in H ′A and uses its selected datapoints to update its own representation. The computational complexity of this operation is O(nhf e), which is linear in training set size n.

Figure 4: Inducing Point NPs outperform NP baselines and train much faster. Plots display mean ± std. deviation from 5 runs with different random seeds.

of D independent genomic segments, D (d) c is the set of reference genomes in that segment, and D (d) t is the set of genomes that we want to learn to impute. At each meta-training step, we sample a new pair (D (d) c , D (d) t ) and update the model parameters to maximize the likelihood of D (d)

Non-Parametric Transformers (Kossen et al., 2021) use a domain-agnostic architecture based on attention that runs in O(n 2 d 2 ) at training time and O(nd 2 ) at inference time, while ours runs in O(nd) and O(d), respectively.

c ) is the posterior distribution conditioned on target and context sets, and p(z | D (d) c ) is the prior conditioned only on the context.

Figure 5: CIPNP architecture diagram

MAB(D, D) (Self attn between data points) for CANP/BANPMAB(H D , D) (XABD) for CIPNP Q, K, VCross attn between query and context points r ∈ R B×n×128×1Pooling Encoder D ∈ R B×m×2×1 (1) MLP (1 hidden layer, ReLU activation),r ′ ∈ R B×n×128×1

MAB(D, D) (Self attn between data points) for CANP/BANP MAB(H D , D) (XABD) for CIPNP (3) Mean pooling (on context points) (4) MLP (1 hidden layer, ReLU activation) (5) Repeat n times Decoder

Figure 6: Predicted mean ± standard deviation of y for different NP models given varying context sizes: 4 (top), 8 (middle), and 16 (bottom).

Figure 7: Calibration of NP models given varying context sizes: 4 (left), 8 (middle), and 16 (right).

Figure 8: Synthetic Experiment analyzing cross-attention: (a)-(b) Data generating process to form two distinct clusters and data matrix with duplicate query samples and their labels masked. (c) Cross-attention map between query samples and the latent H D components when projected with PCA, highlighting that the dataset consists of two distinct clusters.We duplicate a part of this data to form the query samples so that they can be looked up from the latent via the cross attention mechanism. Figure8 (b)shows the schematic of dataset with masked values for query labels. We use a SPIN model with a single XABD module, two induced latents (h=2) and input embedding dimension of 32 and inspect the cross-attention mechanism of the decoder. In Figure8(c), we plot the decoder cross attention map between test query and the induced latent and observe that the grouped query datapoints attend to the two latents consistent with the data generating clusters.

Performance Summary on UCI Datasets

sampled from a Gaussian Process. At each meta-training step, we sample B = 16 functions {f (b) } B b=1 . For each f (b) , we sample m ∼ U[min_ctx, max_ctx] context points and n ∼ U[min_tgt, max_tgt] target points. The range for n is fixed across all experiments at [4, 64]. The range for m is varied from [64, 128], [128, 256], [256, 512], [512, 1024],

Performance Summary on Genomic Sequence Imputation. ( * ) represents parametric models. A difference of 0.5% is statistically significant at pvalue 0.05. ↑ 87.63 95.31 89.70 95.64 95.84 ±0.06 95.97±0.09 95.92 ±0.12 Table3presents the main results for genotype imputation. Compared to the previous state-of-the-art commercial software, Beagle, which is specialized to this task, all Transformerbased methods achieve strong performance, despite making fewer assumptions and being more general. While all the three Transformer-based approaches report similar Pearson R 2 , SPIN achieves competitive performance with a much smaller parameter count. Among traditional ML approaches, MLP perform the best, but requires training one model per imputed SNP, and hence cannot scale to full genomes. We provide additional details on resource usage and hyper-parameter tuning in Appendix A.3.





Ablation Analysis

Performance on Genomics Imputation

Performance on UCI Regression Datasets

, we report results for 10 CV splits for Breast Cancer dataset and 1 CV split for Kick, Income, Forest Cover, Poker-Hand and Higgs Boson datasets. Number of splits were chosen according to computational requirements. Below we provide details about each dataset.

Average Ranking on UCI Classification Dataset (Breast Cancer, Kick, Income, Forest Cover, Poker-Hand, Higgs-Boson based on Classification Accuracy

CANP / BANP / CIPNP architecture (no latent path)

ANP / IPNP architecture (latent path)

Calibration scores (↓) for NP models across context sizes. Calibration score equals mean squared deviation from 45 • line of number of target points falling within confidence intervals ranging from 0 to 1 by intervals of 0.1, see Equation (1). Best (lowest) scores for each context size are bolded.

Image Classification Experiments

Effect of induced points h, f for one genomic window (SNPs 424600-424700)Induced Points h Induced Points f Pearson R 2 ↑ Peak GPU (GB)

note that during training for NPT, if n >4de, then Q.K T computation in ABD dominates, otherwise the four times expansion of feedforward for ABD dominates. For Set Transformer, usually the four times expansion of feedforward dominates. For SPIN, depending on values for n l , d, f, h, e, different computations can dominate, however it is always linear in dataset size n. During inference SPIN's time complexity is independent of number of layers n l and dataset size n and depends entirely on inducing datapoints h, model embedding dimension e, and feature+target space d. l n 2 de, n l nd 2 e, max(n l n 2 de, n l nd 2 e, 4n l nd 2 e 2 ) 4n l nd 2 e 2 ) STF max(2n l ndf e, 2n l nhf e, max(2n l ndf e, 2n l nhf e, 8n l nd 2 e 2 ) 8n l nd 2 e 2 ) SPIN max(n l nhf e, n l ndf e, n l nf 2 e, max(hde, 4d 2 e 2 ) 4 4n l nf 2 e, nhde, 4nd 2 e 2 )

7. ACKNOWLEDGMENTS

This work was supported by Tata Consulting Services, the Cornell Initiative for Digital Agriculture, the Hal & Inge Marcus PhD Fellowship, and an NSF CAREER grant (#2145577). We would like to thank Edgar Marroquin for help with preprocessing of raw genomic data. We would like to thank NPT authors -Jannic and Neil for helpful discussions and correspondence regarding NPT architecture. We would also like to thank the anonymous reviewers for their significant effort to provide suggestions and helpful feedback, thereby improving our paper.

8. REPRODUCIBILITY

We provide details on the compute resources in Appendix A.1, including GPU specifications. Code and data used to reproduce experimental results are provided in Appendix C. We provide error bars on the reported results by varying seeds or a different test split, however for certain large datasets, such as UCI datasets for Kick, Forest Cover, Protein, Higgs and Genomic Imputation experiments with large output sizes, we reported results on a single run due to computational limitations. These details are provided in Appendix A.3, Appendix A.4 and Appendix A.5. 

annex

• Kick dataset consists of 72,983 instances, 14 continuous and 18 categorical features, and 2 target classes.• Income consists of 299,285 instances, 6 continuous and 36 categorical features, and 2 target classes.• Forest Cover consists of 581,012 instances, 10 continuous and 44 categorical features, and 7 target classes.• Poker-Hand consists of 1,025,010 instances, 10 categorical features, and 10 target classes.• Higgs Boson consists of 11,000,000 instances, 28 continuous features, and 2 target classes.We provide the range of hyperparameters for UCI datasets in Table 13 . Additionally, we provide average ranking separated by Regression and Classification tasks in Table 14 and Table 15 , respectively. 

