DEEP TRANSFORMER Q-NETWORKS FOR PARTIALLY OBSERVABLE REINFORCEMENT LEARNING

Abstract

Real-world reinforcement learning tasks often involve some form of partial observability where the observations only give a partial or noisy view of the true state of the world. Such tasks typically require some form of memory, where the agent has access to multiple past observations, in order to perform well. One popular way to incorporate memory is by using a recurrent neural network to access the agent's history. However, recurrent neural networks in reinforcement learning are often fragile and difficult to train and sometimes fail completely as a result. In this work, we propose Deep Transformer Q-Networks (DTQN), a novel architecture utilizing transformers and self-attention to encode an agent's history. DTQN is designed modularly, and we compare results against several modifications to our base model. Our experiments demonstrate that our approach can solve partially observable tasks faster and more stably than previous recurrent approaches.

1. INTRODUCTION

In recent years, deep neural networks have become the computational backbone of reinforcement learning, achieving strong performance across a wide array of difficult tasks including games (Mnih et al., 2015; Silver et al., 2016) and robotics (Levine et al., 2018; Gao et al., 2020) . In particular, Deep Q-Networks (DQN) (Mnih et al., 2015) revolutionized the field of deep RL by achieving super-human performance on Atari 2600 games in the Atari Learning Environment (Bellemare et al., 2013) . Since then, several advancements have been proposed to improve DQN (Hessel et al., 2018) , and deep RL has been shown to excel in continuous control tasks as well (Haarnoja et al., 2018; Fujimoto et al., 2018) . However, most Deep RL methods assume the agent is operating within a fully observable environment; that is, one in which the agent has access to the environment's full state information. But this assumption does not hold for many realistic domains due to components such as noisy sensors, occluded images, or additional unknown agents. These domains are partially observable, and pose a much bigger challenge for RL compared to the standard fully observable setting. Indeed, naïve methods often fail to learn in partially observable environments without additional architectural or training support (Pinto et al., 2017; Igl et al., 2018; Ma et al., 2020) . To solve partially observable domains, RL agents may need to remember (some or possibly all) previous observations (Kaelbling et al., 1998) . As a result, RL methods typically add some sort of memory component, allowing them to store or refer back to recent observations in order to make more informed decisions. The current state-of-the-art approaches integrate recurrent neural networks, like LSTMs (Hochreiter & Schmidhuber, 1997) or GRUs (Cho et al., 2014) , in conjunction with fully observable Deep RL architectures to process an agent's history (Ni et al., 2021) . But recurrent neural networks (RNNs) can be fragile and difficult to train, often requiring complicated "warm-up" strategies to initialize its hidden state at the start of each training batch (Lample & Chaplot, 2017) . Conversely, the Transformer has been shown to model sequences much better than RNNs and is ubiquitous in natural language processing (NLP) (Devlin et al., 2018) and increasingly common in computer vision (Dosovitskiy et al., 2020) . Therefore, we propose Deep Transformer Q-Network (DTQN), a novel architecture using selfattention to solve partially observable RL domains. DTQN leverages a transformer decoder architecture with learned positional encodings to represent an agent's history and accurately predict Q-values at each timestep. Rather than a standard approach that trains on a single next step for a given history, we propose a training regime called intermediate Q-value prediction, which allows us to train DTQN on the Q-values generated for each timestep in the agent's observation history and provide more robust learning. DTQN encodes an agent's history more effectively than recurrent methods, which we show empirically across several challenging partially observable environments. We evaluate and analyze several architectural components, including: gated skip connections (Parisotto et al., 2020) , positional encodings, identity map reordering (Parisotto et al., 2020) , and intermediate value prediction (Al-Rfou et al., 2019) . Our results provide strong evidence that our approach can successfully represent agents' histories in partially observable domains. We visualize attention weights showing DTQN learns an understanding of the domains as it works to solve tasks.

2. BACKGROUND

When an environment does not emit its full state to the agent, the problem can be modeled as a Partially Observable Markov Decision Process (POMDP) Kaelbling et al. (1998) . A POMDP is formally described as the 6-tuple (S, A, T , R, ⌦, O). S, A, and ⌦ represent the environment's set of states, actions, and observations, respectively. T is the state transition function T (s, a, s 0 ) = P (s 0 |s, a), denoting the probability of transitioning from state s to state s 0 given action a. R describes the reward function R : S ⇥ A ! R; that is, the resultant scalar reward emitted by the environment for an agent that was in some state s 2 S and took some action a 2 A. And O is the observation function O(s 0 , a, o) = P (o|s 0 , a), the probability of observing o when action a is taken resulting in state s 0 . At each time step, t, the agent is in the environment's state s t 2 S, takes action a t 2 A, manipulates the environment's state to some s t+1 2 S based on the transition probability T (s t , a t , s t+1 ) and receives a reward, r t = R(s t , a t ). The goal of the agent is to maximize E ⇥ P t t r t ⇤ , its expected discounted return for some discount factor 2 [0, 1) (Sutton & Barto, 2018) . Because agents in POMDPs do not have access to the environment's full state information, they must rely on the observations o t 2 ⌦ which relate to the state via the observation function, O(s t+1 , a t , o t ) = P (o t |s t+1 , a t ). In general, agents acting in partially observable space cannot simply use observations as a proxy for state, since several states may be aliased into the same observation. Instead, they often consider some form of their full history of information, h t = {(o 0 , a 0 ), (o 1 , a 1 ), ..., (o t 1 , a t 1 )}. Because the history grows indefinitely as the agent proceeds in a trajectory, various ways of encoding the history exist. Previous work has truncated the history to make it a fixed length (Zhu et al., 2017) or used an agent's belief, which represents the estimate of the current state (Kaelbling et al., 1998) . Since the deep learning revolution, others have used forms of recurrency, such as LSTMs and GRUs, to encode the history (Hausknecht & Stone, 2015; Yang & Nguyen, 2021) . (Watkins & Dayan, 1992) aims to learn a function Q : S ⇥ A ! R which represents the value of each state-action pair in an MDP. Given a state s, action a, reward r, next state s 0 , and learning rate ↵, the Q-function is updated with the equation

2.1. DEEP

RECURRENT Q-NETWORKS Q-Learning Q(s, a) := Q(s, a) + ↵(r + max a 0 2A Q(s 0 , a 0 ) Q(s, a)) In more challenging domains, however, the state-action space of the environment is often too large to be able to learn an exact Q-value for each state-action pair. Instead of learning a tabular Q-function, DQN (Mnih et al., 2015) learns an approximate Q-function featuring strong generalization capabilities over similar states and actions. DQN is trained to minimize the Mean Squared Bellman Error L(✓) = E (s,a,r,s 0 )⇠D ⇥ r + max a 0 2A Q(s 0 , a 0 ; ✓ 0 ) Q(s, a; ✓) 2 ⇤ (2) where transition tuples of states, actions, rewards, and future states (s, a, r, s 0 ) are sampled uniformly from a replay buffer, D, of past experiences while training. The target r + max a 0 2A Q(s 0 , a 0 ; ✓ 0 ) invokes DQN's target network (parameterized by ✓ 0 ), which lags behind the main network (parameterized by ✓) to produce more stable updates. However, in partially observable domains, DQN may not learn a good policy by simply replacing the network's input from states to observations (i.e., an agent can often perform better by remembering

