EXCLUSIVE SUPERMASK SUBNETWORK TRAINING FOR CONTINUAL LEARNING

Abstract

Continual Learning (CL) methods mainly focus on avoiding catastrophic forgetting and learning representations that are transferable to new tasks. Recently, Wortsman et al. (2020) proposed a CL method, SupSup, which uses a randomly initialized, fixed base network (model) and finds a supermask for each new task that selectively keeps or removes each weight to produce a subnetwork. They prevent forgetting as the network weights are not being updated. Although there is no forgetting, the performance of supermask is sub-optimal because fixed weights restrict its representational power. Furthermore, there is no accumulation or transfer of knowledge inside the model when new tasks are learned. Hence, we propose EXSSNET (Exclusive Supermask SubNEtwork Training), that performs exclusive and nonoverlapping subnetwork weight training. This avoids conflicting updates to the shared weights by subsequent tasks to improve performance while still preventing forgetting. Furthermore, we propose a novel KNN-based Knowledge Transfer (KKT) module that dynamically initializes a new task's mask based on previous tasks for improving knowledge transfer. We demonstrate that EXSSNET outperforms SupSup and other strong previous methods on both text classification and vision tasks while preventing forgetting. Moreover, EXSSNET is particularly advantageous for sparse masks that activate 2-10% of the model parameters, resulting in an average improvement of 8.3% over SupSup. Additionally, EXSSNET scales to a large number of tasks (100) and our KKT module helps to learn new tasks faster while improving the overall performance. 1

1. INTRODUCTION

In artificial intelligence, the overarching goal is to develop autonomous agents that can learn to accomplish a set of tasks. Continual Learning (CL) (Ring, 1998; Thrun, 1998; Kirkpatrick et al., 2017) is a key ingredient for developing agents that can accumulate expertise on new tasks. However, when a model is sequentially trained on tasks t 1 and t 2 with different data distributions, the model's ability to extract meaningful features for the previous task t 1 degrades. This loss in performance on the previously learned tasks, is referred to as catastrophic forgetting (CF) (McCloskey & Cohen, 1989; Zhao & Schmidhuber, 1996; Thrun, 1998; Goodfellow et al., 2013) . Forgetting is a consequence of two phenomena happening in conjunction: (1) not having access to the data samples from the previous tasks, and (2) multiple tasks sequentially updating shared model parameters resulting in conflicting updates, which is called as parameter interference (McCloskey & Cohen, 1989) . Recently, some CL methods avoid parameter interference by taking inspiration from the Lottery Ticket Hypothesis (Frankle & Carbin, 2018) and Supermasks (Zhou et al., 2019) to exploit the expressive power of sparse subnetworks. Zhou et al. (2019) observed that the number of sparse subnetwork combinations is large enough (combinatorial) that even within randomly weighted neural networks, there exist supermasks that achieve good performance. A supermask is a sparse binary mask that selectively keeps or removes each connection in a fixed and randomly initialized network to produce a subnetwork with good performance on a given task. We call this the subnetwork as supermask subnetwork that is shown in Figure 1 , highlighted in red weights. Building upon this idea, Wortsman et al. ( 2020) proposed a CL method, SupSup, which initializes a network with fixed and random weights (0) . For task 1, we first learn a supermask M1 (the corresponding subnetwork is marked by red color, column 2 row 1) and then train the weight corresponding to M1 resulting in weights W (1) (bold red lines, column 1 row 2). Similarly, for task 2, we learn the mask M2 over fixed weights W (1) . If mask M2 weights overlap with M1 (marked by bold dashed green lines in column 3 row 1), then only the non-overlapping weights (solid green lines) of the task 2 subnetwork are updated (as shown by bold and solid green lines column 3 row 2). These already trained weights (bold lines) are not updated by any subsequent task. Finally, for task 3, we learn the mask M3 (blue lines) and update the solid blue weights. and then learns a different supermask for each new task. This allows them to prevent catastrophic forgetting (CF) as there is no parameter interference (because the model weights are fixed). Although SupSup (Wortsman et al., 2020) prevents CF, there are some problems to using supermasks for CL: (1) Fixed random model weights in SupSup limits the supermask subnetwork's representational power resulting in sub-optimal performance. As shown in Figure 2 , for a single task the test accuracy of SupSup is approximately 10% worse compared to a fully trained model where all model weights are updated. As a possible remedy, one could try to naively train the weights corresponding to supermask subnetworks of different tasks; however, it can lead to CF as shown in Figure 3 . This happens because subnetworks for different tasks can overlap and training subnetworks weights might result in parameter interference. (2) When learning a task, there is no mechanism for transferring knowledge from previously learned tasks to better learn the current task. Moreover, the model is not accumulating knowledge over time as they weights are not being updated. We overcome the aforementioned issues, we propose our method, EXSSNET (Exclusive Supermask SubNEtwork Training), pronounced as 'excess-net', which first learns a mask for a task and then selectively trains a subset of weights from the supermask subnetwork. We train the weights of this subnetwork via exclusion that avoids updating parameters from the current subnetwork that have already been updated by any of the previous tasks. This helps us to prevent forgetting. This procedure is demonstrated in Figure 1 for learning three tasks sequentially. Training the supermask subnetwork's weights increases its representational power and allows EXSSNET to encode task-specific knowledge inside the subnetwork. This solves the first problem and allows EXSSNET to perform comparable to a fully trained network on individual tasks; and when learning multiple tasks, the exclusive subnetwork training improves the performance of each task while still preventing forgetting. To address the second problem of knowledge transfer, we propose a k-nearest neighbors based knowledge transfer (KKT) module that transfers relevant information from the previously learned tasks to improve performance on new tasks while learning them faster. Our KKT module uses KNN classification to select a subnetwork from the previously learned tasks that has better than random predictive power for the current task and use it as a starting point to learn the new tasks. Next, we show our method's advantage by experimenting with both natural language and vision tasks. For natural language, we evaluate on WebNLP classification tasks (de 



Our code is uploaded as supplementary material.



Figure1: EXSSNET diagram. We start with random weights W(0) . For task 1, we first learn a supermask M1 (the corresponding subnetwork is marked by red color, column 2 row 1) and then train the weight corresponding to M1 resulting in weights W (1) (bold red lines, column 1 row 2). Similarly, for task 2, we learn the mask M2 over fixed weights W(1) . If mask M2 weights overlap with M1 (marked by bold dashed green lines in column 3 row 1), then only the non-overlapping weights (solid green lines) of the task 2 subnetwork are updated (as shown by bold and solid green lines column 3 row 2). These already trained weights (bold lines) are not updated by any subsequent task. Finally, for task 3, we learn the mask M3 (blue lines) and update the solid blue weights.

De Lange & Tuytelaars, 2021), and SplitTinyImageNet (Buzzega et al., 2020)  datasets. We show that for both language and vision domains, EXSSNET outperforms multiple strong and recent continual learning methods based on replay, regularization, distillation, and parameter isolation. For the vision domain, EXSSNET outperforms the strongest baseline by 4.8% and 1.4% on SplitCIFAR and SplitTinyImageNet datasets respectively, while surpassing multitask model and bridging the gap to training individual models for each task. In addition, for GLUE datasets, EXSSNET is 2%

