CRITICAL POINTS AND CONVERGENCE ANALYSIS OF GENERATIVE DEEP LINEAR NETWORKS TRAINED WITH BURES-WASSERSTEIN LOSS

Abstract

We consider a deep matrix factorization model of covariance matrices trained with the Bures-Wasserstein distance. While recent works have made important advances in the study of the optimization problem for overparametrized low-rank matrix approximation, much emphasis has been placed on discriminative settings and the square loss. In contrast, our model considers another interesting type of loss and connects with the generative setting. We characterize the critical points and minimizers of the Bures-Wasserstein distance over the space of rank-bounded matrices. For low-rank matrices the Hessian of this loss can blow up, which creates challenges to analyze convergence of optimizaton methods. We establish convergence results for gradient flow using a smooth perturbative version of the loss and convergence results for finite step size gradient descent under certain assumptions on the initial weights.

1. INTRODUCTION

We investigate generative deep linear networks and their optimization using the Bures-Wasserstein distance. More precisely, we consider the problem of approximating a target Gaussian distribution with a deep linear neural network generator of Gaussian distributions by minimizing the Bures-Wasserstein distance. This problem is of interest in two important ways. First, it pertains to the optimization of deep linear networks for a type of loss that is qualitatively different from the well-studied and very particular square loss. Second, it can be regarded as a simplified but instructive instance of the parameter optimization problem in generative networks, specifically Wasserstein generative adversarial networks, which are currently not as well understood as discriminative networks. The optimization landscapes and the properties of parameter optimization procedures for neural networks are among the most puzzling and actively studied topics in theoretical deep learning (see, e.g. Mei et al., 2018; Liu et al., 2022) . Deep linear networks, i.e., neural networks having the identity as activation function, serve as a simplified model for such investigations (Baldi & Hornik, 1989; Kawaguchi, 2016; Trager et al., 2020; Kohn et al., 2022; Bah et al., 2021) . The study of linear networks has guided the development of several useful notions and intuitions in the theoretical analysis of neural networks, from the absence of bad local minima to the role of parametrization and overparametrization in gradient optimization (Arora et al., 2018; 2019a; b) . Many previous works have focused on discriminative or autoregressive settings and have emphasized the square loss. Although the square loss is indeed a popular choice in regression tasks, it interacts in a very special way with the particular geometry of linear networks (Trager et al., 2020) . The behavior of linear networks optimized with different losses has also been considered in several works (Laurent & Brecht, 2018; Lu & Kawaguchi, 2017; Trager et al., 2020) but is less well understood. The Bures-Wasserstein distance was introduced by Bures (1969) to study Hermitian operators in quantum information, particularly density matrices. It induces a metric on the space of positive semi-definite matrices. The Bures-Wasserstein distance corresponds to the 2-Wasserstein distance between two centered Gaussian distributions (Bhatia et al., 2019) . Wasserstein distances enjoy several properties, e.g. they remain well defined between disjointly supported measures and have duality formulations that allow for practical implementations (Villani, 2003) , that make them good candidates and indeed popular choices of a loss for learning generative models, with a well-known case being the Wasserstein Generative Adversarial Networks (GANs) (Arjovsky et al., 2017) . While the 1-Wasserstein distance has been most commonly used in this context, the Bures-Wasserstein distance has also attracted much interest, e.g. in the works of Muzellec & Cuturi (2018); Chewi et al. (2020); Mallasto et al. (2022) , and has also appeared in the context of linear quadratic Wasserstein generative adversarial networks (Feizi et al., 2020) . A 2-Wasserstein GAN is a minimum 2-Wasserstein distance estimator expressed in Kantorovich duality (see details in Appendix B). This model can serve as an attractive platform to develop the theory particularly when the inner problem can be solved in closed-form. Such a formula is available when comparing pairs of Gaussian distributions. In the case of centered Gaussians this corresponds precisely to the Bures-Wasserstein distance. Strikingly, even in this simple case, the optimization properties of the corresponding problem are not well understood; which we aim to address in the present work.

1.1. CONTRIBUTIONS

We establish a series of results on the optimization of deep linear networks trained with the Bures-Wasserstein loss, which we can summarize as follows. • We obtain an analogue of the Eckart-Young-Mirsky theorem characterizing the critical points and minimizers of the Bures-Wasserstein distance over matrices of a given rank (Theorem 4.2). • To circumvent the non-smooth behaviour of the Bures-Wasserstein loss when the matrices drop rank, we introduce a smooth perturbative version (Definition 5 and Lemma 3.3), and characterize its critical points and minimizers over rank-constrained matrices (Theorem 4.4) and link them to the critical points on the parameter space (Proposition 4.5). • For the smooth Bures-Wasserstein loss, in Theorem 5.6 we show exponential convergence of the gradient flow assuming balanced initial weights (Definition 2.1) and a uniform margin deficiency condition (Definition 5.2). • For the Bures-Wasserstein loss and its smooth version, in Theorem 5.7 we show convergence of gradient descent provided the step size is small enough and assuming balanced initial weights.

1.2. RELATED WORKS

Low rank matrix approximation The function space of a linear network corresponds to n × m matrices of rank at most d, the lowest width of the network. Hence optimization in function space is closely related to the problem of approximating a given data matrix by a low-rank matrix. When the approximation error is measured in Frobenius norm, Eckart & Young (1936) characterized the optimal bounded-rank approximation of a given matrix in terms of its singular value decomposition. Mirsky (1960) obtained the same characterization for the more general case of unitary invariant matrix norms, which include the Euclidean operator norm and the Schatten-p norms. There are generalizations to certain weighted norms (Ruben & Zamir, 1979; Dutta & Li, 2017) . However, for general norms the problem is known to be difficult (Song et al., 2017; Gillis & Vavasis, 2018; Gillis & Shitov, 2019) . Loss landscape of deep linear networks For the square loss, the optimization landscape of linear networks has been studied in numerous works. The pioneering work of Baldi & Hornik (1989) showed, focusing on the two-layer case, that there is a single minimum (up to a trivial parametrization symmetry) and all other critical points are saddle points. Kawaguchi (2016) obtained corresponding results for deep linear nets and showed the existence of bad saddles (with no negative Hessian eigenvalues) for networks with more than three layers. et al., 2022) . For losses different from the square loss there are also several results. Laurent & Brecht (2018) showed that deep linear nets with no bottlenecks have no local minima that are not global for arbitrary convex differentiable loss. Lu & Kawaguchi (2017) showed that if the loss is such that any local minimizer in parameter space can be perturbed to an equally good minimizer with full-rank factor matrices, then all local minima in parameter space are local minima in function space. Trager



Chulhee et al. (2018)  found sets of parameters such that any critical point in this set is a global minimum and any critical point outside is a saddle. Variations include the study of critical points for different types of architectures, such as deep linear residual networks(Hardt & Ma, 2017)  and deep linear convolutional networks (Kohn

