INTERNEURONS ACCELERATE LEARNING DYNAMICS IN RECURRENT NEURAL NETWORKS FOR STATISTICAL ADAPTATION

Abstract

Early sensory systems in the brain rapidly adapt to fluctuating input statistics, which requires recurrent communication between neurons. Mechanistically, such recurrent communication is often indirect and mediated by local interneurons. In this work, we explore the computational benefits of mediating recurrent communication via interneurons compared with direct recurrent connections. To this end, we consider two mathematically tractable recurrent linear neural networks that statistically whiten their inputs -one with direct recurrent connections and the other with interneurons that mediate recurrent communication. By analyzing the corresponding continuous synaptic dynamics and numerically simulating the networks, we show that the network with interneurons is more robust to initialization than the network with direct recurrent connections in the sense that the convergence time for the synaptic dynamics in the network with interneurons (resp. direct recurrent connections) scales logarithmically (resp. linearly) with the spectrum of their initialization. Our results suggest that interneurons are computationally useful for rapid adaptation to changing input statistics. Interestingly, the network with interneurons is an overparameterized solution of the whitening objective for the network with direct recurrent connections, so our results can be viewed as a recurrent linear neural network analogue of the implicit acceleration phenomenon observed in overparameterized feedforward linear neural networks.

1. INTRODUCTION

Efficient coding and redundancy reduction theories of neural coding hypothesize that early sensory systems decorrelate and normalize neural responses to sensory inputs (Barlow, 1961; Laughlin, 1989; Barlow & Földiák, 1989; Simoncelli & Olshausen, 2001; Carandini & Heeger, 2012; Westrick et al., 2016; Chapochnikov et al., 2021) , operations closely related to statistical whitening of inputs. Since the input statistics are often in flux due to dynamic environments, this calls for early sensory systems that can rapidly adapt (Wark et al., 2007; Whitmire & Stanley, 2016) . Decorrelating neural activities requires recurrent communication between neurons, which is typically indirect and mediated by local interneurons (Christensen et al., 1993; Shepherd et al., 2004) . Why do neuronal circuits for statistical adaptation mediate recurrent communication using interneurons, which take up valuable space and metabolic resources, rather than using direct recurrent connections? A common explanation for why communication between neurons is mediated by interneurons is Dale's principle, which states that each neuron has exclusively inhibitory or excitatory effects on all of its targets (Strata & Harvey, 1999) . While Dale's principle provides a physiological constraint that explains why recurrent interactions are mediated by interneurons, we seek a computational principle that can account for using interneurons rather than direct recurrent connections. This perspective is useful for a couple of reasons. First, perhaps Dale's principle is not a hard constraint; see (Saunders et al., 2015; Granger et al., 2020) for results along these lines. In this case, a computational benefit of interneurons would provide a normative explanation for the existence of interneurons to mediate recurrent communication. Second, decorrelation and whitening are useful in statistical and machine learning methods (Hyvärinen & Oja, 2000; Krizhevsky, 2009) , especially recent self-supervised methods (Ermolov et al., 2021; Zbontar et al., 2021; Bardes et al., 2022; Hua et al., 2021) , so our analysis is potentially relevant to the design of artificial neural networks, which are not bound by Dale's principle. In this work, to better understand the computational benefits of interneurons for statistical adaptation, we analyze the learning dynamics of two mathematically tractable recurrent neural networks that statistically whiten their inputs using Hebbian/anti-Hebbian learning rules -one with direct recurrent connections and the other with indirect recurrent interactions mediated by interneurons, Figure 1 . We show that the learning dynamics of the network with interneurons are more robust than the learning dynamics of the network with direct recurrent connections. In particular, we prove that the convergence time of the continuum limit of the network with direct lateral connections scales linearly with the spectrum of the initialization, whereas the convergence time of the continuum limit of the network with interneurons scales logarithmically with the spectrum of the initialization. We also numerically test the networks and, consistent with our theoretical results, find that the network with interneurons is more robust to initialization. Our results suggest that interneurons are computationally important for rapid adaptation to fluctuating input statistics. Our analysis is closely related to analyses of learning dynamics in feedforward linear networks trained using backpropagation (Saxe et al., 2014; Arora et al., 2018; Saxe et al., 2019; Gidel et al., 2019; Tarmoun et al., 2021) . The optimization problems for deep linear networks are overparameterizations of linear problems and this overparameterization can accelerate convergence of gradient descent or gradient flow optimization -a phenomenon referred to as implicit acceleration (Arora et al., 2018) . Our results can be viewed as an analogous phenomenon for gradient flows corresponding to recurrent linear networks trained using Hebbian/anti-Hebbian learning rules. In our setting, the network with interneurons is naturally viewed as an overparameterized solution of the whitening objective for the network with direct recurrent connections. In analogy with the feedforward setting, the interneurons can be viewed as a hidden layer that overparameterizes the optimization problem. In summary, our main contribution is a theoretical and numerical analysis of the synaptic dynamics of two linear recurrent neural networks for statistical whitening -one with direct lateral connections and one with indirect lateral connections mediated by interneurons. Our analysis shows that the synaptic dynamics converge significantly faster in the network with interneurons than the network with direct lateral connections (logarithmic versus linear convergence times). Our results have potential broader implications: (i) they suggest biological interneurons may facilitate rapid statistical adaptation, see also (Duong et al., 2023) ; (ii) including interneurons in recurrent neural networks for solving other learning tasks may also accelerate learning; (iii) overparameterized whitening objectives may be useful for developing online self-supervised learning algorithms in machine learning.

2. STATISTICAL WHITENING

Let n ≥ 2 and x 1 , . . . , x T be a sequence of n-dimensional centered inputs with positive definite (empirical) covariance matrix C xx := 1 T XX ⊤ , where X := [x 1 , . . . , x T ] is the n × T data matrix of concatenated inputs. The goal of statistical whitening is to linearly transform the inputs so that the n-dimensional outputs y 1 , . . . , y T have identity covariance; that is, C yy := 1 T YY ⊤ = I n , where Y := [y 1 , . . . , y T ] is the n × T data matrix of concatenated outputs. Statistical whitening is not a unique transformation. For example, if C yy = I n , then left multiplication of Y by any n × n orthogonal matrix results in another data matrix with identity covariance (Kessy et al., 2018) . We focus on Zero-phase Components Analysis (ZCA) whitening (Bell & Sejnowski, 1997) , also referred to as Mahalanobis whitening, given by Y = C -1/2 xx X, which is the whitening transformation that minimizes the mean-squared error between the inputs X and the outputs Y (Eldar & Oppenheim, 2003) . In other words, it is the unique solution to the objective arg min Y∈R n×T 1 T ∥Y -X∥ 2 Frob subject to 1 T YY ⊤ = I n . The goal of this work is to analyze the learning dynamics of 2 recurrent neural networks that learn to perform ZCA whitening in the online, or streaming, setting using Hebbian/anti-Hebbian learning rules. The derivations of the 2 networks, which we include here for completeness, are closely related to derivations of PCA networks carried out in (Pehlevan & Chklovskii, 2015; Pehlevan et al., 2018) . Our main theoretical and numerical results are presented in sections 6 and 7, respectively.

3. OBJECTIVES FOR DERIVING ZCA WHITENING NETWORKS

In this section, we rewrite the ZCA whitening objective 1 to obtain 2 objectives that will be the starting points for deriving our 2 recurrent whitening networks. We first expand the square in equation 1, substitute in with the constraint YY ⊤ = T I n , drop the terms that do not depend on Y and finally flip the sign to obtain the following objective in terms of the neural activities data matrix Y: max Y∈R n×T 2 T Tr(YX ⊤ ) subject to 1 T YY ⊤ = I n .

3.1. OBJECTIVE FOR THE NETWORK WITH DIRECT RECURRENT CONNECTIONS

To derive a network with direct recurrent connections, we introduce the positive definite matrix M as a Lagrange multiplier to enforce the upper bound YY ⊤ ⪯ T I n : min M∈S n ++ max Y∈R n×T f (M, Y), where S n ++ denotes the set of positive definite n × n matrices and f (M, Y) := 2 T Tr(YX ⊤ ) - 1 T Tr(M(YY ⊤ -T I n )). Here, we have interchanged the order of optimization, which is justified because f (M, Y) is linear in M and strongly concave in Y, so it satisfies the saddle point property (see Appendix A). The maximization over Y ensures that the upper bound YY ⊤ ⪯ T I n is in fact saturated, so the whitening constraint holds. The minimax objective equation 2 will be the starting point for our derivation of a statistical whitening network with direct recurrent connections in section 4.

3.2. OVERPARAMETERIZED OBJECTIVE FOR THE NETWORK WITH INTERNEURONS

To derive a network with k ≥ n interneurons, we replace the matrix M in equation 2 with the overparameterized product WW ⊤ , where W is an n×k matrix, which yields the minimax objective min W∈R n×k max Y∈R n×T g(W, Y), where g(W, Y) := f (WW ⊤ , Y) = 2 T Tr(YX ⊤ ) - 1 T Tr(WW ⊤ (YY ⊤ -T I n )). The overparameterized minimax objective equation 3 will be the starting point for our derivation of a statistical whitening network with interneurons in section 5.

4. DERIVATION OF A NETWORK WITH DIRECT RECURRENT CONNECTIONS

Starting from the ZCA whitening objective in equation 2, we derive offline and online algorithms and map the online algorithm onto a network with direct recurrent connections. We assume that the neural dynamics operate at faster timescale than the synaptic updates, so we first optimize over the neural activities and then take gradient-descent steps with respect to the synaptic weight matrix. In the offline setting, at each iteration, we first maximize f (M, Y) with respect to the neural activities Y by taking gradient ascent steps until convergence: Y ← Y + γ(X -MY) ⇒ Y = M -1 X, where γ > 0 is a small constant. After the neural activities converge, we minimize f (M, Y) with respect to M by taking a gradient descent step with respect to M: M ← M + η 1 T YY ⊤ -I n . In the online setting, at each time step t, we only have access to the input x t rather than the whole dataset X. In this case, we first optimize f (M, Y) with respect to the neural activities vector y t by taking gradient steps until convergence: y t ← y t + γ(x t -My t ) ⇒ y t = M -1 x t . After the neural activities converge, the synaptic weight matrix M is updated by taking a stochastic gradient descent step: M ← M + η(y t y ⊤ t -I n ). This results in Algorithm 1. Algorithm 1: A whitening network with direct recurrent connections input centered inputs {x t }; parameters γ, η initialize positive definite n × n matrix M for t = 1, . . . , T do y t ← 0 repeat y t ← y t + γ(x t -My t ) until convergence M ← M + η(y t y ⊤ t -I n ) end for Algorithm 1 can be implemented in a network with n principal neurons with direct recurrent connections -M, Figure 1 (left). At each time step t, the external input to the principal neurons is x t , the output of the principal neurons is y t and the recurrent input to the principal neurons is -My t . Therefore, the total input to the principal neurons is encoded in the n-dimensional vector x t -My t . The neural outputs are updated according to the neural dynamics in Algorithm 1 until they converge at y t = M -1 x t . After the neural activities converge, the synaptic weight matrix M is updated according to Algorithm 1.

5. DERIVATION OF A NETWORK WITH INTERNEURONS

Starting from the overparameterized ZCA whitening objective in equation 3, we derive offline and online algorithms and we map the online algorithm onto a network with k ≥ n interneurons. As in the last section, we assume that the neural dynamics operate at a faster timescale than the synaptic updates. In the offline setting, at each iteration, we first maximize g(W, Y) with respect to the neural activities Y by taking gradient steps until convergence: Y ← Y + γ(X -WW ⊤ Y) ⇒ Y = (WW ⊤ ) -1 X. After the neural activities converge, we minimize g(W, Y) with respect to W by taking a gradient descent step with respect to W: W ← W + η 1 T YY ⊤ W -W . In the online setting, at each time step t, we first optimize g(W, Y) with respect to the neural activities vectors y t by running the gradient ascent-descent neural dynamics: y t ← y t + γ(x t -WW ⊤ y t ) ⇒ y t = (WW ⊤ ) -1 x t . After convergence of the neural activities, the matrix W is updated by taking a stochastic gradient descent step: W ← W + η(y t y ⊤ t W -W). To implement the online algorithm in a recurrent neural network with interneurons, we let z t := W ⊤ y t denote the k-dimensional vector of interneuron activities at time t. After substituting into the neural dynamics and the synaptic update rules, we obtain Algorithm 2. Algorithm 2: A whitening network with interneurons input centered inputs {x t }; k ≥ n; parameters γ, η initialize full rank n × k matrix W for t = 1, . . . , T do y t ← 0 repeat z t ← W ⊤ y t y t ← y t + γ(x t -Wz t ) until convergence W ← W + η(y t z ⊤ t -W) end for Algorithm 2 can be implemented in a network with n principal neurons and k interneurons, Figure 1 (right). The principal neurons (resp. interneurons) are connected to the interneurons (resp. principal neurons) via the synaptic weight matrix W ⊤ (resp. -W). At each time step t, the external input to the principal neurons is x t , the activity of the principal neurons is y t , the activity of the interneurons is z t = W ⊤ y t and the recurrent input to the principal neurons is -Wz t . The neural activities are updated according to the neural dynamics in Algorithm 2 until they converge to y t = (WW ⊤ ) -1 x t and z t = W ⊤ y t . After the neural activities converge, the synaptic weight matrix W is updated according to Algorithm 2. Here, the principal neuron-to-interneuron weight matrix W ⊤ is the negative transpose of the interneuron-to-principal neuron weight matrix -W, Figure 1 (right). In general, enforcing such symmetry is not biologically plausible and commonly referred to as the weight transport problem. In addition, we do not sign-constrain the weights, so the network can violate Dale's principle. In Appendix B, we modify the algorithm to be more biologically realistic and we map the modified algorithm onto the vertebrate olfactory bulb and show that the algorithm is consistent with several experimental observations.

6. ANALYSES OF CONTINUOUS SYNAPTIC DYNAMICS

We now present our main theoretical results on the convergence of the corresponding continuous synaptic dynamics for M and W. We first show that the synaptic updates are naturally viewed as (stochastic) gradient descent algorithms. We then analyze the corresponding continuous gradient flows. Detailed proofs of our results are provided in Appendix C

6.1. GRADIENT DESCENT ALGORITHMS

We first show that the offline and online synaptic dynamics are naturally viewed as gradient descent and stochastic gradient descent algorithms for minimizing whitening objectives. Let U (M) be the convex function defined by U (M) := max Y∈R n×T f (M, Y) = Tr M -1 C xx -M . Substituting the optimal neural activities Y = M -1 X into the offline update in equation 4, we see that the offline algorithm is a gradient descent algorithm for minimizing U (M): M ← M + η M -1 C xx M -1 -I n = M -η∇U (M). Similarly, Algorithm 1 is a stochastic gradient descent algorithm for minimizing U (M). Next, let V (W) be the nonconvex function defined by V (W) := max Y∈R n×T g(W, Y) = Tr (WW ⊤ ) -1 C xx -WW ⊤ . Again, substituting the optimal neural activities Y = (WW ⊤ ) -1 X into the offline update in equation 5, we see the offline algorithm is a gradient descent algorithm for minimizing V (W): W ← W + η (WW ⊤ ) -1 C xx (WW ⊤ ) -1 W -W = W - η 2 ∇V (W). Similarly, Algorithm 2 is a stochastic gradient descent algorithm for minimizing V (W). A common approach for studying stochastic gradient descent algorithms is to analyze the corresponding continuous gradient flows (Saxe et al., 2014; 2019; Tarmoun et al., 2021) , which are more mathematically tractable and are useful approximations of the average behavior of the stochastic gradient descent dynamics when the step size is small. In the remainder of this section, we analyze and compare the continuous gradient flows associated with Algorithms 1 and 2. To further facilitate the analysis, we consider so-called 'spectral initializations' that commute with C xx (Saxe et al., 2014; 2019; Gidel et al., 2019; Tarmoun et al., 2021) . Specificially, we say that A 0 ∈ S n ++ is a spectral initialization if A 0 = U x diag(σ 1 , . . . , σ n )U ⊤ x , where U x is the n × n orthogonal matrix of eigenvectors of C xx and σ 1 , . . . , σ n > 0. To characterize the convergence rates of the gradient flows, we define the Lyapunov function ℓ(A) := ∥C xx -A 2 ∥ Frob , A ∈ S n ++ .

6.2. GRADIENT FLOW ANALYSIS OF ALGORITHM 1

The gradient flow of U (M) is given by dM(t) dt = -∇U (M(t)) = M(t) -1 C xx M(t) -1 -I n . To analyze solutions of M(t), we focus on spectral intializations. Lemma 1. Suppose M 0 is a spectral initialization. Then the solution M(t) of the ODE 6 is of the form M(t) = U x diag(σ 1 (t), . . . , σ n (t))U ⊤ x where σ 1 (t), . . . , σ n (t), are the solutions of the ODE dσ i (t) dt = λ 2 i σ i (t) 2 -1, i = 1, . . . , n. Consequently, d dt (σ i (t) 2 -λ 2 i ) 2 = - 4 σ i (t) (σ i (t) 2 -λ 2 i ) 2 , i = 1, . . . , n. From Lemma 1, we see that for a spectral initialization with σ i (0) ≤ λ i , the dynamics of σ i (t) satisfy d dt (σ i (t) 2 -λ 2 i ) 2 ≤ - 4 λ i (σ i (t) 2 -λ 2 i ) 2 . It follows that σ i (t) 2 converges to λ 2 i exponentially with convergence rate greater than 2/ √ λ i . On the other hand, suppose σ i (0) ≫ λ i . From equation 7, we see that while σ i (t) ≫ λ i , σ i (t) decays at approximately unit rate; that is, dσ i (t) dt ≈ -1. Therefore, the time for σ i (t) to converge to λ i grows linearly with σ i (0). We make these statements precise in the following proposition. Proposition 1. Suppose M 0 is a spectral initialization and let M(t) denote the solution of the ODE 6 starting from M 0 . If σ i ≤ λ i for all i = 1, . . . , n, then for ϵ < ℓ(M 0 ), min{t ≥ 0 : ℓ(M(t)) ≤ ϵ} ≤ √ λ max 2 log ℓ(M 0 )ϵ -1 , where λ max := max i λ i . On the other hand, if σ i > λ i for some i = 1, . . . , n, then for ϵ < σ 2 i -λ 2 i , min{t ≥ 0 : ℓ(M(t)) ≤ ϵ} ≥ σ i -λ 2 i + ϵ.

6.3. GRADIENT FLOW ANALYSIS OF ALGORITHM 2

The gradient flow of the overparameterized cost V (W) is given by dW(t) dt = - 1 2 ∇V (W) = W(t)W(t) ⊤ -1 C xx W(t)W(t) ⊤ -1 W(t) -W(t). Next, we show that ℓ(W(t)W(t) ⊤ ) converges to zero exponentially for any initialization W 0 . Proposition 2. Let W(t) denote the solution of the ODE 12 starting from any W 0 ∈ R n×k . Let ϵ < ℓ(W 0 W ⊤ 0 ). For spectral initializations W 0 W ⊤ 0 , min{t ≥ 0 : ℓ(W(t)W(t) ⊤ ) ≤ ϵ} = 1 4 log ℓ(W 0 W ⊤ 0 )ϵ -1 . ( ) For general initializations W 0 W ⊤ 0 , min{t ≥ 0 : ℓ(W(t)W(t) ⊤ ) ≤ ϵ} ≤ 1 2 log ℓ(W 0 W ⊤ 0 )ϵ -1 .

7. NUMERICAL EXPERIMENTS

In this section, we numerically test the offline and online algorithms on synthetic datasets. Let Whitening error(t ) = ∥A -1 t C xx A -1 t -I n ∥ Frob , A t ∈ {M t , W t W ⊤ t }, where A t is the value of matrix after the t th iterate. To quantify the convergence time, define Convergence time := min{t ≥ 1 : Whitening error(t) < 0.1}. In plots with multiple runs, lines and shaded regions respectively denote the means and 95% confidence intervals over 10 runs.

7.1. OFFLINE ALGORITHMS

Let n = 5, k = 10 and T = 10 5 . We generate a data matrix X with i.i.d. entries x t,i chosen uniformly from the interval (0, 12). The eigenvalues of C xx are {24.01, 16.42, 10.45, 6.59, 3.28}. We initialize W 0 = Q √ αΣP ⊤ , where α > 0, Σ = diag(5, 4, 3, 2, 1), Q is an n × n orthogonal matrix and P is a random k × n matrix with orthonormal column vectors. We set M 0 = W 0 W ⊤ 0 = αQΣ 2 Q ⊤ . We consider spectral initializations (i.e., Q = U x ) and nonspectral initializations (i.e., Q is a random orthogonal matrix). We use step size η = 10 -3 . The results of running the offline algorithms for α = 0.1, 1, 10 are shown in Figure 2 . Consistent with Propositions 1 and 2, the network with direct recurrent connections convergences slowly for 'large' α, whereas the network with interneurons converges exponentially for all α. Further, as predicted by equation 9, when the eigenvalues of M are 'large', i.e., σ i ≫ λ i , they decay linearly. In Figure 3 , we plot the convergence times of the offline algorithms with spectral and nonspectral initializations, for α = 1, 2, . . . , 20. Consistent with our analysis of the gradient flows, the convergence times for network with direct lateral connections (resp. interneurons) grows linearly (resp. logarithmically) with α. B , where U A , U B are random 2 × 2 rotation matrices. We generate a 2 × 4T dataset X = [x 1 , . . . , x 4T ] with independent samples x 1 , . . . , x T , x 2T +1 , . . . , x 3T ∼ N (0, C A ) and x T +1 , . . . , x 2T , x 3T +1 , . . . , x 4T ∼ N (0, C B ). We evaluated our online algorithms with step size η = 10 -4 and initializations M 0 = W 0 W ⊤ 0 , where W 0 = Qdiag(σ 1 , σ 2 )P ⊤ , Q is a random 2 × 2 rotation matrix, P is a random 4 × 2 matrix with orthonormal column vectors and σ 1 , σ 2 are independent random variables chosen uniformly from the interval (1, 1.5). The results are shown in Figure 4 . Consistent with our theoretical analyses, we see that the network with interneurons adapts to changing distributions faster than the network with direct recurrent connections.

7.3. APPLICATION TO PRINCIPAL SUBSPACE LEARNING

A challenge in unsupervised and self-supervised learning is preventing collapse (i.e., degenerate solutions). A recent approach in self-supervised learning is to decorrelate or whiten the feature representation (Ermolov et al., 2021; Zbontar et al., 2021; Bardes et al., 2022; Hua et al., 2021) . Here, using Oja's online principal component algorithm (Oja, 1982) as a tractable example, we demonstrate that the speed of whitening can affect the accuracy of the learned representation. Consider a neuron whose input and output at time t are respectively s t ∈ R d and y t := v ⊤ s t , where v ∈ R d represents the synaptic weights connecting the inputs to the neuron. Oja's algorithm learns the top principal component of the inputs by updating the vector v as follows: v ← v + ζ(y t s t -y 2 t v), where ζ > 0 is the step size. Next, consider a population of 2 ≤ n ≤ d neurons with outputs y t ∈ R n and feedforward synaptic weight vectors v 1 , . . . , v n ∈ R d connecting the inputs s t to the n neurons. The goal is to project the inputs onto their n-dimensional principal subspace. Running n instances of Oja's algorithm in parallel without lateral connections results in collapse -each synaptic weight vector v i converges to the top principal component. One way to avoid collapse is to whiten the output y t using recurrent connections, Figure 5 (left). Here we show that when the subspace projection and output whitening are learned simultaneously, it is critical that the whitening transformation is learned sufficiently fast. We set d = 3, n = 2 and generate i.i.d. inputs s t ∼ N (0, diag(5, 2, 1)). We initialize two random vectors v 1 , v 2 ∈ R 3 with independent N (0, 1) entries. At each time step t, we use the projection x t := (v ⊤ 1 s t , v ⊤ 2 s t ) as the input to either Algorithm 1 or 2 and we let y t be the output; that is, y t = M -1 x t or y t = (WW ⊤ ) -1 x t . For i = 1, 2, we update v i according to equation 15 with ζ = 10 -3 and with v i (resp. y t,i ) in place of v (resp. y t ). We update the recurrent weights M or W according to Algorithm 1 or 2. To measure the performance, we define Subspace error := ∥V(V ⊤ V) -1 V ⊤ -diag(1, 1, 0)∥ 2 Frob , V := [v 1 , v 2 ] ∈ R 3×2 , which is equal to zero when v 1 , v 2 span the 2-dimensional principal subspace of the inputs. We plot the subspace error for Algorithms 1 and 2 using learning rates η = ζ and η = 10ζ, Figure 5 (right). The results suggest that in order to learn the correct subspace, the whitening transformation must be learned sufficiently fast relative to the feedforward vectors v i ; for additional results along these lines, see (Pehlevan et al., 2018, section 6) . Therefore, since interneurons accelerate learning of the whitening transform, they are useful for accurate or optimal representation learning.

8. DISCUSSION

We analyzed the gradient flow dynamics of 2 recurrent neural networks for ZCA whitening -one with direct recurrent connections and one with interneurons. For spectral initializations we can analytically estimate the convergence time for both gradient flows. For nonspectral initializations, we show numerically that the convergence times are close to the spectral initializations with the same initial spectrum. Our results show that the recurrent neural network with interneurons is more robust to initialization. An interesting question is whether including interneurons in other classes of recurrent neural networks also accelerates the learning dynamics of those networks. We hypothesize that interneurons accelerate learning dynamics when the objective for the network with interneurons can be viewed as an overparameterization of the objective for the recurrent neural network with direct connections. We begin by replacing W and W ⊤ in Algorithm 2 with W zy and W yz , respectively. Then the neural activities are given by y t ← y t + γ(x t -W zy z t ), z t ← z t + γ(W yz y t -z t ), which converge to (W zy W yz ) -1 x t and z t = W yz y t . After the neural activities converge, the synaptic weights are updated according to the update rules W zy ← W zy + η(y t z ⊤ t -W zy ) W yz ← W yz + η(z t y ⊤ t -W yz ). Let W zy,t and W yz,t denote the values of the synaptic weight matrices after t updates. By iterating the above updates, we see that the difference between the weight matrices after t updates is given by W yz,t -W ⊤ zy,t = (1 -η) t (W yz,0 -W ⊤ zy,0 ). Therefore, provided 0 < η < 1, the difference decays exponentially.

B.2 SIGN-CONSTRAINING THE INTERNEURON WEIGHTS

In equation 16, the matrix W yz is preceded by a positive sign and the matrix W zy is preceded by a negative sign, which is consistent with the fact that the principal neuron-to-interneuron synapses are excitatory whereas the interneuron-to-principal neuron synapses are inhibitory. However, this interpretation is superficial because the matrices are not constrained to be non-negative, so the network can violate Dale's principle. One approach to enforce Dale's principle is to take projected gradient descent steps when updating the synaptic weight matrices, which results in the offline updates W zy ← W zy + η 1 T YZ ⊤ -W zy + , W yz ← W yz + η 1 T ZY ⊤ -W yz + , and online updates W zy ← [W zy + η(y t z ⊤ t -W zy )] + , W yz ← [W yz + η(z t y ⊤ t -W yz )] + , where [•] + denotes the element-wise rectification operation. This results in Algorithm 3. Algorithm 3: A whitening network with interneurons and sign-constrained weights input centered inputs {x t }; parameters γ, η initialize non-negative matrices W yz , W zy for t = 1, . . . , T do repeat y t ← y t + γ(x t -W zy z t ) z t ← z t + γ(W yz y t -z t ) until convergence W yz ← [W yz + η(z t y ⊤ t -W yz )] + W zy ← [W zy + η(y t z ⊤ t -W zy )] + end for In general, this modification results in a network output that does not correspond to ZCA whitening of the inputs. That being said, it is still worth comparing the performance of the network to the network with interneurons derived in section 5. In Figure 6 , we compare the performance of the offline versions of Algorithms 2 and 3. We see that Algorithm 3 equilibrates at a rate that is comparable to Algorithm 2 (and therefore much faster than Algorithm 1 for large α). In particular, this modified algorithm appears to also exhibit the accelerated dynamics due to the overparamterization of the objective. As expected, the outputs of the projected gradient descent algorithm are not fully whitened -after the synaptic dynamics of Algorithm 3 equilibrate, the eigenvalues of the output covariance C yy are {1.39, 1.01, 1.00, 0.98, 0.61}. This can be compared with the eigenvalues of the input covariance matrix C xx : {24.01, 16.42, 10.45, 6.59, 3.28}, which suggests that the network significantly normalizes the eigenvalues of the output covariance without exactly performing ZCA whitening. In general, understanding the exact statistical transformation that results from the projected gradient descent algorithm is more mathematically challenging. For instance, in the case of Algorithm 2, the weights W adapt so that WW ⊤ = C 1/2 xx ; that is, the algorithm solves a symmetric matrix factorization problem whose solutions can be written in closed form in terms of the SVD of the covariance matrix C xx . In the case of Algorithm 3, the algorithm is essentially solving for the optimal (approximate) non-negative symmetric matrix factorization of C 1/2 xx , for which there do not exist general closed form solutions. That being said, the empirical results suggest that the more biologically realistic network still performs rapid statistical adaptation.

B.3 MAPPING ALGORITHM 3 ONTO THE OLFACTORY BULB

In vertebrates, an early stage of olfactory processing occurs in the olfactory bulb. The olfactory bulb receives direct olfactory inputs, which it processes and transmits to higher order brain regions (Shepherd et al., 2004) , Figure 7 . Olfactory receptor neurons project to the olfactory bulb, where their axon terminals cluster by receptor type into spherical structures called glomeruli. Mitral cells, which are the main projection neurons of the olfactory bulb, receive direct inputs from the olfactory receptor neuron axons and output to the rest of the brain. Each mitral cell extends its apical dendrite into a single glomerulus, where it forms synapses with the axons of the olfactory receptor neurons expressing a common receptor type. The mitral cell activities are modulated by lateral inhibition from interneurons called granule cells, which are axonless neurons that form reciprocal dendrodendritic synapses with the basal dendrites of mitral cells (Shepherd, 2009) . Experimental evidence indicates that the olfactory bulb transforms its inputs so that mitral cell responses to distinct odors are approximately orthogonal (Friedrich & Laurent, 2001; 2004; Giridhar et al., 2011; Gschwend et al., 2015; Wanner & Friedrich, 2020) , a transformation referred to as pattern separation, which is closely related to statistical whitening (Wick et al., 2010) . We explore the possibility that the mitral-granule cell microcircuit implements Algorithm 3, Figure 7 . In this case, the vectors x t and y t represent the inputs to and outputs from the mitral cells, respectively. The vector z t represents the granule cell outputs. We let W yz (resp. -W zy ) denote the mitral cell-to-granule cell (resp. granule cell-to-mitral cell) synaptic weights. We compare our model with additional experimental observations. First, Algorithm 3 learns by adapting the matrices W yz and W zy . This is in line with experimental observations that granule cells in the olfactory bulb are highly plastic suggesting that the dendrodendritic synapses are a synaptic substrate for experience dependent modulation of the olfactory bulb output (Nissant et al., 2009; Sailor et al., 2016) . Second, a consequence of Algorithm 3 is that the mitral cell-togranule cell synapses and granule cell-to-mitral cell synapses are (asymptotically) symmetric; that is, W zy = W ⊤ yz . There is experimental evidence in support of this symmetry -the dendrodendritic synapses are mainly present in reciprocal pairs (Woolf et al., 1991) . In other words, the connectivity matrix of mitral cell-to-granule cell synapses (the matrix obtained by setting the non-zero entries in W yz to 1) is approximately equal to the transpose of the connectivity matrix of granule cell-tomitral cell synapses. Finally, Algorithm 3 requires that k ≥ n; that is, there are more granule cells than mitral cells. This is consistent with measurements in the olfactory bulb indicating that there are approximately 50-100 times more granule cells than mitral cells (Shepherd et al., 2004) .

C.1 PROOF OF LEMMA 1

Proof. Let σ 1 (t), . . . , σ n (t) be solutions of the ODE 7 and define M(t) := U x Σ(t)U ⊤ x , where Σ(t) := diag(σ 1 (t), . . . , σ n (t)). Then dM(t) dt = U x diag λ 2 1 σ 1 (t) 2 , . . . , λ 2 n σ n (t) 2 U ⊤ x -I n = M(t) -1 C xx M(t) -1 -I n . In particular, we see that M(t) must be the unique solution of the ODE 6, where uniqueness of solutions follows because the right-hand-side of equation 6 is locally Lipschitz continuous on its domain of definition. Equation 8 then follows from the ODE 7 and the chain rule.



Figure 1: Recurrent neural networks for ZCA whitening with direct recurrent connections (left, Algorithm 1) and with interneurons (right, Algorithm 2).

Figure 2: Comparison of whitening error and eigenvalue evolution for the offline algorithms with (a) spectral initializations and (b) nonspectral initializaitons.

Figure 3: Comparison of convergence times for the offline algorithms with both spectral and nonspectral initializations, as functions of α = 1, 2, . . . , 20.

Figure 4: Comparison of whitening error and eigenvalue evolution for the online algorithms on a dataset with switching distributions (white vs. light orange backgrounds).

Figure 5: Left: networks for principal subspace projection with output whitening using direct lateral connections or interneurons. Right: comparison of subspace error for the two networks with relative learning rates η = ζ or η = 10ζ.

Figure 6: Comparison of whitening error and eigenvalue evolution for the offline algorithms (with non-spectral initializations) corresponding to the network with interneurons with unconstrained or non-negative weights. The lines and shaded regions respectively denote the means and 95% confidence intervals over 10 runs.

Figure 7: A simplified schematic of the olfactory bulb.

ACKNOWLEDGMENTS

We thank Yanis Bahroun, Nikolai Chapochnikov, Lyndon Duong, Johannes Friedrich, Siavash Golkar, Jason Moore and Tiberiu Tes ¸ileanu for helpful feedback on an earlier draft of this work. C. Pehlevan acknowledges support from the Intel Corporation.

A SADDLE POINT PROPERTY

Here we recall the following minimax property for a function that satisfies the saddle point property (Boyd & Vandenberghe, 2004, section 5.4 ). Theorem 1. Let V ⊆ R n , W ⊆ R m and f : V × W → R. Suppose f satisfies the saddle point property; that is, there existsThen

B RELATION TO NEURONAL CIRCUITS

In this section, we modify Algorithm 2 to satisfy additional biological constraints and we map the modified algorithm onto the vertebrate olfactory bulb.

B.1 DECOUPLING THE INTERNEURON SYNAPSES

The neural circuit implementation of Algorithm 2 requires that the principal neuron-to-interneuron synaptic weight matrix W ⊤ is the negative transpose of the interneuron-to-principal neuron synaptic weight matrix -W. In general, enforcing this symmetry is not biologically plausible, and is commonly referred to as the weight transport problem. Here, following (Golkar et al., 2020, appendix D) , we decouple the synapses and show that the (asymptotic) symmetry of the synaptic weight matrices follows from the symmetry of the local Hebbian/anti-Hebbian updates.Published as a conference paper at ICLR 2023

C.2 PROOF OF PROPOSITION 1

Proof. Suppose σ i (0) ≤ λ i for i = 1, . . . , n. By equation 8, σ i (t) ≤ λ i for all t ≥ 0 and soTherefore,It follows that equation 10 holds. Now suppose σ i (0) ≥ λ i for some i = 1, . . . , n. By equation 7,It follows that equation 11 holds.

C.3 PROOF OF PROPOSITION 2

Proof. Define A(t) := W(t)W(t) ⊤ . By the product rule,Then, by the chain rule and equation 17, for all t ≥ 0,Suppose W 0 W ⊤ 0 is a spectral initialization. Then A(t) commutes with C xx and so dℓ(A(t)) 2 dt = -8ℓ(A(t)) 2 ⇒ ℓ(A(t)) = ℓ(A(0)) exp(-4t).Thus, equation 13 holds. For general initializations, dℓ(A(t)) 2 dt ≤ -4ℓ(A(t)) 2 ⇒ ℓ(A(t)) ≤ ℓ(A(0)) exp(-2t).Therefore, inequality equation 14 holds.

