TRANSFORMER-BASED WORLD MODELS ARE HAPPY WITH 100K INTERACTIONS

Abstract

Deep neural networks have been successful in many reinforcement learning settings. However, compared to human learners they are overly data hungry. To build a sample-efficient world model, we apply a transformer to real-world episodes in an autoregressive manner: not only the compact latent states and the taken actions but also the experienced or predicted rewards are fed into the transformer, so that it can attend flexibly to all three modalities at different time steps. The transformer allows our world model to access previous states directly, instead of viewing them through a compressed recurrent state. By utilizing the Transformer-XL architecture, it is able to learn long-term dependencies while staying computationally efficient. Our transformer-based world model (TWM) generates meaningful, new experience, which is used to train a policy that outperforms previous model-free and model-based reinforcement learning algorithms on the Atari 100k benchmark.

1. INTRODUCTION

Deep reinforcement learning methods have shown great success on many challenging decision making problems. Notable methods include DQN (Mnih et al., 2015) , PPO (Schulman et al., 2017), and MuZero (Schrittwieser et al., 2019) . However, most algorithms require hundreds of millions of interactions with the environment, whereas humans often can achieve similar results with less than 1% of these interactions, i.e., they are more sample-efficient. The large amount of data that is necessary renders a lot of potential real world applications of reinforcement learning impossible. Recent works have made a lot of progress in advancing the sample efficiency of RL algorithms: model-free methods have been improved with auxiliary objectives (Laskin et al., 2020b) , data augmentation (Yarats et al., 2021 , Laskin et al., 2020a ), or both (Schwarzer et al., 2021) . Model-based methods have been successfully applied to complex image-based environments and have either been used for planning, such as EfficientZero (Ye et al., 2021) , or for learning behaviors in imagination, such as SimPLe (Kaiser et al., 2020) . A promising model-based concept is learning in imagination (Ha & Schmidhuber, 2018; Kaiser et al., 2020; Hafner et al., 2020; Hafner et al., 2021) : instead of learning behaviors from the collected experience directly, a generative model of the environment dynamics is learned in a (self-)supervised manner. Such a so-called world model can create new trajectories by iteratively predicting the next state and reward. This allows for potentially indefinite training data for the reinforcement learning algorithm without further interaction with the real environment. A world model might be able to generalize to new, unseen situations, because of the nature of deep neural networks, which has the potential to drastically increase the sample efficiency. This can be illustrated by a simple example: in the game of Pong, the paddles and the ball move independently. In the best case, a successfully trained world model would imagine trajectories with paddle and ball configurations that have never been observed before, which enables learning of improved behaviors. o t- z t- ôt- a t- h t- r t- • • • • • • o t-1 z t-1 ôt-1 a t-1 h t-1 r t- In Our contributions: The contributions of this work can be summarized as follows: 1. We present a new autoregressive world model based on the Transformer-XL (Dai et al., 2019) architecture and a model-free agent trained in latent imagination. Running our policy is computationally efficient, as the transformer is not needed at inference time. This is in contrast to related works (Hafner et al., 2020; 2021; Chen et al., 2022) that require the full world model during inference. 2. Our world model is provided with information on how much reward has already been emitted by feeding back predicted rewards into the world model. As shown in our ablation study, this improves performance. 3. We rewrite the balanced KL divergence loss of Hafner et al. (2021) to allow us to fine-tune the relative weight of the involved entropy and cross-entropy terms. 4. We introduce a new thresholded entropy loss that stabilizes the policy's entropy during training and hereby simplifies the selection of hyperparameters that behave well across different games. 5. We propose a new effective sampling procedure for the growing dataset of experience, which balances the training distribution to shift the focus towards the latest experience. We demonstrate the efficacy of this procedure with an ablation study. 6. We compare our transformer-based world model (TWM) on the Atari 100k benchmark with recent sample-efficient methods and obtain excellent results. Moreover, we report empirical confidence intervals of the aggregate metrics as suggested by Agarwal et al. (2021) . ). The goal is to find a policy π that maximizes the expected sum of discounted rewards E π ∞ t=1 γ t-1 r t , where γ ∈ [0, 1) is the discount factor. Learning in imagination consists of three steps that are repeated iteratively: learning the dynamics, learning a policy, and interacting in the real environment. In this section, we describe our world model and policy, concluding with the training procedure.



Our world model architecture. Observations o t-:t are encoded using a CNN. Linear embeddings of stochastic, discrete latent states z t-:t , actions a t-:t , and rewards r t-:t are fed into a transformer, which computes a deterministic hidden state h t at each time step. Predictions of the reward r

this paper, we propose to model the world with transformers(Vaswani et al., 2017), which have significantly advanced the field of natural language processing and have been successfully applied to computer vision tasks(Dosovitskiy et al., 2021). A transformer is a sequence model consisting of multiple self-attention layers with residual connections. In each self-attention layer the inputs are mapped to keys, queries, and values. The outputs are computed by weighting the values by the similarity of keys and queries. Combined with causal masking, which prevents the self-attention layers from accessing future time steps in the training sequence, transformers can be used as autoregressive generative models. The Transformer-XL architecture(Dai et al., 2019)  is much more computationally efficient than vanilla transformers at inference time and introduces relative positional encodings, which remove the dependence on absolute time steps.

We consider a partially observable Markov decision process (POMDP) with discrete time steps t ∈ N, scalar rewards r t ∈ R, high-dimensional image observations o t ∈ R h×w×c , and discrete actions a t ∈ {1, . . . , m}, which are generated by some policy a t ∼ π(a t | o 1:t , a 1:t-1 ), where o 1:t and a 1:t-1 denote the sequences of observations and actions up to time steps t and t -1, respectively. Episode ends are indicated by a boolean variable d t ∈ {0, 1}. Observations, rewards, and episode ends are jointly generated by the unknown environment dynamics o t , r t , d t ∼ p(o t , r t , d

availability

Our code is available at https://github.com/jrobine/twm.

