LEARNING BY DISTILLING CONTEXT

Abstract

Language models significantly benefit from context tokens, such as prompts or scratchpads. They perform better when prompted with informative instructions, and they acquire new reasoning capabilities by generating a scratch-pad before predicting the final answers. However, they do not internalize these performance gains, which disappear when the context tokens are gone. Our work proposes to apply context distillation so that a language model can improve itself by internalizing these gains. Concretely, given a synthetic unlabeled input for the target task, we condition the model on "[instructions] + [task-input]" to predict "[scratch-pad] + [final answer]"; then we fine-tune the same model to predict its own "[final answer]" conditioned on the "[task-input]", without seeing the "[instructions]" or using the "[scratch-pad]". We show that context distillation is a general method to train language models, and it can effectively internalize 3 types of training signals. First, it can internalize abstract task instructions along with explanations, so we can recursively update the model parameters with new instructions and overwrite old ones. Second, it can internalize concrete training examples, and it outperforms directly learning with gradient descent by 9% on the SPIDER Text-to-SQL dataset; furthermore, combining multiple context distillation operations can internalize more training examples than what the context window size allows. Finally, we show preliminary evidence that it can internalize step-by-step reasoning on 8-digit addition, and such a newly acquired capability proves to be useful for other downstream tasks.

1. INTRODUCTION

Recent work has shown that language models significantly benefit from context tokens. When prompted with task definitions, language models can perform zero-shot learning (Wei et al., 2022a; Sanh et al., 2022) , and the performance further improves with additional in-context examples and explanations (Chen et al., 2022; Scheurer et al., 2022) . They also acquire the capability to perform more complex tasks by generating step-by-step reasoning in the context window before predicting the final answer (Nye et al., 2021b; Wei et al., 2022b; Zhou et al., 2022) . However, language models cannot internalize these performance gains, which disappear when the context tokens are gone. Consequently, we always need to pay extra computation for running inference on context tokens; this is undesirable, as sometimes the task instructions and the scratch-pad can be more than 10x longer than the actual task inputs. Furthermore, it is unclear how to leverage the context tokens when their total length exceeds the context window size. These shortcomings are analogous to how humans are slow at performing complex cognitive tasks (Wason & Evans, 1974) and can hold only a limited amount of information in the working memory (Baddeley, 1992) . Humans get around this by practicing. Consider, for example, learning to type your friends' phone numbers. The first few times you type it, you need to consciously recall the number using working memory and slowly decide which button to press. After repeatedly typing the same number, it becomes a habit and you can type the number quickly without conscious reasoning. Through repeated practice, the knowledge of your friend's phone number is "distilled" into your muscle memories.foot_0 This mechanism for distilling knowledge is critical for learning complex tasks because it allows us Top. An overview of our context distillation framework. We sample a raw task input, form the teacher's prompt by pre-prending a detailed instruction that might contain more examples and explanations, and ask the language model to conditionally sample a scratch-pad and a final answer. Then we fine-tune the same language model to directly predict the final answer with a minimal instruction. We formalize this framework mathematically in Section 2. Bottom. An instantiation of our framework that internalizes step-by-step reasoning for 8-digit addition. to incrementally build up our knowledge and skills, so that we can learn to accomplish increasingly complex tasks. We propose to apply a similar method, context distillation, to fine-tune language models. For example, as shown in Figure 1 , to make language models internalize the step-by-step addition capability, we first synthesize a large number of "practice" addition questions; we then ask the model to follow the more informative instruction to reason step-by-step before generating the target answer; finally, we fine-tune the language model to directly predict the answer conditioned on a simpler student prompt. As a result, by practicing on a lot of addition problems, the ability to add is distilled into its parameters. We formally state our generalized context distillation framework in Section 2. Section 3 shows that we can apply context distillation to a wide range of settings: learning from abstract statements, learning from concrete examples, and learning from step-by-step reasoning. Section 3.1 (Figure 2 ) shows that context distillation can effectively internalize task instructions along with natural language explanations from Natural-Instructions-V2 (Wang et al., 2022b) ; additionally, we can teach the student to associate numerical indices with certain tasks, and then we can recursively re-assign these task indices, overwriting the student's past associations. Section 3.2 (Figure 3 ) shows that context distillation can be used to internalize Text-to-SQL training examples from the SPIDER dataset (Yu et al., 2018 ) into Incoder (Fried et al., 2022) , and it outperforms directly learning with gradient descent by 9% for 8-shot adaptation; additionally, we show that as we distill more training examples than can fit in the context window, we observe continual improvements in performance. Section 3.3 (Figure 3 ) shows that we can internalize step-by-step reasoning to perform 8-digit addition, and such a capability can transfer to downstream question answering tasks; we hope this preliminary results can generalize to more complex and realistic tasks that larger models can perform with chain-of-thoughts reasoning (Wei et al., 2022b; Zhou et al., 2022) . Overall, context distillation demonstrates promising potential as a general method for learning. As discussed in Section 4, we predict that future models will be better able to learn from context than today's models, and researchers will use these models to tackle increasingly complex tasks that



See declarative learning vs. procedural learning for a friendly but more in-depth discussion. https: //en.wikipedia.org/wiki/Declarative_learning



Figure1: Top. An overview of our context distillation framework. We sample a raw task input, form the teacher's prompt by pre-prending a detailed instruction that might contain more examples and explanations, and ask the language model to conditionally sample a scratch-pad and a final answer. Then we fine-tune the same language model to directly predict the final answer with a minimal instruction. We formalize this framework mathematically in Section 2. Bottom. An instantiation of our framework that internalizes step-by-step reasoning for 8-digit addition.

