REMEMBERING FOR THE RIGHT REASONS: EXPLANATIONS REDUCE CATASTROPHIC FORGETTING

Abstract

The goal of continual learning (CL) is to learn a sequence of tasks without suffering from the phenomenon of catastrophic forgetting. Previous work has shown that leveraging memory in the form of a replay buffer can reduce performance degradation on prior tasks. We hypothesize that forgetting can be further reduced when the model is encouraged to remember the evidence for previously made decisions. As a first step towards exploring this hypothesis, we propose a simple novel training paradigm, called Remembering for the Right Reasons (RRR), that additionally stores visual model explanations for each example in the buffer and ensures the model has "the right reasons" for its predictions by encouraging its explanations to remain consistent with those used to make decisions at training time. Without this constraint, there is a drift in explanations and increase in forgetting as conventional continual learning algorithms learn new tasks. We demonstrate how RRR can be easily added to any memory or regularizationbased approach and results in reduced forgetting, and more importantly, improved model explanations. We have evaluated our approach in the standard and few-shot settings and observed a consistent improvement across various CL approaches using different architectures and techniques to generate model explanations and demonstrated our approach showing a promising connection between explainability and continual learning. Our code is available at https://github.com/ SaynaEbrahimi/Remembering-for-the-Right-Reasons.

1. INTRODUCTION

Humans are capable of continuously learning novel tasks by leveraging their lifetime knowledge and expanding them when they encounter a new experience. They can remember the majority of their prior knowledge despite the never-ending nature of their learning process by simply keeping a running tally of the observations distributed over time or presented in summary form. The field of continual learning or lifelong learning (Thrun & Mitchell, 1995; Silver et al., 2013) aims at maintaining previous performance and avoiding so-called catastrophic forgetting of previous experience (McCloskey & Cohen, 1989; McClelland et al., 1995) when learning new skills. The goal is to develop algorithms that continually update or add parameters to accommodate an online stream of data over time. An active line of research in continual learning explores the effectiveness of using small memory budgets to store data points from the training set (Castro et al., 2018; Rajasegaran et al., 2020; Rebuffi et al., 2017; Wu et al., 2019) , gradients (Lopez-Paz et al., 2017) , or storing an online generative model that can fake them later (Shin et al., 2017) . Memory has been also exploited in the form of accommodating space for architecture growth and storage to fully recover the old performance when needed (Ebrahimi et al., 2020b; Rusu et al., 2016) . Some methods store an old snapshot of the model to distill the features (Li & Hoiem, 2016) or attention maps (Dhar et al., 2019) between the teacher and student models. In our proposed idea (RRR), in addition to M rep , we also store model explanations (saliency maps) as M RRR for those samples and encourage the model to remember the original reasoning for the prediction. Note that the saliency maps are small masks resulting in a negligible memory overhead (see Section 4.1). The internal reasoning process of deep models is often treated as a black box and remains hidden from the user. However, recent work in explainable artificial intelligence (XAI) has developed methods to create human-interpretable explanations for model decisions (Simonyan et al., 2013; Zhang et al., 2018; Petsiuk et al., 2018; Zhou et al., 2016; Selvaraju et al., 2017) . We posit that the catastrophic forgetting phenomenon is due in part to not being able to rely on the same reasoning as was used for a previously seen observation. Therefore, we hypothesize that forgetting can be mitigated when the model is encouraged to remember the evidence for previously made decisions. In other words, a model which can remember its final decision and can reconstruct the same prior reasoning. Based on this approach, we develop a novel strategy to exploit explainable models for improving performance. Among the various explainability techniques proposed in XAI, saliency methods have emerged as a popular tool to identify the support of a model prediction in terms of relevant features in the input. These methods produce saliency maps, defined as regions of visual evidence upon which a network makes a decision. Our goal is to investigate whether augmenting experience replay with explanation replay reduces forgetting and how enforcing to remember the explanations will affect the explanations themselves. Figure 1 illustrates our proposed method. In this work, we propose RRR, a training strategy guided by model explanations generated by any white-box differentiable explanation method; RRR adds an explanation loss to continual learning. White-box methods generate an explanation by using some internal state of the model, such as gradients, enabling their use in end-to-end training. We evaluate our approach using various popular explanation methods including vanilla backpropagation (Zeiler & Fergus, 2014) , backpropagation with smoothing gradients (Smoothgrad) (Smilkov et al., 2017) , Guided Backpropagation (Springenberg et al., 2014) , and Gradient Class Activation Mapping (Grad-CAM) (Selvaraju et al., 2017) and compare their performance versus their computational feasibility. We integrate RRR into several state of the art class incremental learning (CIL) methods, including iTAML (Rajasegaran et al., 2020 ), EEIl (Castro et al., 2018 ), BiC (Wu et al., 2019) , TOPIC (Tao et al., 2020 ), iCaRL (Rebuffi et al., 2017 ), EWC (Kirkpatrick et al., 2017 ), and LwF (Li & Hoiem, 2016) . Note that RRR does not require task IDs at test time. We qualitatively and quantitatively analyze model explanations in the form of saliency maps and demonstrate that RRR remembers its earlier decisions in a sequence of tasks due to the requirement to focus on the the right evidence. We empirically show the effect of RRR in standard and few-shot class incremental learning (CIL) scenarios on popular benchmark datasets including CIFAR100, ImageNet100, and Caltech-UCSD Birds 200 using different network architectures where RRR improves overall accuracy and forgetting over experience replay and other memory-based method. Our contribution is threefold: we first propose our novel, simple, yet effective memory constraint, which we call Remembering for the Right Reasons (RRR), and show that it reduces catastrophic forgetting by encouraging the model to look at the same explanations it initially found for its decisions. Second, we show how RRR can be readily combined with memory-based and regularization-based



task ( y, y pred ) + L RRR ( ŝ, s pred ) L task ( y, y pred )

Figure 1: An illustration of applying RRR paradigm. (Left) In a typical experience replay scenario, samples from prior tasks are kept in a memory buffer M rep and revisited during training. (Right)In our proposed idea (RRR), in addition to M rep , we also store model explanations (saliency maps) as M RRR for those samples and encourage the model to remember the original reasoning for the prediction. Note that the saliency maps are small masks resulting in a negligible memory overhead (see Section 4.1).

