TRANSFORMERS ARE SAMPLE-EFFICIENT WORLD MODELS

Abstract

Deep reinforcement learning agents are notoriously sample inefficient, which considerably limits their application to real-world problems. Recently, many model-based methods have been designed to address this issue, with learning in the imagination of a world model being one of the most prominent approaches. However, while virtually unlimited interaction with a simulated environment sounds appealing, the world model has to be accurate over extended periods of time. Motivated by the success of Transformers in sequence modeling tasks, we introduce IRIS, a data-efficient agent that learns in a world model composed of a discrete autoencoder and an autoregressive Transformer. With the equivalent of only two hours of gameplay in the Atari 100k benchmark, IRIS achieves a mean human normalized score of 1.046, and outperforms humans on 10 out of 26 games, setting a new state of the art for methods without lookahead search. To foster future research on Transformers and world models for sample-efficient reinforcement learning, we release our code and models at https://github.com/eloialonso/iris.

1. INTRODUCTION

Deep Reinforcement Learning (RL) has become the dominant paradigm for developing competent agents in challenging environments. Most notably, deep RL algorithms have achieved impressive performance in a multitude of arcade (Mnih et al., 2015; Schrittwieser et al., 2020; Hafner et al., 2021) , real-time strategy (Vinyals et al., 2019; Berner et al., 2019) , board (Silver et al., 2016; 2018; Schrittwieser et al., 2020) and imperfect information (Schmid et al., 2021; Brown et al., 2020a) games. However, a common drawback of these methods is their extremely low sample efficiency. Indeed, experience requirements range from months of gameplay for DreamerV2 (Hafner et al., 2021) in Atari 2600 games (Bellemare et al., 2013b) to thousands of years for OpenAI Five in Dota2 (Berner et al., 2019) . While some environments can be sped up for training agents, real-world applications often cannot. Besides, additional cost or safety considerations related to the number of environmental interactions may arise (Yampolskiy, 2018) . Hence, sample efficiency is a necessary condition to bridge the gap between research and the deployment of deep RL agents in the wild. Model-based methods (Sutton & Barto, 2018) constitute a promising direction towards data efficiency. Recently, world models were leveraged in several ways: pure representation learning (Schwarzer et al., 2021) , lookahead search (Schrittwieser et al., 2020; Ye et al., 2021) , and learning in imagination (Ha & Schmidhuber, 2018; Kaiser et al., 2020; Hafner et al., 2020; 2021) . The latter approach is particularly appealing because training an agent inside a world model frees it from sample efficiency constraints. Nevertheless, this framework relies heavily on accurate world models since the policy is purely trained in imagination. In a pioneering work, Ha & Schmidhuber (2018) successfully built imagination-based agents in toy environments. SimPLe recently showed promise in the more challenging Atari 100k benchmark (Kaiser et al., 2020) . Currently, the best Atari agent learning in imagination is DreamerV2 (Hafner et al., 2021) , although it was developed and evaluated with two hundred million frames available, far from the sample-efficient regime. Therefore, designing new world model architectures, capable of handling visually complex and partially observable environments with few samples, is key to realize their potential as surrogate training grounds. The Transformer architecture (Vaswani et al., 2017) is now ubiquitous in Natural Language Processing (Devlin et al., 2019; Radford et al., 2019; Brown et al., 2020b; Raffel et al., 2020) , and is also gaining traction in Computer Vision (Dosovitskiy et al., 2021; He et al., 2022) , as well as in Offline The green arrows correspond to the encoder E and the decoder D of a discrete autoencoder, whose task is to represent frames in its learnt symbolic language. The backbone G of the world model is a GPT-like Transformer, illustrated with blue arrows. For each action that the policy π takes, G simulates the environment dynamics, by autoregressively unfolding new frame tokens that D can decode. G also predicts a reward and a potential episode termination. More specifically, an initial frame x 0 is encoded with E into tokens z 0 = (z 1 0 , . . . , z K 0 ) = E(x 0 ). The decoder D reconstructs an image x0 = D(z 0 ), from which the policy π predicts the action a 0 . From z 0 and a 0 , G predicts the reward r0 , episode termination d0 ∈ {0, 1}, and in an autoregressive manner ẑ1 = (ẑ 1 1 , . . . , ẑK 1 ), the tokens for the next frame. A dashed box indicates image tokens for a given time step, whereas a solid box represents the input sequence of G, i.e. (z 0 , a 0 ) at t = 0, (z 0 , a 0 , ẑ1 , a 1 ) at t = 1, etc. The policy π is purely trained with imagined trajectories, and is only deployed in the real environment to improve the world model (E, D, G). Reinforcement Learning (Janner et al., 2021; Chen et al., 2021) . In particular, the GPT (Radford et al., 2018; 2019; Brown et al., 2020b) family of models delivered impressive results in language understanding tasks. Similarly to world models, these attention-based models are trained with highdimensional signals and a self-supervised learning objective, thus constituting ideal candidates to simulate an environment. Transformers particularly shine when they operate over sequences of discrete tokens (Devlin et al., 2019; Brown et al., 2020b) . For textual data, there are simple ways (Schuster & Nakajima, 2012; Kudo & Richardson, 2018) to build a vocabulary, but this conversion is not straightforward with images. A naive approach would consist in treating pixels as image tokens, but standard Transformer architectures scale quadratically with sequence length, making this idea computationally intractable. To address this issue, VQGAN (Esser et al., 2021) and DALL-E (Ramesh et al., 2021) employ a discrete autoencoder (Van Den Oord et al., 2017) as a mapping from raw pixels to a much smaller amount of image tokens. Combined with an autoregressive Transformer, these methods demonstrate strong unconditional and conditional image generation capabilities. Such results suggest a new approach to design world models. In the present work, we introduce IRIS (Imagination with auto-Regression over an Inner Speech), an agent trained in the imagination of a world model composed of a discrete autoencoder and an autoregressive Transformer. IRIS learns behaviors by accurately simulating millions of trajectories. Our approach casts dynamics learning as a sequence modeling problem, where an autoencoder builds a language of image tokens and a Transformer composes that language over time. With minimal tuning, IRIS outperforms a line of recent methods (Kaiser et al., 2020; Hessel et al., 2018; Laskin et al., 2020; Yarats et al., 2021; Schwarzer et al., 2021) for sample-efficient RL in the Atari 100k benchmark (Kaiser et al., 2020) . After only two hours of real-time experience, it achieves a mean human normalized score of 1.046, and reaches superhuman performance on 10 out of 26 games. We describe IRIS in Section 2 and present our results in Section 3.



Figure1: Unrolling imagination over time. This figure shows the policy π, depicted with purple arrows, taking a sequence of actions in imagination. The green arrows correspond to the encoder E and the decoder D of a discrete autoencoder, whose task is to represent frames in its learnt symbolic language. The backbone G of the world model is a GPT-like Transformer, illustrated with blue arrows. For each action that the policy π takes, G simulates the environment dynamics, by autoregressively unfolding new frame tokens that D can decode. G also predicts a reward and a potential episode termination. More specifically, an initial frame x 0 is encoded with E into tokens z 0 = (z 1 0 , . . . , z K 0 ) = E(x 0 ). The decoder D reconstructs an image x0 = D(z 0 ), from which the policy π predicts the action a 0 . From z 0 and a 0 , G predicts the reward r0 , episode termination d0 ∈ {0, 1}, and in an autoregressive manner ẑ1 = (ẑ 1 1 , . . . , ẑK 1 ), the tokens for the next frame. A dashed box indicates image tokens for a given time step, whereas a solid box represents the input sequence of G, i.e. (z 0 , a 0 ) at t = 0, (z 0 , a 0 , ẑ1 , a 1 ) at t = 1, etc. The policy π is purely trained with imagined trajectories, and is only deployed in the real environment to improve the world model (E, D, G).

