TRANSFORMER-BASED WORLD MODELS ARE HAPPY WITH 100K INTERACTIONS

Abstract

Deep neural networks have been successful in many reinforcement learning settings. However, compared to human learners they are overly data hungry. To build a sample-efficient world model, we apply a transformer to real-world episodes in an autoregressive manner: not only the compact latent states and the taken actions but also the experienced or predicted rewards are fed into the transformer, so that it can attend flexibly to all three modalities at different time steps. The transformer allows our world model to access previous states directly, instead of viewing them through a compressed recurrent state. By utilizing the Transformer-XL architecture, it is able to learn long-term dependencies while staying computationally efficient. Our transformer-based world model (TWM) generates meaningful, new experience, which is used to train a policy that outperforms previous model-free and model-based reinforcement learning algorithms on the Atari 100k benchmark. Our code is available at https://github.com/jrobine/twm. ∞ t=1 γ t-1 r t , where γ ∈ [0, 1) is the discount factor. Learning in imagination consists of three steps that are repeated iteratively: learning the dynamics, learning a policy, and interacting in the real environment. In this section, we describe our world model and policy, concluding with the training procedure.

1. INTRODUCTION

Deep reinforcement learning methods have shown great success on many challenging decision making problems. Notable methods include DQN (Mnih et al., 2015) , PPO (Schulman et al., 2017) , and MuZero (Schrittwieser et al., 2019) . However, most algorithms require hundreds of millions of interactions with the environment, whereas humans often can achieve similar results with less than 1% of these interactions, i.e., they are more sample-efficient. The large amount of data that is necessary renders a lot of potential real world applications of reinforcement learning impossible. Recent works have made a lot of progress in advancing the sample efficiency of RL algorithms: model-free methods have been improved with auxiliary objectives (Laskin et al., 2020b) , data augmentation (Yarats et al., 2021 , Laskin et al., 2020a) , or both (Schwarzer et al., 2021) . Model-based methods have been successfully applied to complex image-based environments and have either been used for planning, such as EfficientZero (Ye et al., 2021) , or for learning behaviors in imagination, such as SimPLe (Kaiser et al., 2020) . o t- z t- ôt- a t- h t- r t- • • • • • • o t-1 z t-1 ôt-1 a t-1 h t-1 r t-1 o t z t ôt a t h t rt γt ẑt+1 Transformer Figure 1 : Our world model architecture. Observations o t-:t are encoded using a CNN. Linear embeddings of stochastic, discrete latent states z t-:t , actions a t-:t , and rewards r t-:t are fed into a transformer, which computes a deterministic hidden state h t at each time step. Predictions of the reward r t , discount factor γ t , and next latent state z t+1 are computed based on h t using MLPs. A promising model-based concept is learning in imagination (Ha & Schmidhuber, 2018; Kaiser et al., 2020; Hafner et al., 2020; Hafner et al., 2021) : instead of learning behaviors from the collected experience directly, a generative model of the environment dynamics is learned in a (self-)supervised manner. Such a so-called world model can create new trajectories by iteratively predicting the next state and reward. This allows for potentially indefinite training data for the reinforcement learning algorithm without further interaction with the real environment. A world model might be able to generalize to new, unseen situations, because of the nature of deep neural networks, which has the potential to drastically increase the sample efficiency. This can be illustrated by a simple example: in the game of Pong, the paddles and the ball move independently. In the best case, a successfully trained world model would imagine trajectories with paddle and ball configurations that have never been observed before, which enables learning of improved behaviors. In this paper, we propose to model the world with transformers (Vaswani et al., 2017) , which have significantly advanced the field of natural language processing and have been successfully applied to computer vision tasks (Dosovitskiy et al., 2021) . A transformer is a sequence model consisting of multiple self-attention layers with residual connections. In each self-attention layer the inputs are mapped to keys, queries, and values. The outputs are computed by weighting the values by the similarity of keys and queries. Combined with causal masking, which prevents the self-attention layers from accessing future time steps in the training sequence, transformers can be used as autoregressive generative models. The Transformer-XL architecture (Dai et al., 2019) is much more computationally efficient than vanilla transformers at inference time and introduces relative positional encodings, which remove the dependence on absolute time steps.

Our contributions:

The contributions of this work can be summarized as follows: 1. We present a new autoregressive world model based on the Transformer-XL (Dai et al., 2019) architecture and a model-free agent trained in latent imagination. Running our policy is computationally efficient, as the transformer is not needed at inference time. This is in contrast to related works (Hafner et al., 2020; 2021; Chen et al., 2022) that require the full world model during inference. 2. Our world model is provided with information on how much reward has already been emitted by feeding back predicted rewards into the world model. As shown in our ablation study, this improves performance. 3. We rewrite the balanced KL divergence loss of Hafner et al. (2021) to allow us to fine-tune the relative weight of the involved entropy and cross-entropy terms. 4. We introduce a new thresholded entropy loss that stabilizes the policy's entropy during training and hereby simplifies the selection of hyperparameters that behave well across different games. 5. We propose a new effective sampling procedure for the growing dataset of experience, which balances the training distribution to shift the focus towards the latest experience. We demonstrate the efficacy of this procedure with an ablation study. 6. We compare our transformer-based world model (TWM) on the Atari 100k benchmark with recent sample-efficient methods and obtain excellent results. Moreover, we report empirical confidence intervals of the aggregate metrics as suggested by Agarwal et al. (2021) . ). The goal is to find a policy π that maximizes the expected sum of discounted rewards E π

2.1. WORLD MODEL

Our world model consists of an observation model and a dynamics model, which do not share parameters. Figure 1 illustrates our combined world model architecture. Observation Model: The observation model is a variational autoencoder (Kingma & Welling, 2014) , which encodes observations o t into compact, stochastic latent states z t and reconstructs the observations with a decoder, which in our case is only required to obtain a learning signal for z t : Observation encoder: z t ∼ p φ (z t | o t ) Observation decoder: ôt ∼ p φ (ô t | z t ). We adopt the neural network architecture of DreamerV2 (Hafner et al., 2021) with slight modifications for our observation model. Thus, a latent state z t is discrete and consists of a vector of 32 categorical variables with 32 categories. The observation decoder reconstructs the observation and predicts the means of independent standard normal distributions for all pixels. The role of the observation model is to capture only non-temporal information about the current time step, which is different from Hafner et al. (2021) . However, we include short-time temporal information, since a single observation o t consists of four frames (aka frame stacking, see also Section 2.2). Autoregressive Dynamics Model: The dynamics model predicts the next time step conditioned on the history of its past predictions. The backbone is a deterministic aggregation model f ψ which computes a deterministic hidden state h t based on the history of the previously generated latent states, actions, and rewards. Predictors for the reward, discount, and next latent state are conditioned on the hidden state. The dynamics model consists of these components: Aggregation model: h t = f ψ (z t-:t , a t-:t , r t-:t-1 ) Reward predictor: rt ∼ p ψ (r t | h t ) Discount predictor: γt ∼ p ψ (γ t | h t ) Latent state predictor: ẑt+1 ∼ p ψ (ẑ t+1 | h t ). (2) The aggregation model is implemented as a causally masked Transformer-XL (Dai et al., 2019) , which enhances vanilla transformers (Vaswani et al., 2017) with a recurrence mechanism and relative positional encodings. With these encodings, our world model learns the dynamics independent of absolute time steps. Following Chen et al. (2021) , the latent states, actions, and rewards are sent into modality-specific linear embeddings before being passed to the transformer. The number of input tokens is 3 -1, because of the three modalities (latent states, actions, rewards) and the last reward not being part of the input. We consider the outputs of the action modality as the hidden states and disregard the outputs of the other two modalities (see Figure 1 ; orange boxes vs. gray boxes). The latent state, reward, and discount predictors are implemented as multilayer perceptrons (MLPs) and compute the parameters of a vector of independent categorical distributions, a normal distribution, and a Bernoulli distribution, respectively, conditioned on the deterministic hidden state. The next state is determined by sampling from p ψ (ẑ t+1 | h t ). The reward and discount are determined by the mean of p ψ (r t | h t ) and p ψ (γ t | h t ), respectively. As a consequence of these design choices, our world model has the following beneficial properties: 1. The dynamics model is autoregressive and has direct access to its previous outputs. 2. Training is efficient since sequences are processed in parallel (compared with RNNs). 3. Inference is efficient because outputs are cached (compared with vanilla Transformers). 4. Long-term dependencies can be captured by the recurrence mechanism. We want to provide an intuition on why a fully autoregressive dynamics model is favorable: First, the direct access to previous latent states enables to model more complex dependencies between them, compared with RNNs, which only see them indirectly through a compressed recurrent state. This also has the potential to make inference more robust, since degenerate predictions can be ignored more easily. Second, because the model sees which rewards it has produced previously, it can react to its own predictions. This is even more significant when the rewards are sampled from a probability distribution, since the introduced noise cannot be observed without autoregression. Loss Functions: The observation model can be interpreted as a variational autoencoder with a temporal prior, which is provided by the latent state predictor. The goal is to keep the distributions of the encoder and the latent state predictor close to each other, while slowly adapting to new observations and dynamics. Hafner et al. (2021) apply a balanced KL divergence loss, which lets them control which of the two distributions should be penalized more. To control the influences of its subterms more precisely, we disentangle this loss and obtain a balanced cross-entropy loss that computes the cross-entropy H(p φ (z t+1 | o t+1 ), p ψ (ẑ t+1 | h t )) and the entropy H(p φ (z t | o t )) explicitly. Our derivation can be found in Appendix A.2. We call the cross-entropy term for the observation model the consistency loss, as its purpose is to prevent the encoder from diverging from the dynamics model. The entropy regularizes the latent states and prevents them from collapsing to one-hot distributions. The observation decoder is optimized via negative log-likelihood, which provides a rich learning signal for the latent states. In summary, we optimize a self-supervised loss function for the observation model that is the expected sum over the decoder loss, the entropy regularizer and the consistency loss L Obs. φ = E T t=1 -ln p φ (o t | z t ) decoder -α 1 H(p φ (z t | o t )) entropy regularizer + α 2 H(p φ (z t | o t ), p ψ (ẑ t | h t-1 )) consistency , where the hyperparameters α 1 , α 2 ≥ 0 control the relative weights of the terms. For the balanced cross-entropy loss, we also minimize the cross-entropy in the loss of the dynamics model, which is how we train the latent state predictor. The reward and discount predictors are optimized via negative log-likelihood. This leads to a self-supervised loss for the dynamics model L Dyn. ψ = E T t=1 H(p φ (z t+1 | o t+1 ), p ψ (ẑ t+1 | h t )) latent state predictor -β 1 ln p ψ (r t | h t ) reward predictor -β 2 ln p ψ (γ t | h t ) discount predictor , with coefficients β 1 , β 2 ≥ 0 and where γ t = 0 for episode ends (d t = 1) and γ t = γ otherwise.

2.2. POLICY

Our policy π θ (a t | ẑt ) is trained on imagined trajectories using a mainly standard advantage actorcritic (Mnih et al., 2016) approach. We train two separate networks: an actor a t ∼ π θ (a t | ẑt ) with parameters θ and a critic v ξ (ẑ t ) with parameters ξ. We compute the advantages via Generalized Advantage Estimation (Schulman et al., 2016) while using the discount factors predicted by the world model γt instead of a fixed discount factor for all time steps. As in DreamerV2 (Hafner et al., 2021) , we weight the losses of the actor and the critic by the cumulative product of the discount factors, in order to softly account for episode ends. Thresholded Entropy Loss: We penalize the objective of the actor with a slightly modified version of the usual entropy regularization term (Mnih et al., 2016) . Our penalty normalizes the entropy and only takes effect when the entropy falls below a certain threshold L Ent. θ = max 0, Γ - H(π θ ) ln(m) , where 0 ≤ Γ ≤ 1 is the threshold hyperparameter, H(π θ ) is the entropy of the policy, m is the number of discrete actions, and ln(m) is the maximum possible entropy of the categorical action distribution. By doing this, we explicitly control the percentage of entropy that should be preserved across all games independent of the number of actions. This ensures exploration in the real environment and in imagination without the need for -greedy action selection or changing the temperature of the action distribution. We also use the same stochastic policy when evaluating our agent in the experiments. The idea of applying a hinge loss to the entropy was first introduced by of the policy and it affects the design choices for the world model. Using x t = o t (or ôt ) is relatively stable even with imperfect reconstructions ôt , as the underlying distribution of observations p(o) does not change during training. However, it is also less computationally efficient, since it requires reconstructing the observations during imagination and additional convolutional layers for the policy. Using x t = z t (or ẑt ) is slightly less stable, as the policy has to adopt to the changes of the distributions p φ (z t | o t ) and p ψ (ẑ t+1 | h t ) during training. Nevertheless, the entropy regularizer and consistency loss in Equation ( 3) stabilize these distributions. Using x t = [z t , h t ] (or [ẑ t , h t ] ) provides the agent with a summary of the history of experience, but it also adds the burden of running the transformer at inference time. Model-free agents already perform well on most Atari games when using a stack of the most recent frames (e.g., Mnih et al. 2015; Schulman et al. 2017) . Therefore, we choose x t = z t and apply frame stacking at inference time in order to incorporate short-time information directly into the latent states. At training time we use x t = ẑt , i.e., the predicted latent states, meaning no frame stacking is applied. As a consequence, our policy is computationally efficient at training time (no reconstructions during imagination) and at inference time (no transformer when running in the real environment).

2.3. TRAINING

As is usual for learning with world models, we repeatedly (i) collect experience in the real environment with the current policy, (ii) improve the world model using the past experience, (iii) improve the policy using new experience generated by the world model. During training we build a dataset D = [(o 1 , a 1 , r 1 , d 1 ), . . . , (o T , a T , r T , d T )] of the collected experience. After collecting new experience with the current policy, we improve the world model by sampling N sequences of length from D and optimizing the loss functions in Equations ( 3) and (4) using stochastic gradient descent. After performing a world model update, we select M of the N × observations and encode them into latent states to serve as initial states for new trajectories. The dynamics model iteratively generates these M trajectories of length H based on actions provided by the policy. Subsequently, the policy is improved with standard model-free objectives, as described in Section 2.2. In Algorithm 1 we present pseudocode for training the world model and the policy. Balanced Dataset Sampling: Since the dataset grows slowly during training, uniform sampling of trajectories focuses too heavily on early experience, which can lead to overfitting especially in the low data regime. Therefore, we keep visitation counts v 1 , . . . , v T , which are incremented every time an entry is sampled as start of a sequence. These counts are converted to probabilities using the softmax function (p 1 , . . . , p T ) = softmax -v1 τ , . . . , -v T τ , where τ > 0 is a temperature hyperparameter. With our sampling procedure, new entries in the dataset are oversampled and are selected more often than old ones. Setting τ = ∞ restores uniform sampling as a special case, whereas reducing τ increases the amount of oversampling. See Figure 2 for a comparison. We empirically show the effectiveness in Section 3.3. Figure 4 : Performance profiles on the Atari 100k benchmark based on score distributions (Agarwal et al., 2021) . It shows the fraction of runs across all games (y-axis) above a human normalized score (x-axis). Shaded regions show pointwise 95% confidence bands. We follow the advice of Agarwal et al. (2021) who found significant discrepancies between reported point estimates of mean (and median) scores and a thorough statistical analysis that includes statistical uncertainty. Thus, we report confidence interval estimates of the aggregate metrics median, interquartile mean (IQM), mean, and optimality gap in Figure 3 and performance profiles in Figure 4 , which we created using the toolbox provided by Agarwal et al. (2021) . The metrics are computed on human normalized scores, which are calculated as (score_agent -score_random)/ (score_human -score_random). We report the unnormalized scores per game in Table 1 . We compare with new scores for DER, CURL, DrQ, and SPR that were evaluated on 100 runs and provided by Agarwal et al. (2021) . They report scores for the improved DrQ(ε), which is DrQ evaluated with standard ε-greedy parameters. We perform 5 runs per game and compute the average score over 100 episodes at the end of training for each run. TWM shows a significant improvement over previous approaches in all four aggregate metrics and brings the optimality gap closer to zero. 

3.2. ANALYSIS

In Figure 5 we show imagined trajectories of our world model. In Figure 6 we visualize an attention map of the transformer for an imagined sequence. In this example a lot of weight is put on the current action and the last three states. However, the transformer also attends to states and rewards in the past, with only past actions being mostly ignored. The two high positive rewards also get high attention, which confirms that the rewards in the input sequence are used by the world model. We hypothesize that these rewards correspond to some events that happened in the environment and this information can be useful for prediction. An extended analysis can be found in Appendix A.1, including more imagined trajectories and attention maps (and a description of the generation of the plots), sample efficiency, stochasticity of the world model, long sequence imagination, and frame stacking. Published as a conference paper at ICLR 2023 s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a 

3.3. ABLATION STUDIES

Uniform Sampling: To show the effectiveness of the sampling procedure described in Section 2.3, we evaluate three games with uniform dataset sampling, which is equivalent to setting τ = ∞ in Equation ( 6). In Figure 7 we show that balanced dataset sampling significantly improves the performance in these games. At the end of training, the dynamics loss from Equation ( 4) is lower when applying balanced sampling. One reason might be that the world model overfits on early training data and performs bad in later stages of training. No Rewards: As described in Section 2.1, the predicted rewards are fed back into the transformer. In Figure 8 we show on three games that this can significantly increase the performance. In some games the performance is equivalent, probably because the world model can make correct predictions solely based on the latent states and actions. In Appendix A.1 we perform additional ablation studies, including the thresholded entropy loss, a shorter history length, conditioning the policy on [z, h], and increasing the sample efficiency. 

4. RELATED WORK

The Dyna architecture (Sutton, 1991) introduced the idea of training a model of the environment and using it to further improve the value function or the policy. Ha & Schmidhuber (2018) introduced the notion of a world model, which tries to completely imitate the environment and is used to generate experience to train a model-free agent. They implement a world model as a VAE (Kingma & Welling, 2014) and an RNN and learn a policy in latent space with an evolution strategy. With SimPLe, Kaiser et al. (2020) propose an iterative training procedure that alternates between training the world model and the policy. Their policy operates on pixel-level and is trained using PPO (Schulman et al., 2017) . Hafner et al. (2020) present Dreamer and implement a world model as a stochastic RNN that splits the latent state in a stochastic part and a deterministic part; this idea was first introduced by Hafner et al., 2019. This allows their world model to capture the stochasticity of the environment and simultaneously facilitates remembering information over multiple time steps. Robine et al. (2020) use a VQ-VAE to construct a world model with drastically lower number of parameters. DreamerV2 (Hafner et al., 2021) achieves great performance on the Atari 50M benchmark after making some changes to Dreamer, the most important ones being categorical latent variables and an improved objective. Another direction of model-based reinforcement learning is planning, where the model is used at inference time to improve the action selection by looking ahead several time steps into the future. The most prominent work is MuZero (Schrittwieser et al., 2019) , where a learned sequence model of rewards and values is combined with Monte-Carlo Tree Search (Coulom, 2006) without learning explicit representations of the observations. MuZero achieves impressive performance on the Atari 50M benchmark, but it is also computationally expensive and requires significant engineering effort. EfficientZero (Ye et al., 2021) improves MuZero and achieves great performance on the Atari 100k benchmark. Transformers (Vaswani et al., 2017) advanced the effectiveness of sequence models in multiple domains, such as natural language processing and computer vision (Dosovitskiy et al., 2021) . Recently, they have also been applied to reinforcement learning tasks. The Decision Transformer (Chen et al., 2021) and the Trajectory Transformer (Janner et al., 2021) are trained on an offline dataset of trajectories. The Decision Transformer is conditioned on states, actions, and returns, and outputs optimal actions. The Trajectory Transformer trains a sequence model of states, actions, and rewards, and is used for planning. Chen et al. (2022) replace the RNN of Dreamer with a transformer and outperform Dreamer on Hidden Order Discovery tasks. However, their transformer has no access to previous rewards and they do not evaluate their method on the Atari 100k benchmark. Moreover, their policy depends on the outputs of the transformer, leading to higher computational costs during inference time. Concurrent to and independent from our work, Micheli et al. (2022) apply a transformer to sequences of frame tokens and actions and achieve state-of-the-art results on the Atari 100k benchmark.

5. CONCLUSION

In this work, we discuss a reinforcement learning approach using transformer-based world models. Our method (TWM) outperforms previous model-free and model-based methods in terms of human normalized score on the 26 games of the Atari 100k benchmark. By using the transformer only during training, we were able to keep the computational costs low during inference, i.e., when running the learned policy in a real environment. We show how feeding back the predicted rewards into the transformer is beneficial for learning the world model. Furthermore, we introduce the balanced cross-entropy loss for finer control over the trade-off between the entropy and cross-entropy terms in the loss functions of the world model. A new thresholded entropy loss effectively stabilizes the entropy of the policy. Finally, our novel balanced sampling procedure corrects issues of naive uniform sampling of past experience. 6-14, 2021, virtual, pp. 25476-25488, 2021 . URL https://proceedings.neurips.cc/ paper/2021/hash/d5eca8dc3820cad9fe56a3bafda65ca1-Abstract.html. higher as well, probably because the entropy is in a more sensible range for the explorationexploitation trade-off. This cannot be solved by adjusting the penalty coefficient η alone, since it would increase or decrease the entropy in all games. 2. History Length: We trained our world model with a shorter history and set = 4 instead of = 16. This has a negative impact on the score, as can be seen in Figure 16 , demonstrating that more time steps into the past are important. 3. Choice of Policy Input: In Section 2.2 we explained why the input to the policy is only the latent state, i.e., x = z. In Figure 17 we that using x = [z, h] can result in lower final scores. We hypothesize that the policy network has a hard time keeping up with the changes of the space of h during training and cannot ignore this additional information. 4. Increasing the Sample Efficiency: To find out whether we can further increase the sample efficiency shown in Table 2 , we train a subset games again on 10K, 25K, and 50K interactions with the full training budget that we used for the 100K interactions. In Figure 18 we see that this can lead to significant improvements in some cases, which could mean that the policy benefits from more training on imagined trajectories, but can even lead to worse performance in other cases, which could possibly be caused by overfitting of the world model. When the performance stays the same even with longer training, this could mean that better exploration in the real environment is required to get further improvements. Hafner et al. (2021) propose to use a balanced KL divergence loss to jointly optimize the observation encoder q θ and state predictor p θ with shared parameters θ, i.e., λ D KL (sg(q θ ) p θ ) + (1 -λ) D KL (q θ sg(p θ )), (7) where sg(•) denotes the stop-gradient operation and λ ∈ [0, 1] controls how much the state predictor adapts to the observation encoder and vice versa. We use the identity D KL (q p) = H(q, p)-H(q), where H(q, p) is the cross-entropy of distribution p relative to distribution q, and show that our loss functions lead to the same gradients as the balanced KL objective, but with finer control over the individual components: ∇ θ [λ D KL (sg(q θ ) p θ ) + (1 -λ) D KL (q θ sg(p θ ))] (8) = ∇ θ [λ (H(sg(q θ ), p θ ) -H(sg(q θ ))) + (1 -λ) (H(q θ , sg(p θ )) -H(q θ ))] (9) = ∇ θ [λ 1 H(sg(q θ ), p θ ) + λ 2 H(q θ , sg(p θ )) -λ 3 H(q θ )], (10) since ∇ θ H(sg(q θ )) = 0 and by defining λ 1 = λ and λ 2 = 1 -λ and λ 3 = 1 -λ. In this form, we have control over the cross-entropy of the state predictor relative to the observation encoder and vice versa. Moreover, we explicitly penalize the entropy of the observation encoder, instead of being entangled inside of the KL divergence. As common in the literature, we define the loss function by omitting the gradient in Equation ( 10), so that automatic differentiation computes this gradient. For our world model, we split the objective into two loss functions, as the observation encoder and state predictor have separate parameters, yielding Equations ( 3) and (4). Published as a conference paper at ICLR 2023 s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a 



Figure 2: Comparing our balanced dataset sampling procedure (see Equation (6)) for different values of τ with uniform sampling (τ = ∞). The x-axes correspond to the entries in dataset D in the order they are experienced. The left plot shows the number of times an entry has been selected for training the world model. The right plot shows the relative amount of training time that has been spent on the data up to that entry. E.g., with uniform sampling, 50% of the training time is used for the first 19K entries, whereas for τ = 20 approximately the same time is spend on both halves of the dataset.

Boxing. The player (white) presses fire, hits the opponent, and gets a reward. Freeway. The player moves up and bumps into a car. The world model correctly pushes the player down, although up is still pressed. The movement of the cars is modeled correctly.

Figure 5: Trajectories imagined by our world model. Above each frame we show the performed action and the produced reward.

Figure 6: Attention map of the learned transformer for the current hidden state h, computed on an imagined trajectory for the game Assault. The x-axis corresponds to the input sequence with the three modalities (states, actions, rewards), where the two rightmost columns are the current state and action. The y-axis corresponds to the layer of the transformer.

Figure 7: Comparison of the proposed balanced sampling procedure with uniform sampling on a random subset of games. We show the development of the human normalized score in the course of training. The score is higher with balanced sampling, demonstrating its importance.

Figure8: Effect of removing rewards from the input. We show the human normalized score during training of a random subset of games. Conditioning on rewards can significantly increase the performance. Some games do not benefit from the rewards and the score stays roughly the same.

This world model focuses on previous states. s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a This world model focuses on previous actions, indicating that the effect of actions can last longer than a single time step. s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a This world model attends to all three modalities in the recent past.

s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a This world model attends to states at all time steps, probably because of the complexity of this 3D game. s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a r s a This world model mainly focuses on four states at specific time steps.

Figure 10: Average attention maps of the transformer, computed over time steps. They show how different games require a different focus on modalities and time steps.

Figure 11: Attention map for Freeway for a single time step. At this point the player hits a car and gets pushed back (see also Figure 5b) and the world model puts more attention to past states and rewards, compared with the average attention at other time steps, as shown in Figure 10e. The world model has learned to handle this situation separately.

Aggregate metrics on the Atari 100k benchmark with 95% stratified bootstrap confidence intervals(Agarwal et al., 2021). Higher median, interquartile mean (IQM), and mean, but lower optimality gap indicate better performance. Scores for previous methods are fromAgarwal et al. (2021) with 100 runs per game (except SimPLe with 5 runs). We evaluate 5 runs per game, leading to wider confidence intervals.3 EXPERIMENTSTo compare data-efficient reinforcement learning algorithms,Kaiser et al. (2020) proposed the Atari 100k benchmark, which uses a subset of 26 Atari games from the Arcade Learning Environment(Bellemare et al., 2013) and limits the number of interactions per game to 100K. This corresponds to 400K frames (because of frame skipping) or roughly 2 hours of gameplay, which is 500 times less than the usual 200 million frames (e.g.,Mnih et al. 2015;Schulman et al. 2017;Hafner et al. 2021).

Mean scores on the Atari 100k benchmark per game as well as the aggregated human normalized mean and median. We perform 5 runs per game and compute the average over 100 episodes at the end of training for each run. Bold numbers indicate the best scores.

Denis Yarats, Ilya Kostrikov, and Rob Fergus. Image augmentation is all you need: Regularizing deep reinforcement learning from pixels. In 9th International Conference on Learning Representations, ICLR 2021, Virtual Event, Austria, May 3-7, 2021. OpenReview.net, 2021. URL https://openreview.net/forum?id=GY6-6sTvGaf.

Performance of our method at different stages of training compared with final scores of previous methods. We show individual game scores and mean human normalized scores. The normalized mean of our method is higher than SimPLe after only 25K interactions, and higher than previous methods after 50K interactions.

Approximate runtime (i.e., training and evaluation time for a single run) of our method compared with previous methods that also evaluate on the Atari 100k benchmark. Runtimes of previous methods are takenfrom Schwarzer et al. (2021). They used an improved version of DER(van Hasselt et al., 2019), which is roughly equivalent to DrQ(Yarats et al., 2021), so the specified runtime might differ from the original DER implementation. There are data augmented versions for SPR and DER. All runtimes are measured on a single NVIDIA P100 GPU. This suggests that our method could potentially outperform previous methods with shorter training, which would take less than 23.3 hours.To determine how time-consuming the individual parts of our method are, we investigate the throughput of the models, with the batch sizes of our main experiments. The Transformer-XL version is almost twice as fast, which again shows the importance of this design choice. The throughputs were measured on an NVIDIA A100 GPU and are given in (approximate) samples per second:

annex

Additional Analysis:1. We provide more example trajectories in Figure 9 .2. We present more attention plots in Figures 10 and 11 . All attention maps are generated using the attention rollout method by Abnar & Zuidema (2020) . Note that we had to modify the method slightly, in order to take the causal masks into account.

3.. Sample Efficiency:

We provide the scores of our main experiments after different amounts of interactions with the environment in Table 2 . After 50K interactions, our method already has a higher mean normalized score than previous sample-efficient methods. Our mean normalized score is higher than DER, CURL, and SimPLe after 25K interactions. This demonstrates the high sample efficiency of our approach.

4.. Stochasticity:

The stochastic prediction of the next state allows the world model to sample a variety of trajectories, even from the same starting state, as can be seen in Figure 12 .

5.. Long Sequence Imagination:

The world model is trained using sequences of length = 16, however, it generalizes well to very long trajectories, as shown in Figure 13 .6. Frame Stacking: In Figure 14 we visualize the learned stacks of frames. This shows that the world model encodes and predicts the motion of objects.Additional Ablation Studies:1. Thresholded Entropy Loss: In Figure 15 we compare (i) our thresholded entropy loss for the policy (see Section 2.2) with (ii) the usual entropy penalty. For (i) we use the same hyperparameters as in our main experiments, i.e., η = 0.01 and Γ = 0.1. For (ii) we set η = 0.001 and Γ = 1.0, which effectively disables the threshold. Without a threshold, the entropy is more likely to either collapse or diverge. When the threshold is used, the score is Wall-Clock Times: For each run, we give the agent a total training and evaluation budget of roughly 10 hours on a single NVIDIA A100 GPU. The time can vary slightly, since the budget is based on the number of updates. An NVIDIA GeForce RTX 3090 requires 12-13 hours for the same amount of training and evaluation. When using a vanilla transformer, which does not use the memory mechanism of the Transformer-XL architecture (Dai et al., 2019) , the runtime is roughly 15.5 hours on an NVIDIA A100 GPU, i.e., 1.5 times higher.We compare the runtime of our method with previous methods in Table 3 . Our method is more than 20 times faster than SimPLe, but slower than model-free methods. However, our method should be as fast as other model-free methods during inference. In Table 2 we have shown that our method 

A.3 ADDITIONAL TRAINING DETAILS

In Algorithm 1 we present pseudocode for training the world model and the actor-critic agent. We use the SiLU activation function (Elfwing et al., 2018) for all models. In Table 4 we summarize all hyperparameters that we used in our experiments. In 

