LEARNING LOW DIMENSIONAL STATE SPACES WITH OVERPARAMETERIZED RECURRENT NEURAL NETS

Abstract

Overparameterization in deep learning typically refers to settings where a trained neural network (NN) has representational capacity to fit the training data in many ways, some of which generalize well, while others do not. In the case of Recurrent Neural Networks (RNNs), there exists an additional layer of overparameterization, in the sense that a model may exhibit many solutions that generalize well for sequence lengths seen in training, some of which extrapolate to longer sequences, while others do not. Numerous works have studied the tendency of Gradient Descent (GD) to fit overparameterized NNs with solutions that generalize well. On the other hand, its tendency to fit overparameterized RNNs with solutions that extrapolate has been discovered only recently and is far less understood. In this paper, we analyze the extrapolation properties of GD when applied to overparameterized linear RNNs. In contrast to recent arguments suggesting an implicit bias towards short-term memory, we provide theoretical evidence for learning low-dimensional state spaces, which can also model long-term memory. Our result relies on a dynamical characterization which shows that GD (with small step size and near-zero initialization) strives to maintain a certain form of balancedness, as well as on tools developed in the context of the moment problem from statistics (recovery of a probability distribution from its moments). Experiments corroborate our theory, demonstrating extrapolation via learning low-dimensional state spaces with both linear and non-linear RNNs. 1 There is no formal contradiction between our results and those of (Emami et al., 2021) and (Cohen-Karlik et al., 2022) . These works make restrictive assumptions (namely, the former assumes that the teacher is stable and its impulse response decays exponentially fast, and the latter assumes that the teacher is memoryless) under which implicit extrapolation via learning low-dimensional state spaces leads to solutions with short-term memory. Our work on the other hand is not limited by these assumptions, and we show that in cases where they are violated, learning yields solutions with low-dimensional state spaces that do not result in short-term memory.

1. INTRODUCTION

Neural Networks (NNs) are often overparameterized, in the sense that their representational capacity far exceeds what is necessary for fitting training data. Surprisingly, training overparameterized NNs via (variants of) Gradient Descent (GD) tends to produce solutions that generalize well, despite existence of many solutions that do not. This implicit generalization phenomenon attracted considerable scientific interest, resulting in various theoretical explanations (see, e.g., Woodworth et al. (2020) ; Yun et al. (2020) ; Zhang et al. (2017) ; Li et al. (2020) ; Ji & Telgarsky (2018) ; Lyu & Li (2019) ). Recent studies have surfaced a new form of implicit bias that arises in Recurrent Neural Networks (RNNs) and their variants (e.g., Long Short-Term Memory Hochreiter & Schmidhuber (1997) and Gated Recurrent Units Chung et al. (2014) ). For such models, the length of sequences in training is often shorter than in testing, and it is not clear to what extent a learned solution will be able to extrapolate beyond the sequence lengths seen in training. In the overparameterized regime, where the representational capacity of the learned model exceeds what is necessary for fitting short sequences, there may exist solutions that generalize but do not extrapolate, meaning that their accuracy is high over short sequences but arbitrarily poor over long ones (see Cohen-Karlik et al. (2022) ). In practice however, when training RNNs using GD, accurate extrapolation is often observed. We refer to this phenomenon as the implicit extrapolation of GD. As opposed to the implicit generalization of GD, little is formally known about its implicit extrapolation. Existing theoretical analyses of the latter focus on linear RNNs -also known as Linear Dynamical Systems (LDS) -and either treat infinitely wide models (Emami et al., 2021) , or models of finite width that learn from a memoryless teacher (Cohen-Karlik et al., 2022) . In these regimes, GD has been argued to exhibit an implicit bias towards short-term memory. While such results are informative, their generality remains in question, particularly since infinitely wide NNs are known to substantially differ from their finite-width counterparts, and since a memoryless teacher essentially neglects the main characteristic of RNNs (memory). In this paper, we theoretically investigate the implicit extrapolation of GD when applied to overparameterized finite-width linear RNNs learning from a teacher with memory. We consider models with symmetric transition matrices, in the case where a student (learned model) with state space dimension d is trained on sequences of length k generated by a teacher with state space dimension d. Our interest lies in the overparameterized regime, where d is greater than both k and d, meaning that the student has state space dimensions large enough to fully agree with the teacher on sequences of length k, while potentially disagreeing with it on longer sequences. As a necessary assumption on initialization, we follow prior work and focus on a certain balancedness condition, which is known (see experiments in Cohen-Karlik et al. (2022) , as well as our theoretical analysis) to capture near-zero initialization as commonly employed in practice. Our main theoretical result states that GD originating from a balanced initialization leads the student to extrapolate, irrespective of how large its state space dimension is. Key to the result is a surprising connection to a moment matching theorem from Cohen & Yeredor (2011) , whose proof relies on ideas from compressed sensing (Elad, 2010; Eldar & Kutyniok, 2012) and neighborly polytopes (Gale, 1963) . This connection may be of independent interest, and in particular may prove useful in deriving other results concerning the implicit properties of GD. We corroborate our theory with experiments, which demonstrate extrapolation via learning low-dimensional state spaces in both the analyzed setting and ones involving non-linear RNNs. The implicit extrapolation of GD is an emerging and exciting area of inquiry. Our results suggest that short-term memory is not enough for explaining it as previously believed. We hope the techniques developed in this paper will contribute to a further understanding of this phenomenon.

2. RELATED WORK

The study of linear RNNs, or LDS, has a rich history dating back to at least the early works of Kalman (Kalman, 1960; 1963) . An extensively studied question relevant to extrapolation is that of system identification, which explores when the parameters of a teacher LDS can be recovered (see Ljung (1999) ). Another related topic concerns finding compact realizations of systems, i.e. realizations of the same input-output mapping as a given LDS, with a state space dimension that is lower (see Antoulas (2005) ). Despite the relation, our focus is fundamentally different from the above -we ask what happens when one learns an LDS using GD. Since GD is not explicitly designed to find a low-dimensional state space, it is not clear that the application of GD to an overparameterized student allows system identification through a compact realization. The fact that it does relate to the implicit properties of GD, and to our knowledge has not been investigated in the classic LDS literature. The implicit generalization of GD in training RNNs has been a subject of theoretical study for at least several years (see, e.g., Hardt et al. (2016) ; Allen-Zhu et al. (2019) ; Lim et al. (2021) ). In contrast, works analyzing the implicit extrapolation of GD have surfaced only recently, specifically in Emami et al. (2021) and Cohen-Karlik et al. (2022) Emami et al. (2021) analyzes linear RNNs in the infinite width regime, suggesting that in this case GD is implicitly biased towards impulse responses corresponding to short-term memory. Cohen-Karlik et al. (2022) studies finite-width linear RNNs (as we do), showing that when the teacher is memoryless (has state space dimension zero), GD emanating from a balanced initialization successfully extrapolates. Our work tackles an arguably more realistic and challenging setting -we analyze the regime in which the teacher has memory. Our results suggest that the implicit extrapolation of GD does not originate from a bias towards short-term memory, but rather a tendency to learn low-dimensional state spaces. 1 We note that there have been works studying extrapolation in the context of non-recurrent NNs, e.g. Xu et al. (2020) . This type of extrapolation deals with the behavior of learned functions outside the support of the training distribution, and thus fundamentally different from the type of extrapolation we consider, which deals with the behavior of learned functions over sequences longer than those seen in training. Linear RNNs fundamentally differ from the commonly studied model of linear (feed-forward) NNs (see, e.g., Arora et al. (2018) ; Ji & Telgarsky (2018) ; Arora et al. (2019a; b) ; Razin & Cohen (2020) ). One of the key differences is that a linear RNN entails different powers of a parameter (transition) matrix, leading to a loss function which roughly corresponds to a sum of losses for multiple linear NNs having different architectures and shared weights. This precludes the use of a vast array of theoretical tools tailored for linear NNs, rendering the analysis of linear RNNs technically challenging. On the empirical side, extrapolation of NNs to sequence lengths beyond those seen in training has been experimentally demonstrated in numerous recent works, covering both modern language and attention models (Press et al., 2022; Anil et al., 2022; Zhang et al., 2022) , and RNNs with transition matrices of particular forms (Gu et al., 2022; 2021; 2020; Gupta, 2022) . The current paper is motivated by these findings, and takes a step towards theoretically explaining them.

3. LINEAR RECURRENT NEURAL NETWORKS

Our theoretical analysis applies to single-input single-output (SISO) linear RNNs with symmetric transition matrices. Given a state space dimension d P N, this model is defined by the update rule: s t`1 " As t `Bx t , y t " Cs t , t " 0, 1, 2, . . . , (3.1) where A P R dˆd , B P R dˆ1 and C P R 1ˆd are configurable parameters, the transition matrix A satisfies A " A J ; x 0 , x 1 , . . . P R form an input sequence; y 0 , y 1 , . . . P R form the corresponding output sequence; and s t P R dˆ1 represents the internal state at time t, assumed to be equal to zero at the outset (i.e. it is assumed that s 0 " 0). As with any linear time-invariant system (Porat, 1996) , the input-output mapping realized by the RNN is determined by its impulse response. Definition 1. The impulse response of the RNN is the output sequence corresponding to the input sequence px 0 , x 1 , x 2 , . . .q " p1, 0, 0, . . .q. Namely, it is the sequence pCB, CAB, CA 2 B, . . . q. For brevity, we employ the shorthand Θ :" pA, B, Cq. The d ˆd symmetric transition matrix A is parameterized through a dpd `1q{2-dimensional vector holding its upper triangular elements, and with a slight overloading of notation, the symbol A is also used to refer to this parameterization. We note that our theory readily extends to multiple-input multiple-output (MIMO) networks, and the focus on the SISO case is merely for simplicity of presentation. Note also that the restriction to symmetric transition matrices is customary in both theory (Hazan et al., 2018) and practice (Gupta, 2022) , and represents a generalization of the canonical modal form, which under mild non-degeneracy conditions does not limit generality (Boyd & Lessard, 2006) . Given a length k input sequence, x " px 0 , . . . , x k´1 q P R k , consider the output at the last time step, i.e. y :" y k P R, and denote it by RN N pxq. Using this output as a label, we define an empirical loss induced by a training set S " ␣`x p1q , y p1q ˘, . . . , `xpNq , y pN q ˘( Ă R k ˆR: L S pA, B, Cq " 1 N N ÿ i"1 ℓ ´RN N ´xpiq ¯, y piq ¯, (3.2) where ℓpy, ŷq " py ´ŷq 2 is the square loss. By the update rule of the RNN (Equation 3.1), we have: L S pA, B, Cq " 1 N N ÿ i"1 ˜k´1 ÿ j"0 CA k´1´j Bx piq j ´ypiq ¸2 . (3.3) Suppose that ground truth labels are generated by an RNN as defined in Equation 3.1, and denote the state space dimension and parameters of this teacher network by d and Θ " p Â, B, Ĉq respectively. We employ the common assumption (e.g., see Hardt et al. (2016) ) by which input sequences are drawn from a whitened distribution, i.e. a distribution where E rx j x j 1 s equals 1 if j " j 1 and 0 otherwise. The population loss over length k sequences can then be written as (see Lemma E.1): LpA, B, Cq " k´1 ÿ j"0 ´CA j B ´Ĉ Âj B¯2 . (3.4) Equation 3.4 implies that a solution Θ " pA, B, Cq achieves zero population loss over length k sequences if and only if CA j B " Ĉ Âj B for j " 0, . . . , k ´1. To what extent does such a solution imply that the student (i.e., the learned RNN) extrapolates to longer sequences? This depends on how close CA j B is to Ĉ Âj B for j ě k. Definition 2. For ϵ ě 0 and q P N, we say that the student ϵ-extrapolates with horizon q with respect to (w.r.t) the teacher if: |CA j B ´Ĉ Âj B| ď ϵ, @j P t0, 1, . . . , q ´1u . (3.5) If the above holds for all q P N then the student is said to ϵ-extrapolate w.r.t the teacher, and if it holds for all q P N with ϵ " 0 then the student is simply said to extrapolate w.r.t the teacher. Per Definition 2, ϵ-extrapolation with horizon q is equivalent to the first q elements of the student's impulse response being ϵ-close to those of the teacher's, whereas extrapolation means that the student's impulse response fully coincides with the teacher's. The latter condition implies that the student realizes the same input-output mapping as the teacher, for any sequence length (this corresponds to the notion of system identification; see Section 2). Notice that when the student is overparameterized, in the sense that d is greater than k and d, it may perfectly generalize, i.e. lead the population loss over length k sequences (Equation 3.4) to equal zero, and yet fail to extrapolate, as stated in the following proposition. Proposition 3. Assume d ą k, and let ϵ ě 0 and q P tk `1, k `2, . . .u. Then, for any teacher parameters Θ, there exist student parameters Θ with which the population loss in Equation 3.4 equals zero, and yet the student does not ϵ-extrapolate with horizon q. Proof sketch (for complete proof see Appendix E.1.2). The result follows from the fact that the first d elements of the student's impulse response can be assigned freely via a proper choice of Θ. We are interested in the extent to which student parameters learned by GD extrapolate in the overparameterized regime. Proposition 3 implies that, regardless of how many (length k) sequences are used in training, if GD leads to any form of extrapolation, it must be a result of some implicit bias induced by the algorithm. Note that in our setting, extrapolation cannot be explained via classic tools from statistical learning theory, as evaluation over sequences longer than those seen in training violates the standard assumption of train and test data originating from the same distribution. To decouple the question of extrapolation from that of generalization, we consider the case where the training set S is large, or more formally, where the empirical loss L S p¨q (Equation 3.3) is well represented by the population loss Lp¨q (Equation 3.4). We model GD with small step size via Gradient Flow (GF), as customary in the theory of NNs -see Saxe et al. (2013) ; Gunasekar et al. (2017) ; Arora et al. (2018; 2019b) ; Lyu & Li (2019) ; Li et al. (2020) ; Azulay et al. (2021) for examples where it is used and Elkabetz & Cohen (2021) for a theoretical justification of its usage. Using the GF formulation, we analyze the following dynamics: 9 αpτ q :" d dτ αpτ q " ´B Bα L `Apτ q, Bpτ q, Cpτ q ˘, τ ě 0 , (3.6) where α P tA, B, Cu. If no assumption on initialization is made, no form of extrapolation can be established (indeed, the initial point may be a global minimizer of Lp¨q that fails to extrapolate, and GF will stay there). Following prior work (see Cohen-Karlik et al. (2022) ), we assume that the initialization adheres to the following balancedness condition: Definition 4. An RNN with parameters Θ " pA, B, Cq is said to be balanced if B " C J . It was shown empirically in Cohen-Karlik et al. (2022) that the balancedness condition captures near-zero initialization as commonly employed in practice. We support this finding theoretically in Section 4.3. Aside from the initialization of the student, we will also assume that the teacher adheres to the balancedness condition.

4. THEORETICAL ANALYSIS

We turn to our theoretical analysis. Section 4.1 proves that in the setting of Section 3, convergence of GF to a zero loss solution leads the student to extrapolate, irrespective of how large its state space dimension is. Section 4.2 extends this result by establishing that, under mild conditions, approximate convergence leads to approximate extrapolation. The results of sections 4.1 and 4.2 assume that GF emanates from a balanced initialization, which empirically is known to capture near-zero initialization as commonly employed in practice (see Section 3). Section 4.3 theoretically supports this empirical premise, by showing that with high probability, random near-zero initialization leads to balancedness. We introduce notations that will be used throughout the analysis. For a matrix Q P R mˆn , we let }Q} F , }Q} 8 and }Q} 2 denote the Frobenius, ℓ 8 and ℓ 2 (spectral) norms, respectively. For a vector v P R m , we use }v} to denote the Euclidean norm and v i to denote its i th entry. 4 In order to prove Theorem 5, we introduce two lemmas: Lemma 6, which shows that balancedness is preserved under GF; and Lemma 7, which (through a surprising connection to a moment problem from statistics) establishes that a balanced solution attaining zero loss necessarily extrapolates. With Lemmas 6 and 7 in place, the proof of Theorem 5 readily follows. Lemma 6. Let Θpτ q, with τ ě 0, be a curve brought forth by applying GF to the loss Lp¨q starting from a balanced initialization. Then, Θpτ q is balanced for every τ ě 0. Proof sketch (for complete proof see Appendix E.1.5). The result follows from the symmetric role of B and C in the loss Lp¨q. Lemma 7. Suppose that d ą k ą 2 d, the teacher is balanced, and that the student parameters Θ are balanced and satisfy LpΘq " 0. Then Θ extrapolates. Proof sketch (for complete proof see Appendix E.2). The proof is based on a surprising connection that we draw to the moment problem from statistics (recovery of a probability distribution from its moments), which has been studied for decades (see, e.g., Schmüdgen (2017) ). Without loss of generality, we may assume that Â is diagonal (if this is not the case then we apply an orthogonal eigendecomposition to Â and subsequently absorb eigenvectors into B and Ĉ). We may also assume that Ĉ B " 1 (otherwise we absorb a scaling factor into B and/or C). Since ĈJ " B (teacher is balanced), we may define a probability vector (i.e. a vector with non-negative entries summing up to one) p P R d via pi " Ĉi Bi , i " 1, . . . , d. We let Ẑ denote the random variable supported on t Â1,1 , . . . , Â d, du, which assumes the value Âi,i with probability pi , i " 1, . . . , d. Notice that for every j P N: Ĉ Âj B " ÿ d i"1 pi Âj i,i " Er Ẑj s , meaning that the elements of the teacher's impulse response are precisely the moments of Ẑ. Similarly to above we may assume A is diagonal, and since LpΘq " 0 it holds that CB " Ĉ B " 1. We may thus define a probability vector p P R d via p i " C i B i , i " 1, . . . , d, and a random variable Z which assumes the value A ii with probability p i , i " 1, . . . , d. For every j P N: CA j B " ÿ d i"1 p i A j i,i " ErZ j s , and so the elements of the student's impulse response are precisely the moments of Z. The probabilistic formulation we set forth admits an interpretation of extrapolation as a moment problem. Namely, since LpΘq " 0 (i.e. CA j B " Ĉ Âj B for j " 0, . . . , k´1) the random variables Z and Ẑ agree on their first k ´1 moments, and the question is whether they agree on all higher moments as well. We note that this question is somewhat more challenging than that tackled in classic instances of the moment problem, since the support of the random variable whose moments we match (Z) is not known to coincide with the support of the random variable we seek to recover ( Ẑ). Luckily, a recent powerful result allows addressing the question we face - Cohen & Yeredor (2011) showed that the first 2n moments of a discrete random variable X taking at most n P N values uniquely define X, in the sense that any discrete random variable agreeing with these 2n moments must be identical to X. Translating this result to our setting, we have that if Z agrees with Ẑ on its first 2 d moments, it must be identical to Ẑ, and in particular it must agree with Ẑ on all higher moments as well. The fact that k ´1 ě 2 d then concludes the proof. To attain some intuition for the result we imported from Cohen & Yeredor (2011) , consider the simple case where d " 1. The transition matrix Â is then a scalar â P R, the random variable Ẑ is deterministically equal to â, and the teacher's impulse response is given by the moments Er Ẑj s " âj , j " 0, 1, . . .. Since we assume k ą 2 d, the fact that LpΘq " 0 means the random variable corresponding to the student, Z, agrees with the first two moments of Ẑ. That is, Z satisfies ErZs " â and ErZ 2 s " â2 . This implies that VarrZs " ErZ 2 s ´ErZs 2 " 0, and therefore Z is deterministically equal to â, i.e. it is identical to Ẑ. The two random variables thus agree on all of their moments, meaning the impulse responses of the student and teacher are the same. Proof of Theorem 5. By Lemma 6 (as well as continuity considerations) Θ ˚is balanced. Therefore, Lemma 7 implies that it extrapolates.

4.2. APPROXIMATE CONVERGENCE LEADS TO APPROXIMATE EXTRAPOLATION

Theorem 5 in Section 4.1 proves extrapolation in the case where GF converges to a zero loss solution. Theorem 8 below extends this result by establishing that, under mild conditions, approximate convergence leads to approximate extrapolation -or more formally -for any ϵ ą 0 and q P N, when GF leads the loss to be sufficiently small, the student ϵ-extrapolates with horizon q. Theorem 8. Assume the conditions of Theorem 5, and that the teacher parameters Θ are stable, i.e. the eigenvalues of Â are in r´1, 1s. Assume also that Θ are non-degenerate, in the sense that the input-output mapping they realize is not identically zero. Finally, assume that the student parameters Θ learned by GF are confined to some bounded domain in parameter space. Then, for any ϵ ą 0 and q P N, there exists δpϵ, qq ą 0 such that whenever LpΘq ď δpϵ, qq, the student ϵ-extrapolates with horizon q. Proof sketch (for complete proof see Appendix E.3). Let δ ą 0 be a constant whose value will be chosen later, and suppose GF reached a point Θ satisfying LpΘq ď δ. Following the proof of Lemma 7, Θ is identified with a distribution supported on the eigenvalues of Â, whose jth moment is mj :" Ĉ Âj Bp Ĉ Bq ´1 for every j P N. Similarly, Θ is identified with a distribution supported on the eigenvalues of A, whose jth moment is m j :" CA j BpCBq ´1 for every j P N. The fact that LpΘq ď δ implies |CB ´Ĉ B| ď ? δ, and in addition | mj ´mj | ď Op ? δq for every j " 1, . . . , k ´1. To conclude the proof it suffices to show that | mj ´mj | ď Opϵq @j P t1, . . . , q ´1u (4.1) given a small enough choice for δ (this choice then serves as δpϵ, qq in the theorem statement). We establish Equation 4.1 by employing the theory of Wasserstein distances (Vaserstein, 1969) . For p P N, denote by W p the p-Wasserstein distance between the distributions identified with Θ and Θ. Since k ą 2 d, it holds that | mj ´mj | ď Op ? δq for every j " 1, . . . , 2 d. Proposition 2 in Wu & Yang (2020) then implies W 1 ď Opδ 1{4 dq. For any p P N, W p ď OpW 1{p 1 q (see Section 2.3 in Panaretos & Zemel ( 2019)) and | mp ´mp | ď OpW p q (see Section 1.2 in Biswas & Mackey (2021) ). Combining the latter three inequalities, we have that | mp ´mp | ď Opδ 1{4 dp q for any p P N. Choosing δ " Opϵ 4 dpq´1q q therefore establishes Equation 4.1.

4.3. BALANCEDNESS CAPTURES NEAR-ZERO INITIALIZATION

Theorems 5 and 8 assume that GF emanates from a balanced initialization, i.e. from a point Θ " pA, B, Cq satisfying B " C J . It was shown in Cohen-Karlik et al. (2022) that theoretical predictions derived assuming balanced initialization faithfully match experiments conducted with near-zero initialization (an initialization commonly used in practice). Proposition 9 below theoretically supports this finding, establishing that with high probability, random near-zero initialization leads GF to arrive at an approximately balanced point, i.e. a point Θ " pA, B, Cq for which the difference between B and C J is negligible compared to their size. Proposition 9. Suppose that: (i) d ą 20; (ii) the teacher parameters Θ are balanced and are non-degenerate, in the sense that the input-output mapping they realize is not identically zero; and (iii) the student parameters are learned by applying GF to the loss Lp¨q. Let Θ be a random point in parameter space, with entries drawn independently from the standard normal distribution. For ϵ ą 0, consider the case where GF emanates from the initialization ϵ Θ, and denote the resulting curve by Θ ϵ pτ q " pA ϵ pτ q, B ϵ pτ q, C ϵ pτ qq, with τ ě 0. Then, w.p. at least 0.75, for every ϵ ą 0 there exists τ ϵ ě 0 such that: lim ϵÑ0 `||B ϵ pτ ϵ q ´CJ ϵ pτ ϵ q|| F ||B ϵ pτ ϵ q `CJ ϵ pτ ϵ q|| F " 0 . (4.2) Proof sketch (for complete proof see Appendix E.4). The idea behind the proof is as follows. Assume ϵ is sufficiently small. Then, when the entries of Θ " pA, B, Cq are on the order of ϵ, we have B BB LpΘq « ´2 Ĉ B ¨CJ and B BC LpΘq « ´2 Ĉ B ¨BJ . This implies that during the first part of the curve Θ ϵ pτ q it holds that d dτ pB ϵ pτ q ´CJ ϵ pτ qq « ´2 Ĉ B ¨pB ϵ pτ q ´CJ ϵ pτ qq and similarly d dτ pB ϵ pτ q `CJ ϵ pτ qq « 2 Ĉ B ¨pB ϵ pτ q `CJ ϵ pτ qq. Since Ĉ B ą 0 (follows from the teacher parameters being balanced and non-degenerate), the entries of B ϵ pτ q ´CJ ϵ pτ q shrink exponentially fast while those of B ϵ pτ q `CJ ϵ pτ q grow at the same rate. This exponential shrinkage/growth leads }B ϵ pτ q ´CJ ϵ pτ q} L }B ϵ pτ q `CJ ϵ pτ q} to become extremely small, more so the smaller ϵ is.

5. EXPERIMENTS

In this section we present experiments corroborating our theoretical analysis (Section 4). The latter establishes that, under certain conditions, a linear RNN with state space dimension d extrapolates when learning from a teacher network with state space dimension d via training sequences of length k, irrespective of how large d is compared to d and k. A key condition underlying the result is that k is larger than 2 d. Section 5.1 below considers the theoretically analyzed setting, and empirically evaluates extrapolation as k varies. Its results demonstrate a phase transition, in the sense that extrapolation takes place when k ą 2 d, in compliance with theory, but fails when k falls below 2 d, in which case the theory indeed does not guarantee extrapolation. Section 5.2 displays the same phenomenon with linear RNNs that do not adhere to some of the assumptions made by the theory (in particular the assumption of symmetric transition matrices, and those concerning balancedness). Finally, Section 5.3 considers non-linear RNNs (specifically, Gated Recurrent Unit networks Chung et al. ( 2014)), and shows that they too exhibit a phase transition in extrapolation as the training sequence length varies. For brevity, we defer some of the details behind our implementation, as well as additional experiments, to Appendix B.

5.1. THEORETICALLY ANALYZED SETTING

Our first experiment considers the setting described in Section 3 and theoretically analyzed in Section 4. As representative values for the state space dimensions of the teacher and (overparameterized) student, we choose d " 5 and d " 40 respectively (higher state space dimensions for the student, namely d " 100 and d " 200, yield qualitatively identical results). For a given training sequence length k, the student is learned via GD applied directly to the population loss defined in Equation 3.4 (applying GD to the empirical loss defined in Equation 3.3, with N " 10, 000 training examples, led to similar results). Figure 1 (a) reports the extrapolation error (quantified by the ℓ 8 distance between the impulse response of the learned student and that of the teacher) as a function of k. As can be seen, extrapolation exhibits a phase transition that accords with our theory: when k ą 2 d extrapolation error is low, whereas when k falls below 2 d extrapolation error is high. Published as a conference paper at ICLR 2023

5.2. OTHER SETTINGS WITH LINEAR RECURRENT NEURAL NETWORKS

To assess the generality of our findings, we experiment with linear RNNs in settings that do not adhere to some of the assumptions made by our theory. Specifically, we evaluate settings in which: (i) the teacher is unbalanced, meaning B ‰ ĈJ , and its transition matrix Â is non-symmetric; (ii) the student's transition matrix A is not restricted to be symmetric; (iii) learning is implemented by optimizing the empirical loss defined in Equation 3.3 (rather than the population loss defined in Equation 3.4); and (iv) optimization is based on Adam Kingma & Ba (2014) (rather than GD), emanating from standard near-zero initialization which is generally unbalanced (namely, B ‰ C J ). 

5.3. NON-LINEAR RECURRENT NEURAL NETWORKS

As a final experiment, we explore implicit extrapolation with non-linear RNNs, namely GRU networks. Specifically, we evaluate the extent to which a student GRU with state space dimension d g " 100 extrapolates when learning from a teacher GRU with state space dimension dg " 10 (higher state space dimensions for the student, namely d g " 200 and d g " 500, yield qualitatively identical results). The student is learned by optimizing an empirical loss comprising training sequences of length k g , where k g is predetermined. Optimization is based on Adam emanating from standard nearzero initialization. Figure 2 (a) reports the extrapolation error (quantified by the ℓ 8 distance between the response of the learned student and that of the teacher, averaged across randomly generated input sequences) for different choices of k g . As can be seen, similarly to the case with linear RNNs (see Sections 5.1 and 5.2), there exists a critical threshold for the training sequence length k g , above which extrapolation error is low and below which extrapolation error is high (note that this critical threshold is around four times the teacher's state space dimension, whereas with linear RNNs the critical threshold was around two times the teacher's state space dimension; theoretically explaining this difference is an interesting direction for future work). sequences of length k g " 30, and the other with sequences of length k g " 60.foot_0 As expected, the impulse response of each student tracks that of the teacher for the first k g time steps (where k g is student-dependent). However, while the student for which k g " 30 fails to track the teacher beyond k g time steps, the student for which k g " 60 succeeds, thereby exemplifying implicit extrapolation.

6. CONCLUSION

This paper studies the question of extrapolation in RNNs, and more specifically, of whether a student RNN trained on data generated by a teacher RNN can capture the behavior of the teacher over sequences longer than those seen in training. We focus on overparameterized students that can perfectly fit training sequences while producing a wide range of behaviors over longer sequences. Such a student will fail to extrapolate, unless the teacher possesses a certain structure, and the learning algorithm is biased towards solutions adhering to that structure. We show -theoretically for linear RNNs and empirically for both linear and non-linear RNNs -that such implicit extrapolation takes place when the teacher has a low dimensional state space and the learning algorithm is GD. Existing studies of implicit extrapolation in (linear) RNNs (Emami et al., 2021; Cohen-Karlik et al., 2022) suggest that GD is biased towards solutions with short-term memory. While low dimensional state space and short-term memory may coincide in some cases, in general they do not, and a solution with low dimensional state space may entail long-term memory. Our theory and experiments show that in settings where low dimensional state space and short-term memory contradict each other, the implicit extrapolation chooses the former over the latter. An important direction for future work is extending our theory to non-linear RNNs. We believe it is possible, in the same way that theories for linear (feed-forward) NNs were extended to account for non-linear NNs (see, e.g., Razin et al. (2021; 2022) ; Lyu & Li (2019) ). An additional direction to explore is the applicability of our results to the recently introduced S4 model (Gu et al., 2022) .

A NECESSITY OF LOWER BOUND ON TRAINING SEQUENCE LENGTH

Our theoretical guarantees of implicit extrapolation (Theorems 5 and 8) assumed that the training sequence length k is greater than two times the teacher's state space dimension d. Below we show that this assumption is necessary (up to a small additive constant). More precisely, we prove that if k ď 2 d ´1, implicit extrapolation cannot be guaranteed. Lemma A.1. For any d P N, there exist two configurations of teacher parameters Θ1 " p Â1 , B1 , Ĉ1 q and Θ2 " p Â2 , B2 , Ĉ2 q, both balanced (Definition 4), stable (meaning the eigenvalues of Â1 and Â2 are in r´1, 1s) and non-degenerate (meaning the input-output mappings realized by Θ1 and Θ2 are not identically zero), such that: B1 Âj 1 Ĉ1 " B2 Âj 2 Ĉ2 for all j " 0, 1, . . . , 2 d ´2 , and yet: B1 Âj 1 Ĉ1 ‰ B2 Âj 2 Ĉ2 for j " 2 d ´1 . Proof. A derivation as in the proof sketch of Lemma 7 shows that any d-atomic distribution (i.e. any distribution supported on a set of d real numbers) can be associated with a balanced configuration of teacher parameters Θ " p Â, B, Ĉq satisfying Ĉ B " 1, such that the values to which the distribution assigns non-zero probability are the eigenvalues of Â, and the jth moment of the distribution is equal to Ĉ Âj B for every j P N. In light of this, and of the fact that any configuration of teacher parameters Θ " p Â, B, Ĉq satisfying Ĉ B " 1 is non-degenerate (the first element of its impulse response is non-zero), it suffices to prove that there exist two d-atomic distributions supported on r´1, 1s which agree on their first 2 d ´2 moments yet disagree on their p2 d ´1q'th moment. This follows from Lemmas 4 and 30 in Wu & Yang (2020). Corollary A.2. Assume the conditions of Theorem 5, but with k ą 2 d replaced by k ď 2 d ´1. Then, the stated result does not hold, i.e. GF may converge to a point Θ ˚satisfying LpΘ ˚q " 0 which does not extrapolate (Definition 2). Similarly, replacing k ą 2 d by k ď 2 d ´1 in the conditions of Theorem 8 renders the stated result false, meaning there exist ϵ ą 0 and q P N such that for every δ ą 0, there is some Θ satisfying LpΘq ď δ which does not ϵ-extrapolate with horizon q (Definition 2). Proof. In the context of either Theorem 5 or Theorem 8, if k ď 2 d ´1 then by Lemma A.1 there exist two configurations of teacher parameters -both satisfying the conditions of the theorem -which lead to a different impulse response (Definition 1) yet induce the same loss Lp¨q (Equation 3.4). If the result stated in the theorem were true, it would mean extrapolation, or ϵ-extrapolation with horizon q for arbitrarily small ϵ ą 0 and arbitrarily large q P N (see Definition 2), simultaneously with respect to both teachers, and this leads to a contradiction.

B FURTHER EXPERIMENTS

In this section we provide additional experiments that are not included in the main manuscript due to space constraints.

B.1 BALANCED TEACHER

In Section 5.1, we have experimented with our proposed theoretical setup. In this section we provide additional figures and experiments.

B.1.1 UNBALANCED STUDENT

In this experiment we use the same balanced teacher with d " 5 as done in Section 5.1. Instead of the diagonal student with balanced initialization, we use a general (non-diagonal) student with weights sampled from a Gaussian with scale 10 ´5 and d " 40. Results are depicted in Figure 3 (a). A similar phase transition phenomenon to the one in Figure 1 is found also here. As can be seen in Figure 4 , the extrapolation deteriorates for larger initialization scale, in the sense that it requires longer training sequences for getting good extrapolation error. This suggests that the condition of small initialization required by our theory is not an artifact of our proof technique, but rather a necessary condition for extrapolation to occur.

B.2 UNBALANCED TEACHER

In Section 5.2, we have tested the extrapolation with respect to a specific unbalanced teacher and have observed a similar phase transition as predicted by the theory of Section 4 and empirical evaluation of Section 5.1. Here we show that the phase transition is not limited to the specific teacher discussed by 

B.3 IMPULSE RESPONSE FIGURES

In Section 5 we have presented the extrapolation performance in different settings. In order to better convey the meaning of extrapolating vs non-extrapolating solutions we present here figures of the impulse response of different models. We start with the impulse response corresponding to the experiment described in Section 5.1. Figure 5 depicts the balanced teacher with d " 5 and two selected students (with d " 40), one trained with k " 10 and the other with k " 20. We can see that the student trained with k " 10 tracks the teacher several steps beyond the 10th time step and then decays to zero. For k " 20 we can see near perfect extrapolation for the horizon evaluated. Next we turn to Section 5.2 and depict the average impulse responses (Figure 6 ) of the "delay teacher" and the students trained with respect to the mentioned teacher. Since the teacher here has d " 10, a model trained with k " 8 is trained with respect to the zero impulse response (see Section C.2.2 for details on delay teacher), and as expected results with the 'zero' solution. we can see that for k " 18 the student diverges from the teacher shortly after the 18th time step. For k " 20 we can see near perfect extrapolation up to the horizon considered.

C IMPLEMENTATION DETAILS

All the experiments are implemented using PyTorch.

C.1 OPTIMIZATION

In Section 5.1 we optimize the population loss, which entails minimizing Equation 3.4 with respect to the parameters of the learned model. We use 15K optimization steps with Adam optimizer and a learning rate of 10 ´3. In this experiment, the results were not sensitive to the initialization scale of the (balanced) student. In Section 5.2 and Section 5.3 in the experiments that involve minimizing the empirical loss, we use 50K optimization steps with early stopping (most experiments required less than 10K steps). The batch size is set to 100, data is sampled from a Gaussian with zero mean k " 8, 18, 20 with respect to the unbalanced delay teacher described in Section C.2.2. We can see that for k " 18 the student diverges for longer sequences while k " 20 which is trained for merely two additional time steps extrapolates and tracks the teacher almost perfectly. and scale of 1. Experiments were not sensitive to most hyper-parameters other than learning rate and initialization scale. The examination of the effect of initialization scale presented in Section B.1.2 is done with learning rate scheduler torch.optim.lr_scheduler.MultiStepLR using milestones at r5000, 10000, 15000, 30000s and a decaying factor of γ " 0.1.

C.2 TEACHER GENERATION

One of the main challenges in empirically evaluating extrapolation is that randomly sampling weights from a Gaussian distribution may result with an RNN of lower effective rank (i.e. the resulting RNN may be accurately approximated with another RNN with a smaller hidden dimension). We will now describe the teacher generation scheme for the different experiments.

C.2.1 BALANCED TEACHER GENERATION

A balanced teacher consists of d entries corresponding to the diagonal teacher and d entries representing B " ĈJ . In order to avoid cases of rapid decay in the impulse response on the one hand, and exponential growth on the other, we set the eigenvalues to distribute uniformly between 0.6 and 1.05. The values of B and Ĉ are randomly sampled from a Gaussian around 0.5 and scale 1 and then normalized such that Ĉ B " 1.

C.2.2 UNBALANCED TEACHER GENERATION

In this experiment, the teacher has a general (non-symmetrid) matrix Â and B ‰ ĈJ . We set the weights as described next. Delay Teacher A 'delay' teacher has an impulse response of 1 at time step i " d ´1, that is, the teacher has an impulse response of p0, . . . , 0, 1, 0, . . . q. In order to generate the mentioned impulse response we set the weights as follows, A " ¨0 1 0 . . . Note that B, C above are set to extract the last entry of the first row of A i and A is a Nilpotent shift matrix. It is straightforward to verify that CA i B " 1 for i " d ´1 and 0 otherwise.

Random Unbalanced Teacher

The second unbalanced teacher is randomly generated. In order to avoid the caveats mentioned in Section B.1, we randomly sample the diagonal (from a Gaussian with zero mean and scale 0.1) and super diagonal (from a Gaussian with mean 0.7 and scale 0.1) of A. We set B, C as in equation C.1. The structure of A ensures similar properties to that of the delayed teacher, specifically, that the first entries of the impulse response is zero and the teacher is 'revealed' only after d time steps.

C.2.3 NON-LINEAR TEACHER GENERATION

As opposed to the linear teacher discussed in previous sections, when the teacher is a Gated Recurrent Units (GRU), it is unclear how to generate a non-trivial teacher. When randomly generating a teacher GRU the result is either a trivial model that quickly decays to zero or a teacher with an exploding impulse response (depending on the scale of the initialization). In order to produce a teacher with interesting extrapolation behaviour, we initialize a model with an initialization scale of 10 ´6 and train for 1000 step the model to mimic an arbitrarily chosen impulse response. The result of the mentioned procedure is a teacher GRU with non-trivial behaviour. Figure 2 (b) shows that we get with this non-trivial teacher the phase transition phenomena as described in Section 5.3.

C.3 EXTRAPOLATION ERROR

The concept of extrapolation is very intuitive, and yet it does not admit any standard error measure. A proper extrapolation error measure should: (a) capture fine differences between two models with good extrapolation behaviour; and on the other hand, (b) be insensitive to the scale in which two non-extrapolating model explode. A natural approach which we take here is to report the ℓ 8 norm difference on the tail of the impulse response. A model is considered non-extrapolating if the extrapolation error is worse than the extrapolation error of a trivial solution which has an impulse response of zeros.

D ACCUMULATING LOSS

In the main paper the analysis is performed for the loss function defined in Section 3, which corresponds to a regression problem over sequences. Another important and common loss function is an accumulating loss defined over the full output sequence. Specifically, the empirical loss of Equation 3.3 is replaced with, L S pA, B, Cq " 1 N N ÿ i"1 k´1 ÿ j"0 ℓ ´RN N ´xpiq j ¯, y piq j ¯, (D.1) In this section we discuss the adaptations required to accommodate our theory with the loss defined in Equation D.1.

D.1 POPULATION LOSS

A similar derivation of the population loss described in Appendix E.1.1 can be applied to Equation D.1. The difference is that an additional summation is introduced and is preserved throughout the analysis to result with, E x"D « k´1 ÿ j"0 ℓ pRN N px j q , y j q ff " k´1 ÿ j"0 j ÿ i"0 `CA i B ´wi ˘2 . (D. 2) The loss above can be viewed as a different weighting of the original population loss, i.e. Equation D.2 can be written as k´1 ÿ i"0 pk ´iqpCA i B ´wi q 2 . (D.3) It is clear that the minimizers of Equation D.2 are the same minimizers of Equation 3.4 (i.e. CA i B " w i for i " 0, . . . , k ´1). Thus Lemma 7 holds with no additional modifications.

D.2 APPROXIMATE EXTRAPOLATION

For Theorem 8 the analysis in the proof makes use of the fact that the difference in the moments defined by the student and teacher is bounded by Op ? δq. The same is true for the case of the weighted loss, specifically, if the loss ď δ, then for all i " 0, . . . , k ´1, a pk ´iqpCA i B ´wi q ď ? δ. Since k ´i ě 1 for i " 0, . . . , k ´1 we have pCA i B ´wi q ď a pk ´iqpCA i B ´wi q ď ? δ and the remainder of the proof is the same.

D.3 IMPLICIT BIAS FOR BALANCEDNESS

The proof of the implicit bias for balancedness involve the gradients of the population loss defined in Equation 3.4. For the weighted population loss the gradients differ, but the symmetries are all preserved (the gradient computation boils down to adding an external summation to the terms computed in Section E.1.4. The same steps described in Section E.4 apply for the weighted loss.

E DEFERRED PROOFS

Here we provide complete proofs for the results in the paper.

E.1 AUXILARY PROOFS

In this section we provide missing proofs from the main paper and additional lemmas to be used in the main proofs.

E.1.1 POPULATION LOSS

Lemma E.1 (Proof of Equation 3.4). Assume x " D such that E x"D rxs " 0, E x"D rxx J s " I k P R k,k , where I k is the identity matrix. y is given by y " { RN N pxq where { RN N p¨q denotes the output of a teacher RNN, Θ " p Â, B, Ĉq. Denote w i " Ĉ Âi B, the loss for the student RNN satisfies: E x"D rℓ pRN N pxq , yqs " k´1 ÿ i"0 `CA i B ´wi ˘2 . (E.1) Proof of Lemma E.1. The population loss for training with sequences of length k is E x"D rℓ pRN N pxq , yqs " E x"D » - ˜k´1 ÿ i"0 CA k´1´i Bx i ´k´1 ÿ j"0 w k´1´j x j ¸2fi fl . (E.2) Reversing the order of summation, expanding the terms, E x"D rℓ pRN N pxq , yqs " E x"D » - ˜k´1 ÿ i"0 CA i Bx k´1´i ´k´1 ÿ j"0 w j x k´1´j ¸2fi fl (E.3) " k´1 ÿ i,j"0 " CA i BCA j B ´2CA i Bw j `wi w j ‰ E x"D rx k´1´i x k´1´j s (E.4) " k´1 ÿ i,j"0 " CA i BCA j B ´2CA i Bw j `wi w j ‰ 1 k´1´i"k´1´j (E.5) " k´1 ÿ i,j"0 " CA i BCA j B ´2CA i Bw j `wi w j ‰ 1 i"j (E.6) " k´1 ÿ i"0 " pCA i Bq 2 ´2CA i Bw i `w2 i ‰ " k´1 ÿ i"0 `CA i B ´wi ˘2 . (E.7) where the transition from the second to third rows is by our assumption that E x"D rx i x j s " 1 i"j . Therefore we have, E x"D rℓ pRN N pxq , yqs " k´1 ÿ i"0 `CA i B ´wi ˘2 . (E.8) concluding the proof.

E.1.2 PERFECT GENERALIZATION AND FAILED EXTRAPOLATION

Proposition E.2 (Proposition 3 in main paper). Assume d ą k, and let ϵ ě 0 and q P tk 1, k `2, . . .u. Then, for any teacher parameters Θ, there exist student parameters Θ with which the population loss in Equation 3.4 equals zero, and yet the student does not ϵ-extrapolate with horizon q. Proof. Consider a student, Θ, such that A is symmetric (and therefore has an orthogonal eigendecomposition). Denote A " U ΛU J . The impulse response at time step i can be expressed as CA i B " CU Λ i U J B. The latter can be written compactly in matrix form as V g where V is the Vandermonde matrix with diagpΛq as its values, V " ¨1 1 . . . 1 λ 1 λ 2 . . . λ d λ 2 1 λ 2 2 . . . λ 2 d . . . . . . . . . λ d´1 1 λ d´1 2 . . . λ d´1 d ‹ ‹ ‹ ‹ ‚ , and g is defined as g " pCU q J d U J B.foot_1 A known result on square Vandermonde matrices is that they are invertible if and only if λ i ‰ λ j , @i ‰ j. Given a fixed set of distinct values pλ 1 , . . . , λ d q and an arbitrary impulse response r P R d , in order for the student to generate the impulse response r (i.e. V g " r), one can set the coefficient vector, g " V ´1r and end up with a symmetric student with r as its impulse response of length d. Consider a teacher RNN, Θ " pA, B, Cq, we can set and the first k entries of r to r i " Ĉ Âi´1 B, @i " t1, . . . , ku. We are therefore left with d ´k degrees of freedom which yields many different students that correspond to the first k entries of the teacher while fitting arbitrary values beyond the k considered.

E.1.3 EQUIVALENCE BETWEEN BALANCED RNNS WITH SYMMETRIC AND DIAGONAL TRANSITION MATRICES

Lemma E.3. A balanced RNN, Θ " pA, B, Cq, with a symmetric transition matrix (i.e. B " C J and A " A J ) has an equivalent (i.e. generating the same impulse response) RNN, Θ 1 " pA 1 , B 1 , C 1 q, which is balanced and its transition matrix is diagonal. Lemma E.3 allows alternating between systems with symmetric and diagonal matrices. This is useful to simplify the analysis in Section 4. Proof of Lemma E.3. Any symmetric matrix admits an orthogonal eigendecomposition with real (non-imaginary) eigenvalues. Denote A " U ΛU J . We can define A 1 " Λ, B 1 " U J B and C 1 " CU , The i th index of impulse response is given by CA i B " CU Λ i U J B " C 1 `A1 ˘i B 1 concluding that Θ and Θ 1 have the same impulse response of any length.

E.1.4 GRADIENT DERIVATION

For completeness and Section E.4, we compute the gradients for the general setting. Lemma E.4. Given the population loss LpA, B, Cq " k´1 ÿ j"0 ´CA j B ´Ĉ Âj B¯2 . (3.4 revisited) Denote ∇ℓ i " CA i B ´wi , the derivatives of the loss with respect to B, and C satisfy: BL BB " k´1 ÿ i"0 ∇ℓ i pA i q J C J , (E.9) BL BC " k´1 ÿ i"0 ∇ℓ i B J `Ai ˘J . (E.10) Proof of Lemma E.4. Here, we will compute the gradient of the population loss. Note that for j ě 0, the derivative of CA j B with respect to to B is given by BpCA j Bq BB " pA j q J C J . (E.11) Similarly, the derivative of CA j B with respect to to C is given by BpCA j Bq BC " B J pA j q J . (E.12) Using these derivatives, we can calculate the derivative of the population loss, (assigning w i " B Âi Ĉ), LpA, B, Cq " E x"D rℓ pRN N pxq , yqs " k´1 ÿ i"0 `CA i B ´wi ˘2 . (E.13) Denoting ∇ℓ i " CA i B ´wi , and noting that w i is constant (depends on Θ), we have for X P tB, Cu: BL BX " k´1 ÿ i"0 B `CA i B ´wi ˘2 BX " k´1 ÿ i"0 ∇ℓ i B `CA i B ´wi BX " k´1 ÿ i"0 ∇ℓ i B `CA i B BX . (E.14) In (Cohen & Yeredor, 2011 , Theorem 1) and in (Wu & Yang, 2020, Lemma 4) it is shown that the first 2 d of a discrete random variable taking at most d different values uniquely define this random variable. Therefore, any other discrete random variable identifying with the teacher on 2 d moments must be the same random variable and therefore identifies on higher moments as well. Since we assumed k ą 2 d, this result immediately implies that equality in the first k ´1 moments implies equality in all other moments. For the case Ĉ B " 0, from our assumption that the teacher is balanced, we have that the condition is met only if Ĉi " Bi " 0 for i " 1, . . . , d. Such a teacher has an impulse response of zeros, for k ě 1, a student minimizing the loss must also satisfy CB " 0 and therefore has the zeros as its impulse response (recall the student is balanced) thus extrapolating with respect to the said teacher.

E.3 THEOREM 8 (APPROXIMATE EXTRAPOLATION)

This section is devoted to the proof of Theorem 8 which ties the approximation error of optimization to that of extrapolation. Theorem E.9. [Theorem 8 in main paper] Consider the minimization of Equation 3.4 and assume: (i) d ą k ą 2 d; (ii) the teacher is balanced and stable (i.e. the eigenvalues of Â are in r´1, 1s); (iii) the teacher is non-degenerate, i.e. the input output mapping they realize is not identically zero; (iv) the student parameters are learned by applying GF to the loss Lp¨q, starting from a balanced initialization; (v) the student parameters Θ are bounded. Then, for any ϵ ą 0 and q P N, there exists δpϵ, qq ą 0 such that whenever LpΘq ď δpϵ, qq, the student ϵ-extrapolates with horizon q. Proof of Theorem E.9. Let δ ą 0 be a constant whose value will be chosen later, and suppose GF reached a point Θ satisfying LpΘq ď δ. Following the proof of Lemma 7, Θ is identified with a distribution supported on the eigenvalues of Â, whose j'th moment is mj :" Ĉ Âj Bp Ĉ Bq ´1 for every j P N. Similarly, Θ is identified with a distribution supported on the eigenvalues of A, whose j'th moment is m j :" CA j BpCBq ´1 for every j P N. From our assumption that LpΘq ď δ, LpΘq " k´1 ÿ j"0 ´CA j B ´Ĉ Âj B¯2 ď δ. (E.28) and specifically, each term satisfies pCA j B ´Ĉ Âj Ĉq 2 ď δ for j " 0, . . . , k ´1. In particular, pCB ´Ĉ Bq 2 ď δ. Denote β " Ĉ B ´CB, then β P r´?δ, ? δs. Note that Ĉ B is a (positive) constant, multiplying the loss by p Ĉ Bq ´2 we have that each term ď δp Ĉ Bq ´2. We can write for each j " 0, . . . , k ´1, ˜CA j B Ĉ B ´Ĉ Âj B Ĉ B ¸2 " ˜CA j B Ĉ B ´CA j B CB `CA j B CB ´Ĉ Âj B Ĉ B ¸2 (E.29) " ¨CA j B ˆ1 Ĉ B ´1 CB ˙`CA j B CB ´Ĉ Âj B Ĉ B loooooooooomoooooooooon mj ´mj ‹ ‹ ‹ ‚ 2 (E.30) We can further expand the term on the left, 1 Ĉ B ´1 CB " 1 Ĉ B ´1 Ĉ B ´β " Ĉ B ´β ´Ĉ B Ĉ Bp Ĉ B ´βq " β Ĉ Bpβ ´Ĉ Bq (E.31) Plugging back to the above, we have δp Ĉ Bq ´2 ě ˜CA j B Ĉ B ´Ĉ Âj B Ĉ B ¸2 ˜βCA j B Ĉ Bpβ ´Ĉ Bq `pm j ´m j q ¸2 (E.32) " β 2 κ 2 `2βκpm j ´m j q `pm j ´m j q 2 (E.33) ě 2βκpm j ´m j q `pm j ´m j q 2 (E.34) ě ´2|δκpm j ´m j q| `pm j ´m j q 2 (E.35) where κ " CA j B Ĉ Bpβ´Ĉ Bq . From assumption (ii), the teacher is stable and therefore mj ď Ĉ B for all j " 0, . . . , k ´1. Similarly, from assumption (v) the student parameters are bounded and therefore CA j B is bounded by τ j`2 (where τ " maxt1, ηu and η is a bound on the Frobenous norm of A, B, C). m j is bounded in a similar fashion by τ j . Combining the above, for δ ă Ĉ B we have, δ p Ĉ Bq 2 ě ´2δτ j`2 pτ j `1q 2p Ĉ Bq 2 `pm j ´m j q 2 (E.36) Setting δ 1 ă δp Ĉ Bq 2 1`τ k`1 pτ k´1 `1q , if LpΘq ď δ 1 then |m j ´m j | ď ? δ for j " 1, . . . , k ´1. Proposition 2 in Wu & Yang (2020) then implies W 1 pΘ, Θq ď Opδ 1{4 dq.foot_5  Denote Ω " ´Ťd i"1 A ii ¯Ť ´Ť d j" 1 Âjj ¯(the union of the supports of Θ and Θ), from Section 2.3 in Panaretos & Zemel (2019), for q ą p the q th and p th Wasserstein distances satisfy W q q pΘ, Θq ď W p p pΘ, Θqγ q´p where γ " max x,yPΩ |x ´y|. In particular, for p " 1, W q pΘ, Θq ď ´W1 pΘ, Θqγ q´1

¯1{q

. Note that γ can is bounded by γ ď τ `1 ď 2τ (recall the student is bounded and teacher is stable). Finally, |m q ´m q | ď W q pΘ, Θq (see Section 1.2 in Biswas & Mackey (2021) ). Combining the steps above, for all q P N, |m q ´m q | ď W q pΘ, Θq ď ´W1 pΘ, Θqp2τ q q´1 ¯1{q ď ´ρδ 1{4 dp2τ q q´1 ¯1{q (E.37) where ρ is a constant satisfying W 1 pΘ, Θq ď ρδ 1{4 d. To achieve |m j ´m j | ă ϵ for any ϵ ą 0, we can set δpϵ, qq ă ´ϵq ργ q´1 ¯4 d concluding the proof.

E.4 PROPOSITION 9 (IMPLICIT BIAS FOR BALANCEDNESS)

The proof of Proposition 9 consists of several steps. First, we bound with high probability the norms of B and C at initialization (Lemma E.10). We then derive bounds on the differential equations of d dt pBptq `CJ ptqq and d dt pBptq ´CJ ptqq (Lemma E.14). We show that when the initialization scale tends to zero, the ratio between the differential equations tends to zero. (Lemma E.13). Before we turn to prove Proposition 9, we first need to bound the initial values for a vector v P R n initialized with N p0, ϵ 2 n q. Lemma E.10. Assume a vector v P R n with N p0, ϵ 2 n q per coordinate. Then: P r ˆϵ 2 ă }v} ă 3ϵ 2 ˙ě 1 ´2 expp´9n{64q. (E.38) Proof of Lemma E.10. The proof of E.10 uses known results on the Chi-square distribution Laurent & Massart (2000) , applied to our specific setting to achieve the desired bounds. We begin by changing variables, ṽi " v i ¨?n ϵ . The entries ṽi , are standard Gaussian variables. The squared norm of ṽ distributes according to the χ-squared distribution. By (Laurent & Massart, 2000, Lemma 1) Note that for a matrix X P R mˆp , Lemma E.10 bounds its Frobenius norm, P r `ϵ 2 ă }X} F ă 3ϵ 2 ˘ě 1 ´2exp p´9mp{64q. The result is straight forward by applying the lemma to X's vectorized form. Proposition E.11. [Proposition 9 in main paper] Suppose that: (i) d ą 4; (ii) the teacher parameters Θ are balanced and are non-degenerate, in the sense that the input-output mapping they realize is not identically zero; and (iii) the student parameters are learned by applying GF to the loss Lp¨q. Let Θ be a random point in parameter space, with entries drawn independently from the standard normal distribution. For ϵ ą 0, consider the case where GF emanates from the initialization ϵ Θ, and denote the resulting curve by Θ ϵ pτ q " pA ϵ pτ q, B ϵ pτ q, C ϵ pτ qq, with τ ě 0. Then, w.p. at least 0.75, for every ϵ ą 0 there exists τ ϵ ě 0 such that: lim ϵÑ0 `||B ϵ pτ ϵ q ´CJ ϵ pτ ϵ q|| F ||B ϵ pτ ϵ q `CJ ϵ pτ ϵ q|| F " 0 . (E.45) The consequence of Proposition 9 is that as ϵ converges to zero, B and C converge towards each other. For convenience, we refer to the mentioned initialization scheme (where every coordinate in a vector is initialized as N ´0, ϵ 2 d ¯) as ϵ-normal initialization. In order to prove the proposition we define a few relevant terms, Y " B ´CJ , W " B `CJ , w 0 " Ĉ B. (E.46) We will in fact prove the stronger, following lemma, for any matrix A, not necessarily symmetric. Lemma E.12. Assume w 0 ą 0, and A, B, C are ϵ-normally initialized. Then Dt such that lim ϵÑ0 }Y ptq} 2 }W ptq} 2 " 0, lim ϵÑ0 }Aptq} 2 F }W ptq} 2 " 0. (E.47) The proof of Lemma E.12 follows three steps: (1) establish a time in the optimization for which the norms of all parameters are bounded (Lemma E.13); (2) derive upper (and lower) bounds for the differential equations describing the evolvement of Y , W and A. Our approximations are limited to the initial phase of training. Concretely, we show that for 0 ď t ď 1 2w0 ln `1 ϵ 0.5 ˘, all norms are bounded. Thus, it is possible to obtain meaningful bounds on the ODEs of Y and W while A remains in the magnitude of initialization (Lemma E.14); (3) using the relevant bounds, we show that as the initialization scale tends to zero, so do the limits in Equation E.47. As it turns out, there is a critical time t " O `ln `1 ϵ 0.5 ˘˘, up until which the considered bounds are valid (see details in the proof of Lemma E.13). Lemma E.13. Assume d ą 20, student parameters are ϵ-normally initialized, assume also a balanced teacher. Then w.p. at least 0.75, for all 0 ď t ď t, there exist M 1 , M 2 such that: }Cptq}, }Bptq} ă M 1 ϵ 0.75 (E.48) and }Aptq} F ă M 2 ϵ (E.49) To prove this, we note that at initialization, A, B and C satisfy these bounds. From continuity, there exists a maximal time for which they are satisfied. We bound the rate of their growth, and thus show that for all t as described, we are within this region. Lemma E.14. Assume w 0 ą 0 (see Lemma E.1 for definition of w i ) and assume A, B, C are ϵ-normally initialized, we have the following bounds hold for all 0 ď t ď t w.p. at least 0.75, Y ptq J Y ptq ď c 1 ϵ 2 e ´2w0t `c2 ϵ 2.5 , (E.50) and W ptq J W ptq ě c 3 ϵ 2 e 2w0t ´c4 ϵ 2.5 . (E.51) Lemma E.14 shows that the growth rate of W ptq J W ptq and the decay rate of Y ptq J Y ptq both depend on the sign of w 0 . In our analysis we assume the teacher is balanced and therefore w 0 ą 0, the same analysis applies for w 0 ă 0 with opposite roles for Y and W . The proof of Lemma E.14 follows from writing the leading terms of the ODE and bounding the remaining terms by their upper bounds in the time considered. Using these lemmas, we proceed to prove Lemma E.12. Proof of Lemma E.12. Consider the dynamics at time t " C ln `1 ϵ 0.5 ˘. By Lemma E.14, w.p. at least 0.75, we have, Y ptq J Y ptq ď c 1 ϵ 2 e ´2w0C ln 1 ϵ 0.5 `c2 ϵ 2.5 " pc 1 e ´2w0C `c2 qϵ 2.5 (E.52) and W ptq J W ptq ě c 3 ϵ 2 e 2w0C ln 1 ϵ 0.5 ´c4 ϵ 2.5 " c 3 e 2w0C ϵ 1.5 ´c4 ϵ 2.5 (E.53) We can calculate the limit (where c1 and c3 account for the relevant constant factors), lim ϵÑ0 }Y ptq} 2 }W ptq} 2 ď lim ϵÑ0 pc 1 `c2 qϵ 2.5 c3 ϵ 1.5 ´c4 ϵ 2.5 " 0 (E.54) From Lemma E.13, }Ap tq} F ď M 2 ϵ, so we can calculate the limit, lim ϵÑ0 }Ap tq} 2 F }W p tq} 2 ď lim ϵÑ0 M 2 ϵ 2 c3 ϵ 1.5 ´c4 ϵ 2.5 " 0 (E.55) which concludes the proof. Proof of Lemma E.13. Applying Lemma E.10 with d ą 20 results with the bounds holding at initialization with probabilities ě 1 ´2expp´9 ¨20 2 {64q for A, and ě 1 ´2expp´9 20{64q for B and C. The probability for A, B, C satisfying the inequalities simultaneously ě p1 ´2expp´9 ¨20{64qq 3 « 0.83 ą 0.75. Suppose that the norm bounds of Equation E.38 are satisfied at t " 0. In particular, DM 1 , M 2 such that }Bp0q}, }Cp0q} ă 2ϵ ă M 1 ϵ 0.75 , (E.56) and }Ap0q} F ă 2ϵ ă M 2 ϵ. (E.57) Where 2 ă M 2 ă 1 ϵ , and M 1 ą 4ϵ 0.25 . Denote by t A the minimal time for which }Apt A q} " M 2 ϵ. Similarly, t B , t C are the times for which }Bpt B q} " }Cpt C q} " M 1 ϵ 0.75 . Denote also t " min tt A , t B , t C u. Proving the lemma amounts to showing there exists C P R such that t " C ln `1 ϵ 0.5 ˘. Next, we turn to develop the differential inequalities of the norms, which will later be used to lower bound the time until violation of the mentioned bounds. Recall the derivative of B with respect to time (see Section E.1.4), 9 B " ´k´1 ÿ i"0 ∇ℓ i pA J q i C J . (E.58) Using Cauchy-Schwartz inequality, we have that for all t P r0, ts, the norm of B is upper bounded by › › › 9 B › › › " › › › › › ´k´1 ÿ i"0 ∇ℓ i pA J q i C J › › › › › ď k´1 ÿ i"0 |∇ℓ i | › › A J › › i F › › C J › › . (E.59) We now bound the norms ∇ℓ i , A and C in order to transfer the inequality to a differential one. Denote M " max i p|w i |q `M 2 1 ϵ 1.5 , then we have Next we show that t " O `ln `1 ϵ 0.5 ˘˘. Suppose this is not the case, then there exists t ă ln `1 ϵ 0.5 such that one of the bounds are violated: (i) }Bp tq} ě M 1 ϵ 0.75 ą 4ϵ; (ii) }Cp tq} ě M 1 ϵ 0.75 ą 4ϵ; or (iii) }Ap tq} ě M 2 ϵ ą 2ϵ. M " max i p|w i |q `M 2 1 ϵ 1. } 9 B} ď k´1 ÿ i"0 |∇ℓ i |}pA J q} i F }C J } ă M p}B} `4ϵq k´1 ÿ i"0 }pA J q} i F (E.66) ă M kp}B} `4ϵq. (E.67) Denoting γ " }B} 2 " B J B, then 9 γ " 9 B J B `BJ 9 B " 2B J 9 B (E. Consider case (i),foot_6 from continuity there exists t 1 P R such that }Bpt 1 q} " 4ϵ, and t 2 P R such that for any t P rt 1 , t 2 s, 4ϵ ď }Bptq} ď M 1 ϵ 0.75 . In a case, we also have which may be further manipulated to reach, |γpsq| ă e 4kM ln p 1 ϵ 0.5 q |γpt 1 q| " p4ϵq 2 e 4kM ln p 1 ϵ 0.5 q . (E.75) where we have used }Bpt 1 q} " 4ϵ. The final bound on the norm of γpsq is therefore, |γpsq| ă 16ϵ 2 ϵ 0.5 e 4kM " 16ϵ 1.5 e 4kM . (E.76) Denoting M 2 1 " 16 ¨e4kM , and taking the square root of the above, }Bpsq} ă M 1 ϵ 0.75 . (E.77) We have shown that for all 0 ď t ď t ď ln `1 ϵ 0.5 ˘, there exists M 1 s.t }Bptq} ă M 1 ϵ 0.75 (the same proof applies for case (ii)). Consider case (iii), we need to show that the bound over }A} F applies for t P r0, t s. Notice that for a matrix, }A} 2 F " T rpA J Aq, and d dt T rpA J Aq " T r ˆd dt pA J Aq ˙(E.78) " T r ´9 A J A ¯`T r ´AJ 9 A ¯(E.79) " 2T r ´AJ 9 A ¯, (E.80) where we have used the linearity of trace and its invariance to transpose. The derivative of A with respect to time (see Section E.1.4),

9

A " ´k´1 ÿ i"1 ∇ℓ i i´1 ÿ r"0 pA J q r C J B J pA J q i´r´1 . (E.81) Multiplying it from the left by A J and then taking trace provides us with T rpA J 9 Aq " ´k´1 ÿ i"1 ∇ℓ i i´1 ÿ r"0 T r `pA J q r`1 C J B J pA J q i´r´1 ˘. (E.82) Taking a transpose and then using the cyclic property of trace and, for each summand, T r `pA J q r`1 C J B J pA J q i´r´1 ˘" T r `BJ pA J q i C J ˘(E.83) " T r `CA i B ˘(E.84) " CA i B. (E.85) Equation E.82 evaluates to T rpA J 9 Aq " ´k´1 ÿ i"1 ∇ℓ i i´1 ÿ r"0 CA i B " ´k´1 ÿ i"1 ∇ℓ i ¨i ¨CA i B. (E.86) Bounding T rpA J 9 Aq, T rpA J 9 Aq ď k´1 ÿ i"1 |∇ℓ i | ¨i ¨|CA B| (E.87) Using the Cauchy-Schwartz inequality and then plugging M ą |∇ℓ i | and the bounds found for }C}, }B} ă M 1 ϵ 0.75 leads to |∇ℓ i | ¨i ¨|CA i B| ď M ¨i ¨}C}}A} i F }B} ď M ¨i ¨M 2 1 ϵ 1.5 }A} i F (E.88) Putting the bound of Equation E.88 into Equation E.87, results with, We make use of the fact that @x ą 0, lnpxq ă x to bound, M2 ϵ 1.5 ln ˆ1 ϵ 0.5 ˙ď M2 ϵ 1.5 ϵ ´0.5 " M2 ϵ. T rpA J 9 Aq ď k´1 ÿ i"1 M ¨i ¨M 2 1 ϵ 1.5 }A} i F " M ¨M 2 1 ϵ 1.5 k´1 ÿ i"1 i ¨}A} i F (E. (E.97) From our assumption on initialization, }Ap0q} F ă 2ϵ. ∇ℓ i `pA J q i C J ´Ai B Denote pA i q S " A i `pA J q i 2 and pA i q S " A i ´pA J q i 2 the symmetric and anti-symmetric parts of A i . We can now write pA J q i C J ´Ai B " " pA i q S ´pA i q S ‰ C J ´"pA i q S `pA i q S ‰ B (E.103) " pA i q S pC J ´Bq ´pA i q S pC J `Bq " pA i q S Y ´pA i q S W Note also that for i " 0, we have A 0 " I and its anti-symmetric part is the zero matrix, writing i " 0 separately and assigning equation E.103 into equation E.102, Y " ∇ℓ 0 Y `k´1 ÿ i"1 ∇ℓ i `pA i q S Y ´pA i q S W ˘(E.104) Let us look at d dt pY J Y q " 9 Y J Y `Y J 9 Y " 2Y J 9 Y . Multiplying equation E.104 from the left with Y J evaluates to Y J 9 Y " ∇ℓ 0 Y J Y `k´1 ÿ i"1 ∇ℓ i `Y J pA i q S Y ´Y J pA i q S W ˘(E.105) We now turn to bound the terms in the sum. Y J 9 Y ď ∇ℓ 0 Y J Y `k´1 ÿ i"1 |∇ℓ i | `|Y J pA i q S Y | `|Y J pA i q S W | ˘(E.106) We can bound each term using Cauchy-Schwarz. We first need to bound }Y } and }W }, which are trivially bounded by }Y }, }W } ď }C} `}B} ď 2M 1 ϵ 0.75 (E.107) As for the symmetric and anti-symmetric parts of A, }pA i q S } F " › › › › A i `pA i q J 2 › › › › F ď 1 2 `}A i } F `}pA i q J } F ˘" }A i } F ď }A} i F , (E.108) where the last inequality follows again from Cauchy-Schwarz (the same considerations apply for pA i q S ). From Cauchy-Schwarz we can bound |Y J pA i q S W | ď }Y J }}pA i q S } F }W } ď }Y }}A} i F }W }, denote M 3 " maxtM 1 , M 2 u and derive, |Y J pA i q S W | ď 2M 1 ϵ 0.75 pM 2 ϵq i 2M 1 ϵ 0.75 ă 4M 3 3 ϵ 1.5`i , (E.109) Published as a conference paper at ICLR 2023 which is maximized when i " 1. We again bound: M " max i p|w i |q `M 2 1 ϵ 1.5 ą |∇ℓ i |. We can bound the terms in equation E.106 by |∇ℓ i | `|Y J pA i q S Y | `|Y J pA i q S W | ˘ď 8 ¨M ¨M 3 3 ϵ 2.5 (E.110) Plugging back into equation E.106: Y J 9 Y ď ∇ℓ 0 Y J Y `k´1 ÿ i"1 8 ¨M ¨M 3 3 ϵ 2.5 (E.111) We can also bound ∇ℓ 0 " pCB ´w0 q ď ´w0 `|CB| ď ´w0 `}C}}B} ď ´w0 `M 2 1 ϵ 1.5 . Note also that we multiply by Y J Y so we can bound ∇ℓ 0 Y J Y ď ´w0 Y J Y `M 2 1 ϵ 1.5 pM 2 1ϵ 0.75 q 2 loooooooooomoooooooooon "4M  Y q " 9 Y J Y `Y J 9 Y " 2Y J 9 Y , d dt pY J Y q ď ´2w 0 Y J Y `M4 ϵ 2.5 (E.115) Denoting zptq " Y ptq J Y ptq and xptq " W ptq J W ptq, and using Lemma E.15, we have the desired bounds zptq ď 1 2w 0 " p33w 0 dϵ 2 qe ´2w0t `M4 ϵ 2.5 ‰ (E.116) In particular, we can write zptq ď c 1 ϵ 2 e ´2w0t `c2 ϵ 2.5 (E.117) and xptq ě c 3 ϵ 2 e 2w0t ´c4 ϵ 2.5 (E.118) where c i 's are positive constants. Note that the derivation of xptq is exactly the same as zptq with opposite signs and bounding from below instead.

E.4.2 INTEGRAL BOUND OF DIFFERENTIAL EQUATIONS

Lemma E.15. Assume 9 z ă ´2w 0 z `M4 ϵ 2.5 ă 0, 9 x ą 2w 0 z ´M4 ϵ 2.5 ą 0 , where w 0 ą 0. Then, under the assumptions of Lemma E.10: zpt 1 q ă 1 2w 0 `expp´2w 0 t 1 q ¨p75w 0 dϵ 2 q `M4 ϵ 2.5 ˘(E.119) xpt 2 q ą 1 2w 0 ˆexpp2w 0 t 2 q ¨p w 0 ϵ 2 25d q `M4 ϵ 2.5 ˙(E.120) Proof of Lemma E.15. Assume 9 z ă ´2w 0 z `M4 ϵ 2.5 , where w 0 ą 0, 2w 0 z ą M 4 ϵ 2.5 . Similarly, assume 9 x ą 2w 0 x ´M4 ϵ 2.5 . Then: 9 z 2w 0 z ´M4 ϵ 2.5 ă ´1 (E.121) x 2w 0 x ´M4 ϵ 2.5 ą 1 (E.122) Integrating both sides by dt, and using integration by substitution, we get: We note: ż zpt1q zp0q 1 2w 0 z ´M4 ϵ 2.5 dz " 1 2w 0 rlnp2w 0 zpt 1 q ´M4 ϵ 2.5 q ´lnp2w 0 zp0q ´M4 ϵ 2.5 qs (E.125) ñ ż zpt1q zp0q 1 2w 0 z ´M4 ϵ 2.5 dz " 1 2w 0 " ln ˆ2w 0 zpt 1 q ´M4 ϵ 2.5 2w 0 zp0q ´M4 ϵ 2.5 ˙ȷ (E.126) ż xpt2q xp0q 1 2w 0 x ´M4 ϵ 2.5 dx " 1 2w 0 " ln ˆ2w 0 xpt 2 q ´M4 ϵ 2.5 2w 0 xp0q ´M4 ϵ 2.5 ˙ȷ (E.127) Combining equations, we have: 1 2w 0 " ln ˆ2w 0 zpt 1 q ´M4 ϵ 2.5 2w 0 zp0q ´M4 ϵ 2.5 ˙ȷ ă ´t1 (E.128) ñ ln ˆ2w 0 zpt 1 q ´M4 ϵ 2.5 2w 0 zp0q ´M4 ϵ 2.5 ˙ă ´2w 0 t 1 (E.129) ñ 2w 0 zpt 1 q ´M4 ϵ 2.5 2w 0 zp0q ´M4 ϵ 2.5 ă expp´2w 0 t 1 q (E.130) ñ 2w 0 zpt 1 q ´M4 ϵ 2.5 ă expp´2w 0 t 1 q ¨`2w 0 zp0q ´M4 ϵ 2.5 ˘(E.131) ñ zpt 1 q ă 1 2w 0 " expp´2w 0 t 1 q ¨p2w 0 zp0q ´M4 ϵ 2.5 q `M4 ϵ 2.5 ‰ (E.132) xpt 2 q ą 1 2w 0 " expp2w 0 t 2 q ¨p2w 0 xp0q ´M4 ϵ 2.5 q `M4 ϵ 2.5 ‰ (E.133) Note that zp0q " Y J p0qY p0q. From Cauchy-Schwartz, zp0q ă 3ϵ 2 with high probability. W is distributed as Y , therefore, xp0q ą 1 2 ϵ 2 . Assuming M 4 ϵ 2.5 ă w0ϵ 2 2 , we have: zpt 1 q ă 1 2w 0 `expp´2w 0 t 1 q ¨p6w 0 ϵ 2 q `M4 ϵ 2.5 ˘(E.135) xpt 2 q ą 1 2w 0 `expp2w 0 t 2 q ¨pw 0 ϵ 2 q `M4 ϵ 2.5 ˘(E.136) Concluding the proof.



Note that with GRU networks, in contrast to linear RNNs, the impulse response does not identify the input-output mapping realized by a network. It is presented in Figure2(b) for demonstrative purposes. Here d denotes the Hadamard (elementwise) product. The last equality follows since in the SISO setup, B J B and CC J are scalars and therefore the trace operator can be omitted. The case for which Ĉ B " 0 is handled separately. Equality of the 0 th moment ensures the student induces a valid probability, i.e. ř i CiBi= ř i Ĉi Bi=1. Here we overload notations and denote the distributions of the teacher and student by Θ and Θ respectively The case of }Cptq} ě M1ϵ 0.75 is handled similarly.



Figure 1(b) reports the results of an experiment where the state space dimensions of the teacher and (overparameterized) student are d " 10 and d " 50 respectively (higher state space dimensions for the student, namely d " 100 and d " 200, yield qualitatively identical results), and where the teacher implements a delay line of d time steps (for details see Appendix C.2.2). Similar results obtained with randomly generated teachers are reported in Appendix B. As can be seen, despite the fact that our theory does not apply to the evaluated settings, its conclusions still hold -extrapolation error is low when the training sequence length k is greater than 2 d, and high when k falls below 2 d.

Figure 1: Demonstration of implicit extrapolation with linear RNNs. Plots show extrapolation error (average over three random seeds, with shaded region marking standard deviation) as a function of training sequence length k, for a student with state space dimension d learning from a teacher with state space dimension d, where d " d. (a) Models adhere to the setting described in Section 3 and theoretically analyzed in Section 4, with d " 5, d " 40. (b) Models do not adhere to some of the assumptions made by the theory, and d " 10, d " 50. Notice that extrapolation exhibits a phase transition that accords with theory -when k ą 2 d extrapolation error is low, and when k falls below 2 d extrapolation error is high. The gradual transition exhibited is due to numerical errors introduced by the optimization not reaching an exact global minimum. For further details see Sections 5.1 and 5.2 and Appendix B.

Figure 2: Demonstration of implicit extrapolation with non-linear RNNs, namely GRU networks. Plots show results for a student with state space dimension d g " 100 learning from a teacher with state space dimension dg " 10 using training sequences of length k g , where k g varies. (a) Extrapolation error (average over ten random seeds, with shaded region marking standard deviation) as a function of k g . (b) Average output over several inputs of teacher and student for different choices of k g . Notice that, similarly to the case with linear RNNs, there exists a critical threshold for k g above which extrapolation error is low and below which extrapolation error is high. See details in Appendix B.

Figure 3: Extrapolation error as a function of the training sequence length k. (a) a balanced teacher with state dimensions d " 5 and a general (unbalanced and non diagonal) student with d " 40. (b) a random unbalanced teacher (see Section B.2) with dimension d " 5, and a student that has a non-diagonal transition matrix and is trained with standard (small) initialization, with state dimension d " 50. In both plots results are averaged over 3 seeds.

Figure 4: Extrapolation error as a function of training sequence length k for different initialization scales. Extrapolation error increases along with the scale of initialization.

Figure 5: Balanced teacher and student impulse response. Students trained with: k " 10, 20 with respect to the balanced teacher described in Section C.2.1. As can be seen, both students track the teacher up to the k used in training, for k " 10 there is no extrapolation for larger values of k, whereas k " 20 tracks the teacher well beyond the sequence length used in training.

Figure 6: Unbalanced teacher (delay) and student impulse response. Students trained with:k " 8, 18, 20 with respect to the unbalanced delay teacher described in Section C.2.2. We can see that for k " 18 the student diverges for longer sequences while k " 20 which is trained for merely two additional time steps extrapolates and tracks the teacher almost perfectly.

Putting back together, .4.1 BOUNDING THE DIFFERENTIAL EQUATIONS Proof of Lemma E.14. Denote Y " C J ´B, W " C J `B.

In particular, there exists M 4 such thatY J 9 Y ď ´w0 Y J Y `M4 ϵ 2.5 (E.114)Recall that we were interested in bounding d dt pY J

By linearity of sum of variances, Y p0q's entries are distributed

7. ACKNOWLEDGEMENTS

This work was supported by the European Research Council (ERC) under the European Unions Horizon 2020 research and innovation programme (grant ERC HOLI 819080), the Tel-Aviv University Data-Science and AI Center (TAD), a Google Research Scholar Award, a Google Research Gift, the Yandex Initiative in Machine Learning, the Israel Science Foundation (grant 1780/21), Len Blavatnik and the Blavatnik Family Foundation, and Amnon and Anat Shashua.

annex

Plugging in Equation E.11 and Equation E.12, we have:i"0 ∇ℓ i B J pA i q J , (E.16) E.1.5 LEMMA 6 (CONSERVATION OF BALANCEDNESS)Lemma E.5. [Lemma 6 in main paper] When optimizing equation 3.4 with GF emenating from a balanced initialization Θp0q, the parameters Θpτ q are balanced for all τ P R `.We prove the above result by first showing it for GD and then translating the result to GF. The GD result is stated below, and generalizes a result that was shown in Cohen-Karlik et al. (2022) for the memoryless case. Lemma E.6. When optimizing equation 3.4 with GD with balanced initial conditions, then @t P N, Θ has a balanced weight configuration, i.e. B t " C J t .Proof of Lemma E.6. We prove by induction. By our assumption, the condition holds for t " 0. Assume B t " C J t , our goal is to show the conditions hold for pB t`1 , C t`1 q. In order to show that B t`1 " C J t`1 , we only need to show that BL BBt " ´BL BCt ¯J. Writing the gradients (Lemma E.4), we have ˆBLwhere the inequality follows from the induction assumption and the symmetric structure of A t . To conclude, the gradients at time t are the same and B t " C J t by the induction assumption, arriving atThe proof of Lemma 6 follows from Lemma E.6 and the fact that for sufficiently small step size GD approximates GF with arbitrary precision (see Theorem 3 in Elkabetz & Cohen, 2021 ).

E.1.6 CONSERVATION OF DIFFERENCE OF NORMS

Appendix E.1.5 shows that if weights are initialized to be balanced, this property is conserved throughout optimization. Here we show under standard initialization schemes, the difference between the norms of B and C is also conserved. Lemma E.7. When optimizing equation 3.4 with GF the difference between the norms of B, C is conserved throughout GF, i.e., d dtProof of Lemma E.7. We wish to prove that the difference between the norms is conserved over time. Consider the following expression:Published as a conference paper at ICLR 2023 With this notation, we just need to prove that 9 α " 0. The derivative of B, C with respect to time is by,Using the interchangeability of derivative and transpose, we have:Plugging equation E.21 and equation E.22, we get

E.2 LEMMA 7 (EXACT EXTRAPOLATION)

Lemma E.8. [Lemma 7 in main paper] Suppose that d ą k ą 2 d, the teacher is balanced, and that the student parameters Θ are balanced and satisfy LpΘq " 0. Then Θ extrapolates.Proof of Lemma E.8. By Lemma E.3, a balanced RNN with symmetric transition matrix has an equivalent (generating the same impulse response) balanced RNN with a diagonal transition matrix. We will continue under the assumption of diagonal transition matrices.Without loss of generality we assume Ĉ B " 1. Otherwise, the problem can be rescaled by Ĉ B, which is equivalent to rescaling the initial conditions, and providing no additional information. 5From the balanced assumption, we have ĈJ " B. Denote p " ĈJ d B " B d B, and we get pi ě 0 and ř i pi " 1, and therefore p may be interpreted as a distribution over a random variable with d possible values. We shall assume that these values are Â1,1 , . . . , Â d, d, and denote the corresponding random variable by Ẑ.Furthermore, we can also interpret elements of the impulse response of Θ as moments of this distribution. Let us write the n th element of the impulse response as:where E p rZs is the expected value of a random variable Z under the distribution p. In the same way, we can define for the learned model Θ, a distribution p i " C i B i , and write the learned impulse response as:This view provides us with a moment matching interpretation of the learning problem. Namely, the fact that Θ matches the first k elements of the teacher impulse response, is the same as saying they agree on the first k ´1 moments E p rZ j s for j P t1, . . . , k ´1u. 6 The question of extrapolation is whether equality in the first k ´1 moments implies an equality in all other moments.

