TRANSFORMER-BASED WORLD MODELS ARE HAPPY WITH 100K INTERACTIONS

Abstract

Deep neural networks have been successful in many reinforcement learning settings. However, compared to human learners they are overly data hungry. To build a sample-efficient world model, we apply a transformer to real-world episodes in an autoregressive manner: not only the compact latent states and the taken actions but also the experienced or predicted rewards are fed into the transformer, so that it can attend flexibly to all three modalities at different time steps. The transformer allows our world model to access previous states directly, instead of viewing them through a compressed recurrent state. By utilizing the Transformer-XL architecture, it is able to learn long-term dependencies while staying computationally efficient. Our transformer-based world model (TWM) generates meaningful, new experience, which is used to train a policy that outperforms previous model-free and model-based reinforcement learning algorithms on the Atari 100k benchmark.

1. INTRODUCTION

Deep reinforcement learning methods have shown great success on many challenging decision making problems. Notable methods include DQN (Mnih et al., 2015) , PPO (Schulman et al., 2017), and MuZero (Schrittwieser et al., 2019) . However, most algorithms require hundreds of millions of interactions with the environment, whereas humans often can achieve similar results with less than 1% of these interactions, i.e., they are more sample-efficient. The large amount of data that is necessary renders a lot of potential real world applications of reinforcement learning impossible. Recent works have made a lot of progress in advancing the sample efficiency of RL algorithms: model-free methods have been improved with auxiliary objectives (Laskin et al., 2020b) , data augmentation (Yarats et al., 2021 , Laskin et al., 2020a) , or both (Schwarzer et al., 2021) . Model-based methods have been successfully applied to complex image-based environments and have either been used for planning, such as EfficientZero (Ye et al., 2021) , or for learning behaviors in imagination, such as SimPLe (Kaiser et al., 2020) .  o t- z t- ôt- a t- h t- r t- • • • • • • o t-1 z t-1 ôt-1 a t-1 h t-1 r t-



Our world model architecture. Observations o t-:t are encoded using a CNN. Linear embeddings of stochastic, discrete latent states z t-:t , actions a t-:t , and rewards r t-:t are fed into a transformer, which computes a deterministic hidden state h t at each time step. Predictions of the reward r t , discount factor γ t , and next latent state z t+1 are computed based on h t using MLPs.

availability

Our code is available at https://github.com/jrobine/twm.

