CORRECTING MOMENTUM IN TEMPORAL DIFFERENCE LEARNING Anonymous authors Paper under double-blind review

Abstract

A common optimization tool used in deep reinforcement learning is momentum, which consists in accumulating and discounting past gradients, reapplying them at each iteration. We argue that, unlike in supervised learning, momentum in Temporal Difference (TD) learning accumulates gradients that become doubly stale: not only does the gradient of the loss change due to parameter updates, the loss itself changes due to bootstrapping. We first show that this phenomenon exists, and then propose a first-order correction term to momentum. We show that this correction term improves sample efficiency in policy evaluation by correcting target value drift. An important insight of this work is that deep RL methods are not always best served by directly importing techniques from the supervised setting.

1. INTRODUCTION

Temporal Difference (TD) learning (Sutton, 1988 ) is a fundamental method of Reinforcement Learning (RL), which works by using estimates of future predictions to learn to make predictions about the present. To scale to large problems, TD has been used in conjunction with deep neural networks (DNNs) to achieve impressive performance (e.g. Mnih et al., 2013; Schulman et al., 2017; Hessel et al., 2018) . Unfortunately, the naive application of TD to DNNs has been shown to be brittle (Machado et al., 2018; Farebrother et al., 2018; Packer et al., 2018; Witty et al., 2018) , with extensions such as the n-step TD update (Fedus et al., 2020) or the TD(λ) update (Molchanov et al., 2016) only marginally improving performance and generalization capabilities when coupled with DNNs. Part of the success of DNNs, including when applied to TD learning, is the use of adaptive or accelerated optimization methods (Hinton et al., 2012; Sutskever et al., 2013; Kingma & Ba, 2015) to find good parameters. In this work we investigate and extend the momentum algorithm (Polyak, 1964) as applied to TD learning in DNNs. While accelerated TD methods have received some attention in the literature, this is typically done in the context of linear function approximators (Baxter & Bartlett, 2001; Meyer et al., 2014; Pan et al., 2016; Gupta et al., 2019; Gupta, 2020; Sun et al., 2020) , and while some studies have considered the mix of DNNs and TD (Zhang et al., 2019; Romoff et al., 2020) , many are limited to a high-level analysis of hyperparameter choices for existing optimization methods (Sarigül & Avci, 2018; Andrychowicz et al., 2020) ; or indeed the latter are simply applied as-is to train RL agents (Mnih et al., 2013; Hessel et al., 2018) . For an extended discussion of related work, we refer the reader to appendix A. As a first step in going beyond the naive use of supervised learning tools in RL, we examine momentum. We argue that momentum, especially as it is used in conjunction with TD and DNNs, adds an additional form of bias which can be understood as the staleness of accumulated information. We quantify this bias, and propose a corrected momentum algorithm that reduces this staleness and is capable of improving performance.

1.1. REINFORCEMENT LEARNING AND TEMPORAL DIFFERENCE LEARNING

A Markov Decision Process (MDP) (Bellman, 1957; Sutton & Barto, 2018 ) M = S, A, R, P, γ consists of a state space S, an action space A, a reward function R : S → R and a transition function P (s |s, a). RL agents usually aim to optimize the expectation of the long-term return, G(S t ) = ∞ k=t γ k-t R(S k ) where γ ∈ [0, 1) is called the discount factor. Policies π(a|s) map states to action distributions. Value functions V π and Q π map states and states-action pairs to expected returns, and can be written recursively: V π (S t ) = E π [G(S t )] = E π [R(S t , A t ) + γV (S t+1 )|A t ∼ π(S t )] Q π (S t , A t ) = E π [R(S t , A t ) + γ a π(a|S t+1 )Q(S t+1 , a)] We approximate V π with V θ . We can train V θ via regression to observed values of G, but these recursive equations also give rise to the Temporal Difference (TD) update rules for policy evaluation, relying on current estimates of V to bootstrap, which for example in the tabular case is written as: V (S t ) ← V (S t ) -α(V (S t ) -(R(S t ) + γV (S t+1 ))), where α ∈ [0, 1) is the step size. Alternatively, estimates of this update can be performed by a so-called semi-gradient (Sutton & Barto, 2018) algorithm where the "TD(0) loss" is minimized: θ t+1 = θ t -α∇ θt V θt (S t ) -(R(S t ) + γ Vθt (S t+1 )) 2 , with V meaning we consider V constant for the purpose of gradient computation.

1.2. BIAS AND STALENESS IN MOMENTUM

The usual form of momentum (Polyak, 1964; Sutskever et al., 2013) in stochastic gradient descent (SGD) maintains an exponential moving average with factor β of gradients w.r.t. to some objective J, changing parameters θ t ∈ R n (t is here SGD time rather than MDP time) with learning rate α: µ t = βµ t-1 + (1 -β)∇ θt-1 J t (θ t-1 ) (3) θ t = θ t-1 -αµ t (4) We assume here that the objective J t is time-dependent, as is the case for example in minibatch training or online learning. Note that other similar forms of this update exist, notably Nesterov's accelerated gradient method (Nesterov, 1983) , as well as undampened methods that omit (1β) in (3) or replace (1β) with α, found in popular deep learning packages (Paszke et al., 2019) . We make the observation that, at time t, the gradients accumulated in µ are stale. They were computed using past parameters rather than θ t , and in general we'd expect ∇ θt J t (θ t ) = ∇ θ k J t (θ k ), k < t. As such, the update in (4) is a biased update. In supervised learning where one learns a mapping from x to y, this staleness only has one source: θ changes but the target y stays constant. We argue that in TD learning, momentum becomes doubly stale: not only does the value network change, but the target (the equivalent of y) itself changesfoot_0 with every parameter update. Consider the TD objective in (2), when θ changes, not only does V (s) change, but V (s ) as well. The objective itself changes, making past gradients stale and less aligned with recent gradients (even more so when there is gradient interference (Liu et al., 2019; Achiam et al., 2019; Bengio et al., 2020) , constructive or destructive). Note that several sources of bias already exist in TD learning, notably the traditional parametric bias (of the bias-variance tradeoff when selecting capacity), as well as the bootstrapping bias (of the error in V (s ) when using it as a target; using a frozen target prevents this bias from compounding). We argue that the staleness in momentum we describe is an additional form of bias, slowing down or preventing convergence. This has been hinted at before, e.g. Gupta (2020) suggests that momentum hinders learning in linear TD(0). We wish to understand and possibly correct this staleness in momentum. In this work we propose answers to the following questions: • Is this bias significant in supervised learning? No, the bias exists but has a minimal effect at best when comparing to an unbiased oracle. • Is this bias significant in TD learning? Yes, we can quantify the bias, and comparisons to an unbiased oracle reveal significant differences. • Can we correct µ t to remove this bias? Yes, we derive an online update that approximately corrects µ t using only first order derivatives. • Does the correction help in TD learning? Yes, using a staleness-corrected momentum improves sample complexity, in policy evaluation, especially in an online setting.



Interestingly, even in most recent value-based control works(Hessel et al., 2018) a (usually frozen) copy is used for stability, meaning that the target only changes when the copy is updated. This is considered a "trick" which it would be compelling to get rid of, since it slows down learning, and since most recent policy-gradient methods (which still use a value function) do not make use of such copies(Schulman et al., 2017).

