ON THE CONVERGENCE OF GRADIENT FLOW ON MULTI-LAYER LINEAR MODELS Anonymous authors Paper under double-blind review

Abstract

In this paper, we analyze the convergence of gradient flow on a multi-layer linear model with a loss function of the form f We show that when f satisfies the gradient dominance property, proper weight initialization leads to exponential convergence of the gradient flow to a global minimum of the loss. Moreover, the convergence rate depends on two trajectory-specific quantities that are controlled by the weight initialization: the imbalance matrices, which measure the difference between the weights of adjacent layers, and the least singular value of the weight product Our analysis provides improved rate bounds for several multi-layer network models studied in the literature, leading to novel characterizations of the effect of weight imbalance on the rate of convergence. Our results apply to most regression losses and extend to classification ones.

1. INTRODUCTION

The mysterious ability of gradient-based optimization algorithms to solve the non-convex neural network training problem is one of the many unexplained puzzles behind the success of deep learning in various applications (Krizhevsky et al., 2012; Hinton et al., 2012; Silver et al., 2016) . A vast body of work has tried to theoretically understand this phenomenon by analyzing either the loss landscape or the dynamics of the training parameters. The landscape-based analysis is motivated by the empirical observation that deep neural networks used in practice often have a benign landscape (Li et al., 2018a) , which can facilitate convergence. Existing theoretical analysis (Lee et al., 2016; Sun et al., 2015; Jin et al., 2017) shows that gradient descent converges when the loss function satisfies the following properties: 1) all of its local minimums are global minima; and 2) every saddle point has a Hessian with at least one strict negative eigenvalue. Prior work suggests that the matrix factorization model (Ge et al., 2017) , shallow networks (Kawaguchi, 2016) , and certain positively homogeneous networks (Haeffele & Vidal, 2015; 2017) have such a landscape property, but unfortunately condition 2) does not hold for networks with multiple hidden layers (Kawaguchi, 2016) . Moreover, the landscape-based analysis generally fails to provide a good characterization of the convergence rate, except for a local rate around the equilibrium (Lee et al., 2016; Ge et al., 2017) . In fact, during early stages of training, gradient descent could take exponential time to escape some saddle points if not initialized properly (Du et al., 2017) . The trajectory-based analyses study the training dynamics of the weights given a specific initialization. For example, the case of small initialization has been studied for various models (Arora et al., 2019a; Gidel et al., 2019; Li et al., 2018b; Stöger & Soltanolkotabi, 2021; Li et al., 2021b; a) . Under this type of initialization, the trained model is implicitly biased towards low-rank (Arora et al., 2019a; Gidel et al., 2019; Li et al., 2018b; Stöger & Soltanolkotabi, 2021; Li et al., 2021b) , and sparse (Li et al., 2021a) models. While the analysis for small initialization gives rich insights on the generalization of neural networks, the number of iterations required for gradient descent to find a good model often increases as the initialization scale decreases. Such dependence proves to be logarithmic on the scale for symmetric matrix factorization model (Li et al., 2018b; Stöger & Soltanolkotabi, 2021; Li et al., 2021b) , but for deep networks, existing analysis at best shows a polynomial dependency (Li et al., 2021a) . Therefore, the analysis for small initialization, while insightful in understanding the implicit bias of neural network training, is not suitable for understanding the training efficiency in practice since small initialization is rarely implemented due to its slow convergence. Another line of work studies the initialization in the kernel regime, where a randomly initialized sufficiently wide neural network can be well approximated by its linearization at initialization Jacot et al. (2018) ; Chizat et al. (2019) ; Arora et al. (2019b) . In this regime, gradient descent enjoys a linear rate of convergence toward the global minimum (Du et al., 2019; Allen-Zhu et al., 2019; Du & Hu, 2019) . However, the width requirement in the analysis is often unrealistic, and empirical evidence has shown that practical neural networks generally do not operate in the kernel regime (Chizat et al., 2019) . The study of non-small, non-kernel-regime initialization has been mostly centered around linear models. For matrix factorization models, spectral initialization (Saxe et al., 2014; Gidel et al., 2019; Tarmoun et al., 2021) allows for decoupling the training dynamics into several scalar dynamics. For non-spectral initialization, the notion of weight imbalance, a quantity that depends on the differences between the weights matrices of adjacent layers, is crucial in most analyses. When the initialization is balanced, i.e., when the imbalance matrices are zero, the convergence relies on the initial end-to-end linear model being close to its optimum (Arora et al., 2018a; b) . It has been shown that having a non-zero imbalance potentially improves the convergence rate (Tarmoun et al., 2021; Min et al., 2021) , but the analysis only works for two-layer models. For deep linear networks, the effect of weight imbalance on the convergence has been only studied in the case when all imbalance matrices are positive semi-definite (Yun et al., 2020) , which is often unrealistic in practice. Lastly, most of the aforementioned analyses study the l 2 loss for regression tasks, and it remains unknown whether they can be generalized to other types of losses commonly used in classification tasks. Our contribution: This paper aims to provide a general framework for analyzing the convergence of gradient flow on multi-layer linear models. We consider the gradient flow on a loss function of the form L = f (W 1 W 2 • • • W L ) , where f satisfies the gradient dominance property. We show that with proper initialization, the loss converges to its global minimum exponentially. More specifically: • Our analysis shows that the convergence rate depends on two trajectory-specific quantities: 1) the imbalance matrices, which measure the difference between the weights of adjacent layers, and 2) a lower bound on the least singular values of weight product W = W 1 W 2 • • • W L . The former is time-invariant under gradient flow, thus it is fully determined by the initialization, while the latter can be controlled by initializing the product sufficiently close to its optimum. • Our analysis covers most initialization schemes used in prior work (Saxe et al., 2014; Tarmoun et al., 2021; Arora et al., 2018a; b; Min et al., 2021; Yun et al., 2020) for both multi-layer linear networks and diagonal linear networks while providing convergence guarantees for a wider range of initializations. Furthermore, our rate bounds characterize the general effect of weight imbalance on convergence. • Our convergence results directly apply to loss functions commonly used in regression tasks, and can be extended to loss functions used in classification tasks with an alternative assumption on f , under which we show O(1/t) convergence of the loss. Notations: For an n × m matrix A, we let A T denote the matrix transpose of A, σ i (A) denote its i-th singular value in decreasing order and we conveniently write σ min (A) = σ min{n,m} (A) and let σ k (A) = 0 if k > min{n, m}. We also let ∥A∥ 2 = σ 1 (A) and ∥A∥ F = tr(A T A). For a square matrix of size n, we let tr(A) denote its trace and we let diag{a i } n i=1 be a diagonal matrix with a i specifying its i-th diagonal entry. For a Hermitian matrix A of size n, we let λ i (A) denote its i-th eigenvalue and we write A ⪰ 0 (A ⪯ 0) when A is positive semi-definite (negative semi-definite). For two square matrices A, B of the same size, we let ⟨A, B⟩ F = tr(A T B). For a scalar-valued or matrix-valued function of time, F (t), we write Ḟ , Ḟ (t) or d dt F (t) for its time derivative. Additionally, we use I n to denote the identity matrix of order n and O(n) to denote the set of n × n orthogonal matrices. Lastly, we use [•] + := max{•, 0}.

2. OVERVIEW OF THE ANALYSIS

This paper considers the problem of finding a matrix W that solves min W ∈R n×m f (W ) , with the following assumption on f . Assumption 1. The function f is differentiable and satisfiesfoot_0 : A1: f satisfies the Polyak-Łojasiewicz (PL) condition, i.e. ∥∇f (W )∥ 2 F ≥ γ(f (W ) -f * ), ∀W . This condition is also known as gradient dominance. A2: f is K-smooth, i.e., ∥∇f (W ) -∇f (V )∥ F ≤ K∥W -V ∥ F , ∀W, V , and f is µ-strongly convex, i.e., f (W ) ≥ f (V ) + ⟨∇f (V ), W -V ⟩ F + µ 2 ∥W -V ∥ 2 F , ∀W, V . While classic work (Polyak, 1987) has shown that the gradient descent update on W with proper step size ensures a linear rate of convergence of f (W ) towards its optimal value f * , the recent surge of research on the convergence and implicit bias of gradient-based methods for deep neural networks has led to a great amount of work on the overparametrized problem: min {W l } L l=1 L {W l } L l=1 = f (W 1 W 2 • • • W L ) , where L ≥ 2, W l ∈ R h l-1 ×h l , i = 1, • • • , L, with h 0 = n, h L = m and min{h 1 , • • • , h L-1 } ≥ min{n, m}. This assumption on min{h 1 , • • • , h L-1 } is necessary to ensure that the optimal value of ( 2) is also f * , and in this case, the product L l=1 W l can represent an overparametrized linear network/model (Arora et al., 2018b; Tarmoun et al., 2021; Min et al., 2021) 2.1 CONVERGENCE VIA GRADIENT DOMINANCE For problem (2), consider the gradient flow dynamics on the loss function L {W l } L l=1 : Ẇl = - ∂ ∂W l L {W l } L l=1 , l = 1, • • • , L . The gradient flow dynamics can be viewed as gradient descent with "infinitesimal" step size and convergence results for gradient flow can be used to understand the corresponding gradient descent algorithm with sufficiently small step size (Elkabetz & Cohen, 2021) . We have the following result regarding the time-derivative of L under gradient flow (3). Lemma 1. Under continuous dynamics in (3), we have L = -∥∇L {W l } L l=1 ∥ 2 F = -T {W l } L l=1 ∇f (W ), ∇f (W ) F , where W = L l=1 W l , and T {W l } L l=1 is the following positive semi-definite linear operator on R n×m T {W l } L l=1 E = L l=1 l-1 i=0 W i l-1 i=0 W i T E L+1 i=l+1 W i T L+1 i=l+1 W i , W 0 = I n , W L+1 = I m . Such an expression of ∥∇L∥ 2 F has been studied in Arora et al. (2018b) , and we include a proof in Appendix C for completeness. Our convergence analysis is as follows. For this overparameterized problem, the minimum L * of ( 2) is f * . Then from Lemma 1 and Assumption A1, we have L = -T {W l } L l=1 ∇f (W ), ∇f (W ) F ≤ -λ min (T {W l } L l=1 )∥∇f (W )∥ 2 F (min-max theorem (Teschl, 2014)) (5) (A1) ≤ -λ min (T {W l } L l=1 )γ(f (W ) -f * ) = -λ min (T {W l } L l=1 )γ(L -L * ). If we can find a lower bound α > 0 such that λ min (T {W l (t)} L l=1 ) ≥ α, ∀t ≥ 0, then the following inequality holds on the entire training trajectory d dt (L -L * ) ≤ -αγ (L -L * ). Therefore, by using Grönwall's inequality (Grönwall, 1919) , we can show that the loss function L converges exponential to its minimum, i.e., L(t) -L * ≤ exp (-αγt) (L(0) -L * ) , ∀t ≥ 0 . Therefore, to show exponential convergence of the loss, we need to lower bound λ min (T {W l (t)} L l=1 ). Most existing work on the convergence of gradient flow/descent on linear networks implicitly provides such a lower bound, given additional assumptions on the initialization {W l (0)} L l=1 , though not presented with such generality. We revisit previous analyses to see how such a problem can be solved for two-layer linear networks, then present our new results regarding deep linear networks.

3. LESSONS FROM TWO-LAYER LINEAR MODELS

In this section, we revisit prior work through the lens of our general convergence analysis in Section 2.1. A lower bound on λ min (T {W l (t)} L l=1 ) can be obtained from the training invariance of the gradient flow. We first consider the following imbalance matrices: D l := W T l W l -W l+1 W T l+1 , l = 1, • • • , L -1 . For such imbalance matrices, we have Lemma 2. Under the continuous dynamics (3), we have Ḋl (t) = 0, ∀t ≥ 0, l = 1, • • • , L -1. Such invariance of weight imbalance has been studied in most work on linear networks (Arora et al., 2018a; Du et al., 2018; Yun et al., 2020) . We include the proof in Appendix C for completeness. Since the imbalance matrices {D l } L-1 l=1 are fixed at its initial value, any point {W l (t)} L l=1 on the training trajectory must satisfy the imbalance constraints W l (t) T W l (t) -W l+1 W T l+1 = D l (0), l = 1, • • • , L -1. Previous work has shown that enforcing certain non-zero imbalance at initialization leads to exponential convergence of the loss for two-layer networks (Tarmoun et al., 2021; Min et al., 2021) , and for deep networks (Yun et al., 2020) . Another line of work (Arora et al., 2018a; b)  has shown that balanced initialization (D l = 0, ∀l) haves exactly λ min (T {W l (t)} L l=1 ) = Lσ 2-2/L min (W (t)), where W (t) = L l=1 W l (t). This suggests that the bound on λ min (T {W l (t)} L l=1 ) we are looking for should potentially depend on both the weight imbalance matrices and weight product matrix. Indeed, for two-layer models, a re-statementfoot_2 of the results in (Min et al., 2022) provides a lower bound on λ min (T {W1,W2} ) with the knowledge of the imbalance and the product. Lemma 3 (re-stated from Min et al. (2022) ). When L = 2, given weights {W 1 , W 2 } with imbalance matrix D = W T 1 W 1 -W 2 W T 2 and product W = W 1 W 2 , define ∆ + = [λ 1 (D)] + -[λ n (D)] + , ∆ -= [λ 1 (-D)] + -[λ m (-D)] + , ∆ = [λ n (D)] + +[λ m (-D)] + . (8) Then for the linear operator T {W1,W2} defined in Lemma 1, we have λ min T {W1,W2} ≥ 1 2 -∆ + + (∆ + + ∆) 2 + 4σ 2 n (W ) -∆ -+ (∆ -+ ∆) 2 + 4σ 2 m (W ) . Min et al. ( 2022) include a detailed discussion on the bound, including tightness. For our purpose, we note the following: Effect of imbalance: It follows from (9) that λ min T {W1,W2} ≥ ∆ since σ min (W ) ≥ 0. Therefore, ∆ is always a lower bound on the convergence rate. This means that, for most initializations, the fact that the imbalance matrices are bounded away from zero (characterized by ∆ > 0) is already sufficient for exponential convergence.

Effect of product:

The role of the product in ( 9) is more nuanced: Assume n = m for simplicity so that σ n (W W T ) = σ m (W T W ) = σ 2 min (W ). We see that the non-negative quantities ∆ + , ∆ - control how much the product affects the convergence. More precisely, the lower bound in (9) is a decreasing function of both ∆ + and ∆ -. When ∆ + = ∆ -= 0, the lower bound reduces to ∆ 2 + 4σ 2 min (W ), showing a joint contribution to convergence from both imbalance and product. However, as ∆ + , ∆ -increases, the bound decreases towards ∆, which means that the effect of imbalance always exists, but the effect of the product diminishes for large ∆ + , ∆ -. We note that ∆ + , ∆ -measure how the eigenvalues of the imbalance matrix D are different in magnitude, i.e., how "ill-conditioned" the imbalance matrix is. Implication on convergence: Note that (9) is almost a lower bound for λ min T {W1(t),W2(t)} , t ≥ 0, as the imbalance matrix D is time-invariant (so are ∆ + , ∆ -, ∆), except the right-hand side of (9) also depends on σ min (W (t)). If f satisfies A2, then f has a unique minimizer W * . Moreover, one can show that given a initial product W (0), W (t) is constrained to lie within a closed ball W : ∥W -W * ∥ F ≤ K µ ∥W (0) -W * ∥ F . That is, the product W (t) does not get too far away from W * during training. We can use this to derive the following lower bound on σ min (W (t)): σ min (W (t)) ≥ σ min (W * ) - K µ ∥W (0) -W * ∥ F + := margin (See Appendix A). ( ) This margin term being positive guarantees that the closed ball excludes any W with σ min (W ) = 0. With this observation, we find a lower bound λ min T {W1(t),W2(t)} , t ≥ 0 that depends on both the weight imbalance and margin, and the exponential convergence of loss L follows: Theorem 1. Let D be the imbalance matrix for L = 2. The continuous dynamics in (3) satisfy L(t) -L * ≤ exp (-α 2 γt) (L(0) -L * ), ∀t ≥ 0 , where 1. If f satisfies only A1, then α 2 = ∆ ; 2. If f satisfies both A1 and A2, then α 2 = -∆ + + (∆ + + ∆) 2 + 4 σ n (W * ) -K/µ∥W (0) -W * ∥ F + 2 -∆ -+ (∆ -+ ∆) 2 + 4 σ m (W * ) -K/µ∥W (0) -W * ∥ F + 2 , ( ) with W (0) = L l=1 W l (0) and W * equal to the unique optimizer of f . Please see Appendix E for the proof. Theorem 1 is new as it generalizes the convergence result in Min et al. (2022) for two-layer linear networks, which is only for l 2 loss in linear regression. Our result considers a general loss function defined by f , including the losses for matrix factorization (Arora et al., 2018a ), linear regression (Min et al., 2022) , and matrix sensing (Arora et al., 2019a) . Additionally, Arora et al. (2018a) first introduced the notion of margin for f in matrix factorization problems (K = 1, µ = 1), and we extend it to any f that is smooth and strongly convex. Towards deep models: So far, we revisited prior results on two-layer networks, showing how λ min (T W1,W2 ) can be lower bounded by weight imbalance and product, from which the convergence result is derived. Can we generalize the analysis to deep networks? The main challenge is that even computing λ min (T {W l } L l=1 ) given the weights {W l } L l=1 is complicated: For L = 2, λ min (T W1,W2 ) = λ n (W 1 W T 1 ) + λ m (W T 2 W 2 ), but such nice relation does not exist for L > 3, which makes the search for a tight lower bound as in (9) potentially difficult. On the other hand, the findings in (9) shed light on what can be potentially shown for the deep layer case: 1. For two-layer networks, we always have the bound λ min T {W1,W2} ≥ ∆, which depends only on the imbalance. Can we find a lower bound on the convergence rate of a deep network that depends only on an imbalance quantity analogous to ∆? If yes, how does such a quantity depend on network depth? 2. For two-layer networks, the bound reduces to ∆ 2 + 4σ 2 min (W ) when the imbalance is "wellconditioned" (∆ + , ∆ -are small). For deep networks, can we characterize such joint contribution from the imbalance and product, given a similar assumption? We will answer these questions as we present our convergence results for deep networks.

4.1. THREE-LAYER MODEL

Beyond two-layer models, the convergence analysis for imbalanced networks not in the kernel regime has only been studied for specific initializations (Yun et al., 2020) . In this section, we derive a novel rate bound for three-layer models that applies to a wide range of imbalanced initializations. For ease of presentation, we denote the two imbalance matrices for three-layer models, D 1 and D 2 , as -D 1 = W 2 W T 2 -W T 1 W 1 := D 21 , D 2 = W T 2 W 2 -W 3 W T 3 := D 23 . ( ) Our lower bound on λ min T {W1,W2,W3} comes after a few definitions. Definition 1. Given two real symmetric matrices A, B of order n, we define the non-commutative binary operation ∧ r as A ∧ r B := diag{min{λ i (A), λ i+1-r (B)}} n i=1 , where λ j (•) = +∞, ∀j ≤ 0. Definition 2. Given imbalance matrices (D 21 , D 23 ) ∈ R h1×h1 × R h2×h2 , define Dh1 = diag{max{λ i (D 21 ), λ i (D 23 ), 0}} h1 i=1 , Dh2 = diag{max{λ i (D 21 ), λ i (D 23 ), 0}} h2 i=1 , (14) ∆ 21 = tr( Dh1 ) -tr( Dh1 ∧ n D 21 ), ∆ (2) 21 = tr( D2 h1 ) -tr ( Dh1 ∧ n D 21 2 ), ∆ 23 = tr( Dh2 ) -tr( Dh2 ∧ m D 23 ), ∆ (2) 23 = tr( D2 h2 ) -tr ( Dh2 ∧ m D 23 2 ). ( ) Theorem 2. When L = 3, given weights {W 1 , W 2 , W 3 } with imbalance matrices (D 21 , D 23 ), then for the linear operator T {W1,W2,W3} defined in Lemma 1, we have λ min T {W1,W2,W3} ≥ 1 2 (∆ (2) 21 + ∆ 2 21 ) + ∆ 21 ∆ 23 + 1 2 (∆ (2) 23 + ∆ 2 23 ) Proof Sketch. Generally, it is difficult to directly work on λ min T {W1,W2,W3} , and we use the lower bound λ min T {W1,W2,W3} ≥ λ n (W 1 W 2 W T 2 W T 1 ) + λ n (W 1 W T 1 )λ m (W T 3 W 3 ) + λ m (W T 3 W T 2 W 2 W 3 ). We show that given D 21 , D 23 , the optimal value of min W1,W2,W3 λ n (W 1 W 2 W T 2 W T 1 ) + λ n (W 1 W T 1 )λ m (W T 3 W 3 ) + λ m (W T 3 W T 2 W 2 W 3 ) (18) s.t. W 2 W T 2 -W T 1 W 1 = D 21 , W T 2 W 2 -W 3 W T 3 = D 23 is ∆ * (D 21 , D 23 ) = 1 2 (∆ (2) 21 + ∆ 2 21 ) + ∆ 21 ∆ 23 + 1 2 (∆ (2) 23 + ∆ 2 23 ), the bound shown in (17). Please see Appendix F for the complete proof and a detailed discussion on the proof idea. With the theorem we immediately have the following corollary. Corollary 1. When L = 3, given initialization with imbalance matrices (D 21 , D 23 ) and f satisfying A1, the continuous dynamics in (3) satisfy L(t) -L * ≤ exp (-α 3 γt) (L(0) -L * ), ∀t ≥ 0 , ( ) where α 3 = 1 2 (∆ (2) 21 + ∆ 2 21 ) + ∆ 21 ∆ 23 + 1 2 (∆ (2) 23 + ∆ 2 23 ). We make the following remarks regarding the contribution. Optimal bound via imbalance: First of all, as shown in the proof sketch, our bound should be considered as the best lower bound on λ min (T {W1(t),W2(t),W3(t)} ) one can obtain given knowledge of the imbalance matrices D 21 and D 23 only. More importantly, this lower bound works for ANY initialization and has the same role as ∆ does in two-layer linear networks, i.e., (17) quantifies the general effect imbalance on the convergence. Finding an improved bound that takes the effect of product σ min (W ) into account is an interesting future research direction. Implication on convergence: Corollary 2 shows exponential convergence of the loss L(t) if α 3 > 0. While it is challenging to characterize all initialization such that α 3 > 0, the case n = m = 1 is rather simpler: In this case, Dh1 ∧ 1 D 21 = D 21 and Dh2 ∧ 1 D 23 = D 23 . Then we have ∆ 21 = tr( Dh1 ) -tr(D 21 ) = h1 i=1 (λ i ( Dh1 ) -λ i (D 21 )) + λ h1 ( Dh1 ) -λ h1 (D 21 ) ≥ -λ h1 (D 21 ) , and similarly we have ∆ 23 ≥ -λ h2 (D 23 ). Therefore, α 3 ≥ ∆ 21 ∆ 23 ≥ λ h1 (D 21 )λ h2 (D 23 ) > 0 when both D 21 and D 23 have negative eigenvalues, which is easy to satisfy as both D 21 and D 23 are given by the difference between two positive semi-definite matrices. Such observation can be generalized to show that α 3 > 0 when D 21 has at least n negative eigenvalues and D 23 has at least m negative eigenvalues. Moreover, we show that α 3 > 0 under certain definiteness assumptions on D 21 and D 23 , please refer to the remark after Theorem 3 in Section 4.2. A better characterization of the initialization that has α 3 > 0 is an interesting future research topic. Technical contribution: The way we find the lower bound in ( 17) is by studying the generalized eigenvalue interlacing relation imposed by the imbalance constraints. Specifically, W 2 W T 2 -W T 1 W 1 = D 21 suggests that λ i+n (W 2 W T 2 ) ≤ λ i (D 21 ) ≤ λ i (W 2 W T 2 ), ∀i because W 2 W T 2 -D 21 is a matrix of at most rank-n. We derive, from such interlacing relation, novel eigenvalue bounds (See Lemma F.6) on λ n (W T 1 W 1 ) and λ n (W 1 W 2 W T 2 W 1 ) that depends on eigenvalues of both W 2 W T 2 and D 21 . Then the eigenvalues of W 2 W T 2 can also be controlled by the fact that W 2 must satisfy both imbalance equations in (13). Since imbalance equations like those in (13) appear in deep networks and certain nonlinear networks Du et al. (2018) ; Le & Jegelka (2022) , we believe our mathematical results are potentially useful for understanding those networks. Comparison with prior work: The convergence of multi-layer linear networks under balanced initialization (D l = 0, ∀l) has been studied in Arora et al. (2018a; b) , and our result is complementary as we study the effect of non-zero imbalance on the convergence of three-layer networks. Some settings with imbalanced weights have been studied: Yun et al. ( 2020) studies a special initialization scheme (D l ⪰ 0, l = 1, • • • , L -2, and D L-1 ⪰ λI h L-1 ) that forces the partial ordering of the weights, and Wu et al. (2019) uses a similar initialization to study the linear residual networks. Our bound works for such initialization and also show such partial ordering is not necessary for convergence.

4.2. DEEP LINEAR MODELS

The lower bound we derived for three-layer networks applies to any initialization. However, the bound is a fairly complicated function of all the imbalance matrices that is hard to interpret. Searching for such a general bound is even more challenging for models with arbitrary depth (L ≥ 3). Therefore, our results for deep networks will rely on extra assumptions on the weights that simplify the lower bound to facilite interpretability. Specifically, we consider the following properties of the weights: Definition 3. A set of weights {W l } L l=1 with imbalance matrices {D l := W T l W l -W l+1 W T l+1 } L-1 l=1 is said to be unimodal with index l * if there exists some l * ∈ [L] such that D l ⪰ 0, for l < l * and D l ⪯ 0, for l ≥ l * . We define its cumulative imbalances { d(i) } L-1 i=1 as d(i) = i l=l * λ m (-D l ), i ≥ l * l * -1 l=i λ n (D l ), i < l * . Furthermore, for weights with unimodality index l * , if additionally, D l = d l I h l , l = 1, • • • , L -1 for d l ≥ 0, for l < l * and d l ≤ 0, for l ≥ l * , then those weights are said to have homogeneous imbalance. The unimodality assumption enforces an ordering of the weights w.r.t. the positive semi-definite cone. This is more clear when considering scalar weights {w l } L l=1 , in which case unimodality requires w 2 l to be descending until index l * and ascending afterward. Under this unimodality assumption, we show that imbalance contributes to the convergence of the loss via a product of cumulative imbalanaces. Furthermore, we also show the combined effects of imbalance and weight product when the imbalance matrices are "well-conditioned" (in this case, homogeneous). More formally, we have: Theorem 3. For weights {W l } L l=1 with unimodality index l * , we have λ min T {W l } L l=1 ≥ L-1 l=1 d(i) . ( ) Furthermore, if the weights have homogeneous imbalance, then λ min T {W l } L l=1 ≥ L-1 l=1 d(i) 2 + Lσ 2-2/L min (W ) 2 , W = L l=1 W l . We make the following remarks: Connection to results for three-layer: For three-layer networks, we present an optimal bound λ min (T W1,W2,W3 ) ≥ 1 2 (∆ (2) 21 + ∆ 2 21 ) + ∆ 21 ∆ 23 + 1 2 (∆ (2) 23 + ∆ 2 23 ) , given knowledge of the imbalance. Interestingly, when comparing it with our bound in (20), we have: Claim. When L = 3, for weights {W 1 , W 2 , W 3 } with unimodality index l * , 1. If l * = 1, then 1 2 (∆ (2) 23 + ∆ 2 23 ) = L-1 l=1 d(i) and 1 2 (∆ (2) 21 + ∆ 2 21 ) = ∆ 21 ∆ 23 = 0; 2. If l * = 2, then ∆ 21 ∆ 23 = L-1 l=1 d(i) and 1 2 (∆ (2) 21 + ∆ 2 21 ) = 1 2 (∆ (2) 23 + ∆ 2 23 ) = 0; 3. If l * = 3, then 1 2 (∆ (2) 21 + ∆ 2 21 ) = L-1 l=1 d(i) and 1 2 (∆ (2) 23 + ∆ 2 23 ) = ∆ 21 ∆ 23 = 0. We refer the readers to Appendix G for the proof. The claim shows that the bound in ( 20) is optimal for three-layer unimodal weights as it coincides with the one in Theorem 2. We conjecture that ( 20) is also optimal for multi-layer unimodal weights and leave the proof for future research. Interestingly, while the bound for three-layer models is complicated, the three terms 1 2 (∆ (2) 23 + ∆ 2 23 ), ∆ 21 ∆ 23 , 1 2 (∆ (2) 21 + ∆ 2 21 ), seem to roughly capture how close the weights are to those with unimodality. This hints at potential generalization of Theorem 2 to the deep case where the bound should have L terms capturing how close the weights are to those with different unimodality (l * = 1, • • • , L). Effect of imbalance under unimodality: For simplicity, we assume unimodality index l * = L. The bound L-1 i=1 d(i) , as a product of cumulative imbalances, generally grows exponentially with the depth L. Prior work Yun et al. (2020) studies the case D l ⪰ 0, l = 1, • • • , L-2, and D L-1 ⪰ λI h L-1 , in which case L-1 i=1 d(i) ≥ λ L-1 . Our bound L-1 i=1 d(i) suggests the dependence on L could be super-exponential: When λ n (D l ) ≥ ϵ > 0, for l = 1, • • • , L -1, we have L-1 i=1 d(i) = L-1 i=1 L-1 l=i λ n (D l ) ≥ L-1 l=1 lϵ = ϵ L-1 (L -1)!, which grows faster in L than λ L-1 for any λ. Therefore, for gradient flow dynamics, the depth L could greatly improve convergence in the presence of weight imbalance. One should note, however, that such analysis can not be directly translated into fast convergence guarantees of gradient descent algorithm as one requires careful tuning of the step size for the discrete weight updates to follow the trajectory of the continuous dynamics (Elkabetz & Cohen, 2021) . With our bound in Theorem 3, we show convergence of deep linear models under various initialization: Convergence under unimodality: The following immediately comes from Theorem 3: Corollary 2. If the initialization weights {W l (0)} L l=1 are unimodal, then the continuous dynamics in (3) satisfy L(t) -L * ≤ exp (-α L γt) (L(0) -L * ), ∀t ≥ 0, (22) where 1. If f satisfies A1 only, then α L = L-1 i=1 d(i) ; 2. If f satisfies both A1, A2 , and the weights additionally have homogeneous imbalance, then α L = L-1 i=1 d(i) 2 + L σ min (W * ) -K/µ∥W (0) -W * ∥ F + 2-2/L 2 , with W (0) = L l=1 W l (0) and W * equal to the unique optimizer of f . Spectral initialization under l 2 loss: Suppose f = 1 2 ∥Y -W ∥ 2 F and W = L l=1 W l . We write the SVD of Y ∈ R n×m as Y = P Σ Y 0 0 0 Q 0 := P ΣY Q, where P ∈ O(n), Q ∈ O(m) . Consider the spectral initialization W 1 (0) = RΣ 1 V T 1 , W l (0) = V l-1 Σ l V T l , l = 2, • • • , L -1, W L (0) = V L-1 Σ L Q, where Σ l , l = 1, • • • , L are diagonal matrices of our choice and V l ∈ R n×h l , l = 1, • • • , L -1 with V T l V l = I h l . It can be shown that (See Appendix D.1 for details) W 1 (t) = RΣ 1 (t)V T 1 , W l (t) = V l-1 Σ l (t)V T l , l = 2, • • • , L -1, W L (t) = V L-1 Σ L (t) Q. (23) Moreover, only the first m diagonal entries of Σ l are changing. Let σ i,l , σ i,y denote the i-th diagonal entry of Σ l , and ΣY respectively, then the dynamics of {σ i,l } L l=1 follow the gradient flow on L i ({σ i,l } L l=1 ) = 1 2 σ i,y - L l=1 σ i,l 2 for i = 1, • • • , m , which is exactly a multi-layer model with scalar weights: f (w) = |σ i,y -w| 2 /2, w = L l=1 w l . Therefore, spectral initialization under l 2 loss can be decomposed into m deep linear models with scalar weights, whose convergence is shown by Corollary 2. Note that networks with scalar weights are always unimodal, because the gradient flow dynamics remain the same under any reordering of the weights, and always have homogeneous imbalance, because the imbalances are scalars. The aforementioned analysis also applies to the linear regression loss f = 1 2 ∥Y -XW ∥ 2 F , provided that {X, Y } is co-diagonalizable (Gidel et al., 2019), we refer the readers to Appendix D.1 for details.

Diagonal linear networks:

Consider f a function on R n satisfying A1 and L = f (w 1 ⊙ • • • ⊙ w L ), where w l ∈ R n and ⊙ denote the Hadamard (entrywise) product. The gradient flow on L can not be decomposed into several scalar dynamics as in the previous example, but we can show that (See Appendix D.2 for details) L = -∥∇L∥ 2 F ≤ -(min 1≤i≤n λ min (T {w l,i } L l=1 ))γ(L -L * ) , where w l,i is the i-th entry of w l . Then Theorem 3 gives lower bound on each λ min (T {w l,i } L l=1 ). Again, here the scalar weights {w l,i } L l always have homogeneous imbalance.  Assumptions Arora et al. (2018a) Yun et al. (2020) Ours Unimodal weights N/A λ L-1 L-1 l=1 d(i) Homogeneous imbalance N/A λ L-1 ( L-1 l=1 d(i)) 2 + (Lσ 2-2/L min (W )) 2 Balanced Lσ 2-2/L min (W ) N/A D l ⪰ 0, l = 1, • • • , L -2 and D L-1 ⪰ λI h L-1 , which is a special case (l * = L) of ours. The homogeneous imbalance assumption was first introduced in Tarmoun et al. ( 2021) for two-layer networks, and we generalize it to the deep case. We compare, in Table 1 , our bound to the existing work (Arora et al., 2018a; Yun et al., 2020) on convergence of deep linear networks outside the kernel regime. Note that Yun et al. (2020) only studies a special case of unimodal weights (l * = L with d(i) ≥ λ > 0, ∀i). For homogeneous imbalance, Yun et al. (2020) studied spectral initialization and diagonal linear networks, whose initialization necessarily has homogeneous imbalance, but the result does not generalize to the case of matrix weights. Our results for homogeneous imbalance works also for deep networks with matrix weights, and our rate also shown the effect of the product Lσ 2-2/L min (W ), thus covers the balanced initialization (Arora et al., 2018a) as well. Remark 1. Note that the loss functions used in Gunasekar et al. (2018) ; Yun et al. (2020) are classification losses, such as the exponential loss, which do not satisfy A1. However, they do satisfy Polyak-Łojasiewicz-inequality-like condition ∥∇f (W )∥ F ≥ γ(f (W ) -f * ), ∀W ∈ R n×m , which allows us to show O 1 t convergence of the loss function. We refer readers to Section 4.3 for details.

4.3. CONVERGENCE RESULTS FOR CLASSIFICATION TASKS

As we discussed in Remark 1, the loss functions used in classification tasks generally do not satisfy our assumption A1 for f . Suppose instead we have the following assumption for f . Assumption 2. f satisfies (A1') ∥∇f (W )∥ F ≥ γ(f (W ) -f * ), ∀W ∈ R n×m . Then we can show O 1 t convergence of the loss function, as stated below. Theorem 4. Given initialization {W l (0)} L l=1 such that λ min (T {W l (t)} L l=1 ) ≥ α, ∀t ≥ 0 , and f satisfying (A1´), then L(t) -L * ≤ L(0) -L * (L(0) -L * )αγ 2 t + 1 . ( ) We refer readers to Appendix B for the proof. The lower bound on λ min (T {W l (t)} L l=1 ) can be obtained for different networks by our results in previous sections. The exponential loss satisfies A1´(see Appendix D.2)and is studied in Gunasekar et al. (2017) ; Yun et al. (2020) for diagonal linear networks.

5. CONCLUSION AND DISCUSSION

In this paper, we study the convergence of gradient flow on multi-layer linear models with a loss of the form f (W 1 W 2 • • • W L ), where f satisfies the gradient dominance property. We show that with proper initialization, the loss converges to its global minimum exponentially. Moreover, we derive a lower bound on the convergence rate that depends on two trajectory-specific quantities: the imbalance matrices, which measure the difference between the weights of adjacent layers, and the least singular value of the weight product W = W 1 W 2 • • • W L . Our analysis applies to various types of multi-layer linear networks, and our assumptions on f are general enough to include loss functions used for both regression and classification tasks. Future directions include extending our results to analyzing gradient descent algorithms as well as to nonlinear networks. Convergence of gradient descent: Exponential convergence of the gradient flow often suggests a linear rate of convergence of gradient descent when the step size is sufficiently small, and Elkabetz & Cohen (2021) formally establishe such a relation. Indeed, Arora et al. (2018a) shows linear rate of convergence of gradient descent on multi-layer linear networks under balanced initialization. A natural future direction is to translate the convergence results under imbalanced initialization for gradient flow to the convergence of gradient descent with a small step size.

A CONTROLLING PRODUCT WITH MARGIN

Most of our results regarding the lower bound on λ min T {W l } L l=1 are given as a value that depends on 1) the imbalance of the weights; 2) the minimum singular value of the product W = L l=1 . The former is time-invariant, thus is determined at initialization. As we discussed in Section 3, we require the notion of margin to lower bound σ min (W (t)) for the entire training trajectory. The following Lemma that will be used in subsequent proofs. Lemma A.1. If f satisfies A2, then the gradient flow dynamics (3) satisfies σ min (W (t)) ≥ σ min (W * ) - K µ ∥W (0) -W * ∥ F , ∀t ≥ 0 where W (t) = L l=1 W l (t) and W * is the unique minimizer of f . Proof. From Polyak (1987) , we know if f is µ-strongly convex, then it has unique minimizer W * and f (W ) -f * ≥ µ 2 ∥W -W * ∥ 2 F . Additionally, if f is K-smooth, then f (W ) -f * ≤ K 2 ∥W -W * ∥ 2 F . This suggests that for any t ≥ 0, K 2 ∥W (t) -W * ∥ 2 F ≥ L(t) -L * ≥ µ 2 ∥W -W * ∥ 2 F . Therefore we have the following (Horn & Johnson, 2012, 7.3  σ min (W (t)) = σ min (W (t) -W * + W * ) (Weyl's inequality .P16)) ≥ σ min (W * ) -∥W (t) -W * ∥ 2 ≥ σ min (W * ) -∥W (t) -W * ∥ F (f is µ-strongly convex) ≥ σ min (W * ) - 2 µ (L(t) -L * ) (L(t) non-decreasing under (3)) ≥ σ min (W * ) - 2 µ (L(0) -L * ) (f is K-smooth) ≥ σ min (W * ) - K µ ∥W (0) -W * ∥ 2 F = σ min (W * ) - K µ ∥W (0) -W * ∥ F . Lemma A.1 directly suggests σ min (W (t)) ≥ σ min (W * ) - K µ ∥W (0) -W * ∥ F + := margin , and the margin is positive when the initial product W (0) is sufficiently close to the optimal W * .

B CONVERGENCE ANALYSIS FOR CLASSIFICATION LOSSES

In this section, we consider f that satisfies, instead of A1, the following Assumption 3. f satisfies (A1´) the Łojasiewicz inequality-like condition ∥∇f (W )∥ F ≥ γ(f (W ) -f * ), ∀W ∈ R n×m . Theorem 4 (Restated). Given initialization {W l (0)} L l=1 such that λ min T {W l (t)} L l=1 ≥ α, ∀t ≥ 0 , and f satisfying (A1´), then L(t) -L * ≤ L(0) -L * (L(0) -L * )αγ 2 t + 1 . Proof. When f satisfies (A1´), then (5) becomes L = -T {W l } L l=1 ∇f (W ), ∇f (W ) F ≤ -λ min T {W l } L l=1 ∥∇f (W )∥ 2 F (A1 ′ ) ≤ -λ min T {W l } L l=1 γ 2 (f (W ) -f * ) 2 = -λ min T {W l } L l=1 γ 2 (L -L * ) 2 . This shows - 1 (L -L * ) 2 d dt (L -L * ) ≥ λ min T {W l } L l=1 γ 2 ≥ αγ 2 . Take integral dt on both sides, we have for any t ≥ 0, 1 L -L * t 0 ≥ αγ 2 t , which is L(t) -L * ≤ L(0) -L * (L(0) -L * )αγ 2 t + 1 . Following similar argument as in Yun et al. (2020) , we can show that exponential loss on linearly separable data satisfies A1´. Claim. Let f (w) = N i=1 exp -y i • (x T i w) , if there exists z ∈ S n-1 and γ > 0 such that y i (x T i z) ≥ γ , ∀i = 1, • • • , N , then ∥∇f (w)∥ F ≥ γf (w) , ∀w ∈ R n . Proof. Using the linear separability, we have ∥∇f (w)∥ 2 F = N i=1 exp -y i • (x T i w) y i x i 2 F (Cauchy-Schwarz inequality) ≥ z, N i=1 exp -y i • (x T i w) y i x i 2 ≥ N i=1 exp -y i • (x T i w) γ 2 = |f (w)γ| 2 , as desired. Therefore, our convergence results applies to classification tasks with exponential loss.

C PROOFS IN SECTION 2

First we prove the expression for L in Lemma 1 Lemma 1 (Restated). Under continuous dynamics in (3), we have L = -∥∇L {W l } L l=1 ∥ 2 F = -T {W l } L l=1 ∇f (W ), ∇f (W ) F , where W = L l=1 W i , and T {W l } L l=1 is a positive semi-definite linear operator on R n×m with T {W l } L l=1 E = L l=1 l-1 i=1 W i l-1 i=1 W i T E L+1 i=l+1 W i T L+1 i=l+1 W i , W 0 = I n , W L+1 = I m . Proof. The gradient flow dynamics (3) satisfies d dt W l = - ∂ ∂W l L {W l } L l=1 = - l-1 i=1 W i T ∇f (W ) L+1 i=l+1 W i T , (C.1) where W = L l=1 W i and W 0 = I n , W L+1 = I m . Therefore L = L l=1 ∂ ∂W l L {W l } L l=1 , d dt W l F = - L l=1 ∂ ∂W l L {W l } L l=1 2 F = - L l=1 l-1 i=1 W i T ∇f (W ) L+1 i=l+1 W i T , l-1 i=1 W i T ∇f (W ) L+1 i=l+1 W i T F = - L l=1 l-1 i=1 W i l-1 i=1 W i T ∇f (W ) L+1 i=l+1 W i T L+1 i=l+1 W i , ∇f (W ) F = - L l=1 l-1 i=1 W i l-1 i=1 W i T ∇f (W ) L+1 i=l+1 W i T L+1 i=l+1 W i , ∇f (W ) F = -T {W l } L l=1 ∇f (W ), ∇f (W ) F . Next, we prove that the imbalance matrices are time-invariant Lemma 2 (Restated). Under continuous dynamics (3), we have Ḋl (t) = 0, ∀t ≥ 0, l = 1, • • • , L -1. Proof. Each imbalance matrix is defined as D l = W T l W l -W l+1 W T l+1 , l = 1, • • • , L -1 We only need to check that d dt W T l W l and d dt W l+1 W T l+1 are identical. From the following derivation, for l  = 1, • • • , L -1, d dt W T l W l = Ẇ T l W l + W T l Ẇl = - L+1 i=l+1 W i ∇ T f (W ) l-1 i=1 W i W l -W T l l-1 i=1 W i T ∇f (W ) L+1 i=l+1 W i T = - L+1 i=l+1 W i ∇ T f (W ) l i=1 W i - l i=1 W i T ∇f (W ) L+1 i=l+1 W i T , d dt W l+1 W T l+1 = Ẇl+1 W T l+1 + W l+1 Ẇ T l+1 = - l i=1 W i T ∇f (W ) L+1 i=l+2 W i T W T l+1 -W l+1 L+1 i=l+2 W i ∇ T f (W ) l i=1 W i = - l i=1 W i T ∇f (W ) L+1 i=l+1 W i T - L+1 i=l+1 W i ∇ T f (W ) l i=1 W i we know d dt W T l W l = d dt W l+1 W T l+1 , therefore Ḋl (t) = 0, l = 1, • • • , L -1 D LINEAR Q 0 := P ΣY Q and X = P Σ X R T . Remark 2. In Section 4, we discussed the case f = 1 2 ∥Y -W ∥ 2 F , which is essentially considering the aforementioned setting with N = n and X = I n .

Given any set of weights {W

l } L l=1 such that W 1 = RΣ 1 V T 1 , W l = V l-1 Σ l V T l , l = 2, • • • , L -1, W L = V L-1 Σ L Q , where Σ l , l = 1, • • • , L are diagonal matrices and V l ∈ R n×h l , l = 1, • • • , L -1 with V T l V l = I h l . The gradient flow dynamics requires Ẇ1 = - ∂L ∂W 1 = -X T (Y -XW )W T L W T L-1 • • • W T 2 = -RΣ X P T • (P ΣY Q -P Σ X R T • R L l=1 Σ L Q) • QT Σ L V L-1 • V L-1 Σ L-1 V T L-2 • • • V 2 Σ 2 V T 1 = -R Σ X Σ Y -Σ X L l=1 Σ l Q QT L l=2 Σ l V T 1 = -R Σ X Σ Y -Σ X L l=1 Σ l I m 0 0 0 L l=2 Σ l V T 1 , which shows that the singular space R, V 1 for W 1 do not change under the gradient flow, and the singular values σ i,1 of W 1 satisfies σi,1 = σ i,y -σ i,x L l=1 σ i,l σ i,x L l=2 σ i,l , i = 1, • • • , m , and σi,1 = 0, i = m + 1, • • • , n. Similarly, we can show that Ẇl = V l-1   Σ X Σ Y -Σ X L i=1 Σ i I m 0 0 0 i̸ =l Σ i   V T l , l = 2, • • • , L -1 , ẆL = V L-1   Σ X Σ Y -Σ X L i=1 Σ i I m 0 0 0 i̸ =L Σ i   Q . Overall, this suggests that the singular space of {W l } L l=1 do not change under the gradient flow, and their singular values satisfies, for i = 1, • • • , m, σi,l = σ i,y -σ i,x L k=1 σ i,k σ i,x L k̸ =l σ i,k , l = 1, • • • , L .

Each dynamic equation is equivalent to the one from gradient flow on

L i ({σ i,l } L l=1 ) = 1 2 σ i,y -σ i,x L l=1 σ i,l . Therefore, under spectral initialization, the dynamics of the weights are decoupled into at most m dynamics discussed in Section 4.2.

D.2 DIAGONAL LINEAR NETWORKS

The loss function of diagonal linear networks Gunasekar et al. (2017) ; Yun et al. (2020)  is of the form f (w 1 ⊙ • • • ⊙ w L ), we write L({w l } L l=1 ) = f (w 1 ⊙ • • • ⊙ w L ) = f (w (1) , • • • , w (n) ) = f L l=1 w l,1 , • • • , L l=1 w l,n , i.e. f takes n variables w (1) , • • • , w (n) and each variable w (i) is overparametrized into L l=1 w l,i . Then we can show that L = -∥∇ {w l } L l=1 L∥ 2 F = n i=1 L l=1 ∂L ∂w l,i 2 = n i=1 L l=1 ∂f ∂w (i) 2 ∂w (i) ∂w l,i 2 = n i=1 ∂f ∂w (i) 2 L l=1 ∂w (i) ∂w l,i 2 = n i=1 ∂f ∂w (i) 2 τ {w l,i } L l=1 ≤ -min 1≤i≤n τ {w l,i } L l=1 n i=1 ∂f ∂w (i) 2 (f satisfies A1) ≤ -min 1≤i≤n τ {w l,i } L l=1 γ(f -f * ) = -min 1≤i≤n τ {w l,i } L l=1 γ(L -L * ) . Moreover, the imbalances {d (i) l := w 2 l,i -w 2 l+1,i } L-1 l=1 are time-invariant for each i = 1, • • • , n by Lemma 2. Therefore, we can lower bound each τ {w l,i } L l=1 using the imbalance {d (i) l } L-1 l=1 as in Proposition 3, from which one obtain the exponential convergence of L. Overall, we have λ min T {W1(t),W2(t)} Lemma 3 ≥ 1 2 -∆ + + (∆ + + ∆) 2 + 4σ 2 n (W (t)) -∆ -+ (∆ -+ ∆) 2 + 4σ 2 m (W (t)) ≥ 1 2   -∆ + + (∆ + + ∆) 2 + 4 σ n (W * ) -K/µ∥W (0) -W * ∥ F + 2 -∆ -+ (∆ -+ ∆) 2 + 4 σ m (W * ) -K/µ∥W (0) -W * ∥ F + 2   := α 2 . Either case, we have d dt (L(t) -L * ) ≤ -α 2 γ(L(t) -L * ), and by Grönwall's inequality, we have L(t) -L * ≤ exp(-α 2 γt)(L(0) -L * ) .

F PROOFS FOR THREE-LAYER MODEL

In Section F.1, we discuss the proof idea for Theorem 2, then present the proof afterwards. In Section G, we show a simplified bound when the weights can be ordered w.r.t. positive-semidefiniteness.

F.1 PROOF IDEA

We first discuss the proof idea behind Theorem 2, then provide the complete proof. Consider the case when n = m = 1, we use the following notations for the weights {w T 1 , W 2 , w 3 } ∈ R 1×h1 × R h1×h2 × R h2×1 . The quantity we need to lower bound is λ min T {w T 1 ,W2,w3} = w T 1 W 2 W T 2 w 1 + w T 1 w 1 • w T 3 w 3 + w T 3 W T 2 W 2 w 3 = ∥W T 2 w 1 ∥ 2 + ∥w 1 ∥ 2 ∥w 3 ∥ 2 + ∥W 2 w 3 ∥ 2 , where our linear operator T {w T 1 ,W2,w3} reduces to a scalar. The remaining thing to do is to find min w T 1 ,W2,w3 ∥W T 2 w 1 ∥ 2 + ∥w 1 ∥ 2 ∥w 3 ∥ 2 + ∥W 2 w 3 ∥ 2 (F.6) s.t. W 2 W T 2 -w 1 w T 1 = D 21 W T 2 W 2 -w 3 w T 3 = D 23 i. e., we try to find the best lower bound on λ min T {w T 1 ,W2,w3} given the fact that the weights have to satisfies the imbalance constraints from D 21 , D 23 , and λ min T {w T 1 ,W2,w3} is related to the norm of some weights ∥w 1 ∥, ∥w 3 ∥ and the "alignment" between weights ∥W T 2 w 1 ∥, ∥W 2 w 3 ∥. The general idea of the proof is to lower bound each term ∥W T 2 w 1 ∥ 2 , ∥w 1 ∥ 2 , ∥w 3 ∥ 2 , ∥W 2 w 3 ∥ 2 individually given the imbalance constraints, then show the existence of some {w T 1 , W 2 , w 3 } that attains the lower bound simultaneously. The following discussion is most for lower bounding ∥w 1 ∥, ∥W T 2 w 1 ∥ but the same argument holds for lower bounding other quantities. Understanding what can be chosen to be the spectrum of W 2 W T 2 (W T 2 W 2 ) is the key to derive an lower bound, and the imbalance constraints implicitly limit such choices. To see this, notice that W 2 W T 2 -w 1 w T 1 = D 21 suggests an eigenvalue interlacing relation (Horn & Johnson, 2012, Corollary 4.39) between W 2 W T 2 and D 21 , i.e. λ h1 (D 21 ) ≤ λ h1 (W 2 W T 2 ) ≤ λ h1-1 (D 21 ) ≤ • • • ≤ λ 2 (W 2 W T 2 ) ≤ λ 1 (D 21 ) ≤ λ 1 (W 2 W T 2 ) . Therefore, any choice of {λ i (W 2 W T 2 )} h1 i=1 must satisfy the interlacing relation with {λ i (D 21 )} h1 i=1 . Similarly, {λ i (W T 2 W 2 )} h2 i=1 must satisfy the interlacing relation with {λ i (D 23 )} h2 i=1 . Moreover, {λ i (W 2 W T 2 )} h1 i=1 and {λ i (W T 2 W 2 )} h2 i=1 agree on non-zero eigenvalues. In short, an appropriate choice of the spectrum of W 2 W T 2 (W T 2 W 2 ) needs to respect the interlacing relation with the eigenvalues of D 21 and D 23 . The following matrix is defined Dh1 := diag{max{λ i (D 21 ), λ i (D 23 ), 0}} h1 i=1 to be the "minimum" choice of the spectrum of W 2 W T 2 (W T 2 W 2 ) in the sense that any valid choice of {λ i (W 2 W T 2 )} h1 i=1 must satisfies λ i (W 2 W T 2 ) ≥ λ i ( Dh1 ) ≥ λ i (D 21 ) , i = 1, • • • , h 1 . That is, the spectrum of Dh1 "lies between" the one of W 2 W T 2 and of D 21 . Now we check the imbalance constraint again W 2 W T 2 -w 1 w T 1 = D 21 , it shows that: using a rank-one update w 1 w T 1 , one obtain the spectrum of D 21 starting from the spectrum of W 2 W T 2 , and more importantly, we require the norm ∥w 1 ∥ 2 to be (taking the trace on the imbalance equation) tr(W 2 W T 2 ) -∥w 1 ∥ 2 = tr(D 21 ) ⇒ ∥w 1 ∥ 2 = tr(W 2 W T 2 ) -tr(D 21 ) . Now since Dh1 "lies inbetween", we have ∥w 1 ∥ 2 = tr(W 2 W T 2 ) -tr(D 21 ) = (changes from λ i (W 2 W T 2 ) to λ i (D 21 )) = (changes from λ i (W 2 W T 2 ) to λ i ( Dh1 )) + (changes from λ i ( Dh1 ) to λ i (D 21 )) ≥ (changes from λ i ( Dh1 ) to λ i (D 21 )) = tr( Dh1 ) -tr(D 21 ) , which is a lower bound on ∥w 1 ∥ 2 . It is exactly the ∆ 21 in Theorem 2 (It takes more complicated form when n > 1).

A lower bound on ∥W T

2 w 1 ∥ 2 requires carefully exam the changes from the spectrum of Dh1 to the one of D 21 . If λ h1 (D 21 ) < 0, then "changes from λ i ( D) to λ i (D 21 )" has two parts 1. (changes from λ i ( D) to [λ i (D 21 )] + ) through the part where w 1 is "aligned" with W T 2 , 2. (changes from 0 to λ h1 (D 21 )) through the part where w 1 is "orthogonal" to W T 2 . Only the former contributes to ∥W T 2 w 1 ∥ 2 hence we need the expression ∆ (2) 21 + ∆ 2 21 , which excludes the latter part. Using similar argument we can lower bound ∥w 3 ∥ 2 , ∥W 2 w 3 ∥ 2 . Lastly, the existence of {w T 1 , W 2 , w 3 } that attains the lower bound is from the fact that Dh1 ( Dh2 ) is a valid choice for the spectrum of W 2 W T 2 (W T 2 W 2 ). The complete proof of the Theorem 2 follows the same idea but with a generalized notion of eigenvalue interlacing, and some related novel eigenvalue bounds.

F.2 PROOF OF THEOREM 2

Theorem 2 is the direct consequence of the following two results. Lemma F.1. Given any set of weights {W 1 , W 2 , W 3 } ∈ R n×h1 × R h1×h2 × R h2×m , we have λ min T {W1,W2,W3} ≥ λ n (W 1 W 2 W T 2 W T 1 ) + λ n (W 1 W T 1 )λ m (W T 3 W 3 ) + λ m (W T 3 W T 2 W 2 W 3 ) . (Note that λ min T {W1,W2,W3} does not have a closed-form expression. One can only work with its lower bound The Lemma F.1 is intuitive and easy to prove: Proof of Lemma F.1. Notice that T {W1,W2,W3} is the summation of three positive semi-definite linear operators on R n×m , i.e. λ n (W 1 W 2 W T 2 W T 1 ) + λ n (W 1 W T 1 )λ m (W T 3 W 3 ) + λ m (W T 3 W T 2 W 2 W 3 ).) Theorem F.2. Given imbalance matrices pair (D 21 , D 23 ) ∈ R h1×h1 × R h2×h2 , then the optimal value of min W1,W2,W3 2 λ n (W 1 W 2 W T 2 W T 1 ) + λ n (W 1 W T 1 )λ m (W T 3 W 3 ) + λ m (W T 3 W T 2 W 2 W 3 ) s.t. W 2 W T 2 -W T 1 W 1 = D 21 W T 2 W 2 -W 3 W T 3 = D 23 is ∆ * (D 21 , D 23 ) = ∆ T {W1,W2,W3} = T 12 + T 13 + T 23 , where T 12 E = W 1 W 2 W T 2 W T 1 E, T 13 E = W 1 W T 1 EW T 3 W 3 , T 23 E = EW T 3 W T 2 W 2 W 3 , and λ min T 12 = λ n (W 1 W 2 W T 2 W T 1 ), λ min T 13 = λ n (W 1 W T 1 )λ m (W T 3 W 3 ), λ min T 23 = λ m (W T 3 W T 2 W 2 W 3 ). Therefore, let E min with ∥E min ∥ F = 1 be the eigenmatrix associated with λ min T {W1,W2,W3} , we have λ min T {W1,W2,W3} = T {W1,W2,W3} , E min F = ⟨T 12 , E min ⟩ F + ⟨T 13 , E min ⟩ F + ⟨T 23 , E min ⟩ F ≥ λ min T 12 + λ min T 13 + λ min T 23 . The rest of this section is dedicated to prove Theorem F.2 We will first state a few Lemmas that will be used in the proof, then show the proof for Theorem F.2, and present the long proofs for the auxiliary Lemmas in the end.

F.4 PROOF OF THEOREM F.2

With these Lemmas, we are ready to prove Theorem F.2. Proof of Theorem F.2. The proof is presented in two parts: First, we show ∆ * (D 21 , D 23 ) is a lower bound on the optimal value; Then we construct an optimal solution (W * 1 , W * 2 , W * 3 ) that attains ∆ * (D 21 , D 23 ) as the objective value. Showing ∆ * (D 21 , D 23 ) is a lower bound: Given any feasible triple (W 1 , W 2 , W 3 ), the imbalance equations W 2 W T 2 -W T 1 W 1 = D 21 , (F.7) W T 2 W 2 -W 3 W T 3 = D 23 , (F.8) implies W 2 W T 2 ⪰ n D 21 and W T 2 W 2 ⪰ m D 23 by Lemma F.3. These interlacing relation shows λ i (W 2 W T 2 ) ≥ λ i (D 21 ), λ i (W T 2 W 2 ) ≥ λ i (D 23 ), ∀i , which is λ i (W 2 W T 2 ) = λ i (W T 2 W 2 ) ≥ max{λ i (D 21 ), λ i (D 21 ), 0} = λ i ( Dh1 ) ≥ 0 , ∀i ∈ [h 1 ] (F.9) Now by Lemma F.6, imbalance equation (F.7) suggests λ n (W 1 W T 1 ) ≥ tr(W 2 W T 2 ) -tr(W 2 W T 2 ∧ n D 21 ) , and 2λ n (W 1 W 2 W T 2 W T 1 ) ≥ tr (W 2 W T 2 ) 2 -tr (W 2 W T 2 ∧ n D 21 ) 2 + tr(W 2 W T 2 ) -tr(W 2 W T 2 ∧ n D 21 ) 2 . Notice that λ r (W 1 W T 1 ) ≥ tr(W 2 W T 2 ) -tr(W 2 W T 2 ∧ n D 21 ) = h1 i=1 λ i (W 2 W T 2 ) -min{λ i (W 2 W T 2 ), λ i+1-n (D 21 )} = h1 i=1 max{λ i (W 2 W T 2 ) -λ i+1-n (D 21 ), 0} ≥ h1 i=1 max{λ i ( Dh1 ) -λ i+1-n (D 21 ), 0} = tr( Dh1 ) -tr( Dh1 ∧ n D 21 ) = ∆ 21 , (F.10) where the inequality holds because (F.9) and the fact that ReLU function f (x) = max{x, 0} is a monotonically non-decreasing function. Since ∆ 21 can be viewed as summation of ReLU outputs, it has to be non-negative, then (F.10) also suggests tr(W 2 W T 2 ) -tr(W 2 W T 2 ∧ n D 21 ) 2 ≥ ∆ 2 21 . (F.11) Next we have 2λ n (W 1 W 2 W T 2 W T 1 ) ≥ tr (W 2 W T 2 ) 2 -tr (W 2 W T 2 ∧ n D 21 ) 2 + tr(W 2 W T 2 ) -tr(W 2 W T 2 ∧ n D 21 ) 2 (F.11) ≥ ∆ 2 21 + tr (W 2 W T 2 ) 2 -tr (W 2 W T 2 ∧ n D 21 ) 2 = ∆ 2 21 + h1 i=1 λ 2 i (W 2 W T 2 ) -min{λ i (W 2 W T 2 ), λ i+1-n (D 21 )} 2 ≥ ∆ 2 21 + h1 i=1 λ 2 i ( Dh1 ) -min{λ i ( Dh1 ), λ i+1-n (D 21 )} 2 = ∆ 2 21 + tr D2 h1 -tr ( Dh1 ∧ n D 21 ) 2 = ∆ 2 21 + ∆ (2) 21 , where the last inequality is because (F.9) and the fact that the function g(x) = x 2 -(min{x, a}) 2 = 0, x ≤ a x 2 -a 2 , x > a , is monotonically non-decreasing on R ≥0 for any constant a ∈ R. At this point, we have shown λ n (W 1 W T 1 ) ≥ ∆ 21 , 2λ n (W 1 W 2 W T 2 W T 1 ) ≥ ∆ 2 21 + ∆ (2) 21 . (F.12) We can repeat the proofs above with the following replacement W 2 → W T 2 , W 1 → W T 3 , D 21 → D 23 , Dh1 → Dh2 , and obtain λ m (W T 3 W 3 ) ≥ ∆ 23 , 2λ m (W T 3 W T 2 W 2 W 3 ) ≥ ∆ 2 23 + ∆ (2) 23 . (F.13) These inequalities (F.12)(F.13) show that ∆ * (D 21 , D 23 ) = ∆ (2) 21 + ∆ 2 21 + 2∆ 21 ∆ 23 + ∆ (2) 23 + ∆ 2 23 . is a lower bound on the optimal value of our optimization problem. Now we proceed to show tightness.

F.5 PROOFS FOR AUXILIARY LEMMAS

We finish this section by providing the proofs for auxiliary lemmas we used in the last section. Proof of Lemma F.5. Since (D 21 , D 23 ) is a pair of imbalance matrices, there exists W 2 W T 2 , such that W 2 W T 2 ⪰ n D 21 , W T 2 W 2 ⪰ m D 23 , (F.18) because at least our weight initialization W 1 (0), W 2 (0), W 3 (0) have to satisfy W 2 (0)W 2 (0) T - W T 1 (0)W 1 (0) = D 21 , W T 2 (0)W 2 (0) -W 3 (0)W T 3 (0) = D 23 . Therefore, for 0 < i ≤ h 1 -n, λ i+n ( Dh1 ) = max{λ i+n (D 21 ), λ i+n (D 23 ), 0} ≤ λ i+n (W 2 W T 2 ) ≤ λ i (D 21 ) ≤ λ i ( Dh1 ) , where the first two inequalities uses (F.18) and the fact that λ i+n (W 2 W T 2 ) = λ i+n (W T 2 W 2 ). Also the last inequality is from the fact that λ i ( Dh1 ) = max{λ i (D 21 ), λ i (D 23 ), 0}, ∀i ∈ [h 1 ]. For h 1 ≥ i > h 1 -n, we still have -∞ = λ i+n ( Dh1 ) ≤ λ i (D 21 ) ≤ λ i ( Dh1 ) , Overall, we have λ i+n ( Dh1 ) ≤ λ i (D 21 ) ≤ λ i ( Dh1 ) , ∀i , which is exactly Dh1 ⪰ n D 21 . Similarly, for 0 < i ≤ h 2 -m, λ i+m ( Dh2 ) = max{λ i+m (D 21 ), λ i+m (D 23 ), 0} ≤ λ i+m (W T 2 W 2 ) ≤ λ i (D 23 ) ≤ λ i ( Dh2 ) , where the first two inequalities uses (F.18) and the fact that λ i+m (W 2 W T 2 ) = λ i+m (W T 2 W 2 ). Also the last inequality is from the fact that λ i ( Dh2 ) = max{λ i (D 21 ), λ i (D 23 ), 0}, ∀i ∈ [h 2 ]. For h 2 ≥ i > h 2 -m, we still have -∞ = λ i+m ( Dh2 ) ≤ λ i (D 23 ) ≤ λ i ( Dh2 ) , Overall, we have λ i+m ( Dh2 ) ≤ λ i (D 23 ) ≤ λ i ( Dh2 ) , ∀i , which is exactly Dh2 ⪰ m D 23 . Proof of Lemma F.6. Notice that rank(Z T Z) ≤ r, hence we consider the eigendecomposition Z T Z = r i=1 λ i (Z T Z)v i v T i , where v i are unit eigenvectors of Z T Z. Then we can write A -λ r (Z T Z)v i v T i - r-1 i=1 λ i (Z T Z)v i v T i = B We let D = A -λ r (Z T Z)v i v T i , then by Lemma F.3, we know A ⪰ 1 D, and D ⪰ r-1 B, which suggests that ∀i, λ i+1 (A) ≤ λ i (D) ≤ λ i (A) (F.19) λ i+r-1 (D) ≤ λ i (B) ≤ λ i (D) . (F.20) In particular, we have λ i (D) ≤ λ i (A) from (F.19) and λ i (D) ≤ λ i+1-r (B) from (F.20), which suggests λ i (D) ≤ min{λ i (A), λ i+1-r (B)} = λ i (A ∧ r B) , ∀i . Hence tr(A ∧ r B) ≥ tr(D) = tr(A) -λ r (Z T Z)tr(v i v T i ) = tr(A) -λ r (Z T Z) . This proves the first inequality. For the second the inequality, let x be the unit eigenvector associated with λ r (ZAZ T ), then λ r (ZAZ T ) = x T ZAZ T x. Now write A -Zxx T Z T -Z(I -xx T )Z T = B . Let D = A -Zxx T Z T , then again by Lemma F.3 we have A ⪰ 1 D, and D ⪰ r-1 B.

Notice that

D2 = (A -Zxx T Z T ) 2 = A 2 + (Zxx T Z T ) 2 -AZxx T Z T -Zxx T Z T A . Taking trace on both side of this equation and using the cyclic property of trace operation lead to tr( D2 ) = tr A 2 + ∥Zx∥ 4 -2λ r (ZAZ T ) . (F.21) We only need to lower bound ∥Zx∥ 4 -tr( D2 ), for which we write the eigendecomposition D using eigenpairs {(λ i ( D), u i )} n i=1 as D = n i=1 λ i ( D)u i u T i = n-1 j=1 λ i ( D)u i u T i + λ n ( D)u n u T n . Then we have ∥Zx∥ 2 = tr(Zxx T Z T ) = tr(A) -tr( D) = tr(A) - n-1 j=1 λ j ( D) -λ n ( D) ≥ tr(A) - n-1 j=1 λ j (A ∧ r B) -λ n ( D) = tr(A) -tr(A ∧ r B) + λ n (A ∧ r B) -λ n ( D) , where the inequality follows similar argument in the previous part of the proof and uses  λ i ( D) ≤ min{λ i (A), λ i+1-r (B)} = λ i (A ∧ r B) , (A ∧ r B) is non-negative because λ i (A) ≥ λ i (A ∧ r B), ∀i. The second component λ n (A ∧ r B) -λ n ( D) is non-negative as well by (F.22) . Therefore the right-hand side is non-negative and we can take square on both sides of the inequality, namely, ∥W 1 x∥ 4 ≥ tr(A) -tr(A ∧ r B) + λ n (A ∧ r B) -λ n ( D) 2 . (F.23) We also have tr( D2 ) = n-1 i=1 λ 2 i ( D) + λ 2 n ( D) ≤ n-1 i=1 λ 2 i (A ∧ r B) + λ 2 n ( D) = tr (A ∧ r B) 2 -λ 2 n (A ∧ r B) + λ 2 n ( D) , (F.24) The inequality holds because for i = 1, • • • , n -1, 0 ≤ λ i+1 (A) ≤ λ i ( D) ≤ λ i (A ∧ r B) , where the inequality on the left is from A ⪰ 1 D and the inequality on the right is due to (F.22) . With those two inequalities (F.23)(F.24), we have (For simplicity, denote λ ∧ := λ n (A ∧ r B), λ := λ n ( D)) ∥W 1 x∥ 4 -tr( D2 ) -(tr(A) -tr(A ∧ r B)) 2 -tr (A ∧ r B) 2 ≥ λ 2 ∧ + λ2 -2λ ∧ λ + 2(λ ∧ -λ)(tr(A) -tr(A ∧ r B)) + λ 2 ∧ -λ2 = 2λ 2 ∧ -2λ ∧ λ + 2(λ ∧ -λ)(tr(A) -tr(A ∧ r B)) = 2(λ ∧ -λ)(tr(A) -tr(A ∧ r B) + λ ∧ ) ≥ 0 , where the last inequality is due to the facts that λ ∧ ≥ λ by (F.22) and tr(A) -tr(A ∧ r B) + λ ∧ = n-1 i=1 (λ i (A) -λ i (A ∧ r B)) + λ n (A) ≥ 0 . This shows ∥Zx∥ 4 -tr( D2 ) ≥ (tr(A) -tr(A ∧ r B)) 2 -tr (A ∧ r B) 2 . Finally from (F.21) we have 2λ r (ZAZ T ) = tr (A) 2 + ∥Zx∥ 4 -tr( D2 ) ≥ tr (A) 2 -tr (A ∧ r B) 2 + (tr(A) -tr(A ∧ r B)) 2 . To proof Lemma F.7, we need one final lemma Lemma F.8. Given two real symmetric matrices A, B of order n, for any r ≤ n, if A ⪰ r B, then Next, notice that A ⪰ 1 (A ∧ r B) and (A ∧ r B) ⪰ r-1 B. λ i (B) ≤ min{λ i (A), λ i+1-r (B)} = λ i (D) , (F.27) where λ i (B) ≤ λ i (A) is from A ⪰ r B, λ i+r-1 (D) = min{λ i+r-1 (A), λ i (B)} ≤ λ i (B) (F.28) (F.27)(F.28) together show D ⪰ r-1 B. Then we are ready to prove Lemma F.7 Proof of Lemma F.7. Denote D := A ∧ r B. We have shown in Lemma F.8 that A ⪰ 1 D and D ⪰ r-1 B. With the two interlacing relations, we know there exist x ∈ R n×1 , X ∈ R n×(r-1) and some orthogonal matrices V 1 , V 2 ∈ O(n) such that A -xx T = V 1 DV T 1 , D -XX T = V 2 BV T 2 , (F.29) then let V := V 1 V 2 , we have A -xx T -V 1 XX T V T 1 = V 1 V 2 BV T 2 V T 1 = V BV T . (F.30) Notice that xx T + V 1 XX T V T 1 = [x V 1 X] x T X T V T 1 , then with Z T := [x V 1 X] ∈ R n×r , we can write A -Z T Z = V 1 V 2 BV T 2 V T 1 = V BV T . It remains to show λ r (ZZ T ) and 2λ r (ZAZ T ) have the exact expressions as stated. Notice that A -xx T = V 1 DV T 1 , then we have ∥x∥ 2 = tr(xx T ) = tr(A -V 1 DV T 1 ) = tr(A) -tr(D) . (F.31) Moreover, taking trace on both sides of (A -xx T ) 2 = (V 1 DV T 1 ) 2 yields tr (A) 2 -2x T Ax + ∥x∥ 4 = tr(D 2 ) , from which we have 2x T Ax = tr(A) -tr(D 2 ) + ∥x∥ 4 = tr(A) -tr(D 2 ) + (tr(A) -tr(D)) 2 . (F.32) Finally, notice that the first diagonal entry of ZZ T = x T X T V T 1 [x V 1 X] = ∥x∥ 2 x T X X T x X T X is ∥x∥ 2 , we have, by (Horn & Johnson, 2012, Corollary 4.3.34) , λ r (ZZ T ) ≤ ∥x∥ 2 = tr(A) -tr(D) = tr(A) -tr(A ∧ r B) . Since we have already shown in Lemma F.6 that λ r (ZZ T ) ≥ tr(A) -tr(A ∧ r B) , we must have the exact equality λ r (ZZ T ) = tr(A) -tr(A ∧ r B). Similarly, the first diagonal entry of ZAZ T = x T X T V T 1 A [x V 1 X] = x T Ax x T AX X T Ax X T AX is x T Ax, then we have, by (Horn & Johnson, 2012, Corollary 4.3.34) , 2λ r (ZAZ T ) ≤ 2x T Ax = tr A 2 -tr (A ∧ r B) 2 + (tr(A) -tr(A ∧ r B)) 2 . Again, Lemma F.6 shows the inequality in the opposite direction, hence one must take the equality 2λ r (ZAZ T ) = x T Ax = tr A 2 -tr (A ∧ r B) 2 + (tr(A) -tr(A ∧ r B)) 2 . G SIMPLIFICATION OF THE BOUND IN THEOREM 2 UNDER UNIMODALITY ASSUMPTION Consider weights {W 1 , W 2 , W 3 } with unimodality index l * , there are three cases: l * = 1: D 21 ⪰ 0, D 23 ⪯ 0 Definiteness of imbalance matrix put rank constraints on the weight matrices:  Since W T 2 W 2 -W 3 W T 3 = D 23 ⪯ 0, rank(W 3 W T 3 ) ≤ m λ i (D 23 ) = = 0, 1 ≤ i < h 2 -m + 1 ≤ 0, h 2 -m + 1 ≤ i ≤ h 2 , λ i (D 21 ) = ≥ 0, 1 ≤ i < m = 0, m + 1 ≤ i ≤ h 1 . We also have In this cases, Dh1 = 0, Dh2 = 0 , and Dh1 = diag{max{λ i (D 21 ), 0}} h1 i=1 = diag{λ i (D 21 )} h1 i=1 , Dh2 = diag{max{λ i (D 21 ), 0}} h2 i=1 , Then Dh1 ∧ n D 21 = Dh1 , Dh2 ∧ m D 23 =    λ i (D 21 ), 1 ≤ i ≤ m -1 0, m ≤ i < h 2 λ h2+1-m (D 23 ), i = h 2 . hence ∆ 21 = ∆ Dh1 ∧ n D 21 = 0, 1 ≤ i < h 1 λ h1+1-n (D 21 ), i = h 1 , Dh2 ∧ m D 23 = 0, 1 ≤ i < h 2 λ h2+1-m (D 23 ), i = h 2 , then ∆ 21 = -λ h1+1-n (D 21 ), ∆ 23 = -λ h2+1-m (D 23 ), ∆ (2) 21 = -λ 2 h1+1-n (D 21 ), ∆ (2) 23 = -λ 2 h2+1-m (D 23 ) = 0 . Therefore 2∆ 21 ∆ 23 = 2 (-λ h1+1-n (D 21 )) (-λ h2+1-m (D 23 )) , ∆ 2 21 + ∆ (2) 21 = ∆ 2 23 + ∆ (2) 23 .

H PROOFS FOR DEEP MODELS

We prove Theorem 3 in two parts: First, we prove the lower bound under the unimodality assumption in Section H.1. Then we show the bound for the weights with homogeneous imbalance in Section H.2. H.1 LOWER BOUND ON λ min (T {W l } L l=1 ) UNDER UNIMODALITY We need the following two Lemmas (proofs in Section H.3): Lemma 4. Given A ∈ R n×h , B ∈ R h×m , and D = A T A -BB T ∈ R h×h . If rank(A) ≤ r and D ⪰ 0, then 1. rank(B) ≤ r, and rank(D) ≤ r.

2.. There exists

Q ∈ R h×r with Q T Q = I r , such that AQQ T B = AB, AQQ T A T = AA T , B T QQ T B = B T B , and λ i (Q T DQ) = λ i (D), i = 1, • • • , r. Lemma 5. For W 1 ∈ R n×h1 , W 2 ∈ R h1×h2 • • • , W L-1 ∈ R h L-2 ×h L-1 and W L ∈ R h L-1 ×h L such that D l = W T l W l -W l+1 W T l+1 ⪰ 0 , l = 1, • • • , L -1 we have λ n (W 1 W 2 • • • W L-1 W T L-1 • • • W T 2 W T 1 ) ≥ L-1 i=1 L-1 l=i λ n (D l ) . Then we can prove the following: Theorem H.1. For weights {W l } L l=1 with unimodality index l * , we have λ min T {W l } L l=1 ≥ L-1 l=1 d(i) . (H.33) Proof. Recall that T {W l } L l=1 E = L l=1 l-1 i=1 W i l-1 i=1 W i T E L+1 i=l+1 W i T L+1 i=l+1 W i , W 0 = I n , W L+1 = I m . For simplicity, define p.s.d. operators T l E := l-1 i=1 W i l-1 i=1 W i T E L+1 i=l+1 W i T L+1 i=l+1 W i , l = 1, • • • , L Then T {W l } L l=1 = L l=1 T l . When l * = L, we have, by Lemma 5, λ min (T {W l } L l=1 ) ≥ λ min (T L ) = λ n (W 1 • • • W L-1 W T L-1 • • • W T 1 ) ≥ L-1 i=1 L-1 l=i λ n (D l ) = L-1 l=1 d(i) . When l * = 1, we have, again by Lemma 5, λ min (T {W l } L l=1 ) ≥ λ min (T 1 ) = λ m (W T L • • • W T 2 W 2 • • • W L ) ≥ L-1 i=1 L-1 l=i λ m (-D L-l ) = L-1 i=1 L-i l=1 λ m (-D l ) = L-1 i=1 i l=1 λ m (-D l ) = L-1 l=1 d(i) . (To see Lemma 5 applies to the case l * = 1, consider the following W T L → W 1 , • • • , W T L-l+1 → W l , • • • , W T 1 → W L , and this naturally leads to -D L-l → D l . The expressions on the right-hand side of the arrow are those appearing in Lemma 5.) Now for unimodality index 1 < l * < L, we have λ min (T {W l } L l=1 ) ≥ λ min (T l * ) = λ n (W 1 • • • W l * -1 W T l * -1 • • • W 1 )λ m (W T L • • • W T l * +1 W l * +1 • • • W L ) . Now apply Lemma 5 to both {W 1 , • • • , W l * -1 , W l * } and {W T L , • • • , W T l * +1 , W T l * }, we have λ n (W 1 • • • W l * -1 W T l * -1 • • • W 1 ) ≥ l * -1 i=1 l * -1 l=i λ n (D l ) = l * -1 i=1 d(i) , (H.34) and λ m (W T L • • • W T l * +1 W l * +1 • • • W L ) ≥ L-l * i=1 L-l * l=i λ m (-D L-l ) = L-l * i=1 L-i l=l * λ m (-D l ) = L-1 i=l * i l=l * λ m (-D l ) = L-1 i=l * d(i) . (H.35) Combining (H.34) and (H.35), we have  λ n (W 1 • • • W l * -1 W T l * -1 • • • W 1 )λ m (W T L • • • W T l * +1 W l * +1 • • • W L ) ≥ L-1 i=1 d(i) , := w 2 i -w 2 L ≥ 0, i = 1, • • • , L -1, we have L l=1 i̸ =l w 2 i = L l=1 w 2 w 2 l ≥ L-1 i=1 d (i) 2 + Lw 2-2/L 2 , (H.37) where w = L l=1 w l . Then we can prove the following: Theorem H.3. For weights {W l } L l=1 with homogeneous imbalance, we have λ min T {W l } L l=1 ≥ L-1 l=1 d(i) 2 + Lσ 2-2/L min (W ) 2 , W = L l=1 W l . (H.38) Proof. When all imbalance matrices are zero matrices, this is the balanced case (Arora et al., 2018b) and λ min T {W l } L l=1 = Lσ 2-2/L min (W ). Here we only prove the case when some d l ̸ = 0. Notice that given the homogeneous imbalance constraint W T l W l -W l+1 W T l+1 = d l I , W T l W l and W l+1 W T l+1 must be co-diagonalizable: If we have Q T Q = I such that Q T W T l W l Q is diagonal, then Q T W l+1 W T l+1 Q must be diagonal as well since Q T W T l W l Q -Q T W l+1 W T l+1 Q = d l I. Moreover, if the diagonal entries of Q T W T l W l Q are in decreasing order, then so are those of Q T W l+1 W T l+1 Q because the latter is the shifted version of the former by d l . This suggests that all W l , l = 1, • • • , L have the same rank and one has the following decomposition of the weights:  W l = Q l-1 Σ l Q T l , l ∈ R h l ×min{n,m} with Q T l Q l = I. (h 0 = n, h L = m). From such decomposition, we have W = W 1 • • • W L = Q 0 Σ 1 Q T 1 Q 1 Σ 2 Q T 2 • • • Q L-1 Σ L Q T L = Q 0 L l=1 Σ l Q T L , (H.40) thus σ min (W ) = L l=1 λ min (Σ l ) . (H.41) Regarding the imbalance, we have Q T l (W T l W l -W l+1 W T l+1 )Q l = d l I ⇒ Σ 2 l -Σ 2 l+1 = d l I , (H.42) which suggests that λ 2 min (Σ l ) -λ 2 min (Σ l+1 ) = d l , l = 1, • • • , L -1 . (H.43) Now consider the set of scalars {w l } L l=1 : w l = λ min (Σ l ), l = 1, • • • , l * -1 w l = λ min (Σ l+1 ), l = l * , • • • , L -1 w L = λ min (Σ l * ) . Then {w l } L l=1 satisfy the assumption in Lemma H.2: w 2 i -w 2 L = d(i) ≥ 0, i = 1, • • • , L -1 , (H.44) where d(i) is precisely the cumulative imbalance. Then Lemma H.2 gives ((H.41) is also used here) L l=1 i̸ =l w 2 i ≥ L-1 i=1 d(i) 2 + Lσ 2-2/L min (W ) 2 (H.45) Recall that T {W l } L l=1 E = L l=1 l-1 i=0 W i l-1 i=0 W i T E L+1 i=l+1 W i T L+1 i=l+1 W i , W 0 = I n , W L+1 = I m . For simplicity, define p.s.d. operators T l E := l-1 i=0 W i l-1 i=0 W i T E L+1 i=l+1 W i T L+1 i=l+1 W i , l = 1, • • • , L Then T {W l } L l=1 = L l=1 T l . Notice that the summand i̸ =l w 2 i exactly corresponds to one of λ min (T l ). For example, λ min (T 1 ) = λ min (W T L • • • W T 2 W 2 • • • W L ) = λ min Q T L L l=2 Σ 2 l Q L = i̸ =1 w 2 i . (H.46) More precisely, we have λ min (T l ) = i̸ =l w 2 i , l < l * λ min (T l ) = i̸ =l-1 w 2 i , l > l * λ min (T l ) = i̸ =L w 2 i , l = l * . Therefore, we finally have  λ min (T {W l } L l=1 ) ≥ L l=1 λ min (T l ) = L l=1 i̸ =l w 2 i ≥ L-1 i=1 d(i) λ n (W 1 W 2 • • • W L-1 W T L-1 • • • W T 2 W T 1 ) ≥ λ n (W L-1 W T L-1 ) • λ n (W 1 W 2 • • • W L-2 W T L-2 • • • W T 2 W T 1 ) ≥ λ n (W L-1 W T L-1 ) • λ n (W L-2 W T L-2 ) • λ n (W 1 W 2 • • • W L-3 W T L-3 • • • W T 2 W T 1 ) • • • ≥ L-1 i=1 λ n (W i W T i ) . Then it remains to show that λ n (W i W T i ) ≥ Therefore, we only need to show λ n (W L-1 W T L-1 ) ≥ λ n (D L-1 ) then the rest follows by the induction above. Indeed λ n (W L-1 W T L-1 ) = λ n (W T L-1 W L-1 ) = λ n (W L W T L + D L-1 ) ≥ λ n (D L-1 ) , which finishes the proof for the case of n = h 1 = h 2 = • • • = h L-1 . When the above assumptions does not hold, Lemma 4 allows us to related the set of weights {W l } L l=1 to the one { Wl } L l=1 that satisfy the equal dimension assumption. More specifically, apply Lemma 4 using each imbalance constraint D l = W T l W l -W l+1 W T l+1 ⪰ 0 , l = 1, • • • , L -1 , to obtain a Q l ∈ R h l ×n that has all the property in Lemma (4). Use these Q l , l = 1, • • • , L -1 to define Wl = Q T l-1 W l Q l , l = 1, • • • , L, Dl = W T l Wl -W T l+1 Wl+1 , l = 1, • • • , L -1 , where Q 0 = I, Q L = I. Now { Wl } L l=1 satisfies the assumption that n = h 1 = • • • = h L-1 , then λ n ( W1 W2 • • • WL-1 W T L-1 • • • W T 2 W T 1 ) ≥ L-1 i=1 L-1 l=i λ n ( Dl ) . (H.48) Using the properties of Q l ∈ R h l ×n , l = 1, • • • , L -1, we have λ n ( W1 W2 • • • WL-1 W T L-1 • • • W T 2 W T 1 ) = λ n (W 1 Q 1 Q T 1 W 2 Q 2 • • • Q T L-2 W L-1 Q L-1 Q T L-1 W T L-1 Q T L-2 • • • Q T 2 W T 2 Q 1 Q T 1 W T 1 ) = λ n (W 1 W 2 • • • W L-1 W T L-1 • • • W T 2 W T 1 ) , and L-1 i=1 L-1 l=i λ n ( Dl ) = L-1 i=1 L-1 l=i λ n (Q T l D l Q l ) = L-1 i=1 L-1 l=i λ n (D l ) . Therefore, (H.48) is exactly λ n (W 1 W 2 • • • W L-1 W T L-1 • • • W T 2 W T 1 ) ≥ L-1 i=1 L-1 l=i λ n (D l ) . (H.49) Proofs for Lemma 4. Since rank(A) ≤ r, A has a compact SVD A = P Σ A Q T such that Q ∈ R h×r and Q T Q = I r . This is exactly Q we are looking for. Let Q ⊥ Q T ⊥ = I h -QQ T be the projection onto the subspace orthogonal to the columns of Q. Then D = A T A -BB T ⇒ Q T ⊥ DQ ⊥ = Q T ⊥ A T AQ ⊥ -Q T ⊥ BB T Q ⊥ ⇒ Q T ⊥ DQ ⊥ + Q T ⊥ BB T Q ⊥ = 0 . Q T ⊥ DQ ⊥ and Q T ⊥ BB T Q ⊥ are two p.s.d. matrices whose sum is zero, which implies Q T ⊥ DQ ⊥ = 0, DQ ⊥ = 0, Q T ⊥ BB T Q ⊥ = 0, B T Q ⊥ = 0 . Q T ⊥ DQ ⊥ = 0 shows that the nullspace of D has at least dimension h -r, i.e., rank(D) ≤ r. Moreover AQQ T B = A(I h -Q ⊥ Q T ⊥ )B = AB AQQ T A T = A(I h -Q ⊥ Q T ⊥ )A T = AA T B T QQ T B = B T (I h -Q ⊥ Q T ⊥ )B = B T B The last equality B T B = B T QQ T B shows that rank(B) ≤ r. Lastly, we have, for i = 1, • • • , r, λ i (Q T DQ) = λ i (QQ T D) = λ i ((I h -Q ⊥ Q T ⊥ )D) = λ i (D) . Before proving Lemma H.2, we state a Lemma that will be used in the proof. Proof. This is from the fact that arithmetic mean of {x i } n i=1 is greater than the geometric mean of {x i } n i=1 . We are ready to prove Lemma H.2. Proof of Lemma H.2. We denote Therefore, when fixing {d (i) } L-1 i=1 , τ can be viewed as a function of w 2 L . When w = 0: one of w l must be zero, and because w 2 L has the least value among all the weights, we know w 2 L = 0. Then τ τ {w l } L i=1 = τ (0; {d (i) } L-1 i=1 ) = L-1 i=1 d (i) , i.e. we actually have equality when w = 0. When w ̸ = 0: then w 2 ̸ = 0 and we write ≥ L p L-1 1/L > 0 , and inverse function theorem (Rudin, 1953) shows the existence of differentiable inverse. Whenever, p -1 exists, it derivative is dw 2 L dp =   L l=1 i̸ =l w 2 L + d (i)   -1 = τ -1 . Now pick any 0 < p 0 ≤ w 2 we have, by Fundamental Theorem of Calculus, 



Note that A2 assumes µ-strong convexity, which implies A1 with γ = 2µ. However, we list A1 and A2 separately since they have different roles in our analysis. InMin et al. (2022), there is no general idea of lower bounding λmin T {W 1 ,W 2 } , but their analyses essentially provide such a bound. Nonlinear networks: While the crucial ingredient of our analysis, invariance of weight imbalance, no longer holds in the presence of nonlinearities such as ReLU activations,Du et al. (2018) shows the diagonal entries of the imbalance are preserved, andLe & Jegelka (2022) shows a stronger version of such invariance given additional assumptions on the training trajectory. Therefore, the weight imbalance could still be used to understand the training of nonlinear networks.



MODELS RELATED TO SCALAR DYNAMICS D.1 SPECTRAL INITIALIZATION UNDER l 2 LOSS The spectral initialization Saxe et al. (2014); Gidel et al. (2019); Tarmoun et al. (2021) considers the following: Suppose f = 1 2 ∥Y -XW ∥ 2 F and we have overparametrized model W = L l=1 W l . Additionally, we assume Y ∈ R N ×m , X ∈ R N ×n (n ≥ m) are co-diagonalizable, i.e. there exist P ∈ R N ×n with P T P = I n and Q ∈ O(m), R ∈ O(n) such that we can write the SVDs of Y, X as Y =

Combining those two results gets λ min T {W1,W2,W3} ≥ ∆ * (D 21 , D 23 )/2, as stated in Theorem 2.

F.22) from the fact that A ⪰ 1 D, and D ⪰ r-1 B. Now examine the right-hand side carefully: The first component tr(A) -tr

Denote D := A ∧ r B, we show A ⪰ 1 D and D ⪰ r-1 B. The following statements holds for any index i ∈ [n]. First of all, we have λ i (D) = min{λ i (A), λ i+1-r (B)} ≤ λ i (A) , (F.25) and λ i+1 (A) ≤ min{λ i (A), λ i+1-r (B)} = λ i (D) , (F.26) where λ i+1 (A) ≤ λ i+1-r (B) is from A ⪰ r B. (F.25)(F.26) together show A ⪰ 1 D.

and ∆ 23 = λ m (D 21 ) -λ h2+1-m (D 23 ) ∆ m (D 21 )(λ m (D 21 ) -λ h2+1-m (D 23 )) . l * = 3: D 23 ⪰ 0, D 21 ⪯ 0 Similar to previous cases, (by considering unimodal weights {W T 3 r (D 23 ) (λ n (D 23 ) -λ h1+1-n (D 21 )) . l * = 2: D 23 ⪯ 0, D 21 ⪯ 0 D 23 , D 21 being negative semi-definite implies rank(D 21 ) ≤ n, rank(D 23 ) ≤ m.

H.39) Here, Σ l , l = 1, • • • , L are diagonal matrix of size k = min{n, m} whose entries are in decreasing order. And Q

FOR AUXILIARY LEMMASProofs for Lemma 5. The proof is rather simple whenn = h 1 = h 2 = • • • = h L-1 : Notice that

λ n (D l ) for i = 1, • • • , L -1. Suppose λ n (W k W T k ) ≥ L-1 l=k λ l (D) for some k ∈ [L -1], then we have λ n (W k-1 W T k-1 ) = λ n (W T k-1 W k-1 ) = λ n (W k W T k + D k-1 ) ≥ λ n (W k W T k ) + λ n (D k-1 ) ≥ L-1 l=k λ n (D l ) + λ n (D k-1 ) =

Lemma H.4. Given positive x i , i = 1, • • • , n, we have

w 2 L + d (i) . Let d (L) = 0, we write the expression for τ asτ {w l } + d (i) := τ (w 2 L ; {d (i) } L-1 i=1 ) .

+ d (l) := p(w 2 L ; {d (i) } L-1 i=1 ) ,which shows w 2 is a function of w 2 L when {d (i) } L-1 i=1 are fixed. Here we use p to denote w 2 for simplicity. Moreover, function p: R ≥0 → R ≥0 has differentiable inverse p -1 as long as p > 0

τ 2 (p -1 (w 2 ); {d (i) } L-1 i=1 ) = τ 2 (p -1 (p 0 ); {d (i) }For the first part, we haveτ 2 (p -1 (p 0 ); {d (i) } L-p -1 (p 0 ) + d (i) p -1 (p 0 ) + d (i) Lw 2-2/L 2 .

Compare our rate bound with prior work on deep networks.

implies rank(D 23 ) ≤ m. (D 23 can only have negative, if non-zero, eigenvalues and any negative eigenvalue is contributed from W 3 W T 3 .) rank(D 23 ) ≤ m and D 23 ⪯ 0 together implies rank(W T 2 W 2 ) ≤ m (W T 2 W 2 having positive invariant subspace with dimension larger than m will give positive eigenvalue to D 23 ), which is equivalent to rank(W T 2 W 2 ) ≤ m. rank(W T 2 W 2 ) ≤ m forces rank(D 21 ) ≤ m. (D 22 can only have positive, if non-zero, eigenvalues and any positive eigenvalue is contributed from W T 2 W 2 .) In summary, we have rank(D 23 ) ≤ m and rank(D 21 ) ≤ m, which implies,

.2 LOWER BOUND ON λ min (T {W l } L l=1 ) UNDER HOMOGENEOUS IMBALANCE We need the following Lemma (proof in Section H.3): Lemma H.2. Given any set of scalars {w l } L l=1 such that d (i)

E PROOF FOR TWO-LAYER MODEL

Using Lemma 3, we can prove Theorem 1 Theorem 1 (Restated). Let D be the imbalance matrix for L = 2. The continuous dynamics in (3) satisfy L(t) -L * ≤ exp (-α 2 γt) (L(0) -L * ), ∀t ≥ 0 , (E.2) where 1. If f satisfies only A1, then α 2 = ∆ ;2. If f satisfies both A1 and A2, thenand W * equal to the unique optimizer of f .Proof. As shown in (5) in Section 2. We haveConsider any {W 1 (t), W 2 (t)} on the trajectory, we have, by Lemma 3,When f also satisfies A2: we need to proveWhen n = m, both inequalities are equivalent towhich is true by Lemma A.1.When n ̸ = m, one of the two inequalities become trivial. For example, if n > m, then (E.4) is trivially 0 ≥ 0, and (E.5) is equivalent towhich is true by Lemma A.1.

F.3 AUXILIARY LEMMAS

The main ingredient used in proving Theorem F.2 is the notion of r-interlacing relation between the spectrum of two matrices, which is a natural generalization of the interlacing relation as seen in classical Cauchy Interlacing Theorem (Horn & Johnson, 2012, Theorem 4.3.17) . Definition 4. Given real symmetric matrices A, B of order n, write A ⪰ r B, ifwhere λ j (•) = +∞, j ≤ 0 and λ j (•) = -∞, j > n. The case r = 1 gives the interlacing relation.Claim. We only need to checkfor showing A ⪰ r B.Proof. Any inequality regarding index outside [n] is trivial.The following Lemma is a direct concequence of Weyl's inequality (Horn & Johnson, 2012, Theorem 4.3.1) , and stated as a special case of (Horn & Johnson, 2012, Corollary 4.3. 3)The converse is also true Lemma F.4. Given real symmetric matrices A, B of order n, if A ⪰ r B, then there exists a positive semi-definite matrix XX T with rank(XX T ) ≤ r and a real orthogonal matrix V such thatProof. The case r = 1 is proved in (Horn & Johnson, 2012, Theorem 4.3.26) . The case r > 1 is proved in (Wang & Zheng, 2019 , Theorem 1.3) by induction.Specifically for our problem, we also need the following ( Dh1 and Dh2 are defined in Section 4) Lemma F.5. Given imbalance matrices pair (D 21 , D 23 ) ∈ R h1×h1 × R h2×h2 , we have Dh1 ⪰ n D 21 and Dh2 ⪰ m D 23 .In our analysis, the weights W 1 , W 2 , W 3 are "constrained" by the imbalance D 21 , D 23 , such constraints leads to some special eigenvalue bounds (The operation ∧ r was defined in Section 4): Lemma F.6. Given an positive semi-definite matrix A of order n, and Z ∈ R r×n with r ≤ n, when2 and this bound is actually tight Lemma F.7. Given two real symmetric matrices A, B of order n, if A ⪰ r B (r ≤ n), then there exist Z ∈ R r×n and some orthogonal matrixRemark 3. To see how Lemma F.6 is used, let A = W 2 W T 2 and Z = W 1 , B = D 21 , one obtain a lower bound on λ r (W 1 W T 1 ) that depends on the entire spectrum of W 2 W T 2 and D 21 . This bound is strictly better than λ r (W 2 W T 2 ) -λ 1 (D 21 ), the one from Weyl's inequality (Horn & Johnson, 2012) . This should not be suprising because we have "more information" on W 2 W T 2 and D 21 (entire spectrum v.s. certain eigenvalue).

Constructing optimal solution:

By Lemma F.5, we know Dh1 ⪰ n D 21 , and by Lemma F.7, there exists(F.14) and most importantly,Similarly, by Lemma Lemma F.5, we know Dh2 ⪰ m D 23 , and by Lemma F.7, there exists Z 3 ∈ R m×h2 and orthogonal (F.16) and most importantly,, and 

