CRITICAL POINTS AND CONVERGENCE ANALYSIS OF GENERATIVE DEEP LINEAR NETWORKS TRAINED WITH BURES-WASSERSTEIN LOSS

Abstract

We consider a deep matrix factorization model of covariance matrices trained with the Bures-Wasserstein distance. While recent works have made important advances in the study of the optimization problem for overparametrized low-rank matrix approximation, much emphasis has been placed on discriminative settings and the square loss. In contrast, our model considers another interesting type of loss and connects with the generative setting. We characterize the critical points and minimizers of the Bures-Wasserstein distance over the space of rank-bounded matrices. For low-rank matrices the Hessian of this loss can blow up, which creates challenges to analyze convergence of optimizaton methods. We establish convergence results for gradient flow using a smooth perturbative version of the loss and convergence results for finite step size gradient descent under certain assumptions on the initial weights.

1. INTRODUCTION

We investigate generative deep linear networks and their optimization using the Bures-Wasserstein distance. More precisely, we consider the problem of approximating a target Gaussian distribution with a deep linear neural network generator of Gaussian distributions by minimizing the Bures-Wasserstein distance. This problem is of interest in two important ways. First, it pertains to the optimization of deep linear networks for a type of loss that is qualitatively different from the well-studied and very particular square loss. Second, it can be regarded as a simplified but instructive instance of the parameter optimization problem in generative networks, specifically Wasserstein generative adversarial networks, which are currently not as well understood as discriminative networks. The optimization landscapes and the properties of parameter optimization procedures for neural networks are among the most puzzling and actively studied topics in theoretical deep learning (see, e.g. Mei et al., 2018; Liu et al., 2022) . Deep linear networks, i.e., neural networks having the identity as activation function, serve as a simplified model for such investigations (Baldi & Hornik, 1989; Kawaguchi, 2016; Trager et al., 2020; Kohn et al., 2022; Bah et al., 2021) . The study of linear networks has guided the development of several useful notions and intuitions in the theoretical analysis of neural networks, from the absence of bad local minima to the role of parametrization and overparametrization in gradient optimization (Arora et al., 2018; 2019a; b) . Many previous works have focused on discriminative or autoregressive settings and have emphasized the square loss. Although the square loss is indeed a popular choice in regression tasks, it interacts in a very special way with the particular geometry of linear networks (Trager et al., 2020) . The behavior of linear networks optimized with different losses has also been considered in several works (Laurent & Brecht, 2018; Lu & Kawaguchi, 2017; Trager et al., 2020) but is less well understood. The Bures-Wasserstein distance was introduced by Bures (1969) to study Hermitian operators in quantum information, particularly density matrices. It induces a metric on the space of positive semi-definite matrices. The Bures-Wasserstein distance corresponds to the 2-Wasserstein distance between two centered Gaussian distributions (Bhatia et al., 2019) . Wasserstein distances enjoy several properties, e.g. they remain well defined between disjointly supported measures and have duality formulations that allow for practical implementations (Villani, 2003) , that make them good candidates and indeed popular choices of a loss for learning generative models, with a well-known case being the Wasserstein Generative Adversarial Networks (GANs) (Arjovsky et al., 2017) . While the 1-Wasserstein distance has been most commonly used in this context, the Bures-Wasserstein distance has also attracted much interest, e.g. in the works of Muzellec & Cuturi (2018) ; Chewi et al. (2020) ; Mallasto et al. (2022) , and has also appeared in the context of linear quadratic Wasserstein generative adversarial networks (Feizi et al., 2020) . A 2-Wasserstein GAN is a minimum 2-Wasserstein distance estimator expressed in Kantorovich duality (see details in Appendix B). This model can serve as an attractive platform to develop the theory particularly when the inner problem can be solved in closed-form. Such a formula is available when comparing pairs of Gaussian distributions. In the case of centered Gaussians this corresponds precisely to the Bures-Wasserstein distance. Strikingly, even in this simple case, the optimization properties of the corresponding problem are not well understood; which we aim to address in the present work.

1.1. CONTRIBUTIONS

We establish a series of results on the optimization of deep linear networks trained with the Bures-Wasserstein loss, which we can summarize as follows. • We obtain an analogue of the Eckart-Young-Mirsky theorem characterizing the critical points and minimizers of the Bures-Wasserstein distance over matrices of a given rank (Theorem 4.2). • To circumvent the non-smooth behaviour of the Bures-Wasserstein loss when the matrices drop rank, we introduce a smooth perturbative version (Definition 5 and Lemma 3.3), and characterize its critical points and minimizers over rank-constrained matrices (Theorem 4.4) and link them to the critical points on the parameter space (Proposition 4.5). • For the smooth Bures-Wasserstein loss, in Theorem 5.6 we show exponential convergence of the gradient flow assuming balanced initial weights (Definition 2.1) and a uniform margin deficiency condition (Definition 5.2). • For the Bures-Wasserstein loss and its smooth version, in Theorem 5.7 we show convergence of gradient descent provided the step size is small enough and assuming balanced initial weights.

1.2. RELATED WORKS

Low rank matrix approximation The function space of a linear network corresponds to n × m matrices of rank at most d, the lowest width of the network. Hence optimization in function space is closely related to the problem of approximating a given data matrix by a low-rank matrix. When the approximation error is measured in Frobenius norm, Eckart & Young (1936) characterized the optimal bounded-rank approximation of a given matrix in terms of its singular value decomposition. Mirsky (1960) obtained the same characterization for the more general case of unitary invariant matrix norms, which include the Euclidean operator norm and the Schatten-p norms. There are generalizations to certain weighted norms (Ruben & Zamir, 1979; Dutta & Li, 2017) . However, for general norms the problem is known to be difficult (Song et al., 2017; Gillis & Vavasis, 2018; Gillis & Shitov, 2019) . Loss landscape of deep linear networks For the square loss, the optimization landscape of linear networks has been studied in numerous works. The pioneering work of Baldi & Hornik (1989) showed, focusing on the two-layer case, that there is a single minimum (up to a trivial parametrization symmetry) and all other critical points are saddle points. Kawaguchi (2016) obtained corresponding results for deep linear nets and showed the existence of bad saddles (with no negative Hessian eigenvalues) for networks with more than three layers. Chulhee et al. (2018) found sets of parameters such that any critical point in this set is a global minimum and any critical point outside is a saddle. Variations include the study of critical points for different types of architectures, such as deep linear residual networks (Hardt & Ma, 2017) and deep linear convolutional networks (Kohn et al., 2022) . For losses different from the square loss there are also several results. Laurent & Brecht (2018) showed that deep linear nets with no bottlenecks have no local minima that are not global for arbitrary convex differentiable loss. Lu & Kawaguchi (2017) showed that if the loss is such that any local minimizer in parameter space can be perturbed to an equally good minimizer with full-rank factor matrices, then all local minima in parameter space are local minima in function space. Trager et al. (2020) found that for linear networks with arbitrarily rank-constrained function space, only for the square loss does one have the nonexistence of non-global local minima. However, for arbitrary convex losses, non-global local minima when they exist are always pure, meaning that they correspond to local minima in function space. Optimization dynamics of deep linear networks Saxe et al. (2014) studied the learning dynamics of deep linear networks under different classes of initial conditions. Arora et al. (2019b) obtained a closed-form expression for the parametrization along time in a deep linear network for the square loss. Notably, the authors found that solutions with a lower rank are preferred as the depth of the network increases. Arora et al. (2018) derive several invariances of the flow and compare the dynamics in parameter and function space. For the square loss Arora et al. (2019a) proved linear convergence for linear networks with no bottlenecks, approximately balanced initial weights, and initial loss smaller than for any rank-deficient solution. A detailed analysis of the dynamics in the shallow case with square loss was conducted by Tarmoun et al. (2021) ; Min et al. (2021) including symmetric factorization. The role on inbalance was remarked in those works. For the deep case, also focusing on the square loss, Bah et al. (2021) showed the gradient flow converges to a critical point and a global minimizer on the manifold of fixed rank matrices of some rank. More recently, Nguegnang et al. (2021) extended this analysis to obtain corresponding results for gradient descent. Bures-Wasserstein distance Chewi et al. (2020) studied the convergence of gradient descent algorithms for the Bures-Wasserstein barycenter, proving linear rates of convergence for the gradient descent. In contrast to our work, they consider a Polyak-Łojasiewicz inequality derived from the optimal transport theory to circumvent the non geodesical convexity of the barycenter. In the same vein, Muzellec & Cuturi (2018) exploit properties of optimal transport theory to optimize the distance between two elliptical distributions. To avoid rank deficiency, they perturbed the diagonal elements of the covariance matrix by a small parameter. Feizi et al. (2020) characterized the optimal solution of a 2-Wasserstein GAN with rank-k linear generator as the k-PCA solution. We will obtain an analogue result in our settings, along with a description of critical points.

1.3. NOTATIONS

We adopt the following notations throughout the paper. For any n ∈ N, denote [n] = {1, 2, . . . , n}. Let S(n) be the spaces of real symmetric matrices of size n. We denote by S + (n) (resp. S ++ (n)) the space of real symmetric positive semi-definite (resp. definite) matrices of size n. Given k ⩽ n, the set of rank k positive semi-definite matrices is denoted by S + (k, n). We use M k (resp. M ⩽k ) to denote the set of matrices of rank exactly k (resp. of rank at most k), with the format of the matrix understood from context. The scalar product between two matrices A, B ∈ R n×m is ⟨A, B⟩ = tr A ⊤ B, and its associated (Frobenius) norm is noted ∥•∥ F . The identity matrix of size n will be denoted I n , or I when n is clear. For a (Fréchet) differentiable function f : X → Y , we denote its differential at x ∈ X in the direction v by df (x) [v] . Finally, Crit(f ) gives the set of critical points of f , i.e. the set of points at which the differential of f is 0.

2. LINEAR NETWORKS AND THEIR GRADIENT DYNAMICS

We consider a linear network with d 0 inputs and N layers of widths d 1 , . . . , d N , which is a model of linear functions of the form x → W N • • • W 1 x, parametrized by the weight matrices W j ∈ R dj ×dj-1 , j ∈ [N ]. We will denote the tuple of weight matrices by -→ W = (W 1 , . . . , W N ) and the space of all such tuples by Θ. This is the parameter space of our model. To slightly simplify the notation we will also denote the input and output dimensions by m = d 0 and n = d N , respectively, and write W = W N • • • W 1 for the end-to-end matrix. For any 1 ⩽ i ⩽ j ⩽ N , we will also write W j:i = W j • • • W i for the matrix product of layer i up to j. We note that the represented function is linear in the network input x, but the parametrization is not linear in the parameters -→ W . We denote the parametrization map by µ : Θ → R d N ×d0 ; -→ W = (W 1 , . . . , W N ) → W N :1 = W N • • • W 1 . The function space of the network is the set of linear functions it can represent, corresponds to the set of possible end-to-end matrices, which are the n × m matrices of rank at most d := min{d 0 , . . . , d N }. When d = min{d 0 , d N }, the function space is a vector space, but otherwise, when there is a bottleneck so that d < d 0 , d N , it is a non-convex subset of R m×n determined by polynomial constraints, namely the vanishing of the (d + 1) × (d + 1) minors. Next we collect a few results on the gradient dynamics of linear networks for general differentiable losses, which have been established in previous works even when in some cases the focus was on the on the square loss (Kawaguchi, 2016; Bah et al., 2021; Chitour et al., 2022; Arora et al., 2018) . In the interest of conciseness, here we only provide the main takeaways and defer a more detailed discussion to Appendix C. For the remainder of this section let L 1 : R n×m → R be any differentiable loss and let L N be defined through the parametrization µ as L N ( -→ W ) = L 1 • µ( -→ W ). For any such loss, the gradient flow t → -→ W (t) is defined by d -→ W (t) dt = -∇L N ( -→ W (t)) ⇐⇒ ∀j ∈ [N ], dW j (t) dt = -∇ Wj L N (W 1 (t), . . . , W N (t)). (GF) This governs the evolution of the parameters during gradient minimization of the loss. Further, we observe that the gradient of L N is given by ∇ Wj L N (W 1 , . . . , W N ) = W ⊤ j+1 • • • W ⊤ N ∇L 1 (W )W ⊤ 1 • • • W ⊤ j-1 for all j ∈ {1, . . . , N }. (1) As it turns our, the gradient flow dynamics preserves the difference of the Gramians of subsequent layer weight matrices, which are thus invariants of the gradient flow, d dt (W j+1 (t) ⊤ W j+1 (t)) = d dt (W j (t)W j (t) ⊤ ). The important notion of balancedness for the weights of linear networks was first introduced by Fukumizu (1998) in the shallow case and generalized to the deep case by Du et al. (2018) . This is useful in particular as a way of removing the redundancy of the parametrization when investigating the dynamics in function space and has been considered in numerous works. Definition 2.1 (Balanced weights). The weights W 1 , . . . , W N are said to be balanced if, for all j ∈ [N -1], W j W ⊤ j = W ⊤ j+1 W j+1 . Assuming balanced initial weights, if the flow of each W j is defined and bounded, then the rank of the end-to-end matrix W remains constant during training (Bah et al., 2021, Proposition 4.4 ). Moreover, the products W N :1 W ⊤ N :1 and W ⊤ N :1 W N :1 can be written in a concise manner; namely, we have W N :1 W ⊤ N :1 = (W N W ⊤ N ) N and W ⊤ N :1 W N :1 = (W ⊤ 1 W 1 ) N , which simplifies many computations. Remark 2.2. In order to relax the balanced initial weights assumption, some works also consider approximate balancedness (Arora et al., 2019a) , which requires only that there exists δ > 0 such that ∥W j W ⊤ j -W ⊤ j+1 W j+1 ∥ F ≤ δ for j ∈ [N -1]. We will use exactly balanced initialization in our proofs, but they would also go through by invoking approximate balancedness. Another alternative initialization has been proposed by Yun et al. (2021) . We defer such extensions to future work favoring here the discussion of the Bures-Wasserstein loss.

3.1. THE BURES-WASSERSTEIN LOSS

The Bures-Wasserstein (BW) distance is defined on the space of positive semi-definite matrices (or covariance space) S + (n) . Here we collect some of the key properties and discuss the gradient, and refer the reader to Bhatia et al. (2019) for further details on this interesting distance. Definition 3.1 (Bures-Wasserstein distance). Given two symmetric positive semidefinite matrices Σ 0 , Σ ∈ S + (n), the squared Bures-Wasserstein distance between Σ 0 and Σ is defined as B 2 (Σ, Σ 0 ) = tr Σ + Σ 0 -2(Σ 1/2 0 ΣΣ 1/2 0 ) 1/2 . (2) (Kroshnin et al., 2021, Lemma A.3) shows that the square root is differentiable on the set of positive definite matrices and as a consequence we can differentiate the BW distance at Σ 0 , Σ ∈ S ++ (n). However, the mapping Σ → B 2 (Σ, Σ 0 ) is not differentiable over all of R n×n . Indeed, letting ΓQΓ ⊤ be a spectral decomposition of Σ 1/2 0 ΣΣ 1/2 0 , (2) can be written as B 2 (Σ, Σ 0 ) = ∥Σ 1/2 ∥ 2 F + ∥Σ 1/2 0 ∥ 2 F -2 tr Q 1/2 . (3) Due to the square root on Q, the map Σ → B 2 (Σ, Σ 0 ) is not differentiable when the rank of Σ 1/2 0 ΣΣ 1/2 0 , i.e. the number of positive eigenvalues of Q, changes. More specifically, while one can compute the gradient over the set of matrices of rank k for any given k, the norm of the gradient blows up if the matrix changes rank. The gradient of B 2 restricted to the set of full-rank matrices is given inAppendix B.

3.2. LINEAR WASSERSTEIN GAN

The distance defined in (2) corresponds to the 2-Wasserstein distance between two zero-centered Gaussians and can be used as a loss for training models of Gaussian distributions and in particular generative linear networks. Recall that zero-centered Gaussian distributions are completely specified by their covariances. Given a bias-free linear network and a latent Gaussian probability measure N (0, I m ), a linear network generator computes a push-forward of the latent distribution, which is again a Gaussian distribution. If Z ∼ N (0, I m ) and X = W Z, then X ∼ N (0, W W ⊤ ) =: ν, Given a target distribution ν 0 = N (0, Σ 0 ) or simply the covariance matrix Σ 0 (which may be a sample covariance matrix), one can select W by minimizing B 2 (W W ⊤ , Σ 0 ) = W 2 2 (ν, ν 0 ) so that the network approximates the distribution N (0, Σ 0 ). We will denote the map that takes the end-toend matrix W to the covariance matrix W W ⊤ by π : R n×m → R n×n ; W → W W ⊤ . Loss in covariance, function, and parameter spaces We consider the following losses, which differ only on the choice of the search variable, taking a function space or a parameter space view. • First, we denote the loss over covariance matrices Σ ∈ S + (n) as L : Σ → B 2 (Σ, Σ 0 ) . • Secondly, given π : W → W W ⊤ ∈ S + (n) , we define the loss in function space, i.e., over end-to-end matrices W ∈ R n×m as L 1 : W → L • π(W ). This is given by ∀ W ∈ R n×m , L 1 (W ) = tr W W ⊤ + Σ 0 -2(Σ 1/2 0 W W ⊤ Σ 1/2 0 ) 1/2 . ( ) This loss is not convex on R n×m , as can be seen in the one-dimensional case. • Lastly, for a tuple of weight matrices -→ W = (W 1 , . . . , W N ), we compose L 1 with the parametrization map µ : -→ W → W N :1 , to define the loss in parameter space as L N : -→ W → L • π • µ( -→ W ), for -→ W ∈ Θ. Observe that this is, again, a non-convex loss. Thus, for -→ W ∈ Θ, L N ( -→ W ) = L 1 (µ( -→ W )) = L(π • µ( -→ W )) = B 2 (π • µ( -→ W ), Σ 0 ). While the gradient flow (GF) is defined on the parameters -→ W , the covariance space perspective is useful since it leads to a convex objective function, even if this may be subject to non-convex constraints. One of our goals will be to translate properties between L, L 1 , and L N . Smooth perturbative loss As mentioned before, the Bures-Wasserstein loss is non-smooth at matrices with vanishing singular values. In turn, the usual analysis tools to prove uniqueness and convergence of the gradient flow do not apply for this loss. Therefore, we introduce a smooth perturbative version. Consider the perturbation map φ τ : Σ → Σ + τ I n , where τ > 0 plays the role of a regularization strength. Then the perturbative loss on the covariance space can be expressed as Lτ = L • φ τ , and the perturbative loss on function space as L 1 τ = Lτ • π. More explicitly, we let L 1 τ (W ) = tr W W ⊤ + τ I n + Σ 0 -2(Σ 1/2 0 (W W ⊤ + τ I n )Σ 1/2 0 ) 1/2 . ( ) This function is smooth enough to apply usual convergence arguments for the gradient flow. Likewise, L N τ := Lτ • π • µ is well-defined and smooth on Θ. Remark 3.2. The perturbative loss (5) is differentiable. Many results from Bah et al. (2021) can be carried over for the differentiable Bures-Wasserstein loss. For example, the uniform boundedness of the end-to-end matrix holds, ∥W (t)∥ ⩽ 2L 1 (W (0)) + tr Σ 0 . Similar observations may apply for the case of L 1 in the case that the matrix W W ⊤ remains positive definite throughout training, in which case the gradient flow remains well defined and the loss is monotonically decreasing. We expand on this in Appendix C. The next lemma, proved in Appendix B.4, provides a quantitative bound for the difference between the original and the perturbative loss. For this lemma, we use the fact that the rank is constant. To compare the two losses, we fix the parameters to be the same. Recall that Σ τ = W W ⊤ + τ I. Lemma 3.3. Let W ∈ R n×m , τ > 0, and let Σ τ = W W ⊤ + τ I n . Assume that rank(W W ⊤ ) = r, rank Σ 0 = n, L 1 (W ) is given by (4) , and L(Σ τ ) is given by (5). Then | L(Σ τ ) -L 1 (W )| ≤ τ n + rτ 1/2 . ( ) We observe that the upper bound ( 6) is tight, since it goes to zero as τ goes to zero.

4. CRITICAL POINTS

In this section, we characterize the critical points of the Bures-Wasserstein loss restricted to matrices of a given rank. The proofs of results in this section are given in Appendix D. For k ∈ N, denote by M k the manifold of rank-k matrices, i.e. M k = {W ∈ R n×m | rank W = k}. Similarly, denote M ⩽k the set of matrices of rank at most k. The format n × m of the matrices is to be inferred from the context. The manifold M k is viewd as an embedded submanifold of the linear space (R n×m , ⟨•, •⟩ F ), with codimension (n -k)(m -k) (Boumal 2022, §2.6; Uschmajew & Vandereycken 2020, §9.2.2). Given a function f : R n×m → R, its restriction on M k is denoted by f | M k : M k ∋ W → f (W ). Even if a function f is not differentiable over all of R n×m , its restriction on M k may be differentiable. Definition 4.1. Let M be a smooth manifold. Let f : R n×m → R be any function such that its restriction on M is differentiable. A point W ∈ M is said to be a critical point for f | M if the differential of f | M at W is the zero function, df | M (W ) = 0. 4.1 CRITICAL POINTS OF L 1 OVER M k Given a matrix A ∈ R n×n and a set J k ⊆ [n] of k indices, we denote by A J k ∈ R n×k the submatrix of A consisting of the columns with index in J k . If the matrix A is diagonal, we denote by ĀJ k ∈ R k×k the diagonal submatrix which extracts the rows and columns with index in J k . The following result characterizes the critical points of the loss in function space. It can be regarded as a type of Eckart-Young-Mirsky result for the case of the Bures-Wasserstein loss. Theorem 4.2 (Critical points of L 1 ). Assume Σ 0 has n distinct, positive eigenvalues. Let Σ 0 = ΩΛΩ ⊤ be a spectral decomposition of Σ 0 (so Ω ∈ U (n)). Let k ∈ [min {n, m}]. A matrix W * ∈ M k is a critical point of L 1 | M k if and only if W * = Ω J k Λ1/2 J k V ⊤ for some J k ⊆ [n] and V ∈ R m×k with V ⊤ V = I k . In particular, the number of critical points is n k . The minimum is attained for J k = [k]. In particular, inf M k L 1 (W ) = min M k L 1 (W ) and min M k L 1 (W ) = min M ⩽k L 1 (W ). The proof relies on an expression of the gradient ∇L 1 | M k (see Lemma D.3) and evaluating its zeros. The value of the loss evaluated at these critical points allows to rank them and identify those that are minimal. Remark 4.3. Interestingly, the critical points and the minimizer of L 1 characterized in the above result agree with those of the square loss (Eckart & Young, 1936; Mirsky, 1960) . Nonetheless, we observe that (2) is only defined for positive semidefinite matrices. Hence the notion of unitary invariance considered by Mirsky (1960) here only makes sense for left and right multiplication by the same matrix. Moreover, while we can establish unitary invariance for a variational extension of the distance, this still is not a norm in the sense that there is no function B : R n×n → R such that B(Σ, Σ 0 ) = B(Σ -Σ 0 ), and hence it does not fall into the framework of Mirsky (1960) . We offer more details about this in Appendix B.

4.2. CRITICAL POINTS OF THE PERTURBATIVE LOSS

For the critical points of the perturbative loss L 1 τ (W ) we obtain the following results. Theorem 4.4 (Critical points of L 1 τ ). Let Σ 0 = ΩΛΩ ⊤ be a spectral decomposition of Σ 0 . A point W * ∈ M k is a critical point for L 1 τ if and only if there exists a subset J k ⊆ [n] and a semi- orthogonal matrix V ∈ R n×k (i.e., so that V ⊤ V = I) such that W * = Ω J k ( ΛJ k -τ I k ) 1/2 V ⊤ . The (unique) minimum over M ⩽k is attained for J k = [k] Note that the above characterization of the critical points imposes an upper bound on τ . In other words, for a given W * to be a critical point, one must have that τ < λ j for all j ∈ J k , because the eigenvalues of ΛJ k -τ I k are positive. In order to link the critical points in parameter space to the critical points in the function space, we appeal to the correspondence drawn in Trager et al. (2020, Propositions 6 and 7) . For the Bures-Wasserstein loss, this allows to conclude the following. Proposition 4.5. Assume a full-rank target Σ 0 , with distinct, decreasing eigenvalues, and spectral decomposition Σ 0 = ΩΛΩ ⊤ . Let τ ∈ (0, λ n ]. If -→ W * ∈ Crit(L N τ ), then Σ * = π(µ( -→ W * )) is a critical point of the loss Lτ | S+(k,n) , where k = rank Σ * . Moreover, if k = d, then -→ W is a local minimizer for the loss L N τ if and only if Σ * = π(µ( -→ W * ) ) is a local minimizer, and therefore the global minimizer, of the loss Lτ | S+ (d,n) . In this case, Σ * τ = Σ * + τ I n is the τ -best d-rank approximation of the target in the covariance space, in the sense that Σ * τ = Ω Λ [d] τ Ω ⊤ . Proposition 4.5 ensures that, under the assumption that the solution of the gradient flow is a (local) minimizer and has the highest possible rank d given the network architecture, the solution in the covariance space is the best d-rank approximation of the target in the sense of the τ -smoothed Bures-Wasserstein distance. Remark 4.6. Under the balancedness assumption, one can show that the rank of the end-to-end matrix does not drop during training (Bah et al., 2021, Proposition 4.4) , and that one escape almost surely the strict saddle points (Bah et al., 2021, Theorem 6.3 ). If the initialization of the network has rank d, the matrices W (t), t > 0, maintain rank d throughout training. There can be issues in the limit, since M d is not closed. Proving that the limit point also belongs to M d is an interesting open problem that we leave for future work.

5. CONVERGENCE ANALYSIS

The Bures-Wasserstein distance can be viewed through the lens of the Procrustes metric (Dryden et al., 2009; Masarotto et al., 2019) . In fact, it can be obtained by the following minimization problem. Lemma 5.1 (Bhatia et al. 2019, Theorem 1) . For Σ, Σ 0 ∈ S + (n), B 2 (Σ, Σ 0 ) = min U ∈U (n) ∥Σ 1/2 -Σ 1/2 0 U ∥ 2 F , where U (n) denotes the n × n orthogonal matrix group. Moreover, the minimizer Ū occurs in the polar decomposition of Σ 1/2 Σ 1/2 0 . We emphasize that in the above description of the Bures-Wasserstein distance, the minimizer Ū depends on W , so that B 2 fundamentally differs from a squared Frobenius norm. Moreover, the square root on Σ θ can lead to singularities when differentiating the loss. Nonetheless, the expression ( 7) can be used to avoid such singularities, by leveraging the following deficiency margin concept. Definition 5.2 (Modified deficiency margin). Given a target matrix Σ 0 ∈ R n×n and a positive constant c > 0, we say that Σ θ ∈ R n×n has a modified deficiency margin c with respect to Σ 0 if min U ∈U (n) ∥Σ 1/2 θ -Σ 1/2 0 U ∥ F ≤ σ min (Σ 1/2 0 ) -c. With a slight abuse of denomination, we will say that W has a uniform deficiency margin if W W ⊤ has one. This deficiency margin idea can be traced back to Arora et al. (2019a) . Note that we can write √ W W ⊤ = Σ 1/2 θ , and this square root can be realized by Cholesky decomposition. Notice that if we initialize close to the target then the above bound (8) holds trivially. In fact, if the initial condition W (0) satisfies the uniform deficient margin, then we have that the least singular value of W (k) remains bounded below by c for all times k ≥ 0, for the gradient flow or gradient descent with decreasing L N : Lemma 5.3. Suppose W (0)W (0) ⊤ has a modified deficiency margin c with respect to Σ 0 . Then σ min W (k)W (k) ⊤ ≥ c, for k ≥ 0. ( ) The proof of this and all results in this section are provided in Appendix E. We note that, while the modified margin deficiency assumption is sufficient for Lemma 5.3 to hold, it is by no means necessary. We will assume that the modified margin deficiency assumption holds for the simplicity of exposition, but the gradient flow analysis in the next paragraph only requires the less restrictive Lemma 5.3 to hold. Convergence of the gradient flow for the smooth loss Because we cannot exclude the possibility that the rank of W W ⊤ drops along the gradient flow of the BW loss, we consider the smooth perturbation as a way to avoid singularities. We consider the gradient flow (GF) for the perturbative loss (5). The gradient of ( 5) is given by ∇L 1 τ (W ) = 2 W -Σ 1/2 0 Σ 1/2 0 (W W ⊤ + τ I n )Σ 1/2 0 -1/2 Σ 1/2 0 W . On the other hand, we may denote Σ τ := W W ⊤ + τ I n as a regularized covariance matrix, and express the L 1 loss in terms of the optimal transport plan between Σ τ and Σ 0 (Kroshnin et al., 2021) . We have L(Σ τ ) = tr Σ τ + Σ 0 -2(Σ 1/2 0 Σ τ Σ 1/2 0 ) 1/2 = ∥ T Σ0 Στ -I Σ τ ∥ 2 F = tr T Σ0 Στ -I Σ τ tr T Σ0 Στ -I , where T Σ0 Στ = Σ 1/2 0 Σ 1/2 0 Σ τ Σ 1/2 0 -1/2 Σ 1/2 0 = Σ -1/2 τ Σ 1/2 τ Σ 0 Σ 1/2 τ 1/2 Σ -1/2 τ . The perturbation τ I n ensures strict convexity as shown in the following result. Lemma 5.4. The function Σ τ → L(Σ τ ) is strictly convex on S ++ (n). Proof. First we observe that the function f (X) = tr X 1/2 is strictly concave on S ++ (n); for details we refer the reader to Bhatia et al. (2019, Theorem 7) . As a result, the function Σ τ → L(Σ τ ) = tr Σ 0 + tr Σ τ -2 tr Σ 1/2 0 Σ τ Σ 1/2 0 1/2 is convex on S ++ (n). Then Σ τ → tr Σ 1/2 0 Σ τ Σ 1/2 0 1/2 is an injective linear map since Σ τ is positive definite matrix. This means that L is strictly convex. What's more, the loss L is strongly-convex on S ++ (n), as stated in the next lemma. Lemma 5.5. The Hessian of the loss L satisfies ∇ 2 Στ L(Σ τ ) ≽ K • I n for any Σ τ ∈ S ++ (n), with K = √ τ λmin(Σ0) 2C 2 0 , where C 0 = 2( Lτ (Σ(0)) + tr Σ 0 ). Let us denote the minimizer of the perturbative loss L(Σ τ ) by Σ * τ . Let ∆ * 0 = Σ τ (0) -Σ * τ be the distance of the initialization from the optimal solution. Equipped with the strict convexity by Lemma 5.4, we are ready to show that the gradient flow has convergence rate O(e -Kc,N Kt ), where K is the constant from the Hessian bound given by Lemma E.5, and Kc,N is a constant which depends on the modified margin deficiency and the depth of the linear neural network. Recall that Σ τ = W N :1 W ⊤ N :1 + τ I n , so we prove convergence of gradient flow for the parametrization -→ W . Theorem 5.6 (Convergence of gradient flow). Assume both balancedness (Definition 2.1) and the modified uniform deficiency margin conditions (Definition 5.2). Then the gradient flow (GF) converges as L(Σ τ (t)) -L(Σ * τ ) ≤ e -8N c 2(2N -1) N Kt ∆ * 0 , where K = √ τ λmin(Σ0) 2C 2 0 is the strong convexity parameter from Lemma 5.5, with C 0 = 2( L(Σ τ (0)) + tr(Σ 0 )). Convergence of gradient descent for the BW loss Assuming that the initial end-to-end matrix W have a uniform deficiency margin, we can establish the following convergence result for gradient descent with finite step sizes, which is valid for the perturbed loss and also for the non-perturbed original loss. Given an initial value -→ W (0), we consider the gradient descent iteration -→ W (k + 1) = -→ W (k) -η∇L N ( -→ W (k)), k = 0, 1, . . . , GD) where η > 0 is the learning rate or step size and the gradient is given by (1). Theorem 5.7 (Convergence of gradient descent). Assume that the initial values W i (0), 1 ≤ i ≤ N , are balanced and are so that W (0) = W N :1 (0) has uniform deficiency margin c. If the learning rate η > 0 satisfies η ≤ min    c 2 8M L 1 (W (0)) , N c 2(N -1) N 2∆ , 1 4N c 2(N -1) N    , where ∆ = 2 N +1 c 2N N 2 M (4N -3)/N λ 1/2 max (Σ 0 ) + 8N (N -1)M (3N -4)/N M 1/N + ∥Σ 1/2 0 ∥ F , M = 2 L 1 (W (0)) + ∥Σ 1/2 0 ∥ 2 F , then, for any ϵ > 0, one can achieve ϵ loss by the gradient descent (GD) at iteration k ≥ 1 2ηN c 2(N -1) N log L 1 (W (0)) ϵ . Remark 5.8. Our Theorems 5.6 and 5.7 show that the depth of the network can accelerate the convergence of the gradient algorithms.

6. CONCLUSION

In this work, we studied the training of generative linear neural networks using the Bures-Wasserstein distance. We characterized the critical points and minimizers of this loss in function space or over the set of matrices of fixed rank k. We introduced a smooth approximation of the BW loss obtained by regularizing the covariance matrix and also characterized its critical points in function space. Furthermore, under the assumption of balanced initial weights satisfying a uniform deficiency margin condition, we established a convergence guarantee to the global minimizer for the gradient flow of the perturbative loss with exponential rate of convergence. Finally, we also considered the finite-step size gradient descent optimization and established a linear convergence result for both the original and the perturbed loss, provided the step size is small enough depending on the uniform margin deficiency condition. We collect our results in Table 1 . These results contribute to the ongoing efforts to better characterize the optimization problems that arise in learning with deep neural networks beyond the commonly considered discriminative settings with the square loss. In future work, it would be interesting to refine our characterization of critical points of the Bures-Wasserstein loss in parameter space. Moreover, the uniform margin deficiency condition that we invoked in order to establish our convergence results constrains the parametrization to be of full rank. Relaxing this assumption is an interesting endeavor to pursue.

Loss Parametrization Critical points Initialization Convergence

L 1 W N :1 Ω J k Λ1/2 J k V ⊤ - - L 1 τ W N :1 Ω J k ( ΛJ k -τ I k ) 1/2 V ⊤ - - Lτ Σ τ Σ 0 Balanced, MDM GF: Exponential L 1 W N :1 ΩΛV ⊤ Balanced, MDM GD: O(log(1/ϵ)) Table 1 : Summary of the results. The target is assumed full rank, with distinct eigenvalues, and spectral decomposition Σ 0 = ΩΛΩ ⊤ . The end-to-end matrix is W N :1 , and the regularized covariance is Σ τ . V ∈ R m×k is any semi-orthogonal matrix, and J k ⊂ [n] is an index set. Balanced stands for balanced weights (Definition 2.1), MDM stands for modified deficiency margin (Definition 5.2).

APPENDIX

The appendix is organized as follows. • Appendix A presents a table summarizing the different geometrical and convergence results. • Appendix B discusses the background on the Bures-Wasserstein Loss and related Optimal Transport topics. • Appendix C presents some general properties of a linear neural network and classical results on convergence in the parameter space. • Appendix D gathers the proofs that were omitted in Section 4. • Finally, Appendix E presents the proofs from Section 5.

A SUMMARY OF THE RESULTS

Table 1 presents a summary of the different results obtained in this paper. Note that, even if the different losses are expressed in the function or the covariance spaces, the gradient flow and gradient descent algorithm are performed on the parameter space Θ.

B BACKGROUND ON THE BURES-WASSERSTEIN

DISTANCE B.1 DEFINITION OF W 2 2 The Bures-Wasserstein distance has a natural connection with the 2-Wasserstein distance on a metric space. In the case of zero-centered Gaussian measures, those distances are identical. We present here a more general definition of the 2-Wasserstein distance, which enjoys desirable properties. Given a metric space (X , ∥•∥), the 2-Wasserstein distance is a well-established metric on the space of quadratically integrable probability measures P 2 (X ). Definition B.1 (2-Wasserstein distance). Given two quadratically integrable measures (ν, ν 0 ) ∈ (P 2 (X )) 2 the 2-Wasserstein distance is defined as the following minimization problem W 2 2 (ν, ν 0 ) = inf π∈Π(ν,ν0) ∥x -y∥ 2 dπ(x, y), where Π(ν, ν 0 ) is the set of fixed marginals distributions: Π(ν, ν 0 ) = {π ∈ P 2 (X × X ) | π 1 = ν, π 2 = ν 0 }, with π i the marginal of π along the ith variable. It is known that this distance metrizes the weak convergence on the space P 2 , see e.g. (Villani, 2008 , Theorem 6.9), and can therefore be leveraged when designing a system that relies on comparing probability distributions such as a GAN. On the other hand, the computational burden of such a loss can quickly become prohibitive (Pele & Werman, 2009) . In a very few cases, efficient computations can be done for the loss (12). This constrasts with a usual WGAN, which was first introduced by Arjovsky et al. (2017) , where the loss is computed using a neural network, based on the dual expression of the (1-)Wasserstein distance. Indeed, between two Gaussian measures, the 2-Wasserstein distance has a closed-form expression or a closed-form expression for the discriminator so that adversarial training is not needed. We will consider two centered Gaussian distributions, which are described by their covariance matrices. In the case of centered Gaussian distributions, the 2-Wasserstein distance reduces to the Bures-Wasserstein distance between the covariance matrices Σ 0 and Σ (Dowson & Landau, 1982) : Lemma B.2. If ν = N (m, Σ) and ν 0 = N (m 0 , Σ 0 ), then W 2 2 (ν, ν 0 ) = ∥m -m 0 ∥ 2 + B 2 (Σ, Σ 0 ) It is well known (see Kantorovitch (1958) or (Villani, 2003, Theorem 1.3) or (Villani, 2008, Theorem 5.10 )) that the squared 2-Wasserstein distance has the following dual expression, also known as the Kantorovich duality: W 2 2 (ν 0 , ν θ ) = sup (f,g)∈L 1 (ν θ )×L 1 (ν0) f (x) dν θ (x) + g(x) dν 0 (x) | ∀(x, y), f (x) + g(y) ⩽ ∥x -y∥ 2 , ( ) where L 1 (ν) is the set of the integrable functions with respect to a measure ν. Therefore, the dual variables f and g are required to be integrable with respect to the source and target measures, and to fulfil the cost inequality. Remark B.3. In the context of GANs it is common to consider the 1-Wasserstein distance with cost given by the distance ∥x -y∥, which has a dual expression, referred to as the Kantorovich-Rubinstein formula (Villani, 2008, §6. 2) that allows for a more tractable computation in practice, with for instance only one dual variable. Nonetheless, there is no closed-form solution known when c(x, y) = ∥x -y∥.

B.2 FURTHER PROPERTIES OF THE BW LOSS

In this section, we provide further background on the Bures-Wasserstein distance. First, we show that, except in some particular cases (Lemma B.4), the Bures-Wasserstein distance between two covariance matrices is not translation invariant (Lemma B.5). Lemma B.4. In the case that Σ 0 and Σ commute, the Bures-Wasserstein distance reduces to the Hellinger distance: Σ 0 Σ = ΣΣ 0 =⇒ B 2 (Σ, Σ 0 ) = ∥Σ 1/2 -Σ 1/2 0 ∥ 2 F . Proof. It simply follows from the fact that, if Σ and Σ 0 commute, so do Σ 1/2 and Σ 1/2 0 , so that Σ 1/2 0 ΣΣ 1/2 0 = (Σ 1/2 0 Σ 1/2 ) 2 and tr ((Σ 1/2 ) 2 + (Σ 1/2 0 ) 2 -2(Σ 1/2 0 Σ 1/2 )) = tr ((Σ 1/2 -Σ 1/2 0 )(Σ 1/2 -Σ 1/2 0 ) ⊤ ) = ∥Σ 1/2 -Σ 1/2 0 ∥ 2 F as claimed. From this, one remarks that the problem of minimizing the BW distance between two covariance matrices that commute does fall under the framework of the Eckart-Young-Mirsky theorem if the optimization variable is Σ 1/2 = (W W ⊤ ) 1/2 , as it reduces to a problem cast with the square loss. Nonetheless, in the case where Σ and Σ 0 do not commute, we do not have such a correspondence, as in general, the BW distance is not translation invariant, neither when considered as a function of (Σ, Σ 0 ) nor when considered as a function of (Σ 1/2 , Σ 1/2 0 ). Lemma B.5 (BW is not translation invariant). There exist positive semidefinite matrices (Σ, Σ 0 ) ∈ S + (n)×S + (n) and a translation T ∈ S + (n), such that B 2 (Σ+T, Σ 0 +T ) ̸ = B 2 (Σ, Σ 0 ). The same statement also holds for the loss L defined on the matrix square roots, L(Σ 1/2 , Σ 1/2 0 ) = B 2 (Σ, Σ 0 ). Proof. For the first part of the statement, taking Σ = 1 0 0 1 , Σ 0 = 1 0 0 2 , T = t 0 0 t , t > 0, then B 2 (Σ + T, Σ 0 + T ) -B 2 (Σ, Σ 0 ) = ( √ 2 + t - √ 1 + t) 2 -( √ 2 -1) 2 , which is non-zero. For the second part of the statement, if Σ 1/2 0 = 1 0 0 2 , Σ 1/2 = 1 1 1 2 , T = 1 0 0 1 , one computes L(Σ 1/2 , Σ 1/2 0 ) =∥Σ 1/2 ∥ 2 F + ∥Σ 1/2 0 ∥ 2 F -2 tr (Σ 1/2 0 ΣΣ 1/2 0 ) 1/2 =12 -2 tr 2 6 6 20 1/2 and L(Σ 1/2 + T, Σ 1/2 0 + T ) =∥Σ 1/2 + T ∥ 2 F + ∥Σ 1/2 0 + T ∥ 2 F -2 tr ((Σ 1/2 0 + T )(Σ 1/2 + T )(Σ 1/2 + T )(Σ 1/2 0 + T )) 1/2 =28 -2 tr 20 30 30 90 1/2 , which gives the difference L(Σ 1/2 + T, Σ 1/2 0 + T ) -L(Σ 1/2 , Σ 1/2 0 ) ≈ 0.121229 ̸ = 0. Lemma B.5 therefore implies that, in the general case, one cannot express the Bures-Wasserstein distance (either on the covariance or on their square roots) as a norm of a difference (otherwise, the loss would be translation invariant). This hinders a direct application of the Eckart-Young-Mirsky theorem, where the problem is cast as min X ∥A -X∥ with a fixed A for some unitary invariant norm. Even if this is not possible in general, a similar expression exists, relying on the following variational formulation. Lemma B.6. The Bures-Wasserstein distance between two covariance matrices Σ 0 and Σ on S n + coincides with the variational formulation (7), min U ∈U (n) ∥Σ 1/2 -U Σ 1/2 0 ∥ 2 F = tr Σ 0 + Σ -2(Σ 1/2 0 ΣΣ 1/2 0 ) . Proof. We write min U ∈U (n) ∥Σ 1/2 -U Σ 1/2 0 ∥ 2 F = tr Σ 1/2 0 -U Σ 1/2 0 ⊤ Σ 1/2 0 -U Σ 1/2 0 . Let Σ 1/2 Σ 1/2 0 = V 1 R 1/2 V 2 be the singular value decomposition for V 1 , V 2 unitary, and R = Σ 1/2 Σ 0 ⊤ Σ 1/2 Σ 1/2 0 = Σ 1/2 0 ΣΣ 1/2 0 . Therefore tr U ⊤ Σ 1/2 Σ 1/2 0 = tr V 1 U ⊤ R 1/2 is max- imized when V 1 U ⊤ V 2 is the identity. Thus, we get ( 14) and the proof is complete. Lemma B.7. For any symmetric matrix C ∈ S n , for any matrices (A, B) ∈ R n×n 2 , one has tr (CAB ⊤ ) = tr (CBA ⊤ ) . Proof. tr (CAB ⊤ ) = tr (C ⊤ AB ⊤ ) = tr (BA ⊤ C) = tr (CBA ⊤ ) .

B.3 GRADIENT OF THE BURES-WASSERSTEIN LOSS

We give here the gradient of the squared-Bures-Wasserstein distance between two full-rank covariance matrices. Lemma B.8 (Gradient of B 2 for full-rank matrices). Suppose Σ, Σ 0 ∈ S ++ (n). Then the gradient of B 2 is given by ∇ Σ B 2 (Σ, Σ 0 ) = I -Σ 1/2 0 (Σ 1/2 0 ΣΣ 1/2 0 ) -1/2 Σ 1/2 0 . ( ) The proof of this Lemma is given in Appendix B.4. The right hand side of ( 16) is the optimal transport plan between two centered Gaussian distributions (Bhatia et al. 2019; Muzellec & Cuturi 2018, eq. 7) , whose Fréchet differentiability has been explored by Kroshnin et al. (2021, Lemma A.2) . This is a formulation that we use in the computation of upper bounds for the Hessian in Appendix E.1.

B.4 PROOFS OF SECTION 3

In this section, we provide the proofs of Section 3. Proof of Lemma B.8. Recall the BW distance is given by B 2 (Σ, Σ 0 ) = tr Σ + tr Σ 0 - 2 tr Σ 1/2 0 ΣΣ 1/2 0 . The gradient of the BW is given by ∇ Σ B 2 (Σ, Σ 0 ) = I -2∇ Σ tr Σ 1/2 0 ΣΣ 1/2 0 1/2 (17) Since Σ, Σ 0 are positive definite, we differentiate the f (Σ, Σ 0 ) = tr Σ 1/2 0 ΣΣ 1/2 0 1/2 with respect to Σ ∇ Σ f (Σ, Σ 0 ) = ∂ Σ Σ 1/2 0 ΣΣ 1/2 0 1/2 ⊤ I = Σ 1/2 0 ΣΣ 1/2 0 1/2 ⊗ I + I ⊗ Σ 1/2 0 ΣΣ 1/2 0 1/2 -1 ∂ Σ Σ 1/2 0 ΣΣ 1/2 0 ⊤ I = Σ 1/2 0 ⊗ Σ 1/2 0 Σ 1/2 0 ΣΣ 1/2 0 1/2 ⊗ I + I ⊗ Σ 1/2 0 ΣΣ 1/2 0 1/2 -1 I = 1 2 Σ 1/2 0 ⊗ Σ 1/2 0 Σ 1/2 0 ΣΣ 1/2 0 -1/2 = 1 2 Σ 1/2 0 Σ 1/2 0 ΣΣ 1/2 0 -1/2 Σ 1/2 0 . (18) Substituting the above expression to (17) we get ∇ Σ B 2 (Σ, Σ 0 ) = I -Σ 1/2 0 Σ 1/2 0 ΣΣ 1/2 0 -1/2 Σ 1/2 0 . ( ) Proof of Lemma 3.3. First note that the difference between the original loss and the perturbative loss is given by | L(Σ τ ) -L 1 (W )| = τ n -2 tr Σ 1/2 0 Σ τ Σ 1/2 0 1/2 -Σ 1/2 0 W W ⊤ Σ 1/2 0 1/2 ≤ τ n + 2 tr Σ 1/2 0 Σ τ Σ 1/2 0 1/2 -Σ 1/2 0 W W ⊤ Σ 1/2 0 1/2 (20) Let the singular value decomposition of (Σ 1/2 0 Σ τ Σ 1/2 0 ) 1/2 = QΛ 1/2 τ Q ⊤ , where Λ τ =    λ 1 + τ . . . λ r + τ    , λ 1 > λ 2 > . . . > λ r , and r = rank(W W ⊤ ). Similarly, we get the singular value decomposition of (Σ 1/1 0 W W ⊤ Σ 1/2 0 ) 1/2 = U Λ 1/2 V ⊤ , where Λ =    λ 1 . . . λ r    . We know that tr ((Σ 1/2 0 Σ τ Σ 1/2 0 ) 1/2 ) = tr (Λ 1/2 τ ) = r i=1 (λ i + τ ) 1/2 since the Frobenius norm is unitary invariant. Likewise we get tr ((Σ 1/2 0 W W ⊤ Σ 1/2 0 ) 1/2 ) = tr (Λ 1/2 ) = r i=1 λ 1/2 i . Next observe that the eigenvalues are distinct and in descending order. This means that we can upper bound the eigenvalues as, r i=1 (λ i + τ ) 1/2 ≤ r i=1 λ 1/2 i + rτ 1/2 . Therefore, we get back to Lemma 3.3 and get that | L(Σ τ ) -L 1 (W )| ≤ τ n + 2rτ 1/2 .

C GENERAL RESULTS FOR LINEAR NETWORKS

This section deals with general properties of linear networks and their convergence in parameter space. We first recall well-known results that hold for any differentiable loss L 1 and its parametrization Bah et al. 2021 , Lemma 2.1). For any differentiable loss L 1 , and parametrization L N = L 1 • µ. Lemma C.1 (Gradient flow, L N = L 1 • µ, such that µ(W 1 , . . . , W N ) = W N • • • W 1 , one has 1. For all j ∈ [N ], ∇ Wj L N (W 1 , . . . , W N ) = W ⊤ j+1 • • • W ⊤ N ∇L 1 (W )W ⊤ 1 • • • W ⊤ j-1 . 2. Assume each of the W i (t) satisfies the flow (GF). Then, the product W N :1 = W N • • • W 1 satisfies dW (t) dt = - N j=1 W N • • • W j+1 W ⊤ j+1 • • • W ⊤ N ∇L 1 (W )W ⊤ 1 • • • W ⊤ j-1 W j-1 • • • W 1 . 3. For all j ∈ [N ], and all t ⩾ 0, we have that d dt (W ⊤ j+1 (t)W j+1 (t)) = d dt (W j (t)W ⊤ j (t)) (23) 4. If W 1 (0), . . . , W N (0) are balanced, then for all t ≥ 0, W ⊤ j+1 (t)W j+1 (t) = W j (t)W ⊤ j (t), R(t) := dW (t) dt + N j=1 (W (t)W ⊤ (t)) N -j N ∇L 1 (W )(W ⊤ (t)W (t)) j-1 N = 0. ( ) The BW loss satisfies the Łojasiewicz inequality. Indeed, the following equality can be computed. Lemma C.2. For any W ∈ R d N ×d0 , and for the loss L 1 defined in (4), we have ∥∇ W L 1 (W )∥ 2 F = 4L 1 (W ). Proof. This equality can be obtained directly by computation. Since ∇L 1 (W ) = 2W -2Σ 1/2 0 (Σ 1/2 0 W W ⊤ Σ 1/2 0 ) -1/2 Σ 1/2 0 W, we have ∥∇ W L 1 (W )∥ 2 F = 4 tr W -Σ 1/2 0 (Σ 1/2 0 W W ⊤ Σ 1/2 0 ) -1/2 Σ 1/2 0 W W ⊤ -W ⊤ Σ 1/2 0 (Σ 1/2 0 W W ⊤ Σ 1/2 0 ) -1/2 Σ 1/2 0 = 4 tr(W W ⊤ ) -4 tr W W ⊤ Σ 1/2 0 (Σ 1/2 0 W W ⊤ Σ 1/2 0 ) -1/2 Σ 1/2 0 -4 tr Σ 1/2 0 (Σ 1/2 0 W W ⊤ Σ 1/2 0 ) -1/2 Σ 1/2 0 W W ⊤ + 4 tr(Σ 0 ). (27) Note that the middle two terms above are the same, and they can be further simplified as tr W W ⊤ Σ 1/2 0 (Σ 1/2 0 W W ⊤ Σ 1/2 0 ) -1/2 Σ 1/2 0 = tr Σ 1/2 0 (Σ 1/2 0 W W ⊤ Σ 1/2 0 ) -1/2 Σ 1/2 0 W W ⊤ = tr (Σ 1/2 0 W W ⊤ Σ 1/2 0 ) 1/2 . (28) Combining all the terms together, we get the equality (25). In the case of a general, twice differentiable loss L 1 and the parametrization L N = L 1 • µ, one can express the second-order differential structures of the loss. Lemma C.3 (Second-order differential). Let ( - → U , - → V ) ∈ Θ × Θ be two parameters, - → U = (U 1 , . . . , U N ), - → V = (V 1 , . . . , V N ). The second-order differential of the loss L N at -→ W = (W 1 , . . . , W N ) ∈ Θ is d 2 L N ( -→ W )[ - → U , - → V ] = N i=1 j̸ =i ⟨U i , W ⊤ i+1 • • • V ⊤ j • • • W ⊤ N ∇L 1 (W )W ⊤ 1 • • • W ⊤ i-1 ⟩ + N i=1 N j=1 vec(U i ) ⊤ W i-1:1 ⊗ (W N :i+1 ) ⊤ • ∇ 2 L 1 (W ) • (W j-1:1 ) ⊤ ⊗ (W N :j+1 ) vec(V j ), ( ) where ∇ 2 L 1 (W ) ∈ R n 2 ×n 2 is the matrix such that, ∀(U, V ) ∈ (R n×n ) 2 , d 2 L 1 (W )[U, V ] = vec(U ) ⊤ ∇ 2 L 1 (W ) vec(V ). Corollary C.4 (Hessian of the Loss). The Hessian of L N , ∇ 2 L N (θ), can be represented as a d 2 θ ×d 2 θ matrix. It is a block matrix, the blocks corresponding to different layers. Each block ∇ 2 Wi,Wj L N ( -→ W ) has dimension d i d i-1 × d j d j-1 , and corresponds to the differential d 2 L N ( -→ W )[ - → U i , - → U j ], where -→ U i = (0, . . . , 0, U i , 0, . . . , 0). The block diagonals elements are ∇ 2 Wi L N ( -→ W ) = (W i-1:1 ⊗ (W N :i+1 ) ⊤ ) • ∇ 2 L 1 (W ) • (W i-1:1 ) ⊤ ⊗ (W N :i+1 ), the off-diagonal terms are ∇ 2 Wi,Wj L N ( -→ W ) = (W i-1:1 ⊗ (W N :i+1 ) ⊤ ) • ∇ 2 L 1 (W ) • ((W j-1:1 ) ⊤ ⊗ W N :j+1 ) + (W i-1 • • • W 1 ∇L 1 (W ) ⊤ W N • • • W j+1 ) ⊗ (W ⊤ i+1 . . . W ⊤ j-1 ) K dj dj-1 , ( ) where K pq is the pq-commutation matrix (for X ∈ R p×q , K pq vec X = vec X ⊤ ). The invariance property on the gradient flow (GF) (Lemma C.1.3) is key in numerous analyses. Another useful property of the gradient flow (GF) is its convergence, under mild assumption on the loss L 1 , to a critical point of L N . Namely, if the trajectory t → -→ W (t) remains bounded for all t ⩾ 0, and if L 1 is an analytic function (i.e. locally given by a power series), then (GF) converges to a critical point of L N , i.e., a point θ * so that ∇L N (θ * ) = 0. This is stated in the next theorem. Theorem C.5 (Gradient flow converges to a critical point of L N ). Let L 1 be analytic, such that the trajectory t → µ(θ(t)) remains bounded under the gradient flow evolution θ = -∇ L 1 • µ (θ). Then, the flows of W i (t) given by (GF) and of W (t) given by ( 22) are defined and bounded for all t ⩾ 0 and (W 1 , . . . , W N ) converges to a critical point of L N = L 1 • µ as t → ∞. This result relies on the Łojasiewicz' argument for the convergence of gradient flows (Absil et al., 2005) . Bah et al. (2021) show how to bound each of the different {W i } N i=1 once the end-to-end product W N :1 is bounded. The boundedness of ∥W ∥ can be showed depending on the loss that is considered. For example, it holds for the regularized loss L 1 τ . In Appendix C we collect further general results for linear networks. In the case of the perturbative loss introduced in (5), on can bound the norm of W throughout training. Since the loss L 1 τ is analytic, one immediately gets the following result. We give a simple test to show the boundedness of a trajectory under (GF). This is allowed by the decrease of the loss along training. Lemma C.6. Assume that, for a given loss L 1 , there exists there exists an increasing function f such that, for any t ⩾ 0, ∥W (t)∥ ⩽ f (L 1 (W (t))). Then, the trajectory t → W (t) under (GF) is bounded. Proof. Under gradient flow, for any t ⩾ 0, L 1 (W (t)) ⩽ L 1 (W (0)). Indeed, writing the chain rule and the gradient flow ( 22), d dt L 1 (W (t)) = j D Wj L N (W 1 (t), . . . , W N (t)) dWj (t) dt = - j ∥∇ Wj L N (W 1 , . . . , W N )∥ 2 F ⩽ 0. Therefore, for any t ⩾ 0, L 1 (W (t)) ⩽ L 1 (W (0)). Now, let f : R → R be an increasing function, so that f (L 1 (W (t))) ⩽ f (L 1 (W (0))). Therefore, if for any t ⩾ 0, ∥W (t)∥ ⩽ f (L 1 (W (t))), then ∥W (t)∥ ⩽ f (L 1 (W (t))) ⩽ f (L 1 (W (0))) is bounded. The assumption of Lemma C.6 is satisfied for a couple of losses, including the square loss (Bah et al., 2021) and the L 1 τ loss, as shown in Lemma C.8. It allows to consider losses that "grow with the weights", so that the end-to-end matrix is bounded when the loss converges to zero. We now show the boundedness of the weights when considering the Bures-Wasserstein loss (4). Lemma C.7 (Boundedness for the BW loss L). The loss L(Σ) can be lower-bounded by the quantity 1 2 tr Σ -tr Σ 0 . Proof. By definition of dual expression of the Wasserstein distance (13), L(Σ) = W 2 2 (ν 0 , ν θ ) = sup f ∈L 1 (ν θ ) f (x) dν θ + f ∥•∥ 2 (y) dν 0 (y), with ν θ = N (0, Σ), ν 0 = N (0, Σ 0 ) and f ∥•∥ 2 the ∥•∥ 2 -transform of f defined as ∀y ∈ R d , f ∥•∥ 2 (y) = inf x∈R d ∥x -y∥ 2 -f (x). With f : x → 1 2 ∥x∥ 2 , the ∥•∥ 2 -transform of f is f ∥•∥ 2 : y → -∥y∥ 2 , and we get L(Σ) = W 2 2 (ν 0 , ν θ ) ⩾ 1 2 ∥x∥ 2 dν θ (x) -∥y∥ 2 dν 0 (y) = 1 2 tr Σ -tr Σ 0 . ( ) as claimed. Lemma C.8 (Boundedness for the loss L 1 τ ). The norm of the end-to-end matrix W is upper-bounded when using the loss L 1 τ defined in (5). Proof. With φ τ (Σ) = Σ + τ I n =: Σ τ , the loss L 1 τ satisfies L 1 τ (W ) = L(φ τ (π(W ))) (32) ⩾ 1 2 tr Σ τ -tr Σ 0 = 1 2 tr W W ⊤ -tr Σ 0 + n 2 τ (33) =⇒ 2L 1 τ (W ) + 2 tr Σ 0 -nτ ⩾ ∥W ∥. ( ) Therefore, there exists an increasing function f such that ∥W ∥ ⩽ f (L 1 τ (W )). Since the loss decreases under gradient flow, one has ∥W (t)∥ ⩽ 2L 1 τ (W (0)) + 2 tr Σ 0 -nτ , and the boundedness of t → W (t) is shown. Corollary C.9. For the Bures-Wasserstein loss, if W W ⊤ is positive definite, and loss is differentiable, and the norm of the weight throughout the training is uniformly bounded ∥W ∥ ≤ 2L 1 (W (0)) + 2 tr Σ 0 , by using similar arguments as in Lemma C.8. Lemma C.10. The gradient flow (GF) on the perturbative loss (5) converges to a critical point θ * of L N τ in the limit. This property of the gradient flow is necessary in order to prove the convergence of the training to a minimizer of L 1 τ . At first glance, the critical points of L N τ do not correspond in general to critical points of L 1 τ since the parametrization µ also comes into play. This led Trager et al. (2020) to distinguish between the pure and spurious critical points; i.e., the points that are shared between L N and L 1 , and those that are exclusive to L N .

D PROOFS OF SECTION 4

In this section, we provide the proofs for the critical points of L 1 | M k and L 1 τ | M k . D.1 CRITICAL POINTS OF L 1 | M k First, the loss L 1 is expressed on the manifolds M k (Lemma D.2), where it is differentiable (Lemma D.3). Then, necessary conditions (Lemma D.5) on the critical points can be expressed, leading to the proof of Theorem 4.2. The second part of Theorem 4.2 is then proven by evaluating the loss at the critical points found, and ranking them. Recall Definition 4.1. Computing the differential of the restriction L 1 | M k will allow to characterize the different critical points. Definition D.1 (Gradient). Given an embedded manifold M and a function with a smooth restriction f | M , the gradient of f | M k at x ∈ M is the (unique) element of the tangent space T x M such that, for all v ∈ T x M, df | M (x)[v] = ⟨∇f | M (x), v⟩. We begin by expressing the loss L 1 | M k with the Singular Value Decomposition (SVD) of Σ 1/2 0 W . Lemma D.2. Let U SV ⊤ = Σ 1/2 0 W be a thin SVD of Σ 1/2 0 W , so that U ∈ R n×k , V ∈ R m×k , U ⊤ U = V ⊤ V = I k , S = diag(s 1 , . . . , s k ) ∈ R k×k , where k = rank Σ 1/2 0 W = rank W . The loss L 1 from (4) on M k can be expressed as L 1 | M k (W ) = ∥W ∥ 2 F + ∥Σ 1/2 0 ∥ 2 F -2 tr S. ( ) Proof. If U SV ⊤ = Σ 1/2 0 W is a thin SVD of Σ 1/2 0 W , then (Σ 1/2 0 W (Σ 1/2 0 W ) ⊤ ) 1/2 = U SU ⊤ . Therefore, the expression of the loss L 1 given by (4) can be written as L 1 | M k (W ) = tr(W W ⊤ ) + tr Σ 0 -2 tr (U SU ⊤ ) = ∥W ∥ 2 F + ∥Σ 1/2 0 ∥ 2 F -2 tr S as claimed. We then give the gradient of L 1 | M k . Lemma D.3 (Gradient of L 1 | M k ). Let (n, m) ∈ N 2 * , and let k ⩽ min {n, m}. The loss L 1 | M k (as given in (37)) is twice continuously differentiable on M k . With W ∈ M k and U SV ⊤ = Σ 1/2 0 W a thin SVD of Σ 1/2 0 W , its gradient is ∇L 1 | M k (W ) = 2W -2Σ 1/2 0 U V ⊤ . ( ) In order to derive this expression, the differential of the singular values is required. But first, a note on the differential notation used throughout the derivations. Notation (Differential). The differential of a function f can be written using different formalisms. Explicitly, df (X)[H] is the differential of f at X in the direction H. Sometimes, with Y = f (X), the shorthand notation dY is preferred, where the same symbol is used for both the variable and the function. In this case, it is assumed that the direction H is a small perturbation dX around X. For instance, if Y = f (X) = XX ⊤ , then dY = dXX ⊤ + X dX ⊤ , which would be written df (X)[H] = HX ⊤ + XH ⊤ with the full notation. Lemma D.4 (Differential of the SVD). Let k ⩽ min {n, m} and let X ∈ M k be a matrix with rank X = k. Let U SV ⊤ = X be a thin SVD of X, with U ∈ R n×k , S ∈ R k×k , V ∈ R m×k , S diagonal and U ⊤ U = V ⊤ V = I k . Then, the differential dS is dS = I k ⊙ (U ⊤ dXV ), where A ⊙ B denotes the Hadamard product between A and B. Proof. Let U SV ⊤ = X be the decomposition as given in the lemma statement. The differential rules ensure that dX = dU SV ⊤ + U dSV ⊤ + U S dV ⊤ . This implies that U ⊤ dXV = U ⊤ dU SV ⊤ V + U ⊤ U dSV ⊤ V + U ⊤ U S dV ⊤ V = U ⊤ dU S + dS + S dV ⊤ V =⇒ dS = U ⊤ dXV -U ⊤ dU S -S dV ⊤ V. Since U ⊤ U = I k , dU ⊤ U + U ⊤ dU = 0, and A := U ⊤ dU = -dU ⊤ U = -A ⊤ . Likewise, B := V ⊤ dV is also antisymmetric. The matrices A and B being antisymmetric, their diagonals are null; hence so are the diagonals of AS and SB, i.e. I k ⊙ (AS) = I k ⊙ (SB) = 0. Since S is constrained to be diagonal, dS must also be diagonal, i.e. I k ⊙ dS = dS. Therefore, dS = I k ⊙ (U ⊤ dXV ) as was claimed. Proof of lemma D.3. For W ∈ M k , let U SV ⊤ = Σ 1/2 0 W be a thin SVD of Σ 1/2 0 W =: X. Lemma D.2 ensures that L 1 | M k (W ) = ∥W ∥ 2 F + ∥Σ 1/2 0 ∥ 2 F -2 tr S. According to Lemma D.4, the matrix S is differentiable and has differential dS = I k ⊙ (U ⊤ dXV ). Therefore, the loss L 1 | M k is differentiable. With the fact that d tr S = tr dS (see e.g. (Magnus & Neudecker, 2019, Chap. 8, Eq. 18 )), we can compute d tr S = tr dS = tr I k ⊙ (U ⊤ dXV ) = tr U ⊤ dXV = ⟨U V ⊤ , dX⟩ = ⟨U V ⊤ , Σ 1/2 0 dW ⟩ = ⟨Σ 1/2 0 U V ⊤ , dW ⟩. Moreover, d∥W ∥ 2 F = 2⟨W , dW ⟩, and so dL 1 | M k (W ) = d∥W ∥ 2 F -2 dtr S = 2⟨W -Σ 1/2 0 U V ⊤ , dW ⟩, ∇L 1 | M k (W ) = 2 (W -Σ 1/2 0 U V ⊤ ) . Since matrices (U, V ) are continuously differentiable on M k , ∇L 1 | M k (W ) = 2(W -Σ 1/2 0 U V ⊤ ) is again continuously differentiable, and L 1 | M k is twice continuously differentiable. We are now ready to give the proof of Theorem 4.2. We divide the proof into necessary and sufficient conditions for a point to be a critical point of L 1 | M k . Lemma D.5 (Necessary condition on the critical points of L 1 W | M k ). Assume Σ 0 has n distinct eigenvalues. Let W * ∈ M k be a critical point of L 1 | M k . Then, with U * S * V * ⊤ = Σ 1/2 0 W * a thin SVD of Σ 1/2 0 W * , and ΩΛΩ ⊤ = Σ 0 an spectral decomposition of Σ 0 (i.e. with Ω ∈ O(n)), there exists J k ⊆ {1, . . . , n}, such that S * = ΛJ k and U * = Ω J k . Proof. Since W * ∈ M k , and U * S * V * ⊤ = Σ 1/2 0 W * is a thin SVD of Σ 1/2 0 W * , this means that S * ∈ R k×k . Then, ∇L 1 (W * ) = 0 =⇒ W * = Σ 1/2 0 U * V * ⊤ , by =⇒ Σ 1/2 0 W * = Σ 0 U * V * ⊤ =⇒ U * S * V * ⊤ = Σ 0 U * V * ⊤ =⇒ S * = U * ⊤ Σ 0 U * , U * ⊤ U * = I k , V * ⊤ V * = I k . Therefore, U * ⊤ Σ 0 U * must be diagonal; and since U * is semi-orthogonal, this is the case if and only if the vectors in U * are eigenvectors for Σ 0 , by uniqueness of the spectral decomposition of Σ 0 . Therefore, there exist j 1 , . . . , j k indices between 1 and n such that U * = ω j1 • • • ω j k = Ω J k , in which case S * = Ω J k ⊤ Σ 0 Ω J k =    λ j1 . . . λ j k    = ΛJ k . Now we are ready to prove the first part of Theorem 4.2. Proof of Theorem 4.2, first part. Consider the expression for the gradient of L 1 | M k given in (38). The necessary condition follows from Lemma D.5, since ∇L 1 | M k (W * ) = 0 =⇒ Σ 1/2 0 W * = Ω J k ΛJ k V ⊤ =⇒ W * = Σ -1/2 0 Ω J k ΛJ k V ⊤ = ΩΛ -1/2 Ω ⊤ Ω J k ΛJ k V ⊤ = ΩΛ -1/2 Λ J k V ⊤ = ΩΛ 1/2 P J k V ⊤ = ΩP J k Λ1/2 J k V ⊤ = Ω J k Λ1/2 J k V ⊤ , which corresponds to the necessary condition in Theorem 4.2. The sufficient condition can be verified as follows. With W * = Ω J k Λ1/2 J k V ⊤ , one has Σ 1/2 0 W * = ΩΛ 1/2 Ω ⊤ Ω J k Λ1/2 J k V ⊤ = Ω J k ΛJ k V ⊤ , and, as this is a correct thin SVD of Σ 1/2 0 W * , Lemma D.3 gives ∇L 1 | M k (W * ) = 2(W * -Σ 1/2 0 Ω J k V ⊤ ). Further, Σ 1/2 0 Ω J k = ΩΛ 1/2 Ω ⊤ Ω J k = ΩΛ 1/2 P J k = ΩP J k Λ1/2 J k = Ω J k Λ1/2 J k . Hence ∇L 1 | M k (W * ) = 2(W * -Σ 1/2 0 Ω J k V ⊤ ) = 2(Ω J k Λ1/2 J k V ⊤ -Ω J k Λ1/2 J k V ⊤ ) = 0, and the sufficient condition is verified. Now, the loss can be evaluated at the critical points in order to statute on its minimizers. Corollary D.6 (Value of L 1 at the critical points). The value of the loss L 1 at a critical point W * = Ω J k Λ1/2 J k V ⊤ is L 1 (W * ) = tr Λ -tr ΛJ k = i / ∈J k λ i . Proof. For k ⩾ 0, let W * be a critical point of L 1 | M k . From Theorem 4.2, with Σ 0 = ΩΛΩ ⊤ a spectral decomposition of Σ 0 , there exists a set J k and a semi-orthogonal matrix V ∈ R n×k such that W * = Ω J k Λ1/2 J k V ⊤ . One can then compute the value of the loss at W * : L 1 (W * ) = tr W * W * ⊤ + tr Σ 0 -2 tr (Σ 1/2 0 W * )(Σ 1/2 0 W * ) ⊤ 1/2 = tr Ω J k ΛJ k Ω J k + tr Λ -2 tr Ω J k Λ2 J k Ω ⊤ J k 1/2 = tr ΛJ k + tr Λ -2 tr ΛJ k = tr Λ -tr ΛJ k . We now have all we need to prove the second part of Theorem 4.2. Proof of Theorem 4.2, second part. The first part of the statement is readily implied by Corollary D.6, as the eigenvalues are in decreasing order. The second part is implied by the fact that the minimum L 1 | M k is indeed achieved for any k ⩽ n (by selecting the k largest eigenvalues of Σ 0 ) and the optimal value of the loss L * k is smaller when considering more eigenvalues, i.e. min M k L 1 ⩽ min M <k L 1 . Moreover, it can be shown that only one point per set M k is a minimizer of the loss L 1 | M k ; all other points are (strict) saddle points. We recall the definition of a strict saddle point: a point where there exist a descent direction. Definition D.7 (Strict saddle point). A critical point x of a function f is said to be a strict saddle point if the Hessian of f at x has a strict negative eigenvalue. If all critical points of f are either a strict saddle point or the global minimizer, the we say that f satisfies the strict saddle point property. If the gradient flow can be expressed on a manifold, with a Riemannian gradient corresponding to a given metric, there is an equivalent definition of those saddle points, which will be handy to use. See (Bah et al., 2021, §6.1) for the details. Proposition D.8. The loss L 1 | M k satisfies the strict saddle point property. Proof. Let Σ 0 = U ΛU ⊤ be the spectral decomposition of Σ 0 with decreasing eigenvalues. For k ∈ N, according to Theorem 4.2, W * is a critical point of L 1 | M k if and only if there exists J k ⊂ {1, . . . , n}, such that W * = U J k Λ 1/2 J k V ⊤ , with any V ∈ R m×k so that V ⊤ V = I k . If J k = {1, . . . , k}, W * is a global minimum of L 1 | M k , as shown in Corollary D.6, and the proposition holds. Assume J k ̸ = {1, . . . , k}, then there exists j 0 ∈ J k such that λ j0 < λ k , and there exists j 1 / ∈ J k but j 1 ∈ {1, 2, . . . , k} such that λ j1 > λ j0 . We will show that W * is a strict saddle point of L 1 | M k . The critical point W * can equivalently be expressed as W * = Σ -1/2 0 i∈J k λ i u i v ⊤ i , where u i , v i are corresponding orthogonal uni-vectors in U and V , and λ i are eigenvalues in Λ. The key is, by the following perturbation, for t ∈ (-1, 1), we define u j0 (t) = tu j1 + 1 -t 2 u j0 and the curve γ : (-1, 1) → M k . We look at the perturbative matrix γ(t) = Σ -1/2 0 λ j0 u j0 (t)v ⊤ j0 + i∈J \{j0} λ i u i v ⊤ i . Note that γ(0) = W . Recall L 1 (W ) = tr W W ⊤ + Σ 0 -2 Σ 1/2 0 W W ⊤ Σ 1/2 0 1/2 . It is enough to show that (Bah et al., 2021, §6 .1)): d 2 dt 2 L 1 (γ(t)) t=0 < 0. We check it term by term, tr γ(t)γ(t) ⊤ = tr Σ -1/2 0 λ j0 u j0 (t)v ⊤ j0 + i∈J \{j0} λ i u i v ⊤ i λ j0 u j0 (t)v ⊤ j0 + i∈J \{j0} λ i u i v ⊤ i ⊤ Σ -1/2 0 = tr Σ -1 0 λ 2 j0 u j0 (t)u j0 (t) ⊤ + i∈J \{j0} λ 2 i u i u ⊤ i = tr 1≤i≤n λ -1 i u i u ⊤ i λ 2 j0 u j0 (t)u j0 (t) ⊤ + i∈J \{j0} λ 2 i u i u ⊤ i = λ 2 j0 λ j1 t 2 + λ j0 (1 -t 2 ) + i∈J \{j0} λ 2 i , and tr Σ 1/2 0 γ(t)γ(t) ⊤ Σ 1/2 0 1/2 = tr λ j0 u j0 (t)v ⊤ j0 + i∈J \{j0} λ i u i v ⊤ i λ j0 u j0 (t)v ⊤ j0 + i∈J \{j0} λ i u i v ⊤ i ⊤ 1/2 = tr λ 2 j0 u j0 (t)u j0 (t) ⊤ + i∈J \{j0} λ 2 i u i u ⊤ i 1/2 = tr t 2 λ 2 j0 u j1 u ⊤ j1 + (1 -t 2 )λ 2 j0 u j0 u ⊤ j0 + i∈J \{j0} λ 2 i u i u ⊤ i 1/2 = t|λ j0 | + 1 -t 2 |λ j0 | + i∈J \{j0} |λ i |. Thus, since λ j1 > λ j0 , d 2 dt 2 L 1 (γ(t)) t=0 = 2(λ 2 j0 λ -1 j1 -λ j0 ) -|λ j0 | < 0. This completes the proof. The loss L 1 τ satisfies the strict-saddle point property in a similar fashion. Lemma D.9. The loss L 1 τ | M k satisfies the strict saddle point property. Proof of Lemma D.9. The proof of Proposition D.8 can be adapted, with the expression of the critical points as, if Σ 0 = ΩΛΩ ⊤ , and with V ∈ R n×k any semi-orthogonal matrix, W * = (Σ 0 -τ I n ) -1/2 n j=1 (λ i -τ )ω i v ⊤ i .

D.2 CRITICAL POINTS OF THE PERTURBATIVE LOSS

In this section, we provide the different derivations for Section 4.2. The structure is similar to Theorem 4.2; first the gradient of L 1 τ is computed, then the critical points are characterized and ranked. Lemma D.10 (Differential of L). The differential of L on S ++ (n) is ∀ Σ ∈ S ++ (n), X ∈ S(n), d L(Σ)[X] = tr (X -Σ 1/2 0 [Σ 1/2 0 ΣΣ 1/2 0 ] -1/2 Σ 1/2 0 X). Corollary D.11 (Gradient of L). The gradient of L on S ++ (n) is ∀ Σ ∈ S ++ (n), ∇ L(Σ) = I -Σ 1/2 0 [Σ 1/2 0 ΣΣ 1/2 0 ] -1/2 Σ 1/2 0 . Lemma D.12 (Gradient of L 1 τ ). The loss L 1 τ has the following gradient ∀ W ∈ R n×m , ∇L 1 τ (W ) = 2 W -Σ 1/2 0 Σ 1/2 0 (W W ⊤ + τ I n )Σ 1/2 0 -1/2 Σ 1/2 0 W . Proof. This results comes from the chain rule for the loss L 1 τ (W ) = L • φ τ • π(W ). With Σ = π(W ) = W W ⊤ and Σ τ = φ τ (Σ) = Σ + τ I n , and since dπ(W )[Z] = W Z ⊤ + ZW ⊤ and dφ τ (Σ) = id, one has dL 1 τ (W )[Z] = d( L • φ τ • π)(W )[Z] = d L(Σ τ ) dφ τ (Σ) dπ(W )[Z] = d L(Σ τ )[W Z ⊤ + ZW ⊤ ] ⟨∇L 1 τ (W ), Z⟩ = ⟨∇ L(Σ τ ), W Z ⊤ + ZW ⊤ ⟩ ⇐⇒ ∇L 1 τ (W ) = (∇ L(Σ τ ) + ∇ L(Σ τ ) ⊤ )W = 2(W -Σ 1/2 0 [Σ 1/2 0 Σ τ Σ 1/2 0 ] -1/2 Σ 1/2 0 W ). Proof of Theorem 4.4. The eigenvectors of W W ⊤ + τ are the same as W W τ , and the eigenvalues are shifted by τ . Therefore, the expression of the critical points in the original loss can be adapted, so that the modified critical points have the same left singular vectors and shifted singular values. This leads to having W * = Ω J k ( ΛJ k -τ I k ) 1/2 V ⊤ = Ω ( ΛJ k -τ I k ) 1/2 0 n-k×n-k V V ⊤ , with V ∈ R m×k such that V ⊤ V = I k . One checks that ∇L 1 τ (W * ) = 0. The value at such a critical point W * = Ω J k ( ΛJ k -τ I k ) 1/2 V ⊤ is L 1 τ (W * ) = j / ∈Jr λ j -2 τ λ j , which is uniquely minimized for J r = [k] when the eigenvalues of Σ 0 are distinct and in descending order.

D.3 PROOF OF PROPOSITION 4.5

We state here the proof on Proposition 4.5. We will transfer the results obtained on the space of linear maps M ⩽k to the space of covariance matrices S + (k, n). Borrowing the terminology from Levin et al. (2022) , we introduce the following notations and definitions. Let M be any smooth manifold, E a linear space, φ : M → E a smooth (over)parametrization (or lift) of the search space X = φ(M) ⊆ E. The following problems are considered min x∈X f (x) (P) min y∈M f • φ(y), where we assume that f : E → R is smooth, and hence so is g := f • φ. The following property is relevant for us. Definition D.13 (Levin et al. 2022, Definition 2.7 ). The lift φ : M → X satisfies the "1 ⇒ 1" property at y if for all differentiable f : X → R, if y is a critical point for (Q), then x = φ(y) is a critical point for (P). Recall that S + (k, n) = {Σ ∈ S(n) : Σ ≽ 0, rank(Σ) = k}, and let S + (⩽ k, n) = {Σ ∈ S(n) : Σ ≽ 0, rank(Σ) ⩽ k}. We will make use of the following result from Levin et al. (2022) . Proposition D.14 (Levin et al. 2022, Proposition 3.4) . Let k ⩽ n, and let φ : R n×k → S + (⩽ k, n) be the parametrization φ(R) = RR ⊤ . Then, φ satisfies the "1 ⇒ 1" property at R ∈ R n×k if and only rank R = k. Said differently, the condition at which a critical point on the space R n×k is such that its image through φ is also a critical point on S + (k, n) are exactly the points of full rank R n×k * . The image of the parametrization µ is R n×m ⩽k . Therefore, we need to adapt Proposition D.14 in order to work on R n×k ⩽k -which is not smooth -instead of R n×k . This is performed in the next proposition. Proposition D.15. Let k ⩽ min(n, m), and let π : R n×m ⩽k → S + (⩽ k, n) be the parametrization π(W ) = W W ⊤ . Then, π satisfies the "1 ⇒ 1" property at W ∈ R n×m ⩽k if rank W = k. Proof. Let φ : R n×k → S + (⩽ k, n), R → RR ⊤ and π : R n×m ⩽k → S + (⩽ k, n), W → W W ⊤ be the covariance parametrizations. Since the manifold R n×m ⩽k is not smooth, we will focus on the smooth manifold R n×m k . Therefore, let φ * : R n×k * → S + (k, n) and π * : R n×m k → S + (k, n) be the parametrization φ, π defined on matrices of rank exactly k. We know that φ * satisfies the "1 ⇒ 1" property, and want to show that π * also satisfies it. The idea of the proof is the pass through given quotient spaces on which the functions are equivalent. Let O k be the set of k × k orthogonal matrices, with the dimension omitted when inferred from context. Consider the equivalent relation on R n×m k (or R n×k * ) such that X 1 ∼ X 2 ⇐⇒ X 1 X ⊤ 1 = X 2 X ⊤ 2 . From (Massart & Absil, 2020, Proposition 2.1), we know that X 1 ∼ X 2 ⇐⇒ ∃Q ∈ O, X 1 = X 2 Q. Denote the equivalent class [X] = {XQ : Q ∈ O} = XO. Let p : R n×m k → R n×m k /O m , W → [W ] be the quotient map on R n×m k , and let Π : R n×m k /O m -→ S + (k, n) W O m -→ W W ⊤ be the map on the quotient space, so that π * = Π • p. Likewise, let q : R n×k * → R n×k * /O k , R → [R] be the quotient map on R n×k * , and let Φ : R n×k * /O k -→ S + (k, n) RO k -→ RR ⊤ be the map on the quotient space, so that φ * = Φ • q. The map Φ is an diffeomorphism (Massart & Absil, 2020, Proposition A.7) , and therefore satisfies the "1 ⇒ 1" property. For W ∈ R n×m k , we can find R ∈ R n×k * such that W W ⊤ = RR ⊤ . It is unique up to an orthogonal matrix. Therefore, let ι : R n×m k /O m -→ R n×k * /O k [W ] -→ [R] , be the identification map between the quotient spaces. With the next two lemma, we will be able to finish the proof of Proposition D.15. Lemma D.16. The map ι : R n×m k /O m → R n×k * /O k , [W ] → [R] is a diffeomorphism. Lemma D.17. The map ι • p is a submersion from R n×m k onto R n×k * /O k . So now, to conclude the proof of Proposition D.15, the map π * can be written π * = Π • p = Φ • ι • p. Since we know that Φ, being a diffeomorphism, satisfies the "1 ⇒ 1" property, and since ι • q is a submersion (Lemma D.17), by (Levin et al., 2022, Proposition 2.42 (b) ), the map Φ • ι • p = π * satisfies the "1 ⇒ 1" property. It remains to show Lemmas D.16 and D.17. Proof of Lemma D.16. From (Massart & Absil, 2020, Proposition A.7 ), the mapping Φ : R n×k * → S + (k, n), R → RR ⊤ is a diffeomorphism. Likewise, the mapping Π : R n×m k → S + (k, n) is also a diffeomorphism, since both π * and p are submersions. Then, since Φ([R]) = Π([W ]), we have [R] = Φ -1 • Π([W ]) =: ι([W ] ), and ι is a diffeomorphism. Proof of Lemma D.17. The map p is a submersion (Massart & Absil, 2020, Proposition A.5 ) and the map ι is a diffeomorphism (Lemma D.16), hence a submersion. Therefore, the composition ι • p is a submersion. We are now ready to proof Proposition 4.5. Proof of Proposition 4.5. From (Trager et al., 2020, Proposition 5) , we know that a critical point in the parameter space -→ W with rank rank( -→ W ) = k will be a critical point for L 1 τ | M ⩽k . Now, from Proposition D.15, a critical point W * for L 1 τ | M ⩽k with rank W * = k is such that π(W * ) is a critical point for L| S+(k,n) , and the first part of the proposition is proved. For the second part, assume that k = d = min i {d i }. Then, according to (Trager et al., 2020  if Σ 0 = ΩΛΩ ⊤ is a spectral decomposition of Σ 0 , we have W * = Ω [d] (Λ -τ I d ) 1/2 V ⊤ [d] , so that Σ * τ = W * W * ⊤ + τ I n = Ω Λ [d] τ Ω ⊤ is also a minimizer of Lτ | S+(d,n) . E PROOFS OF SECTION 5

E.1 BOUNDS ON THE HESSIAN

In this section, we provide bounds on the Hessian of the perturbative loss L 1 τ . We first express the loss as a function of the covariance matrix, in which case the Hessian is known (Kroshnin et al., 2021) . Then, a simple chain rule for the differential allows to express the Hessian in the case the loss is a function of the end-to-end matrix W . Lemma E.1 (Second-order differential of Lτ , Kroshnin et al. 2021, Lemma A.2) . Let W ∈ R n×m and let τ > 0. Define Σ τ = W W ⊤ + τ I n to be the regularized covariance matrix. Given that Σ τ ≻ 0, the loss 10 is twice continuously differentiable for any W . Let ΓQΓ ⊤ = Σ 1/2 0 Σ τ Σ 1/2 0 be a spectral decomposition of Σ 1/2 0 Σ τ Σ 1/2 0 , with Q = diag (q 1 , . . . , q n ). For Y ∈ R n×n , define ∆(Y ) ∈ R n×n to be the matrix with element ∆(Y ) ij = (Γ ⊤ Y Γ) ij √ qi+ √ qj . Let H be the linear operator defined as H(Y ) = Σ 1/2 0 ΓQ -1/2 ∆(Y )Q -1/2 Γ ⊤ Σ 1/2 0 . Then, the second order differential of Lτ is given by ∀(X, Y ) ∈ (R n×n ) 2 , d 2 Lτ (Σ τ )[X, Y ] = ⟨X, H(Y )⟩. Proof. Applying the formula (49) to L 1 τ = Lτ • π gives, with Σ = π(W ) and d 2 π(W )[U, V ] = U V ⊤ + V U ⊤ , d 2 L 1 τ (W )[U, V ] = d 2 Lτ (Σ)[dπ(W )[U ], dπ(W )[V ]] + d Lτ (Σ)[d 2 π(W )[U, V ]] = ⟨U W ⊤ + W U ⊤ , H(V W ⊤ + W V ⊤ )⟩ + tr (U V ⊤ + V U ⊤ ) -tr Σ 1/2 0 (Σ 1/2 0 Σ τ Σ 1/2 0 ) -1 Σ 1/2 0 (U V ⊤ + V U ⊤ ) = 2⟨U, H(V W ⊤ + W V ⊤ )W + V -Σ 1/2 0 (Σ 1/2 0 Σ τ Σ 1/2 0 ) -1 Σ 1/2 0 V ⟩ = ⟨U, H(V )⟩. where we used the symmetry of Σ 1/2 0 (Σ 1/2 0 Σ τ Σ 1/2 0 ) -1/2 Σ 0 to simplify the expression. The maximal eigenvalue of H will then be computed as λ max (H) = sup U : ∥U ∥ F =1 ⟨U, H(U )⟩ in Lemma E.6. E.2 LIPSCHITZ-SMOOTHNESS OF L 1 τ One can use the bounds of Kroshnin et al. (2021, Lemma A.3) to bound the Hessian of the loss. Lemma E.4 (Bounds on the second-order differential, Kroshnin et al. 2021, Lemma A.3) . Let H(X) be defined as in (47). The second-order differential of Lτ respects the following bounds ⟨X, H(X)⟩ ⩽ λ 1/2 max (Σ 1/2 0 Σ τ Σ 1/2 0 ) 2 ∥Σ -1/2 τ XΣ -1/2 τ ∥ 2 F , (52a) ⟨X, H(X)⟩ ⩾ λ 1/2 min (Σ 1/2 0 Σ τ Σ 1/2 0 ) 2 ∥Σ -1/2 τ XΣ -1/2 τ ∥ 2 F . Those in turn bound the extremal eigenvalues of the Hessian, defined as λ max ( H) = sup X̸ =0 ⟨X, H(X)⟩ ∥X∥ F and λ min ( H) = inf X̸ =0 ⟨X, H(X)⟩ ∥X∥ F . Lemma E.5 (Bounds on the Hessian H). Let H be defined as in (47). Then, the extremal eigenvalues of H are bounded as λ max ( H) ⩽ C 0 λ max (Σ 0 ) 2τ 2 , λ min ( H) ⩾ τ λ min (Σ 0 ) 2C 2 0 , where C 0 = 2( L(Σ τ (0)) + tr(Σ 0 )) is initialization-dependent. In particular, the loss Lτ is strongly convex, with parameter K = √ τ λmin(Σ0) 2C 2 0 . Proof. We first provide the proof for the maximal eigenvalue. The maximal eigenvalue of the Hessian is defined as λ max ( H) = sup X:∥X∥ F =1 ⟨X, H(X)⟩. From the upper-bound of ⟨X, H(X)⟩ in (52a), one has sup X:∥X∥ F =1 ⟨X, H(X)⟩ ⩽ sup X:∥X∥ F =1 λ 1/2 max (Σ 1/2 0 Σ τ Σ 1/2 0 ) 2 ∥Σ -1/2 τ XΣ -1/2 τ ∥ 2 F = λ 1/2 max (Σ 1/2 0 Σ τ Σ 1/2 0 ) 2 sup X:∥X∥ F =1 ∥Σ -1/2 τ XΣ -1/2 τ ∥ 2 F = λ 1/2 max (Σ 1/2 0 Σ τ Σ 1/2 0 ) 2 sup X:∥X∥ F =1 ∥Σ -1 τ X∥ 2 F = λ 1/2 max (Σ 1/2 0 Σ τ Σ 1/2 0 ) 2 λ 2 max (Σ -1 τ ) ⩽ λ 1/2 max (Σ 1/2 0 Σ τ Σ 1/2 0 ) 2τ 2 . The last inequality comes from the definition of Σ τ ; if λ 1 ⩾ λ 2 ⩾ • • • ⩾ λ k > 0 are the positive eigenvalues of W W ⊤ , then Σ -1 τ = (W W ⊤ + τ I n ) -1 has eigenvalues τ -1 = • • • = τ -1 n-k times > (λ k + τ ) -1 ⩾ • • • ⩾ (λ 1 + τ ) -1 . For any positive definite matrices A, B ∈ S ++ (n) with increasing eigenvalues, and for any k ∈ [n], we know that λ k (A)λ 1 (B) ⩽ λ k (AB) = λ k (A 1/2 BA 1/2 ) ⩽ λ k (A)λ n (B). Therefore, we have the bound λ 1/2 max (Σ 1/2 0 Σ τ Σ 1/2 0 ) ⩽ λ 1/2 max (Σ 0 )λ 1/2 max (Σ τ ). Moreover, λ max (Σ τ ) ⩽ tr Σ τ , and from Lemma C.7, we know that tr Σ τ ⩽ 2( L(Σ τ )-L(Σ 0 )) =: C 0 . Therefore, we obtain λ max ( H) ⩽ C 0 λ max (Σ 0 ) 2τ 2 . The proof for the minimal eigenvalue is similar and follows from the bound (52b). In this case, the term λ 1/2 min (Σ 1/2 0 Σ τ Σ 1/2 0 ) can be lower bounded by τ λ min (Σ 0 ). We now turn to the Hessian of L 1 τ . Lemma E.6 (Spectral bound of H). Let H be defined as in (51). The maximal eigenvalue for the Hessian of L 1 τ respects the following bound λ max (H) ⩽ λ 1/2 max (Σ 1/2 0 Σ τ Σ 1/2 0 ) 2C 2 τ 2 + 2(1 -λ min (Σ 1/2 0 (Σ 1/2 0 Σ τ Σ 1/2 0 ) -1/2 Σ 1/2 0 )) Proof. From (52a), one has for any X ∈ S n , ⟨X, H(X)⟩ ⩽ λ 1/2 max (Σ 1/2 0 Σ τ Σ 1/2 0 ) 2 ∥Σ -1/2 τ XΣ -1/2 τ ∥ 2 F . Let U ∈ R n×m .With X(U ) = U W ⊤ + W U ⊤ , the bound becomes ⟨U W ⊤ + W U ⊤ , H(X(U ))⟩ ⩽ λ 1/2 max (Σ 1/2 0 Σ τ Σ 1/2 0 ) 2 ∥Σ -1/2 τ X(U )Σ -1/2 τ ∥ 2 F ⇐⇒ 2⟨U W ⊤ , H(X(U ))⟩ ⩽ λ 1/2 max (Σ 1/2 0 Σ τ Σ 1/2 0 ) 2 ∥Σ -1/2 τ X(U )Σ -1/2 τ ∥ 2 F ⇐⇒ 2⟨U, H(X(U ))W ⟩ ⩽ λ 1/2 max (Σ 1/2 0 Σ τ Σ 1/2 0 ) 2 ∥Σ -1/2 τ X(U )Σ -1/2 τ ∥ 2 F . Therefore, ⟨U,H(U )⟩ = 2⟨U, H(X(U ))W + (I -Σ 1/2 0 (Σ 1/2 0 Σ τ Σ 1/2 0 ) -1/2 Σ 1/2 0 )U ⟩ ⩽ λ 1/2 max (Σ 1/2 0 Σ τ Σ 1/2 0 ) 2 ∥Σ -1/2 τ X(U )Σ -1/2 τ ∥ 2 F + 2⟨U, (I -Σ 1/2 0 (Σ 1/2 0 Σ τ Σ 1/2 0 ) -1/2 Σ 1/2 0 )U ⟩. ( ) We proceed by bounding each of the summands. First consider the term ∥Σ -1/2 τ X(U )Σ -1/2 τ ∥ 2 F = ∥Σ -1 τ X(U )∥ 2 F . If U is such that ∥U ∥ F = 1, then ∥X(U )∥ 2 F = ∥U W ⊤ + W U ⊤ ∥ 2 ⩽ 4∥W ∥ 2 F . We know that ∥W ∥ F ⩽ C for some constant C, c.f. (35). Therefore, ∥U ∥ F = 1 =⇒ ∥X(U )∥ ⩽ 2C and sup F for each k, then as L 1 (W (k)) ≤ L 1 (W (0)) for all k ≥ 0, we have The cancellation in the last equality works due to the fact that the multiplication with an arbitrary unitary matrix does not change singular values. Now we are ready to prove the finite step size gradient descent convergence of the BW loss. We consider the perfect balancedness of initial values W i (0), 1 ≤ i ≤ N in the remaining proof. The approximation balancedness case can also be carried out but require more complicated auxiliary estimates. We leave the approximate balancedness assumption as a future direction. U : ∥U ∥ F =1 ∥Σ -1 τ X(U )∥ 2 F ⩽ sup X : ∥X∥ F ⩽2C ∥Σ -1 τ X(U )∥ 2 F = sup X : ∥X∥=1 4C 2 ∥Σ -1 τ X∥ 2 F = 4C 2 λ 2 max (Σ -1 τ ) = 4C 2 τ 2 . σ min W (k)W (k) ⊤ = σ min W (k)W (k) ⊤ -Σ 1/2 0 Ū (k) + Σ 1/2 0 Ū (k) ≥ σ min Σ 1/2 0 Ū (k) -σ max W (k)W (k) ⊤ -Σ 1/2 0 Ū (k) ≥ σ min Σ 1/2 0 Ū (k) -∥ W (k)W (k) ⊤ -Σ 1/2 0 Ū (k)∥ F = σ Proof of Theorem 5.7. Let us start from the gradient descent of the loss with respect to each layer W j (k + 1) = W j (k) -η∇ Wj L N (W 1 (k), • • • W n (k)) = W j (k) -ηW j+1:N (k) ⊤ ∇ W L 1 (W (k))W 1:j-1 (k) ⊤ , 1 ≤ j ≤ N, with the boundary conditions W 1:0 (k) = I d0 and W N +1:N (k) = I d N for all k ≥ 0. With the notations -→ W = (W 1 , W 2 , • • • , W N ) and ∇L N ( -→ W ) =     ∇ W1 L N ( -→ W ) . . . ∇ W N L N ( -→ W )     , we then have the uniform upper bound for all k ≥ 0, ∥A ξ,i (k)∥ F ≤ (1 -ξ)∥W i (k)∥ F + ξ∥W i (k + 1)∥ F ≤ M 1/N . ( ) Using A ξ,i (k) = W i (k)-ξηW j+1:N (k) ⊤ ∇ W L 1 (W (k))W 1:j-1 (k) ⊤ , we can obtain a lower bound in terms of the minimum singular value, σ min A ξ,i (k)A ξ,i (k) ⊤ ≥ σ min W i (k)W i (k) ⊤ -2ξη∥W i (k)∥ F ∥W j+1:N (k)∥ F ∥W 1:j-1 (k)∥ F ∥∇ W L 1 (W (k))∥ F ≥ c 2 -4ηM L 1 (W (k)) ≥ c 2 -4ηM L 1 (W (0)), where we utilize ( 25), ( 66) and ( 65), as well as non-increment of L 1 (W ) throughout the training. We denote X j = -ηW j+1:N (k) ⊤ ∇ W L 1 (W (k))W 1:j-1 (k) ⊤ . We may choose η ≤ c 2 8M L 1 (W (0)) , so that for all k ≥ 0, σ min A ξ,i (k)A ξ,i (k) ⊤ ≥ c 2 2 , σ min A ξ (k)A ξ (k) ⊤ ≥ c 2N 2 N . ( ) Then combining all estimates above, we have ( -→ W (k + 1) - -→ W (k)) ⊤ ∇ 2 L N ( -→ A ξ (k)), -→ W (k + 1) - -→ W (k) ≤ N j=1 X j , ∂ 2 L N ( -→ A ξ (k)) ∂W 2 j X j + N j=1 N i=1,i̸ =j X j , ∂ 2 L N ( -→ A ξ (k)) ∂W i ∂W j X i ≤ N j=1 λ 1/2 max (Σ 1/2 0 A ξ (k)A ξ (k) ⊤ Σ 1/2 0 ) 2 ∥X j A ξ (k)A ξ (k) ⊤ -1 ∥ 2 F M 2(N -1)/N + N j=1 N i=1,i̸ =j M (N -2)/N ∥X i ∥ F ∥X j ∥ F ∥∇ W L 1 (A ξ (k))∥ F + N j=1 N i=1,i̸ =j λ 1/2 max (Σ 1/2 0 A ξ (k)A ξ (k) ⊤ Σ 1/2 0 ) 2 ∥X j A ξ (k)A ξ (k) ⊤ -1 ∥ F × ∥X i A ξ (k)A ξ (k) ⊤ -1 ∥ F M 2(N -1)/N , by using (67), (E.4) and applying the Cauchy-Schwarz inequality for the last term. Notice that ∥X i ∥ F ≤ 2ηM (N -1)/N ∥∇ W L 1 (W (k))∥ F . Now combining all the bounds we obtained previously, in addition to (69), we get that ( -→ W (k + 1) - -→ W (k)) ⊤ ∇ 2 L N ( -→ A ξ (k)), -→ W (k + 1) - -→ W (k) ≤ 2η 2 N 2 ∥A ξ (k)∥ F λ 1/2 max (Σ 0 ) M 4(N -1)/N σ min A ξ (k)A ξ (k) ⊤ ∥∇ W L 1 (W (k))∥ 2 F + 4η 2 N (N -1)M (3N -4)/N ∥∇ W L 1 (A ξ (k))∥ F ∥∇ W L 1 (W (k))∥ 2 F . Moreover, we can use ( 7), (25) again to get ∥∇ W L 1 (A ξ (k))∥ F = 2 L 1 (A ξ (k)) ≤ 2∥ A ξ (k)A ξ (k) ⊤ 1/2 -Σ 1/2 0 U ∥ F ≤ 2∥ A ξ (k)A ξ (k) ⊤ 1/2 ∥ F + 2∥Σ 1/2 0 ∥ F ≤ 2M 1/N + 2∥Σ 1/2 0 ∥ F .



τ (0) -Σ * τ which is the distance to optimality from the initialization. Finally we get the desired exponential rateL(Σ τ (t)) -L(Σ * τ ) ≤ e -8N c OF GRADIENT DESCENT CONVERGENCEWe start by proving Lemma 5.3 so that with the uniform margin deficiency assumption on the initial weights, W W ⊤ does not degenerate along the gradient descent training algorithms.Proof of Lemma 5.3. Let Ū (k) := arg min U ∈U (n) ∥ W (k)W (k) ⊤ -Σ 1/2 0 U ∥ 2

annex

Proof. We begin by stating the first-order differential for the loss L evaluated on the PD matrix Σ τ . This is given in lemma D.10Let GL(n) = A ∈ R n×n | det A ̸ = 0 , and let f : GL(n) ∋ F → F -1 ; then f is differentiable with differential df (F )[X] = -F -1 XF -1 (Magnus & Neudecker, 2019, Theorem 8.3) . Let g : S n ++ ∋ A → A 1/2 be the matrix square root. The function g is differentiable on S n ++ , and its differential can be computed as follows (Kroshnin et al., 2021, Lemma A.1) . Let A ∈ S n ++ , and let ΓQΓ ⊤ be its spectral decomposition, with Q = diag (q i ) n i=1 . For X ∈ S n , define ∆(X) ∈ R n×n to be the matrix with elements ∆(X) ij =Therefore, the chain rule on the differentials givesIn order to express the Hessian of the loss as a function of the end-to-end matrix W , we need the chain rule for the second-order differential. We first recall the chain rule for the second-order differential.Lemma E.2 (Chain rule for second-order differential, Magnus & Neudecker 2019, Theorem 6.9). Let f : R → S and g : S → T be two differentiable functions on open sets, such that h = g•f : R → T is always well defined. Then, given two directions u, v, the second-order differential of h at c isWith this computation rule, we are able to give the second-order differential ofwhereand H is defined as in (47).Therefore,The second summation in (55) can be bounded asProof. This directly follows from the boundedness of the Hessian showed previously and the convexity of L 1 τ using Taylor approximation.Once the Lipschitz-smoothness of the loss has been proven, one can turn to showing that the rank is preserved under balanced initial conditions. Proposition E.8 (Bah et al. 2021, Proposition 4.4) . Let L 1 : R n×m → R be a Lipschitz smooth function (i.e., a differentiable function with Lipschitz gradient). Suppose that W 1 (t), . . . , W N (t) are solutions of the gradient flow (GF) of L N with balanced initial values W j (0) and define the product. The proof follows if the gradient flow is locally Lipschitz continuous in P, Q, W , so that the curves P, Q, W are uniquely determined by an initial datum P (0), Q(0), W (0). From Equations (GF) and ( 21),Now, with the assumption of Lipschitz continuity of the flow, a given solution is uniquely determined by the initial data P 0 , Q 0 , W 0 , and the proof tools of Bah et al. (2021, Proposition 4.4 ) can be used here as well.Remark E.9. The loss L 1 τ satisfies the conditions of Proposition E.8; therefore, the flow on L 1 τ remains in the manifold M k if W (t 0 ) ∈ M k for some t 0 .

E.3 PROOFS OF GRADIENT FLOW CONVERGENCE

Proof of Theorem 5.6. The idea of the proof is to transfer the strong convexity property from Lτ to the evolution of the parameters. Let us start by the inequality which holds due to strong convexitywhere K is the constant from Lemma (E.5. Rearranging the terms in the above equation, we haveOn the covariance space, for the perturbative loss, the gradient flow is writtenSince ∇L 1 (W ) = 2∇ L(Σ)W , and from the balancedness assumption we haveFor one layer ℓ ∈ [N ], we then have, and L = S 2(N -ℓ) N. We evaluate ⟨XR, LX⟩ for diagonals R, L as, and due to the uniform margin deficiency assumption, forFrom the strong convexity of L (56), we get the boundwe consider to write the Taylor expansion in the form, then the first order term in (61), under (60), can be written asfor all 1 ≤ i ≤ N -1 so that, in the symmetric structure above,. Therefore, thanks to Lemma 5.3,Let us mention that Arora et al. (2018, Theorem 1 and Claim 1) provide rigorous derivations about the equalities above. The second order term in ( 61) is more complicated to handle, as we haveThanks to Corollary C.4, we have expressions ofNote that we have the boundedness (C.9)and it is straightforward to see thatMoreover, for allThus, we conclude the estimate for the second order term by+ 8N (N -1)M (3N -4)/N M 1/N + ∥Σ 1/2 0 ∥ F .Let us denote the constant ∆ := 2 N +1 c 2N N 2 M (4N -3)/N λ 1/2 max (Σ 0 ) + 8N (N -1)M (3N -4)/N M 1/N + ∥Σthen, with we can write the iteration asIf we chooseFor η being sufficiently small, we have 1 -2ηN c 

