TRANSIENT NON-STATIONARITY AND GENERALISA-TION IN DEEP REINFORCEMENT LEARNING

Abstract

Non-stationarity can arise in Reinforcement Learning (RL) even in stationary environments. For example, most RL algorithms collect new data throughout training, using a non-stationary behaviour policy. Due to the transience of this non-stationarity, it is often not explicitly addressed in deep RL and a single neural network is continually updated. However, we find evidence that neural networks exhibit a memory effect where these transient non-stationarities can permanently impact the latent representation and adversely affect generalisation performance. Consequently, to improve generalisation of deep RL agents, we propose Iterated Relearning (ITER). ITER augments standard RL training by repeated knowledge transfer of the current policy into a freshly initialised network, which thereby experiences less non-stationarity during training. Experimentally, we show that ITER improves performance on the challenging generalisation benchmarks ProcGen and Multiroom.

1. INTRODUCTION

In RL, as an agent explores more of its environment and updates its policy and value function, the data distribution it uses for training changes. In deep RL, this non-stationarity is often not addressed explicitly. Typically, a single neural network model is initialised and continually updated during training. Conventional wisdom about catastrophic forgetting (Kemker et al., 2018) implies that old updates from a different data-distribution will simply be forgotten. However, we provide evidence for an alternative hypothesis: networks exhibit a memory effect in their learned representations which can harm generalisation permanently if the data-distribution changed over the course of training. To build intuition, we first study this phenomenon in a supervised setting on the CIFAR-10 dataset. We artificially introduce transient non-stationarity into the training data and investigate how this affects the asymptotic performance under the final, stationary data in the later epochs of training. Interestingly, we find that while asymptotic training performance is nearly unaffected, test performance degrades considerably, even after the data-distribution has converged. In other words, we find that latent representations in deep networks learned under certain types of non-stationary data can be inadequate for good generalisation and might not be improved by later training on stationary data. Such transient non-stationarity is typical in RL. Consequently, we argue that this observed degradation of generalisation might contribute to the inferior generalisation properties recently attributed to many RL agents evaluated on held out test environments (Zhang et al., 2018a; b; Zhao et al., 2019) . Furthermore, in contrast to Supervised Learning (SL), simply re-training the agent from scratch once the data-distribution has changed is infeasible in RL as current state of the art algorithms require data close to the on-policy distribution, even for off-policy algorithms like Q-learning (Fedus et al., 2020) . To improve generalisation of RL agents despite this restriction, we propose Iterated Relearning (ITER). In this paradigm for deep RL training, the agent's policy and value are periodically distilled into a freshly initialised student, which subsequently replaces the teacher for further optimisation. While this occasional distillation step simply aims to re-learn and replace the current policy and value outputs for the training data, it allows the student to learn a better latent representation with improved performance for unseen inputs because it eliminates non-stationarity during distillation. We propose a practical implementation of ITER which performs the distillation in parallel to the training process without requiring additional training data. While this introduces a small amount of non-stationarity into the distillation step, it greatly improves sample efficiency without noticeably impacting performance. Experimentally, we evaluate ITER on the Multiroom environment, as well as several environments from the recently proposed ProcGen benchmark and find that it improves generalisation. This provides further support to our hypothesis and indicates that the non-stationarity inherent to many RL algorithms, even when training on stationary environments, should not be ignored when aiming to learn robust agents. Lastly, to further support this claim and provide more insight into possible causes of the discovered effect, we perform additional ablation studies on the CIFAR-10 dataset.

2. BACKGROUND

We describe an RL problem as a Markov decision process (MDP) (S, A, T, r, p 0 , γ) (Puterman, 2014) with actions a ∈ A, states s ∈ S, initial state s 0 ∼ p 0 , transition dynamics s ∼ T (s, a), reward function r(s, a) ∈ R and discount factor γ. The unnormalised discounted state distribution induced by a policy π is defined as d π (s) = ∞ t=0 γ t Pr S t = s|S 0 ∼ p 0 , A t ∼ π(•|S t ), S t+1 ∼ T (S t , A t ) . In ITER, we learn a sequence of policies and value functions, which we denote with π (k) (a|s) and V (k) (s) at the kth iteration (k ∈ {0, 1, 2, . . . }), parameterized by θ k . We briefly discuss some forms of non-stationarity which can arise in RL, even when the environment is stationary. For simplicity, we focus the exposition on actor-critic methods which use samples from interaction with the environment to estimate the policy gradient given by g = E[∇ θ log π θ (a|s)A π (s, a, s )|s, a, s ∼ d π (s)π(a|s)T (s |s, a)]. The advantage is often estimated as A π (s, a, s ) = r(s, a) + γV π (s ) -V π (s). Typically, we also use neural networks to approximate the baseline V π φ (s) and for bootstrapping from the future value V π φ (s ). φ can be learned by minimising E[A π (s, a, s ) 2 ] by stochastic semi-gradient descent, treating V π φ (s ) as a constant. There are at least three main types of non-stationarity in deep RL. First, we update the policy π θ , which leads to changes in the state distribution d π θ (s). Early on in training, a random policy π θ only explores states close to initial states s 0 . As π θ improves, new states further from s 0 are encountered. Second, changes to the policy also change the true value function V π (s) which V π φ (s) is approximating. Lastly, due to the use of bootstrap targets in temporal difference learning, the learned value V π φ (s) is not regressed directly towards V π (s). Instead V π φ fits a gradually evolving target sequence even under a fixed policy π, thereby also changing the policy gradient estimator g.

3. THE IMPACT OF NON-STATIONARITY ON GENERALISATION

In this section we investigate how asymptotic performance is affected by changes to the datadistribution during training. In particular, we assume an initial, transient phase of non-stationarity, followed by an extended phase of training on a stationary data-distribution. This is similar to the situation in RL where the data-distribution is affected by a policy which converges over time. We show that this transient non-stationarity has a permanent effect on the learned representation and negatively impacts generalisation. As interventions in RL training can lead to confounding factors due to off-policy data or changed exploration behaviour, we utilise Supervised Learning (SL) here to provide initial evidence in a more controlled setup. We use the CIFAR-10 dataset for image classification (Krizhevsky et al., 2009) and artificially inject non-stationarity. Our goal is to provide qualitative results on the impact of non-stationarity, not to obtain optimal performance. We use a ResNet18 (He et al., 2016) architecture, similar to those used by Espeholt et al. (2018) and Cobbe et al. (2019a) . Parameters are updated using Stochastic Gradient Descent (SGD) with momentum and, following standard practice in RL, we use a constant learning rate and do not use batch normalisation. Weight decay is used for regularisation. Hyper-parameters and more details can be found in appendix B. We train for a total of 2500 epochs. While the last 1500 epochs are trained on the full, unaltered dataset, we modify the training data in three different ways during the first 1000 epochs. Test data is

