DEEP TRANSFORMER Q-NETWORKS FOR PARTIALLY OBSERVABLE REINFORCEMENT LEARNING

Abstract

Real-world reinforcement learning tasks often involve some form of partial observability where the observations only give a partial or noisy view of the true state of the world. Such tasks typically require some form of memory, where the agent has access to multiple past observations, in order to perform well. One popular way to incorporate memory is by using a recurrent neural network to access the agent's history. However, recurrent neural networks in reinforcement learning are often fragile and difficult to train and sometimes fail completely as a result. In this work, we propose Deep Transformer Q-Networks (DTQN), a novel architecture utilizing transformers and self-attention to encode an agent's history. DTQN is designed modularly, and we compare results against several modifications to our base model. Our experiments demonstrate that our approach can solve partially observable tasks faster and more stably than previous recurrent approaches.

1. INTRODUCTION

In recent years, deep neural networks have become the computational backbone of reinforcement learning, achieving strong performance across a wide array of difficult tasks including games (Mnih et al., 2015; Silver et al., 2016) and robotics (Levine et al., 2018; Gao et al., 2020) . In particular, Deep Q-Networks (DQN) (Mnih et al., 2015) revolutionized the field of deep RL by achieving super-human performance on Atari 2600 games in the Atari Learning Environment (Bellemare et al., 2013) . Since then, several advancements have been proposed to improve DQN (Hessel et al., 2018) , and deep RL has been shown to excel in continuous control tasks as well (Haarnoja et al., 2018; Fujimoto et al., 2018) . However, most Deep RL methods assume the agent is operating within a fully observable environment; that is, one in which the agent has access to the environment's full state information. But this assumption does not hold for many realistic domains due to components such as noisy sensors, occluded images, or additional unknown agents. These domains are partially observable, and pose a much bigger challenge for RL compared to the standard fully observable setting. Indeed, naïve methods often fail to learn in partially observable environments without additional architectural or training support (Pinto et al., 2017; Igl et al., 2018; Ma et al., 2020) . To solve partially observable domains, RL agents may need to remember (some or possibly all) previous observations (Kaelbling et al., 1998) . As a result, RL methods typically add some sort of memory component, allowing them to store or refer back to recent observations in order to make more informed decisions. The current state-of-the-art approaches integrate recurrent neural networks, like LSTMs (Hochreiter & Schmidhuber, 1997) or GRUs (Cho et al., 2014) , in conjunction with fully observable Deep RL architectures to process an agent's history (Ni et al., 2021) . But recurrent neural networks (RNNs) can be fragile and difficult to train, often requiring complicated "warm-up" strategies to initialize its hidden state at the start of each training batch (Lample & Chaplot, 2017) . Conversely, the Transformer has been shown to model sequences much better than RNNs and is ubiquitous in natural language processing (NLP) (Devlin et al., 2018) and increasingly common in computer vision (Dosovitskiy et al., 2020) . Therefore, we propose Deep Transformer Q-Network (DTQN), a novel architecture using selfattention to solve partially observable RL domains. DTQN leverages a transformer decoder architecture with learned positional encodings to represent an agent's history and accurately predict Q-values at each timestep. Rather than a standard approach that trains on a single next step for a given history, 1

