IN-CONTEXT REINFORCEMENT LEARNING WITH ALGORITHM DISTILLATION

Abstract

We propose Algorithm Distillation (AD), a method for distilling reinforcement learning (RL) algorithms into neural networks by modeling their training histories with a causal sequence model. Algorithm Distillation treats learning to reinforcement learn as an across-episode sequential prediction problem. A dataset of learning histories is generated by a source RL algorithm, and then a causal transformer is trained by autoregressively predicting actions given their preceding learning histories as context. Unlike sequential policy prediction architectures that distill post-learning or expert sequences, AD is able to improve its policy entirely in-context without updating its network parameters. We demonstrate that AD can reinforcement learn in-context in a variety of environments with sparse rewards, combinatorial task structure, and pixel-based observations, and find that AD learns a more data-efficient RL algorithm than the one that generated the source data.



(ii) a causal transformer predicts actions from these histories using across-episodic contexts. Since the RL policy improves throughout the learning histories, by predicting actions accurately AD learns to output an improved policy relative to the one seen in its context. AD models state-action-reward tokens, and does not condition on returns.

1. INTRODUCTION

Transformers have emerged as powerful neural network architectures for sequence modeling (Vaswani et al., 2017) . A striking property of pre-trained transformers is their ability to adapt to downstream tasks through prompt conditioning or in-context learning. After pre-training on large offline datasets, large transformers have been shown to generalize to downstream tasks in text completion (Brown et al., 2020) , language understanding (Devlin et al., 2018) , and image generation (Yu et al., 2022) . Recent work demonstrated that transformers can also learn policies from offline data by treating offline Reinforcement Learning (RL) as a sequential prediction problem. While Chen et al. ( 2021) showed that transformers can learn single-task policies from offline RL data via imitation learning, subsequent work showed that transformers can also extract multi-task policies in both same-domain (Lee et al., 2022) and cross-domain settings (Reed et al., 2022) . These works suggest a promising paradigm for extracting generalist multi-task policies -first collect a large and diverse dataset of environment interactions, then extract a policy from the data via sequential modeling. We refer to the family of approaches that learns policies from offline RL data via imitation learning as Offline Policy Distillation, or simply Policy Distillationfoot_1 (PD). Despite its simplicity and scalability, a substantial drawback of PD is that the resulting policy does not improve incrementally from additional interaction with the environment. For instance, the Multi- We hypothesize that the reason Policy Distillation does not improve through trial and error is that it trains on data that does not show learning progress. Current methods either learn policies from data that contains no learning (e.g. by distilling fixed expert policies) or data with learning (e.g. the replay buffer of an RL agent) but with a context size that is too small to capture policy improvement. Our key observation is that the sequential nature of learning within RL algorithm training could, in principle, make it possible to model the process of reinforcement learning itself as a causal sequence prediction problem. Specifically, if a transformer's context is long enough to include policy improvement due to learning updates it should be able to represent not only a fixed policy but a policy improvement operator by attending to states, actions and rewards from previous episodes. This opens the possibility that any RL algorithm can be distilled into a sufficiently powerful sequence model such as a transformer via imitation learning, converting it into an in-context RL algorithm. By in-context RL we mean that the transformer should improve its policy through trial and error within the environment by attending to its context, without updating its parameters. We present Algorithm Distillation (AD), a method that learns an in-context policy improvement operator by optimizing a causal sequence prediction loss on the learning histories of an RL algorithm. AD has two components. First, a large multi-task dataset is generated by saving the training histories of an RL algorithm on many individual tasks. Next, a transformer models actions causally using the preceding learning history as its context. Since the policy improves throughout the course of training of the source RL algorithm, AD is forced to learn the improvement operator in order to accurately model the actions at any given point in the training history. Crucially, the transformer context size must be sufficiently large (i.e. across-episodic) to capture improvement in the training data. The full method is shown in Fig. 1 . We show that by imitating gradient-based RL algorithms using a causal transformer with sufficiently large contexts, AD can reinforcement learn new tasks entirely in-context. We evaluate AD across a number of partially observed environments that require exploration, including the pixel-based Watermaze (Morris, 1981) from DMLab (Beattie et al., 2016) . We show that AD is capable of in-context exploration, temporal credit assignment, and generalization. We also show that AD learns a more data-efficient algorithm than the one that generated the source data for transformer training. To the best of our knowledge, AD is the first method to demonstrate in-context reinforcement learning via sequential modeling of offline data with an imitation loss.

2. BACKGROUND

Partially Observable Markov Decision Processes: A Markov Decision Process (MDP) consists of states s ∈ S, actions a ∈ A, rewards r ∈ R, a discount factor γ, and a transition probability function p(s t+1 |s t , a t ), where t is an integer denoting the timestep and (S, A) are state and action spaces. In



Private & Confidential What we refer to as Policy Distillation is similar to Rusu et al. (2016) but the policy is distilled from offline data, not a teacher network.



Figure 1: Algorithm Distillation (AD) has two steps -(i) a dataset of learning histories is collected from individual single-task RL algorithms solving different tasks;(ii) a causal transformer predicts actions from these histories using across-episodic contexts. Since the RL policy improves throughout the learning histories, by predicting actions accurately AD learns to output an improved policy relative to the one seen in its context. AD models state-action-reward tokens, and does not condition on returns.

Game Decision Transformer (MGDT, Lee et al., 2022)  learns a return-conditioned policy that plays many Atari games while Gato(Reed et al., 2022) learns a policy that solves tasks across diverse environments by inferring tasks through context, but neither method can improve its policy in-context through trial and error. MGDT adapts the transformer to new tasks by finetuning the model weights while Gato requires prompting with an expert demonstration to adapt to a new task. In short, Policy Distillation methods learn policies but not Reinforcement Learning algorithms.

