WHAT LEARNING ALGORITHM IS IN-CONTEXT LEARN-ING? INVESTIGATIONS WITH LINEAR MODELS

Abstract

Neural sequence models, especially transformers, exhibit a remarkable capacity for in-context learning. They can construct new predictors from sequences of labeled examples (x, f (x)) presented in the input without further parameter updates. We investigate the hypothesis that transformer-based in-context learners implement standard learning algorithms implicitly, by encoding smaller models in their activations, and updating these implicit models as new examples appear in the context. Using linear regression as a prototypical problem, we offer three sources of evidence for this hypothesis. First, we prove by construction that transformers can implement learning algorithms for linear models based on gradient descent and closed-form ridge regression. Second, we show that trained in-context learners closely match the predictors computed by gradient descent, ridge regression, and exact least-squares regression, transitioning between different predictors as transformer depth and dataset noise vary, and converging to Bayesian estimators for large widths and depths. Third, we present preliminary evidence that in-context learners share algorithmic features with these predictors: learners' late layers non-linearly encode weight vectors and moment matrices. These results suggest that in-context learning is understandable in algorithmic terms, and that (at least in the linear case) learners may rediscover standard estimation algorithms.

1. INTRODUCTION

One of the most surprising behaviors observed in large neural sequence models is in-context learning (ICL; Brown et al., 2020) . When trained appropriately, models can map from sequences of (x, f (x)) pairs to accurate predictions f (x ′ ) on novel inputs x ′ . This behavior occurs both in models trained on collections of few-shot learning problems (Chen et al., 2022; Min et al., 2022) and surprisingly in large language models trained on open-domain text (Brown et al., 2020; Zhang et al., 2022; Chowdhery et al., 2022) . ICL requires a model to implicitly construct a map from in-context examples to a predictor without any updates to the model's parameters themselves. How can a neural network with fixed parameters to learn a new function from a new dataset on the fly? This paper investigates the hypothesis that some instances of ICL can be understood as implicit implementation of known learning algorithms: in-context learners encode an implicit, contextdependent model in their hidden activations, and train this model on in-context examples in the course of computing these internal activations. As in recent investigations of empirical properties of ICL (Garg et al., 2022; Xie et al., 2022) , we study the behavior of transformer-based predictors (Vaswani et al., 2017) on a restricted class of learning problems, here linear regression. Unlike in past work, our goal is not to understand what functions ICL can learn, but how it learns these functions: the specific inductive biases and algorithmic properties of transformer-based ICL. In Section 3, we investigate theoretically what learning algorithms transformer decoders can implement. We prove by construction that they require only a modest number of layers and hidden units to train linear models: for d-dimensional regression problems, with O(d) hidden size and constant depth, a transformer can implement a single step of gradient descent; and with O(d 2 ) hidden size a Correspondence to akyurek@mit.edu. Ekin is a student at MIT, and began this work while he was intern at Google Research. Code and reference implementations are released at this web page b The work is done when Tengyu Ma works as a visiting researcher at Google Research. and constant depth, a transformer can update a ridge regression solution to include a single new observation. Intuitively, n steps of these algorithms can be implemented with n times more layers. In Section 4, we investigate empirical properties of trained in-context learners. We begin by constructing linear regression problems in which learner behavior is under-determined by training data (so different valid learning rules will give different predictions on held-out data). We show that model predictions are closely matched by existing predictors (including those studied in Section 3), and that they transition between different predictors as model depth and training set noise vary, behaving like Bayesian predictors at large hidden sizes and depths. Finally, in Section 5, we present preliminary experiments showing how model predictions are computed algorithmically. We show that important intermediate quantities computed by learning algorithms for linear models, including parameter vectors and moment matrices, can be decoded from in-context learners' hidden activations. A complete characterization of which learning algorithms are (or could be) implemented by deep networks has the potential to improve both our theoretical understanding of their capabilities and limitations, and our empirical understanding of how best to train them. This paper offers first steps toward such a characterization: some in-context learning appears to involve familiar algorithms, discovered and implemented by transformers from sequence modeling tasks alone.

2. PRELIMINARIES

Training a machine learning model involves many decisions, including the choice of model architecture, loss function and learning rule. Since the earliest days of the field, research has sought to understand whether these modeling decisions can be automated using the tools of machine learning itself. Such "meta-learning" approaches typically treat learning as a bi-level optimization problem (Schmidhuber et al., 1996; Andrychowicz et al., 2016; Finn et al., 2017) : they define "inner" and "outer" models and learning procedures, then train an outer model to set parameters for an inner procedure (e.g. initializer or step size) to maximize inner model performance across tasks. Recently, a more flexible family of approaches has gained popularity. In in-context learning (ICL), meta-learning is reduced to ordinary supervised learning: a large sequence model (typically implemented as a transformer network) is trained to map from sequences [x 1 , f (x 1 ), x 2 , f (x 2 ), ..., x n ] to predictions f (x n ) (Brown et al., 2020; Olsson et al., 2022; Laskin et al., 2022; Kirsch & Schmidhuber, 2021) . ICL does not specify an explicit inner learning procedure; instead, this procedure exists only implicitly through the parameters of the sequence model. ICL has shown impressive results on synthetic tasks and naturalistic language and vision problems (Garg et al., 2022; Min et al., 2022; Zhou et al., 2022) . Past work has characterized what kinds of functions ICL can learn (Garg et al., 2022; Laskin et al., 2022) and the distributional properties of pretraining that can elicit in-context learning (Xie et al., 2021; Chan et al., 2022) . But how ICL learns these functions has remained unclear. What learning algorithms (if any) are implementable by deep network models? Which algorithms are actually discovered in the course of training? This paper takes first steps toward answering these questions, focusing on a widely used model architecture (the transformer) and an extremely well-understood class of learning problems (linear regression).

2.1. THE TRANSFORMER ARCHITECTURE

Transformers (Vaswani et al., 2017) are neural network models that map a sequence of input vectors x = [x 1 , . . . , x n ] to a sequence of output vectors y = [y 1 , . . . , y n ]. Each layer in a transformer maps a matrix H (l) (interpreted as a sequence of vectors) to a sequence H (l+1) . To do so, a transformer layer processes each column h (l) i of H (l) in parallel. Here, we are interested in autoregressive (or "decoder-only") transformer models in which each layer first computes a self-attention: a i = Attention(h (l) i ; W F , W Q , W K , W V ) (1) = W F [b 1 , . . . , b m ] (2) where each b is the response of an "attention head" defined by: b j = softmax (W Q j h i ) ⊤ (W K j H :i ) (W V j H :i ) . (3)

