EFFICIENT PARAMETRIC APPROXIMATIONS OF NEURAL NETWORK FUNCTION SPACE DISTANCE

Abstract

It is often useful to compactly summarize important properties of a training dataset so that they can be used later without storing and/or iterating over the entire dataset. We consider a specific case of this: approximating the function space distance (FSD) over the training set, i.e. the average distance between the outputs of two neural networks. We propose an efficient approximation to FSD for ReLU neural networks based on approximating the architecture as a linear network with stochastic gating. Despite requiring only one parameter per unit of the network, our approach outcompetes other parametric approximations with larger memory requirements. Applied to continual learning, our parametric approximation is competitive with state-of-the-art nonparametric approximations which require storing many training examples. Furthermore, we show its efficacy in influence function estimation, allowing influence functions to be accurately estimated without iterating over the full dataset.

1. INTRODUCTION

There are many situations in which we would like to compactly summarize a model's training data. One motivation is to reduce storage costs: in continual learning, an agent continues interacting with its environment over a long time period -longer than it is able to store explicitly -but we would still like it to avoid overwriting its previously learned knowledge as it learns new tasks (Goodfellow et al., 2013) . Even in cases where it is possible to store the entire training set, one might desire a compact representation in order to avoid expensive iterative procedures over the full data. Examples include influence function estimation (Koh & Liang, 2017; Bae et al., 2022a) , model editing (De Cao et al., 2021; Mitchell et al., 2021), and unlearning (Bourtoule et al., 2021) . While there are many different aspects of the training data that one might like to summarize, we are often particularly interested in preventing the model from changing its predictions too much on the distribution of previously seen data. Methods to prevent such catastrophic forgetting, especially in the field of continual learning, can be categorized at a high level into parametric and nonparametric approaches. Parametric approaches store the parameters of a previously trained network, together with additional information about how important different directions in parameter space are for preserving past knowledge; the canonical example is Elastic Weight Consolidation (Kirkpatrick et al., 2017, EWC) , which uses a diagonal approximation to the Fisher information matrix. Nonparametric approaches explicitly store a collection (coreset) of training examples, often optimized directly to be the most important or memorable ones (Rudner et al., 2022; Pan et al., 2020; Titsias et al., 2019) . Currently, the most effective approaches to prevent catastrophic forgetting are nonparametric, since it is difficult to find sufficiently accurate parametric models. However, this advantage is at the expense of high storage requirements. We focus on the problem of approximating function space distance (FSD): the amount by which the outputs of two networks differ, in expectation over the training distribution. Benjamin et al. (2018) observed that regularizing FSD over the previous task data is an effective way to prevent catastrophic forgetting. Other tasks such as influence estimation (Bae et al., 2022a) , model editing (Mitchell et al., 2021) , and second-order optimization (Amari, 1998; Bae et al., 2022b) have also been formulated in terms of FSD regularization or equivalent locality constraints. In this paper, we formulate the problem of approximating neural network FSD and propose novel parametric approximations. Our methods significantly outperform previous parametric approximations despite being much more memory-efficient, and are also competitive with nonparametric approaches to continual learning. Several parametric approximations, like EWC, are based on a second-order Taylor approximation to the FSD, leading to a quadratic form involving the Fisher information matrix F θ or some other metric matrix G θ , where θ denotes the network parameters. Second-order approximations help because one can approximate F θ or G θ by sampling vectors from a distribution with these matrices as the covariance. Then, tractable probabilistic models can be fit to these samples to approximate the corresponding distribution. Unfortunately, these tend to be inaccurate for continual learning, especially in comparison with nonparametric approaches. We believe the culprit is the secondorder Taylor approximation: we show in several examples that even the exact second-order Taylor approximation can be a poor match to FSD over the scales relevant to continual learning, like average performance over sequentially learned tasks. This is consistent with a recent line of results that find linearized approximations of neural networks to be an inaccurate model of their behavior (Seleznova & Kutyniok, 2022a; b; Hanin & Nica, 2019; Bai et al., 2020; Huang & Yau, 2020) . Our contribution, the Bernoulli Gated Linear Network (BGLN), makes a parametric approximation to neural network FSD which does not make a second-order Taylor approximation in parameter space, and hence is able to capture nonlinear interactions between parameters of the network. Specifically, it linearizes each layer of the network with respect to its inputs. In the case of ReLU networks, our approximation yields a linear network with stochastic gating. Linearizing the ReLU function requires computing its gradient, which can be approximated as an independent Bernoulli random variable for each unit. We derive a stochastic as well as a deterministic estimate of the FSD in this setting, both of which rely only on the first two moments of the data. To demonstrate the practical usefulness of our approximation, we evaluate its closeness to the true empirical FSD. We show that our method estimates and optimizes the true FSD better than other estimators in settings that are prone to forgetting. Further, we show its application and performance in two applications. When applied to continual learning, it outcompetes state-of-the-art methods on sequential MNIST and CIFAR100 tasks, with at least 90% lower memory requirements than nonparametric methods. When applied to influence function estimation, our method achieves over 95% correlation with the ground truth, without iterating over or storing the full dataset.

2. BACKGROUND

Let z = f (x, θ) denote the function computed by a neural network, which takes in inputs x and parameters θ. Consistent with prior works, we use FSD to refer to the expected output space distancefoot_0 ρ between the outputs of two neural networks (Benjamin et al., 2018; Grosse, 2021; Bae et al., 2022b) , with respect to the training distribution, as defined in equation 1. When the population



Note that we use the term distance throughout since we focus on Euclidean distance in practice. However, other metrics like KL divergence, which are not distances, are also possible and commonly used.



Figure 1: Comparison of FSD regularization on a one-dimensional regression task. (Left) Training sequentially on two tasks (blue, then yellow), results in catastrophic forgetting. (Right) BGLN retains performance on Task 1 after training on Task 2 more accurately than other methods. Note that the y-axis represents function space distances for each datapoint.

