GENERAL-PURPOSE IN-CONTEXT LEARNING BY META-LEARNING TRANSFORMERS

Abstract

Modern machine learning requires system designers to specify aspects of the learning pipeline, such as losses, architectures, and optimizers. Meta-learning, or learning-to-learn, instead aims to learn those aspects, and promises to unlock greater capabilities with less manual effort. One particularly ambitious goal of meta-learning is to train general-purpose learning algorithms from scratch, using only black-box models with minimal inductive bias. Such a model takes in training data, and produces test-set predictions across a wide range of problems, without any explicit definition of an inference model, training loss, or optimization algorithm. In this paper we show that Transformers and other black-box models can be meta-trained to act as general-purpose in-context learners. We characterize phase transitions between algorithms that generalize, algorithms that memorize, and algorithms that fail to meta-train at all, induced by changes in model size, number of tasks, and meta-optimization. We further show that the capabilities of meta-trained algorithms are bottlenecked by the accessible state size (memory) determining the next prediction, unlike standard models which are thought to be bottlenecked by parameter count. Finally, we propose practical interventions such as biasing the training distribution that improve the meta-training and meta-generalization of general-purpose learning algorithms.

1. INTRODUCTION

Meta-learning is the process of automatically discovering new learning algorithms instead of designing them manually (Schmidhuber, 1987) . An important quality of human-engineered learning algorithms is their applicability to a wide range of tasks or environments. For learning-to-learn to exceed those capabilities, the meta-learned learning algorithms must be similarily general-purpose. Recently, there has been significant progress toward this goal (Kirsch et al., 2019; Oh et al., 2020) . The improved generality of the discovered learning algorithms has been achieved by introducing inductive bias, such as by bottlenecking the architecture or by hiding information, which encourage learning over memorization. Methods include restricting learning rules to use gradients (Metz et al., 2019; Kirsch et al., 2019; Oh et al., 2020) , symbolic graphs (Real et al., 2020; Co-Reyes et al., 2021) , or parameter sharing or symmetries (Kirsch & Schmidhuber, 2020; Kirsch et al., 2021) . While enabling generalization, these inductive biases come at the cost of increasing the effort to design these systems and potentially restrict the space of discoverable learning algorithms. Instead, we seek to explore general-purpose meta-learning systems with minimal inductive bias. Good candidates for this are black-box sequence-models as meta-learners such as LSTMs (Hochreiter et al., 2001; Wang et al., 2016; Duan et al., 2016) or Transformers (Vaswani et al., 2017) . These memorybased or in-context learners take in training data and produce test-set predictions without any explicit definition of an inference model, training loss, or optimization algorithm. This has lead to strong few-shot learning ability within the context of, for example, language modeling (Brown et al., 2020) . In this work, we investigate how such black-box meta-learners can be trained to (meta-)generalize and learn on significantly different datasets than used during meta-training. For this we propose a Transformer-based General-Purpose In-Context Learner (GPICL) which is described with an associated meta-training task distribution in Section 3. In Section 4.1 we characterize transitionsinduced by scaling the number of tasks or the model size used for meta-training-between memorization, learning, and generalization. We further show in Section 4.2 that the capabilities of metatrained algorithms are bottlenecked by their accessible state size determining the next prediction (such as the hidden state size in a recurrent network), unlike standard models which are thought to be bottlenecked by parameter count. Finally, in Section 4.3, we propose practical interventions that improve the meta-training of general purpose learning algorithms.

2. BACKGROUND

What is a (supervised) learning algorithm? In this paper, we focus on the setting of metalearning supervised learning algorithms. Consider a mapping {x i , y i } N D i=1 , x → y (1) from the training (support) set D = {x i , y i } N D i=1 and a query input x to the query's prediction y where x i , x ∈ R Nx , y i , y ∈ R Ny and N D , N x , N y ∈ N + . The subset of these functions that qualify as learning algorithms are those that improve their predictions y given an increasingly larger training set D. Meta-learning then corresponds to finding these functions via meta-optimization. As in other black-box meta-learning models, we use a neural network to represent such functions. What is a general-purpose learning algorithm? A learning algorithm can be considered generalpurpose if it learns on a wide range of possible tasks D and their respective related queries x , y . For example, gradient-descent on a suitable loss function can be considered a very general-purpose human-engineered learning algorithm (where the gradient is obtained via backpropagation or other means). black-box model in the form of input and output permutation invariances. An alternative to this is the generation of new tasks (Schmidhuber, 2013; Clune, 2019; Such et al., 2020; Parker-Holder et al., 2022) . Unfortunately, it is not easy to generate a wide range of tasks that are both diverse and contain structure as it can be found in the real world. In this work, we take an intermediate step by augmenting existing datasets, in effect increasing the breadth of the task distribution based on existing task regularities. We generate a large number of tasks by taking existing supervised learning datasets, randomly projecting their inputs General-Purpose In-Context Learning Transformer x 1 0 x 2 y 1 x′ 3 y 2 … y′ 1 y′ 2 … y′ 3 Third support set Third query set Figure 1 : Our General-Purpose In-Context Learner (GPICL) is based on the vanilla Transformer which is trained to make predictions for queries x given any prefix of a dataset D := {x i , y i } N D i=1 as in Equation 2. and permuting their classification labels. While the random projection removes spatial structure from the inputs, this structure is not believed to be central to the task (for instance, the performance of SGD-trained fully connected networks is invariant to projection by a random orthogonal matrix (Wadia et al., 2021) ). Task augmentation allows us to investigate fundamental questions about learning-to-learn in the regime of many tasks without relying on huge amounts of existing tasks or elaborate schemes to generate those. A task or dataset D is then defined by its corresponding base dataset D = {x i , ȳi }, (linear) projection A ∈ R Nx×Nx , with A ij ∼ N 0, 1 Nx , and output permutation ρ, D = {Ax i , ρ(ȳ i )}. Unless noted otherwise, the distribution over output permutations p(ρ) is uniform.

3.2. META-LEARNING

Given those generated tasks, we then meta-train jointly on a mini-batch sampled from the whole distribution. We minimize J(θ), the sum of losses on the query prediction after observing any prefix of a dataset D sampled from the augmented task distribution p(D) J(θ) = E D∼p(D)   N D j=1 l(f θ (D 1:j-1 , x j ), y j )   , where in the classification setting, l is the cross entropy loss between the label y j and prediction y = f θ (D 1:j-1 , x j ), f θ is a neural network mapping to predictions y as in Equation 1. During meta-training, we take gradient steps in J(θ) by backpropagation and Adam (Kingma & Ba, 2014) . To investigate the effect of the data distribution, we train on various numbers of tasks (Algorithm 1). Finally, we need to choose a black-box model for the function f θ . We use a vanilla Transformer (Vaswani et al., 2017) with learned positional embeddings, visualized in Figure 1 . We call it the General-Purpose In-Context Learner (GPICL). Each token corresponds to a concatenated and transformed input x i and one-hot encoded label y i-1 predicting the corresponding logits y = y i for the current input x = x i . When querying for the first x 1 , no label for the previous input is available, so we feed a zero vector. Algorithm 1 Meta-Training for General-Purpose In-Context Learning (GPICL) Require: Base dataset D = {x i , ȳi }, Number of tasks K ∈ N + {A (k) ij } K k=1 ∼ N (0, 1 Nx ) Sample input projections {ρ (k) } K k=1 ∼ p(ρ) Sample output permutations D (k) = {A (k) xi , ρ (k) (ȳ i )} p(D) := Uniform[{D (k) } K k=1 ] while not converged do θ ← θ -α∇ θ J(θ) Meta-train across tasks p(D) (Equation 2) Meta-testing At meta-test time, no gradient-based learning is used. Instead, we simply obtain a prediction y by evaluating the neural network f θ on the training dataset D and query point x .

4. EXPERIMENTS ON THE EMERGENCE OF GENERAL LEARNING-TO-LEARN

Multi-task training with standard classifiers Given a task distribution of many different classification tasks, we first ask under what conditions we expect "learning-to-learn" to emerge. We train a single model across many tasks where each task corresponds to a random transformation of the MNIST dataset, but where the MLP only receives a single datapoint instead of a whole sequence as input. This corresponds to N D = 1 in Equation 2. We would expect such a non-sequential classifier to be able to correctly predict for more tasks as its number of parameters increases. When plotting the network capacity against the number of tasks, we indeed observe a linear boundary where an increasing number of tasks can be fit the larger the network (Figure 2a ). This is consistent with results in Collins et al. (2016) , which found that a constant number of bits about the data distribution can be stored per model parameter, across a variety of model architectures and scales. A sequence model (here the GPICL Transformer) that observes a dataset D of inputs and labels transitions into generalizing to an seemingly unbounded number of tasks with an increase in model size. This is achieved by switching from a memorization solution to a learning solution that (c) generalizes to unseen tasks. This generalization does not occur with the MLP. Learning-to-learn with large sequential models and data In contrast to the MLP classifier, a sequence model that observes multiple observations and their labels from the same task, could exceed that linear performance improvement by learning at inference time. Indeed, we observe that when switching to a Transformer that can observe a sequence of datapoints before making a prediction about the query, more tasks can be simultaneously fit (Figure 2b ). At a certain model size and number of tasks, the model undergoes a phase transition, allowing to generalize to a seemingly unbounded number of tasks. We hypothesize that this is due to switching the prediction strategy from memorization to learning-to-learn. Further, when (meta-)testing the same trained models from the previous experiment on an unseen task (new random transformation of MNIST), they generalize only in the regime of large numbers of tasks and model size (Figure 2c ). As an in-context learner, meta-testing does not involve any gradient updates but only running the model in forward mode. Insight 1: It is possible to learn-to-learn with black-box models Effective learning algorithms can be realized using black-box models with few inductive biases, given sufficient meta-training task diversity and large enough model sizes. To transition to the learning-to-learn regime, we needed at least 2 13 = 8192 tasks. In the following, we study learning-to-learn from the perspective of the data distribution, the architecture, and the optimization dynamics. For the data distribution, we look at how the data diversity affects the emergence and phase transitions of learning-to-learn, generalization, and memorization. For architecture, we analyze the role of the model size and state size in various architectures. Finally, we observe challenges in meta-optimization and demonstrate how memorization followed by generalization is an important mechanism that can be facilitated explicitly biasing the data distribution.

4.1. THE LARGE DATA REGIME: GENERALIZATION AND PHASE TRANSITIONS

Simple invariances in data lead to the emergence of learning-to-learn To verify whether the observed generalizing solutions actually implement learning algorithms (opposed to e.g. zero-shot Figure 3 : GPICL learns from examples at test time, and generalizes to unseen tasks and datasets. We meta-trained the Transformer on a set of tasks defined by random transformations of either MNIST (blue) or FashionMNIST (orange). We then meta-test on unseen tasks, and seen (ab) or unseen (ba) datasets. The plot shows the accuracy averaged across multiple runs at each inner step, with shading indicating 95% confidence intervals. The increase in performance at each step suggests we have discovered a learning algorithm. Generalization Naturally, the question arises to what extent these learning algorithms are general. While we have seen generalization to unseen tasks consisting of novel projections of the same dataset, do the learned algorithms also generalize to unseen datasets? In Figure 3 we observe outof-distribution performance on Fashion MNIST after having trained on MNIST (b, blue). In this direction, there is no generalization gap to directly training on Fashion MNIST (b, orange). Similarly, when meta training on Fashion MNIST and meta testing on MNIST (a, orange) we observe that the learning algorithm generalizes, albeit with a larger generalization gap. Comparison to other methods Other datasets and baselines are shown in Table 1 . In particular, rather than focusing on SOTA, we aim to validate whether methods with less inductive bias (such as our GPICL), can compete with methods that include more biases suitable to learning-to-learn. This includes stochastic gradient descent (SGD) that updates the parameters online after observing each datapoint. MAML (Finn et al., 2017) proceeds like SGD, but uses a meta-learned neural network initialization. Both methods that rely on backpropagation and gradient descent, learn more slowly compared to our Transformer. In the case of MAML, this may be due to the main mechanism being feature reuse (Raghu et al., 2020) which is less useful when training across our wider task distribution. Among methods that do not hard-code gradient descent at meta-test time, we test VSML (Kirsch & Schmidhuber, 2020 ) that discovered learning algorithms significantly generalizing between tasks. Our GPICL comes surprisingly close to VSML without requiring the associated inductive bias. Finally, we compare to a standard LSTM that is trained with the same inputs as our Transformer. We observe that it performs worse, which we investigate further. Insight 2: Simple data augmentations are effective for learning-to-learn The generality of the discovered learning algorithm can be controlled via the data distribution. Even when large task distributions are not (yet) naturally available, simple augmentations that promote permutation and scale invariance are effective. Transitioning from memorization to learning to generalizing When do the found solutions correspond to memorizing, learning, and generalizing solutions? In Figure 4 , we plot the accuracy difference between the last and first prediction for a seen task, an unseen task, and an unseen task with a different base dataset. We observe three phases: In the first phase, we memorize each instance, resulting in no within-sequence performance improvement. In the second phase, (2) Tasks are memorized, which is evident as a within-sequence increase of performance. (3) When training across many tasks, we discover a learning algorithm that generalizes to unseen tasks and unseen datasets. we memorize tasks and learn to identify tasks, resulting in a within-sequence improvement confined to seen task instances. In the final and third phase, we observe a more general learning-to-learn, a performance improvement for unseen tasks, even different base datasets (here FashionMNIST). The last transition is very discrete with separate meta-training runs either finding a solution of the task memorization or general learning-to-learn type (see Appendix A.1). Insight 3: The meta-learned behavior has phase transitions When increasing the number of tasks, the meta-learned behavior transitions from instance memorization, to task identification, to general learning-to-learn.

4.2. ARCHITECTURE: A LARGE STATE IS CRUCIAL FOR LEARNING

In the previous experiments we observed that given sufficient task diversity and model size, Transformers can learn general-purpose learning algorithms. This raises the question how essential the Transformer architecture is and whether other black-box models could be used. We hypothesize that for learning-to-learn the size of the memory at meta-test time (or state more generally) is particularly important in order to be able to store learning progress. Through self-attention, Transformers have a particularly large state. We test this by training several architectures with various state sizes in our meta-learning setting. In Figure 5a , we observe that when we vary the respective hyper-parameters which most influence the state size, we observe that for a specific state size we obtain similar performance of the discovered learning algorithm across architectures. In contrast, these architectures have markedly different numbers of parameters (Figure 5b ). Insight 4: Large state is more crucial than parameter count This suggests that the model size in terms of parameter count plays a smaller role in the setting of learning-to-learn and Transformers have benefited in particular from an increase in state size by self-attention. Beyond learning-to-learn, this likely applies to other tasks that rely on storing large amounts of sequence-specific information.

4.3. CHALLENGES IN META-OPTIMIZATION

Meta-optimization is known to be challenging. Meta gradients (Finn et al., 2017; Xu et al., 2018; Bechtle et al., 2021) and works with parameter sharing or weight updates in their architecture (Kirsch & Schmidhuber, 2020; Pedersen & Risi, 2021; Risi, 2021) observed various difficulties: Slower convergence, local minima, unstable training, or loss plateaus at the beginning of training (see Appendix Figure 18 ). We show that some of these problems also occur with black-box models and propose effective interventions. Loss plateaus when meta-learning with black-box models By training across a large number of randomly transformed tasks, memorizing any task-specific information is difficult. Instead, the model is forced to find solutions that are directly learning. We observe that this results in (meta-)loss plateaus during meta-training where the loss only decreases slightly for long periods of time (Figure 6a ). Only after a large number of steps (here around 35 thousand) does a drop in loss occur. In the loss plateau, the generalization loss increases on unseen tasks from both the same and a different base dataset (Figure 6b ). This suggests that being able to first memorize slightly enables the following learning-to-learn phase. Furthermore, we observe that all gradients have a very small norm with exception of the last layer (Appendix Figure 14 ). Intervention 1: Increasing the batch size High variance gradients appear to be one reason training trajectories become trapped on the loss plateau (see Appendix Figures 12, 13 ). This suggests increasing the meta-batch size as a straightforward solution. When plotting various batch sizes against numbers of tasks we obtain three kinds of solutions at the end of meta-training (Figure 7a ): (1) Solutions that generalize and learn, (2) Solutions that memorize, and (3) Solutions that are still in the loss plateau (due to maximum of 50 thousand optimization steps). The larger the batch size, the more tasks we can train on without getting stuck in a loss plateau. When plotting the length of the loss plateau against the task batch size (Figure 7b ) we observe a power-law relationship with increasing batch sizes decreasing the plateau length. At the same time, the batch size also increases the number of total tasks seen in the plateau (Appendix Figure 15 ). Thus, this intervention relies on parallelizability. An increase in the number of tasks also increases the plateau length (Figure 7c ). This may be due to a larger number of tasks making the initial memorization phase more difficult. Intervention 2: Changes in the meta-optimizer Given that many gradients in the loss plateau have very small norm, Adam would rescale those element-wise, potentially alleviating the issue. In practice, we observe that the gradients are so small that the in Adam's gradient-rescaling de-nominator (for numerical stability) limits the up-scaling of small gradients. Using smaller results in more than halving the plateau length. Alternatively, discarding the magnitude of the gradient entirely by applying the sign operator to an exponential moving average of the gradient (replacing Adam's approximate magnitude normalization with direct magnitude normalization) has a similar effect while also increasing the numerical stability over Adam with small (Appendix Figure 16 ). Intervention 3: Biasing the data distribution / Curricula GPICL mainly relies on the data distribution for learning-to-learn. This enables a different kind of intervention: Biasing the data distribution. The approach is inspired by the observation that before leaving the loss plateau the model memorizes biases in the data. Instead of sampling label permutations uniformly at random, we bias towards a specific permutation by using a fixed permutation for a fraction of each batch. This completely eliminates the loss plateau, enabling a smooth path from memorizing to learning (Figure 8 ). Surprisingly, even when heavily biasing the distribution, memorization is followed by generalization. This biased data distribution can be viewed as a curriculum, solving an easier problem first that enables the subsequent harder learning-to-learn. Further investigation is required to understand how this transition occurs. This may be connected to grokking (Power et al., 2022) which we investigate in Appendix A.7. We hypothesize that many natural data distributionsincluding language-contain such sub-tasks that are easy to memorize followed by generalization. We demonstrated the feasibility of metalearning in-context learning algorithms that are general-purpose. An even more useful learning algorithm would be capable of both generalizing, as well as leveraging domain-specific information for learning when it is available. This would allow for considerably more efficient in-context learning, scaling to more difficult datasets without very long input sequences. Toward this goal, we investigate a simple scheme that leverages pre-trained neural networks as features to learn upon. This could be from an unsupervised learner or a frozen large language model (Radford et al., 2021; Tsimpoukelli et al., 2021) . Here, we first project the inputs xi of a base-dataset D into some latent space using a pre-trained network, and then proceed with meta-training and metatesting as before, randomly projecting these alternative features. For the pre-trained network, we use a ResNet trained on ImageNet and remove its final layer. In Figure 9 we have meta-trained GPICL on MNIST either with the randomly transformed raw inputs or randomly transformed embedded features. At meta-testtime the learning algorithm generalizes to a wide range of datasets, measured by the meta-test accuracy of the 100th example. At the same time, the pre-trained ImageNet helps to accelerate learning on datasets that have a matching domain, such as CIFAR10. We observe that with only 100 examples, the learning algorithm meta-trained on MNIST, can achieve about 45% accuracy on CIFAR10.

5. RELATED WORK

Inductive biases in meta-learning Meta-learning approaches exist with a wide range of inductive biases, usually inspired by existing human-engineered learning algorithms. Some methods prewire the entire learning algorithm (Finn et al., 2017) , pre-wire backpropagation and the structure of a gradient-based optimizer (Andrychowicz et al., 2016; Metz et al., 2019; 2020a) , or hard-code gradient-based optimization but learn the loss function (Houthooft et al., 2018; Kirsch et al., 2019; Bechtle et al., 2021) . Many methods search over hyper-parameters that alter existing learning algorithms (Xu et al., 2018; Metz et al., 2020b; Chen et al., 2022) . Fast weight programmers or hypernetworks update the weights of the same or another neural network (Schmidhuber, 1992; 1993a; Ha et al., 2017; Irie et al., 2021; Sandler et al., 2021; Kirsch & Schmidhuber, 2022; Zhmoginov et al., 2022) . Our work aims to keep such inductive biases to a minimum.

General-purpose meta-learning

There has been growing interest in meta-learning more generalpurpose learning algorithms. The improved generality of the discovered learning algorithm has been achieved by introducing inductive bias, such as by bottlenecking the architecture or by hiding information, encouraging learning over memorization. Methods include enforcing learning rules to use gradients (Metz et al., 2019; Kirsch et al., 2019; Oh et al., 2020) , symbolic graphs (Real et al., 2020; Co-Reyes et al., 2021) , or parameter sharing and symmetries (Kirsch & Schmidhuber, 2020; Kirsch et al., 2021) . Parameter sharing and symmetries have also been discussed in the context of self-organization (Tang & Ha, 2021; Risi, 2021; Pedersen & Risi, 2022) . Black-box meta-learning: MetaRNNs, RL 2 , in-context learning In contrast to these inductive biases, neural networks can also learn-to-learn purely in their activations with little architectural and algorithmic biases (Hochreiter et al., 2001; Wang et al., 2016; Duan et al., 2016; Ortega et al., 2019) . This requires a feedback signal in the inputs that allows for learning such as the reward in reinforcement learning or label in supervised learning (Schmidhuber, 1993b) . While a frequently used architecture is the LSTM (Hochreiter & Schmidhuber, 1997; Gers et al., 2000) , this mechanism has also seen substantial recent attention in Transformer models (Brown et al., 2020; Chan et al., 2022) under the name of in-context learning. We refer to these networks simply as black-box meta learners. Our method GPICL is in the class of these black-box meta learners. In contrast to previous methods, GPICL implements general-purpose learning algorithms. Independently, Garg et al. (2022) recently studied generalization on synthetic functions compared to our augmented datasets. PFNs (Mikulik et al., 2020 ) demonstrated learning to learn on small tabular datasets when metatraining on synthetically generated problems. Experiments on more complex classification settings such as Omniglot relied on fine-tuning. In comparison, our method investigated generalization of learning algorithms directly to datasets such as MNIST, Fashion MNIST, and CIFAR10.

6. DISCUSSION AND CONCLUSION

By generating tasks from existing datasets, we demonstrated that black-box models such as Transformers can be used to meta-learn general-purpose in-context learning algorithms (GPICL). We observed that learning-to-learn arises in the regime of large models and large numbers of tasks with several phase transitions from instance memorization, to system identification, to general learning. The size of the memory or model state significantly determines how well any architecture can learn how to learn across various neural network architectures. We identified difficulties in metaoptimization and proposed interventions in terms of optimizers, hyper-parameters, and a biased data distribution acting as a curriculum. We believe our findings open up new possibilities of data-driven general-purpose meta-learning with minimal inductive bias. A current limitation is the applicability of the discovered learning algorithms to arbitrary input and output sizes. Appropriate tokenization to unified representations may solve this (Chowdhery et al., 2022) . Furthermore, learning algorithms often process millions of inputs before outputting the final model. In the current black-box setting, this is still difficult to achieve. Recurrency-based models usually suffer from accumulating errors, whereas Transformer's computational complexity grows quadratically in the sequence length. When do the found solutions correspond to memorizing vs generalizing solutions? In Figure 2 we observe a fairly discrete transition between memorizing and generalizing solutions as a function of the number of tasks. To analyze this transition, we perform multiple training runs with varying seeds and numbers of tasks in Figure 10 , reporting the final training loss. We find that the distribution is bi-modal. Solutions at the end of training are memorizing or generalizing. Memorization cluster: The larger the number of tasks, the more difficult it is to memorize all of them with a fixed model capacity. Generalization cluster: At a certain number of tasks (here 6 thousand), a transition point is reached where optimization sometimes discovers a lower training loss that corresponds to a generalizing solution. For larger numbers of tasks the solutions always settle in the generalizing cluster.

A.2 WHAT CORRESPONDS TO STATE (MEMORY) IN VARIOUS ARCHITECTURES?

We hypothesize that for learning-to-learn the size of the memory N S at meta-test time (or state more generally) is particularly important in order to be able to store learning progress. We test this by training several architectures with various N S in our meta-learning setting. Memory in the context of recurrent neural networks corresponds to the hidden state or context vector of size N H , thus N S ∈ O(N H ). More generally, we can describe the state as the information bottleneck that the sequence has to pass through before making predictions. In the context of learning-to-learn, this state has to hold information about everything that has been learned so far. Standard learning algorithms such as neural networks trained via SGD would have a state that corresponds to the neural network parameters, iteratively updated via SGD. In transformers, self-attention allows for a particularly large state of N S ∈ O(N K N L N T ) where N K is the size of key, value, and query, N L is the number of layers, and N T is the length of the sequence.

A.3 SUMMARY OF INSIGHTS

Insight 1: It is possible to learn-to-learn with black-box models Effective learning algorithms can be realized using black-box models with few inductive biases, given sufficient meta-training task diversity and large enough model sizes. To transition to the learning-to-learn regime, we needed at least 2 13 = 8192 tasks. Insight 2: Simple data augmentations are effective for general learning-to-learn The generality of the discovered learning algorithm can be controlled via the data distribution. Even when large task distributions are not (yet) naturally available, simple augmentations that promote permutation and scale invariance are effective. Insight 3: The meta-learned behavior has phase transitions When increasing the number of tasks, the meta-learned behavior transitions from instance memorization, to task identification, to general learning-to-learn. The last transition is discrete, with two unique clusters. Insight 4: Large state is more crucial than parameter count We conclude that the specific inductive biases of each architecture matter to a smaller degree. The driving factor behind their ability to learn how to learn is the size of their state. Furthermore, this suggests that the model size in terms of numbers of parameters plays a smaller role in the setting of learning-to-learn and Transformers have benefited in particular from an increase in state size by self-attention. In nonmeta-learning sequence tasks parameter count is thought to be the performance bottleneck (Collins et al., 2016) . Beyond learning-to-learn, this likely applies to other tasks that rely on processing and storing large amounts of sequence-specific information.

A.4 LIMITATIONS

Varying input and output sizes Compared to some previous works in meta-learning (Andrychowicz et al., 2016; Finn et al., 2017; Kirsch & Schmidhuber, 2020) , the discovered learning algorithms are not applicable to an arbitrary input and output size which makes it more difficult to apply the learning algorithm to a new, unseen problem. This problem also applies to Transformers applied to multiple tasks and modalities. Related work has solved this problem by tokenizing inputs to compatible, unified representations (Chowdhery et al., 2022) . We expect these techniques or others to be useful in the learning-to-learn context too. Processing large datasets Learning algorithms often process millions of inputs before outputting the final model. In the black-box setting, this is still difficult to achieve. Recurrency-based models usually suffer from accumulating errors, whereas Transformers computational complexity grows quadratically in the sequence length. Additional work is required to build models capable of processing and being trained on long sequences. Alternatively, parallel processing, similar to batching in learning algorithms, may be a useful building block.

A.5 ARCHITECTURAL DETAILS AND HYPER-PARAMETERS

Transformer details By default, all Transformers have a key, value, and query size of 32, 8 heads, and 4 layers, and model size of N M = 256. The model size defines the dimensionality of each token, and the MLP between layers scales this size up to a hidden representation of 4 × N M where N M corresponds to the model size. Outer-product LSTM We slightly modify an LSTM by replacing the context state with an outerproduct update and inner-product read-out. # i = i n p u t , g = c e l l g a t e , f = f o r g e t g a t e , # q = q u e r y , o = o u t p u t g a t e s i z e s = ( 3 * s i z e , 3 * s i z e , s i z e , s i z e ) i n d i c e s = np . cumsum ( s i z e s [ : -1 ] ) k1 , k2 , q , o = j n p . s p l i t ( g a t e d , i n d i c e s , a x i s =-1) s c a l e = j a x . nn . s o f t p l u s ( hk . g e t p a r a m e t e r ( ' k e y s c a l e ' , s h a p e = ( ) , d t y p e =k1 . d t y p e , i n i t = j n p . z e r o s ) ) i , g , f = j n p . e i n s u m ( ' b h k i , b h k j ->k b h i j ' , j a x . nn . t a n h ( s p l i t a x i s ( k1 , ( 3 , s i z e ) ) ) * s c a l e , j a x . nn . t a n h ( s p l i t a x i s ( k2 , ( 3 , s i z e ) ) ) ) f = j a x . nn . s i g m o i d ( f + 1 ) # F o r g e t b i a s c = f * p r e v s t a t e . c e l l + j a x . nn . s i g m o i d ( i ) * g r e a d = j n p . e i n s u m ( ' b h i j , b h i ->b h j ' , c , q ) h = hk . F l a t t e n ( ) ( j a x . nn . s i g m o i d ( o ) * j n p . t a n h ( r e a d ) ) VSML We use a version of VSML with a single layer and self-messages (Kirsch et al., 2021) of size 8. Each LSTM has a hidden size of 16. For each LSTM update we use two micro-ticks. We train on 2 25 tasks with a 90% biased permutation distribution. The task batch size is 8. All images are scaled to a size of 32 × 32 × 3 VSML without symmetries Before activations are fed to a standard instantiation of VSML, all inputs are projected using a learnable linear projection. Logits are generated using another linear projection, followed by a softmax. We use a version of VSML with a single layer and selfmessages (Kirsch et al., 2021) of size 8. The LSTMs are on a grid of k × k LSTMs, where k ∈ {1, 2, 4, 8, 16, 24}. Each LSTM has a hidden size of 64. For each LSTM update we use two micro-ticks. We train on 2 25 tasks with a 90% biased permutation distribution. The task batch size is 128. All images are scaled to a size of 14 × 14. LSTM For the results in Table 1 , we used a hidden size of 256 and 10 5 optimization steps. Larger hidden sizes were harder to optimize. We train on 2 25 tasks with a 90% biased permutation distribution. The task batch size is 128. All images are scaled to a size of 32 × 32 × 3

A.6 EXPERIMENTAL DETAILS

Most experiments can be run on a single GPU, some require 16 GPUs due to sequence length and large batch sizes, with sufficient GPU memory (around 16 GB each). Some experiments, such as Figure 2 , require up to 1000 runs of that kind to produce the final heat-map. Input normalization Each dataset is z-normalized by its mean and standard deviation across all examples and pixels.

Number of seeds and shading

If not noted otherwise, line plots use 8 seeds for meta-training and at least 512 seeds for meta-testing. Shading indicates 95% confidence intervals. Figure 2 The MLP has two hidden layers of varying size with relu activations. The Transformer has the default parameters as defined above. Figure 3 We use a transformer model with a model size of 256. We train on 2 25 tasks with a 90% biased permutation distribution. The task batch size is 128. All images are scaled to a size of 32 × 32 × 3 Inputs are z-normalized across the dataset and all input dimensions. Table 1 The SGD baseline was obtained by sweeping over learning rates from 10 -4 to 0.5, optimizers SGD, Adam and Adam with weight decay, one or two layers, and hidden sizes of 32, 64, or 128 on MNIST. The best configuration (most sample efficient) corresponds to a learning rate of 10 -3 , Adam, and no hidden layers. SGD performs updates online on each one out of the 100 data points. MAML is equivalent to SGD up to the difference that we meta-train the weight initialization according to Equation 2 where θ are the initial parameters of the classifier that is then updated using SGD at meta-test time. All black-box approaches do not use gradient descent at meta-test time. All meta-learning approaches where meta-trained and tuned via grid search on MNIST. 

A.7 ADDITIONAL EXPERIMENTS

Sequence length In all experiments of the main paper we have meta-trained on a sequence length (number of examples) of 100. This is a small training dataset compared to many human-engineered learning algorithms. In general, as long as the learning algorithm does not overfit the training data, more examples should increase the predictive performance. In Figure 11 we investigate how our model scales to longer sequence lengths. We observe that the final accuracy of the last query in the sequence consistently increases with longer sequences. The generalization to longer sequences than those seen during meta-training is another important direction for future work. Gradient and update statistics To better understand the properties of the loss plateau, we visualize different statistics of the gradients, optimizer, and updates. In Figure 12 , we track the exponential moving average statistics of Adam before the loss plateau and after (dashed vertical line). In Figure 13 we investigate how gradients differ between settings with a plateau and settings with a biased distribution where the plateau is avoided. We plot the cosine similarity between consecutive optimization steps, the gradient L2-norm, and the similarity and norm of the weight updates after normalization with Adam. The statistics are plotted cumulatively or smoothed with a Gaussian filter for better readability. The gradient and update cosine similarity differ only marginally between cases with a plateau and cases without. We observe that the gradient L2-norm in the plateau is half as big as in the biased distribution case, although the updates that Adam applies are going towards zero. This also results in not moving far from parameter initialization when in the plateau. We hypothesize this has to do with varying gradient norms when looking at individual parameter tensors (Figure 14 ). We observe that the gradients have a small norm for most tensors, except for the last layer. Batch size and number of tasks influence on plateau length Instead of looking at the plateau length in terms of the number of steps (Figure 7 ), we may also be concerned with the total number of tasks seen within the plateau. This is relevant in particular when the task batch is not processed fully in parallel but gradients are accumulated. Figure 15 shows the same figure but with the number of tasks in the plateau on the y-axis instead. It can be observed that larger batch-sizes actually increase the data requirement to leave the plateau, despite decreasing the plateau in terms of the number of optimization steps. Similarly, a larger task training distribution requires a larger number of tasks to be seen within the plateau. Adjusting Adam's or changing the optimizer As discussed in the main paper and visualized in Figure 16b , decreasing significantly shortens the plateau. This is due to the rescaling of very small gradient magnitudes being limited by . At the same time it incurs some instability. Directly normalizing the gradient by applying the sign function element-wise (Figure 16a ) to the exponential gradient average shortens the plateau even further. When memorization happens, can we elicit grokking? In Figure 7a we have seen that an insufficiently large task distribution can lead to memorization instead of general learning-to-learn. At the same time, Figure 8 showed that biasing the data distribution is helpful to avoid loss plateaus. Power et al. (2022) observed a phenomenon which they called "grokking" in which even after having converged in terms of training loss, test loss may suddenly decrease. Large amounts of regularization, like weight decay with a coefficient of 1.0 were found to facilitate this behavior. Is grokking Figure 15: Instead of plotting the loss plateau length in terms of optimization steps, we look at the total number of tasks seen within the plateau as a function of the task batch size and the number of tasks in the training distribution. An increase in the task batch size leads to more tasks to be processed to leave the plateau. 2022) can be produced when we observe memorization on a smaller numbers of tasks. This would correspond to the test loss decreasing long after the training loss has converged. We have not been able to elicit this behavior when looking at different numbers of tasks and weight decay coefficients. connected to the optimization behavior we observe, and if so, do similar interventions help in our setting? We look in particular at the boundary of memorization and generalization (2 14 = 16384) where doubling the number of tasks a few more times would lead to generalization. Figure 17 shows three task settings, 2 10 , 2 14 , 2 16 , and three different weight decay coefficients, 0.01, 0.1, 1.0. The setting of 2 16 tasks shows generalization by default and only serves as a baseline for the weight decay coefficient analysis. In the cases of memorization due to too few tasks, we have not been able to produce grokking behavior. Optimization difficulties in VSML Previous work has observed several optimization difficulties: Slower convergence, local minima, unstable training, or loss plateaus at the beginning of training. Figure 18 shows some of these difficulties in the context of VSML (Kirsch & Schmidhuber, 2020) . Because VSML has permutation invariance built into the architecture as an inductive bias, changing the number of tasks has only a small effect. We observe that in particular deeper architectures make meta-optimization more difficult. 



GENERAL-PURPOSE IN-CONTEXT LEARNING WITH TRANSFORMERSDue to the small number of inductive biases in black-box models, we can only expect (meta-



Figure 2: GPICL is able to generalize to unseen tasks. Each cell is a separate meta-training run. (a) An MLP classifier trained in a multi-task fashion across various numbers of tasks (generated based on MNIST) and network sizes is able to fit linearly more tasks, the larger its capacity. (b)A sequence model (here the GPICL Transformer) that observes a dataset D of inputs and labels transitions into generalizing to an seemingly unbounded number of tasks with an increase in model size. This is achieved by switching from a memorization solution to a learning solution that (c) generalizes to unseen tasks. This generalization does not occur with the MLP.

Figure 4: Transformers exhibit three different phases in terms of meta-learned behavior. (1) When training on a small number of tasks, specific instances are memorized.(2) Tasks are memorized, which is evident as a within-sequence increase of performance. (3) When training across many tasks, we discover a learning algorithm that generalizes to unseen tasks and unseen datasets.

Figure5: The state size (accessible memory) of an architecture most strongly predicts its performance as a general-purpose learning algorithm. (a) A large state is crucial for learningto-learn to emerge. (b) The parameter count correlates less well with learning capabilities.

Figure 6: Meta-training dynamics often involve an extended period where GPICL's performance is stuck on a plateau. (a) Meta-loss vs. meta-training step, for a uniform distribution over meta-training tasks. Training tasks are generated by random transformations of FashionMNIST. (b) A zoomed in view of the plateau. The loss only decreases slightly and the model memorize small biases in the training data (decreasing generalization) before the loss drops sharply.

Figure 7: Whether GPICL memorizes, generalizes, or remains trapped on a meta-loss plateau depends on the number of meta-training tasks, and the meta-training batch size. (a) A phase diagram showing GPICL's behavior at the end of meta-training (50k steps). Solutions either memorize, generalize and learn, or remain in the loss plateau. With additional training steps, configurations in the plateau might eventually transition to memorization or generalization. Generalization only occurs with large enough batch sizes and sufficient, but not too many, tasks. (b) This behavior is explained by the plateau length decreasing with the increasing batch sizes (reducing the noise contribution), and (c) increasing with larger numbers of tasks.

Figure 8: Biasing the training distribution is an effective intervention which prevents a metaloss plateau. A uniform distribution over tasks leads to a long plateau (d), while increasing the training fraction that corresponds to a single task reduces the plateau (abc).

Figure 9: Using pre-trained networks allows leveraging domain-specific knowledge while still generalizing to other datasets GPICL is meta-trained on MNIST either with the randomly transformed raw inputs or randomly transformed pre-trained features. Pre-training helps to accelerate meta-test-time in-context learning on datasets that have a matching domain, such as CIFAR10. With only 100 examples, the learning algorithm can achieve about 45% accuracy on CIFAR10. The learning algorithms still generalize to a wide range of datasets. Error bars are 95% confidence intervals of the mean across meta-training runs.

Figure 10: Solutions found by GPICL after meta-training are bi-modal, with a memorization and generalization mode. Each point represents the training loss at the end of meta-training for runs with different seeds and for various numbers of tasks that include the transition boundary previously observed. Almost all solutions are either in a memorization cluster or in a generalization cluster.

x a n d h = j n p . c o n c a t e n a t e ( [ i n p u t s , p r e v s t a t e . h i d d e n ] , a x i s =-1) g a t e d = hk . L i n e a r ( 8 * s i z e * s e l f . num heads ) ( x a n d h ) g a t e d = g a t e d . r e s h a p e ( ( b a t c h s i z e , s e l f . num heads , 8 * s i z e ) ) g a t e d = c h e c k p o i n t n a m e ( g a t e d , ' g a t e d ' )

Figure10We trained a Transformer with model size 64 and 32 seeds for each number-of-tasksconfiguration.

Figure 4 Input normalization is disabled.

Figure5The Transformer uses a task batch size of 512.

Figure6Trained on 2 16 tasks generated from FashionMNIST with labels fully permuted.

Figure7Trained on 2 16 tasks generated from FashionMNIST with labels fully permuted.

Figure8Trained on 2 16 tasks generated from FashionMNIST with label permutations varied.

Figure 11: Increasing the sequence length during meta-training and meta-testing improves the predictive performance of the final query in the sequence. Error bars indicate 95% confidence intervals.

Figure 13: Gradient and Adam update statistics for differently biased data distributions. (a) Plateaus in the loss are influenced by the bias in the data distribution. Plateaus result in moving away slowly from the parameter initialization. (b) The cosine similarity of both gradients and updates in consecutive steps is only marginally different with or without a loss plateau. (c) While the gradient norm is about half as big when a plateau exists, the updates are going towards zero.

Figure 16: (a) When replacing Adam with a sign normalization of the gradient or (b) reducing the plateau length is significantly shorter.

Figure17: We investigate whether grokking as defined inPower et al. (2022) can be produced when we observe memorization on a smaller numbers of tasks. This would correspond to the test loss decreasing long after the training loss has converged. We have not been able to elicit this behavior when looking at different numbers of tasks and weight decay coefficients.

Figure 18: Loss plateaus and slow convergence with deeper variants of VSML.

Meta-test generalization to various datasets after meta-training on augmented MNIST and seeing 99 examples, predicting the 100th. We report the mean across 3 meta-training seeds, 16 sequences from each task, 16 tasks sampled from each base dataset.

Figure 14: Gradient L2 norms (left) and gradient cosine similarity for consecutive optimization steps (right) for different parameter tensors. The last (output) layer has the largest gradients. Most other gradients are small.

