GENERALIZATION PROPERTIES OF RETRIEVAL-BASED MODELS Anonymous authors Paper under double-blind review

Abstract

Many modern high-performing machine learning models such as GPT-3 primarily rely on scaling up models, e.g., transformer networks. Simultaneously, a parallel line of work aims to improve the model performance by augmenting an input instance with other (labeled) instances during inference. Examples of such augmentations include task-specific prompts and similar examples retrieved from the training data by a nonparametric component. Remarkably, retrieval-based methods have enjoyed success on a wide range of problems, ranging from standard natural language processing and vision tasks to protein folding, as demonstrated by many recent efforts, including WebGPT and AlphaFold. Despite growing literature showcasing the promise of these models, the theoretical underpinning for such models remains underexplored. In this paper, we present a formal treatment of retrieval-based models to characterize their generalization ability. In particular, we focus on two classes of retrieval-based classification approaches: First, we analyze a local learning framework that employs an explicit local empirical risk minimization based on retrieved examples for each input instance. Interestingly, we show that breaking down the underlying learning task into local sub-tasks enables the model to employ a low complexity parametric component to ensure good overall accuracy. The second class of retrieval-based approaches we explore learns a global model using kernel methods to directly map an input instance and retrieved examples to a prediction, without explicitly solving a local learning task.

1. INTRODUCTION

As our world is complex, we need expressive machine learning models to make high accuracy predictions on real world problems. There are multiple ways to increase expressiveness of a machine learning model. A popular way is to homogeneously scale the size of a parametric model, such as neural networks, which has been behind many recent high-performance models such as GPT-3 (Brown et al., 2020) and ViT (Dosovitskiy et al., 2021) . Their performance (accuracy) exhibits a monotonic behavior with increasing model size, as demonstrated by "scaling laws" (Kaplan et al., 2020) . Such large models, however, have their own limitations, including high computation cost, catastrophic forgeting (hard to adapt to changing data), lack of provenance, and explanability. Classical instancebased models Fix & Hodges (1989) , on the other hand, offer many desirable properties by designefficient data structures, incremental learning (easy addition and deletion of knowledge), and some provenance for its prediction based on the nearest neighbors w.r.t. the input. However, these models often suffer from weaker empirical performance as compared to deep parametric models. Increasingly, a middle ground combining the two paradigms and retaining the best of both worlds is becoming popular across various domains, ranging from natural language (Das et al., 2021; Wang et al., 2022; Liu et al., 2022; Izacard et al., 2022 ), to vision (Liu et al., 2015; 2019; Iscen et al., 2022; Long et al., 2022) , to reinforcement learning (Blundell et al., 2016; Pritzel et al., 2017; Ritter et al., 2020) , to even protein structure predictions (Cramer, 2021) . In such approaches, given a test input, one first retrieves relevant entries from a data index and then processes the retrieved entries along with the test input to make the final predictions using a machine learning model. This process is visualized in Figure 1b . For example, in semantic parsing, models that augment a parametric seq2seq model with similar examples have not only outperformed much larger models but also are more robust to changes in data (Das et al., 2021) . While classical learning setups (cf. Figure 1a ) have been studied extensively over decades, even basic properties and trade-offs pertaining to retrieval-based models (cf. Figure 1b ), despite their aforementioned remarkable successes, remain highly under-explored. Most of the existing efforts on retrieval-based machine learning models solely focus on developing end-to-end domain-specific models, without identifying the key dataset properties or structures that are critical in realizing performance gains by such models. Furthermore, at first glance, due to the highly dependent nature of an input and the associated retrieved set, direct application of existing statistical learning techniques does not appear as straightforward. This prompts the natural question: What should be the right theoretical framework that can help rigorously showcase the value of the retrieved set in ensuring superior performance of modern retrieval-based models? In this paper, we take the first step towards answering this question, while focusing on the classification setting (Sec. 2.1). We begin with the hypothesis that the model might be using the retrieved set to do local learning implicitly and then adapt its predictions to the neighborhood of the test point. This idea is inspired from Bottou & Vapnik (1992) . Such local learning is potentially beneficial in cases where the underlying task has a local structure, where a much simpler function class suffices to explain the data in a given local neighborhood but overall the data can be complex (formally defined in Sec. 2.2). For instance looking at a few answers at Stackoverflow even if not for same problem may help us solve our issue much faster than understanding the whole system. We try to formally show this effect. We begin by analyzing an explicit local learning algorithm: For each test input, (1) we retrieve a few training examples located in the vicinity of the test input, (2) train a local model by performing empirical risk minimization (ERM) with only these retrieved examples -local ERM; and (3) apply the resulting local model to make prediction on the test input. For the aforementioned retrieval-based local ERM, we derive finite sample generalization bounds that highlight a trade-off between the complexity of the underlying function class and size of neighborhood where local structure of the data distribution holds in Sec. 3. Under this assumption of local regularity, we show that by using a much simpler function class for the local model, we can achieve a similar loss/error to that of a complex global model (Thm. 3.4). Thus, we show that breaking down the underlying learning task into local sub-tasks enables the model to employ a low complexity parametric component to ensure good overall accuracy. Note that the local ERM setup is reminiscent of semiparametric polynomial regression (Fan & Gijbels, 2018) in statistics, which is a special case of our setup. However, the semiparametric polynomial regression have been only analyzed asymptotically under mean squared error loss (Ruppert & Wand, 1994) and its treatment under a more general loss is unexplored. We acknowledge that such local learning cannot be the complete picture behind the effectiveness of retrieval-based models. As noted in Zakai & Ritov (2008) , there always exists a model with global component that is more "preferable" to a local-only model. In Sec. 3.2, we extend local ERM to a two-stage setup: First learn a global representation using entire dateset, and then utilize the representation at the test time while solving the local ERM as previously defined. This enables the local learning to benefit from good quality global representations, especially in sparse data regions. Finally, we move beyond explicit local learning to a setting that resembles more closely the empirically successful systems such as REINA, WebGPT, and AlphaFold: A model that directly learns to predict from the input instance and associated retrieved similar examples end-to-end. Towards this, we take a preliminary step in Sec. 4 by studying a novel formulation of classification over an extended feature space (to account for the retrieved examples) by using kernel methods (Deshmukh et al., 2019) .



Figure 1: An illustration of a retrieval-based classification model. Given an input instance x, similar to an instance-based model, it retrieves similar (labeled) examples R x = {(x j , y j )} j from training data. Subsequently, it processes (potentially via a nonparametric method) input instance along with the retrieved examples to make the final prediction ŷ = f (x, R x ).

