LINEAR MODE CONNECTIVITY IN MULTITASK AND CONTINUAL LEARNING

Abstract

Continual (sequential) training and multitask (simultaneous) training are often attempting to solve the same overall objective: to find a solution that performs well on all considered tasks. The main difference is in the training regimes, where continual learning can only have access to one task at a time, which for neural networks typically leads to catastrophic forgetting. That is, the solution found for a subsequent task does not perform well on the previous ones anymore. However, the relationship between the different minima that the two training regimes arrive at is not well understood. What sets them apart? Is there a local structure that could explain the difference in performance achieved by the two different schemes? Motivated by recent work showing that different minima of the same task are typically connected by very simple curves of low error, we investigate whether multitask and continual solutions are similarly connected. We empirically find that indeed such connectivity can be reliably achieved and, more interestingly, it can be done by a linear path, conditioned on having the same initialization for both. We thoroughly analyze this observation and discuss its significance for the continual learning process. Furthermore, we exploit this finding to propose an effective algorithm that constrains the sequentially learned minima to behave as the multitask solution. We show that our method outperforms several state of the art continual learning algorithms on various vision benchmarks.

1. INTRODUCTION

One major consequence of learning multiple tasks in a continual learning (CL) setting -where tasks are learned sequentially, and the model can only have access to one task at a time -is catastrophic forgetting (McCloskey & Cohen, 1989) . This is in contrast to multitask learning (MTL), where the learner has simultaneous access to all tasks, which generally learns to perform well on all tasks without suffering from catastrophic forgetting. This limitation hinders the ability of the model to learn continually and efficiently. Recently, several approaches have been proposed to tackle this problem. They have mostly tried to mitigate catastrophic forgetting by using different approximations of the multitask loss. For example, some regularization methods take a quadratic approximation of the loss of previous tasks (e.g. Kirkpatrick et al., 2017; Yin et al., 2020) . As another example, rehearsal methods attempt to directly use compressed past data either by selecting a representative subset (e.g. Chaudhry et al., 2019; Titsias et al., 2019) or relying on generative models (e.g. Shin et al., 2017; Robins, 1995) . In this work, we depart from the literature and start from the non-conventional question of understanding "What is the relationship, potentially in terms of local geometric properties, between the multitask and the continual learning minima?". Our work is inspired by recent work on mode con-nectivity (Draxler et al., 2018; Garipov et al., 2018; Frankle et al., 2020) finding that different optima obtained by gradient-based optimization methods are connected by simple paths of non-increasing loss. We try to understand whether the multitask and continual solutions are also connected by a manifold of low error, and what is the simplest form this manifold can take. Surprisingly, we find that a linear manifold, as illustrated in Fig. 1 right, reliably connects the multitask solution to the continual ones, granted that the multitask shares same initialization with the continual learning as described below. This is a significant finding in terms of understanding the phenomenon of catastrophic forgetting through the lens of loss landscapes and optimization trajectory and also for designing better continual learning algorithms. 2 and ŵ1 and between w * 2 and ŵ2 . To reach this conclusion, we consider a particular learning regime described in Fig. 1 left, where after learning the first task using the data D 1 , we either sequentially learn a second task obtaining ŵ2 or continue by training on both tasks simultaneously (i.e., train on D 1 + D 2 ), obtaining the multitask solution w * 2 . We investigate the relationship between the two solutions ŵ2 and w * 2 . Note that w * 2 is not the typical multitask solution, which would normally start from w 0 and train on both datasets. We chose this slightly non-conventional setup to minimize the potential number of confounding factors that lead to discrepancies between the two solutions (Fort et al., 2019) . We also rely on the observation from (Frankle et al., 2020) that initialization can have a big impact on the connectivity between the solutions found on the same task, and sharing the same starting point, as we do between ŵ2 and w * 2 , might warrant a linear path of low error between the two solutions. Moreover, Neyshabur et al. ( 2020) noted that in the context of transfer learning, there is no performance barrier between two minima that start from pre-trained weights, which suggests that the pre-trained weights guide the optimization to a flat basin of the loss landscape. In contrast, barriers clearly exist if these two minima start from randomly initialized weights. Our contributions can be summarized as follows: 1. To the best of our knowledge, our work is the first to study the connectivity between continual learning and multitask learning solutions. 2. We show that compared to conventional similarity measures such as Euclidean distance or Central Kernel Alignment (Kornblith et al., 2019) , which are incapable of meaningfully relating these minima, the connectivity through a manifold of low error can reliably be established. And this connecting path is linear, even when considering more than 20 tasks in a row. 3. Motivated by this, we propose an effective CL algorithm (Mode Connectivity SGD or MC-SGD) that is able to outperform several established methods on standard CL benchmarks.

1.1. RELATED WORK

With the trending popularity of deep learning, continual learning has gained a critical importance because the catastrophic forgetting problem imposes key challenges to deploy deep learning models in various applications (e.g Lange et al., 2019; Kemker et al., 2018) . A growing body of research has attempted to tackle this problem in recent years (e.g Parisi et al., 2018; Toneva et al., 2018; Nguyen et al., 2019; Farajtabar et al., 2019; Hsu et al., 2018; Rusu et al., 2016; Li et al., 2019; Kirkpatrick et al., 2017; Zenke et al., 2017; Shin et al., 2017; Rolnick et al., 2018; Lopez-Paz & Ranzato, 2017; Chaudhry et al., 2018b; Riemer et al., 2018; Mirzadeh et al., 2020; Wallingford et al., 2020) . Among these works, our proposed MC-SGD bares most similarities to rehearsal based methods such us (e.g. Shin et al., 2017; Chaudhry et al., 2018b) and regularization based methods (e.g. Kirkpatrick et al., 2017; Zenke et al., 2017 ) similar to (Titsias et al., 2019) . Following (Lange et al., 2019) , one can categorize continual learning methods into three general categories, based on how they approach dealing with catastrophic forgetting. Experience replay: Experience replay methods build and store a memory of the knowledge learned so far (Rebuffi et al., 2016; Lopez-Paz & Ranzato, 2017; Shin et al., 2017; Riemer et al., 2018; Rios & Itti, 



Figure 1: Left: Depiction of the training regime considered. First ŵ1 is learned on task 1. Afterwards we either reach ŵ2 by learning second task or w * 2 by training on both tasks simultaneously. Right: Depiction of linear connectivity between w *2 and ŵ1 and between w * 2 and ŵ2 .

