GENERALIZED VARIATIONAL CONTINUAL LEARNING

Abstract

Continual learning deals with training models on new tasks and datasets in an online fashion. One strand of research has used probabilistic regularization for continual learning, with two of the main approaches in this vein being Online Elastic Weight Consolidation (Online EWC) and Variational Continual Learning (VCL). VCL employs variational inference, which in other settings has been improved empirically by applying likelihood-tempering. We show that applying this modification to VCL recovers Online EWC as a limiting case, allowing for interpolation between the two approaches. We term the general algorithm Generalized VCL (GVCL). In order to mitigate the observed overpruning effect of VI, we take inspiration from a common multi-task architecture, neural networks with task-specific FiLM layers, and find that this addition leads to significant performance gains, specifically for variational methods. In the small-data regime, GVCL strongly outperforms existing baselines. In larger datasets, GVCL with FiLM layers outperforms or is competitive with existing baselines in terms of accuracy, whilst also providing significantly better calibration.

1. INTRODUCTION

Continual learning methods enable learning when a set of tasks changes over time. This topic is of practical interest as many real-world applications require models to be regularly updated as new data is collected or new tasks arise. Standard machine learning models and training procedures fail in these settings (French, 1999) , so bespoke architectures and fitting procedures are required. This paper makes two main contributions to continual learning for neural networks. First, we develop a new regularization-based approach to continual learning. Regularization approaches adapt parameters to new tasks while keeping them close to settings that are appropriate for old tasks. Two popular approaches of this type are Variational Continual Learning (VCL) (Nguyen et al., 2018) and Online Elastic Weight Consolidation (Online EWC) (Kirkpatrick et al., 2017; Schwarz et al., 2018) . The former is based on a variational approximation of a neural network's posterior distribution over weights, while the latter uses Laplace's approximation. In this paper, we propose Generalized Variational Continual Learning (GVCL) of which VCL and Online EWC are two special cases. Under this unified framework, we are able to combine the strengths of both approaches. GVCL is closely related to likelihood-tempered Variational Inference (VI), which has been found to improve performance in standard learning settings (Zhang et al., 2018; Osawa et al., 2019) . We also see significant performance improvements in continual learning. Our second contribution is to introduce an architectural modification to the neural network that combats the deleterious overpruning effect of VI (Trippe & Turner, 2018; Turner & Sahani, 2011) . We analyze pruning in VCL and show how task-specific FiLM layers mitigate it. Combining this architectural change with GVCL results in a hybrid architectural-regularization based algorithm. This additional modification results in performance that exceeds or is within statistical error of strong baselines such as HAT (Serra et al., 2018) and PathNet (Fernando et al., 2017) . The paper is organized as follows. Section 2 outlines the derivation of GVCL, shows how it unifies many continual learning algorithms, and describes why it might be expected to perform better than them. Section 3 introduces FiLM layers, first from the perspective of multi-task learning, and then through the lens of variational over-pruning, showing how FiLM layers mitigate this pathology of VCL. Finally, in Section 5 we test GVCL and GVCL with FiLM layers on many standard bench-marks, including ones with few samples, a regime that could benefit more from continual learning. We find that GVCL with FiLM layers outperforms existing baselines on a variety of metrics, including raw accuracy, forwards and backwards transfer, and calibration error. In Section 5.4 we show that FiLM layers provide a disproportionate improvement to variational methods, confirming our hypothesis in Section 3foot_0 .

2. GENERALIZED VARIATIONAL CONTINUAL LEARNING

In this section, we introduce Generalized Variational Continual Learning (GVCL) as a likelihoodtempered version of VCL, with further details in Appendix C. We show how GVCL recovers Online EWC. We also discuss further links between GVCL and the Bayesian cold posterior in Appendix D.

2.1. LIKELIHOOD-TEMPERING IN VARIATIONAL CONTINUAL LEARNING

Variational Continual Learning (VCL). Bayes' rule calculates a posterior distribution over model parameters θ based on a prior distribution p(θ) and some dataset D T = {X T , y T }. Bayes' rule naturally supports online and continual learning by using the previous posterior p(θ|D T -1 ) as a new prior when seeing new data (Nguyen et al., 2018) . Due to the intractability of Bayes' rule in complicated models such as neural networks, approximations are employed, and VCL (Nguyen et al., 2018) uses one such approximation, Variational Inference (VI). This approximation is based on approximating the posterior p(θ|D T ) with a simpler distribution q T (θ), such as a Gaussian. This is achieved by optimizing the ELBO for the optimal q T (θ), ELBO VCL = E θ∼q T (θ) [log p(D T |θ)] -D KL (q T (θ) q T -1 (θ)), where q T -1 (θ) is the approximation to the previous task posterior. Intuitively, this refines a distribution over weight samples that balances good predictive performance (the first expected prediction accuracy term) while remaining close to the prior (the second KL-divergence regularization term). Likelihood-tempered VCL. Optimizing the ELBO will recover the true posterior if the approximating family is sufficiently rich. However, the simple families used in practice typically lead to poor test-set performance. Practitioners have found that performance can be improved by downweighting the KL-divergence regularization term by a factor β, with 0 < β < 1. Examples of this are seen in Zhang et al. (2018) and Osawa et al. (2019) , where the latter uses a "data augmentation factor" for down-weighting. In a similar vein, sampling from "cold posteriors" in SG-MCMC has also been shown to outperform the standard Bayes posterior, where the cold posterior is given by p T (θ|D) ∝ p(θ|D) Wenzel et al., 2020) . Values of β > 1 have also been used to improve the disentanglement variational autoencoder learned models (Higgins et al., 2017) . We down-weight the KL-divergence term in VCL, optimizing the β-ELBOfoot_1 , 1 T , T < 1 ( β-ELBO = E θ∼q T (θ) [log p(D T |θ)] -βD KL (q T (θ) q T -1 (θ)). VCL is trivially recovered when β = 1. We will now show that surprisingly as β → 0, we recover a special case of Online EWC. Then, by modifying the term further as required to recover the full version of Online EWC, we will arrive at our algorithm, Generalized VCL.

2.2. ONLINE EWC IS A SPECIAL CASE OF GVCL

We analyze the effect of KL-reweighting on VCL in the case where the approximating family is restricted to Gaussian distributions over θ. We will consider training all the tasks with a KLreweighting factor of β, and then take the limit β → 0, recovering Online EWC. Let the approximate posteriors at the previous and current tasks be denoted as q T -1 (θ) = N (θ; µ T -1 , Σ T -1 ) and q T (θ) = N (θ; µ T , Σ T ) respectively, where we are learning {µ T , Σ T }. The optimal Σ T under the β-ELBO has the form (see Appendix C), Σ -1 T = 1 β ∇ µ T ∇ µ T E q T (θ) [-log p(D T |θ)] + Σ -1 T -1 . (2)



Code is available at https://github.com/yolky/gvcl We slightly abuse notation by writing the likelihood as p(DT |θ) instead of p(yT |θ, XT ).

