ANATOMY OF CATASTROPHIC FORGETTING: HIDDEN REPRESENTATIONS AND TASK SEMANTICS

Abstract

Catastrophic forgetting is a recurring challenge to developing versatile deep learning models. Despite its ubiquity, there is limited understanding of its connections to neural network (hidden) representations and task semantics. In this paper, we address this important knowledge gap. Through quantitative analysis of neural representations, we find that deeper layers are disproportionately responsible for forgetting, with sequential training resulting in an erasure of earlier task representational subspaces. Methods to mitigate forgetting stabilize these deeper layers, but show diversity on precise effects, with some increasing feature reuse while others store task representations orthogonally, preventing interference. These insights also enable the development of an analytic argument and empirical picture relating forgetting to task semantic similarity, where we find that maximal forgetting occurs for task sequences with intermediate similarity.

1. INTRODUCTION

While the past few years have seen the development of increasingly versatile machine learning systems capable of learning complex tasks (Stokes et al., 2020; Raghu & Schmidt, 2020; Wu et al., 2019b) , catastrophic forgetting remains a core capability challenge. Catastrophic forgetting is the ubiquitous phenomena where machine learning models trained on non-stationary data distributions suffer performance losses on older data instances. More specifically, if our machine learning model is trained on a sequence of tasks, accuracy on earlier tasks drops significantly. The catastrophic forgetting problem manifests in many sub-domains of machine learning including continual learning (Kirkpatrick et al., 2017 ), multi-task learning (Kudugunta et al., 2019) , standard supervised learning through input distribution shift (Toneva et al., 2019; Snoek et al., 2019; Rabanser et al., 2019; Recht et al., 2019) and data augmentation (Gontijo-Lopes et al., 2020) . Mitigating catastrophic forgetting has been an important research focus (Goodfellow et al., 2013; Kirkpatrick et al., 2017; Lee et al., 2017; Li et al., 2019; Serrà et al., 2018; Ritter et al., 2018; Rolnick et al., 2019) , but many methods are only effective in specific settings (Kemker et al., 2018) , and progress is hindered by limited understanding of catastrophic forgetting's fundamental properties. How does catastrophic forgetting affect the hidden representations of neural networks? Are earlier tasks forgotten equally across all parameters? Are there underlying principles common across methods to mitigate forgetting? How is catastrophic forgetting affected by (semantic) similarities between sequential tasks? This paper takes steps to answering these questions, specifically: 1. With experiments on split CIFAR-10, a novel distribution-shift CIFAR-100 variant, CelebA and ImageNet we analyze neural network layer representations, finding that higher layers are disproportionately responsible for catastrophic forgetting, the sequential training process erasing earlier task subspaces. 2. We investigate different methods for mitigating forgetting, finding that while all stabilize higher layer representations, some methods encourage greater feature reuse in higher layers, while others store task representations as orthogonal subspaces, preventing interference. 3. We study the connection between forgetting and task semantics, finding that semantic similarity between subsequent tasks consistently controls the degree of forgetting. 4. Informed by the representation results, we construct an analytic model that relates task similarity to representation interference and forgetting. This provides a quantitative empirical measure of task similarity, and together these show that forgetting is most severe for tasks with intermediate similarity.

2. RELATED WORK

Mitigation strategies Developing mitigation strategies for catastophic forgetting is an active area of research (Kemker et al., 2018) . Popular approaches based on structural regularization include Elastic Weight Consolidation (Kirkpatrick et al., 2017) and Synaptic Intelligence (Zenke et al., 2017 ). An alternative, functional regularization approach is based on storing and replaying earlier data in a replay buffer (Schaul et al., 2015; Robins, 1995; Rolnick et al., 2019) . We study how these mitigation methods affect the internal network representations to better understand their role in preventing catastrophic forgetting. Understanding catastrophic forgetting Our work is similar in spirit to existing empirical studies aimed at better understanding the catastrophic forgetting phenomenon (Goodfellow et al., 2013; Toneva et al., 2019) . We focus specifically on understanding how layerwise representations change, and on the relation between forgetting and task semantics, which have not previously been explored. This semantic aspect is related to recent work by Nguyen et al. (2019) examining the influence of task sequences on forgetting. It also builds on work showing that learning in both biological and artificial neural networks is affected by semantic properties of the training data (Saxe et al., 2019; Mandler & McDonough, 1993) . Fine-tuning and transfer learning While not studied in the context of catastrophic forgetting, layerwise learning dynamics have been investigated in settings other than continual learning. For example, Raghu et al. (2017) showed that layerwise network representations roughly converge bottomup; (from input to output). Neyshabur et al. (2020) observed similar phenomena in a transfer-learning setting with images, as have others recently in transformers for NLP (Wu et al., 2020; Merchant et al., 2020) . Furthermore, the observation that early layers in networks learn general features like edge detectors, while latter layers learn more task specific features is a well-known result in computer vision (Erhan et al., 2009) , and here we quantitatively study its ramifications for sequential training and catastrophic forgetting.

3. SETUP

Tasks: We conduct this study over many different tasks and datasets: (i) Split CIFAR-10, where the ten class dataset is split into two tasks of 5 classes each (ii) input distribution shift CIFAR-100, where each task is to distinguish between the CIFAR-100 superclasses, but input data for each task is a different subset of the constituent classes of the superclass (see Appendix A.3) (iii) CelebA attribute prediction: the two tasks have input data either men or women, and we predict either smile or mouth open (iv) ImageNet superclass prediction, similar to CIFAR100.

Models:

We perform experiments with three common neural network architectures used in image classification -VGG (Simonyan & Zisserman, 2014 ), ResNet (He et al., 2015) and DenseNet (Huang et al., 2016) . Examples of catastrophic forgetting in these models are shown in Figure 1 .



Figure1: Examples of catastrophic forgetting across different architectures and datasets. We plot accuracy of Task 1 (purple) and Task 2 (green), on both the split CIFAR10 task, and the CIFAR100 distribution-shift task across multiple architectures. Catastrophic forgetting is seen as the significant drop of Task 1 accuracy when Task 2 training begins.

