DEEP CLASS CONDITIONAL GAUSSIANS FOR CONTINUAL LEARNING

Abstract

The current state of the art for continual learning with frozen, pre-trained embedding networks are simple probabilistic models defined over the embedding space, for example class conditional Gaussians. However, as of yet, in the taskincremental online setting, it has been an open question how to extend these methods to when the embedding function has to be learned from scratch. In this paper, we propose an empirical Bayesian framework that works by storing a fixed number of examples in memory which are used to calculate the posterior of the probabilistic model and a conditional marginal likelihood term used to fit the embedding function. The learning of the embedding function can be interpreted as using a variant of experience replay, which is a highly performative method for continual learning. As part of our framework, we decide which examples to store by selecting the subset that minimises the KL divergence between the true posterior and the posterior induced by the subset, which is shown to be necessary to achieve good performance. We demonstrate the performance of our method on a range of task-incremental online settings, including those with overlapping tasks which thus far have been under-explored. Our method outperforms all other methods, including several other replay-based methods, evidencing the potential of our approach.

1. INTRODUCTION

Real world use of deep learning methods can often necessitate dynamic updating of solutions on non-stationary data streams (Farquhar & Gal, 2018; Antoniou et al., 2020) . This is one of the main problems studied in continual learning and as a result, continual learning has become of increasing interest to the machine learning community, with many proposed approaches (Parisi et al., 2019) and settings (Hsu et al., 2018; Antoniou et al., 2020; Delange et al., 2021) . Currently, the two biggest challenges in continual learning are catastrophic forgetting and positive transfer. Catastrophic forgetting describes the common occurrence in learning where unconstrained deep models easily forget information derived from previous data after updating on other data. Positive transfer is the ability of a model, given the current data, to improve its understanding of previous data and of what future data might imply. While there have been significant steps taken to solve these problems (Delange et al., 2021; Mai et al., 2021) , in many settings there are still gains to be made (Farquhar & Gal, 2018) . In common with many works in continual learning, the setting considered here is task-incremental online learning where a data stream is split into a sequential set of tasks and methods are given information about what the current task is (van de Ven & Tolias, 2019; Prabhu et al., 2020) . Each task is encapsulated by a representative dataset (considered to be sampled i.i.d. from a task distribution) which is given to a method batch by batch. Different tasks will generally be associated with different distributions, as well as different target problems. The target problems are summarised in a task objective function. In our case the task objective is classification (Hsu et al., 2018) , where the classes being considered vary between tasks. The overall objective of a method, after seeing all of the tasks, is to perform well on all of them, given constraints on the amount of memory used by the method. Currently one of the best ways to approach continual learning is to use a frozen pretrained embedding function and define a simple probabilistic model on top to classify the data (Ostapenko et al., 2022; Hayes & Kanan, 2020) . However, in some real-world settings it is necessary to learn the embedding function online, for example it might be impossible to pretrain the embedding function due to a lack of data or due to distribution shifts a frozen pretrained encoder may become outdated and so perform badly (Ostapenko et al., 2022) . Therefore, it is an interesting question to see how these methods can be adapted to settings when the embedding function must be continually learnt. We explore this question, reformulating it into an empirical Bayesian framework and propose a general approach: learning the parameters of the probabilistic model in a Bayesian manner while giving a method to learn the embedding function using a conditional marginal log-likelihood loss. The approach also stores a small number of previous examples, used to calculate the posterior of the probabilistic model and the conditional marginal log-likelihood loss. The learning of the embedding function can be seen as a variant of experience replay, this is beneficial as experience replay is one of the best performing continual learning methods (van de Ven & Tolias, 2019; Wu et al., 2022; Mirzadeh et al., 2020) . Another key part of the approach is a method to select examples to store in memory, which is achieved by minimising the KL-divergence between the true posterior and the one induced by the subset of data to be stored in memory, which is shown to be necessary for the method to achieve good performance (see Section 6.3). We also present a specific instantiation of our general approach, DeepCCG, where we use a class conditional Gaussian model with unknown mean and fixed variance on top of the neural network embedding function. We chose to explore this particular instantiation because its posterior is easy to compute and the class conditional Gaussian model has been shown to have state-of-the-art performance when using a frozen pretrained embedding function (Ostapenko et al., 2022; Hayes & Kanan, 2020) . Additionally, in this case, our method for selecting what samples to store in memory reduces to selecting examples that preserve the means of the per-class clusters formed by the embedded data. In our experiments we look at two specific settings, the commonly used disjoint tasks setting, where each task contains different classes to any other (Delange et al., 2021) and an underexplored shifting window setting, where there is an overlap between what classes are in each task. The reason we look at a setting with an overlap between tasks is to explore a methods ability for positive transfer, which is more rewarded in this setting as there is greater shared information (Bang et al., 2021) . The results of our experiments show that DeepCCG performed the best out of all methods tested, gaining an average performance improvement of 2.145% in terms of mean average accuracy, showing the potential of our approach.

2. RELATED WORK

When using frozen pretrained embedding functions in continual learning, simple metric-based probabilistic models have been shown to have state-of-the-art performance (Ostapenko et al., 2022; Hayes & Kanan, 2020; Pelosin, 2022) . For example, Hayes & Kanan (2020) show that online linear discriminant analysis (LDA) on top of a frozen pretrained backbone is the best performing method in their experiments and Ostapenko et al. (2022) show that class-conditional Gaussian models perform the best in certain settings. These probabilistic models have the advantage that it is possible to learn the same estimate of parameters given any ordering of the data, including i.i.d orderings. However, in many cases frozen pretrained embedding functions cannot be used, which is the case we look at in this work, due to a lack of relevant data for pretraining or due to distribution shift a pretrained embedding function will perform badly (Ostapenko et al., 2022) . When the embedding function has to be learnt, there are three main paradigms for solving continual learning problems: regularisation, parameter-isolation and replay (Delange et al., 2021) . Regularisation methods aim to prevent catastrophic forgetting by adding terms to the loss which try to prevent the network from erasing the information of previous tasks (Kirkpatrick et al., 2017; Huszár, 2018) . Parameter-Isolation methods look at controlling what parameters of a neural network are used for what tasks (Mallya & Lazebnik, 2018) , which can often be seen as a hard version of regularisation methods. Finally, Replay methods aim to solve continual learning problems by storing a subset of previously seen samples, which are then trained on alongside new incoming data. This is the approach our method is part of. Replay methods have been shown to have competitive if not the best performance across many settings (van de Ven & Tolias, 2019; Wu et al., 2022; Mirzadeh et al., 2020) . The standard replay method is experience replay (ER) (Chaudhry et al., 2019b; Aljundi et al., 2019a) , which in the setting explored in this work where we can not store all the current task's data,

