MOMENTUM STIEFEL OPTIMIZER, WITH APPLICA-TIONS TO SUITABLY-ORTHOGONAL ATTENTION, AND OPTIMAL TRANSPORT

Abstract

The problem of optimization on Stiefel manifold, i.e., minimizing functions of (not necessarily square) matrices that satisfy orthogonality constraints, has been extensively studied. Yet, a new approach is proposed based on, for the first time, an interplay between thoughtfully designed continuous and discrete dynamics. It leads to a gradient-based optimizer with intrinsically added momentum. This method exactly preserves the manifold structure but does not require additional operation to keep momentum in the changing (co)tangent space, and thus has low computational cost and pleasant accuracy. Its generalization to adaptive learning rates is also demonstrated. Notable performances are observed in practical tasks. For instance, we found that placing orthogonal constraints on attention heads of trained-from-scratch Vision Transformer (Dosovitskiy et al., 2020) could markedly improve its performance, when our optimizer is used, and it is better that each head is made orthogonal within itself but not necessarily to other heads. This optimizer also makes the useful notion of Projection Robust Wasserstein Distance (Paty and Cuturi, 2019; Lin et al., 2020) for high-dim. optimal transport even more effective.

1. INTRODUCTION

Matrices that satisfy orthogonal constraints play important roles in various areas including machine learning. A range of studies showed this both theoretically (Saxe et al., 2013; Xiao et al., 2018 ) and experimentally -for example, orthogonality can boost the performances of many architecturesfoot_0 , e.g., MLP (Cisse et al., 2017) , CNN and ResNet (Li et al., 2020; Wang et al., 2020) , RNN (Arjovsky et al., 2016; Wisdom et al., 2016; Helfrich et al., 2018 ), Transformer (Zhang et al., 2021) . The training of such models amounts to optimization under orthogonal constraints, whose applications to machine learning, however, are not restricted to just improving deep learning models. For example, such optimization helps construct optimal low-dim. projection of high-dim. data, which can be used for, e.g., robust and efficient approximation of Wasserstein distance (Paty and Cuturi, 2019; Lin et al., 2020) . Therefore, this article will focus on this non-Euclidean / Riemannian optimization problem. In fact, matrices that contain orthogonal columns constitute a Riemannian manifold known as the Stiefel manifold. Given integers n ≥ m > 0, it is defined as  St(n, m) := {X ∈ R n×m : X ⊤ X = I m×m } (see Apdx. B.3 for the n = m special case which is almost SO(n)). Then, given a C 1 function f : St(n, m) → R, f (X), which is equivalent to min X∈R n×m ,X ⊤ X=Im×m f (X). Due to the importance of this problem, a wide range of approaches have been proposed. Methods based on regularizers, which approximate the constraints (and hence the manifold structure), for example, have been popular in the machine learning literature (Cisse et al., 2017; Bansal et al., 2018; Wang et al., 2020; Zhang et al., 2021) . Meanwhile, efforts have been continuously made to construct algorithms that truly preserve the constraints (i.e. the manifold structure), although challenges such as computational scalability and how to add momentum still remain; see the 2nd next paragraph for how our method addresses them, and more discussions on existing milestones in Sec. 1.1. Our strategy toward a good Stiefel optimizer is the following: first, formulate a variational principle and use that to construct an optimizing dynamics in continuous time, which is described by a system of ODEs corresponding to damped mechanical systems on a constrained manifold; then, design a delicate time discretization of these ODEs, which yields optimization algorithms that precisely preserve the constraints and mimic the continuous dynamics. Our optimizer has several pleasant properties: 1) It is exactly preserving the manifold structure, not only of the Stiefel manifold, but in fact of its tangent bundle. In other words, throughout the course of optimization, the position variable remains exactly on the Stiefel manifold, and the momentum variable remains exactly in the (co)tangent space. 2) Typically, in order to maintain the manifold structure, some kind of projection/retraction/exponential-map operation is needed, and since we have both position and momentum, such operation is needed for both variables (i.e. to maintain the cotangent bundle structure). However, our carefully designed ODE and its discretization make the structure preservation of momentum automatic, meaning that no extra operation (projection, retraction, parallel transport, etc.) is needed for the momentum variable. This not only leads to improved computational efficiency, but also serves an indirect evidence of having a reduced overall (i.e. both position and momentum) local error. 3) We used a quadratic-convergent iterative solver for our specific position retraction operation, which makes it fast. 4) Due to 2)+3), our per iteration computational complexity, O(nm 2 ), has a small constant factor (see Sec. C for details). 5) Our discretization is also numerically stable so that it well preserves the structure even under low machine precision and numerous iterations, which are beneficial in machine learning contexts. 6) Because our algorithm is derived from a variational framework that unify both Euclidean and Stiefel variables, the same hyperparameters can be used for both these parameters; see Sec.  W Q i , W K i , W V i matrices in attention across all heads, and applies orthogonal constraint to each of the three, via regularizer. This makes each head (approximately) orthogonal, not only within itself, but also to other heads. Orthogonal constraint is also applied, via regularizer, to each weight matrix of feed-forward layers in their case. With our Stiefel optimizer which are not restricted to square matrices, we can now make each head exactly and only orthogonal within itself, which leads to further improvements at least in CV tasks. Meanwhile, 2) having orthogonality both in and across heads is found less effective than 1), but it is still better than requiring no orthogonality (i.e. vanilla ViT). No orthogonality on feed-forward layers was used in either 1) or 2). In addition, 3) to achieve these improvements, our Stiefel optimizer needs to be used; methods that do not have momentum or not exactly preserve structure (e.g., regularizer-based) are seen not fully exploiting the benefit of orthogonality. (2017) ; Lezcano-Casado and Martınez-Rubio (2019) showed orthogonalilty helps plain RNN achieve long term memory and even out-perform advanced models such as LSTM. Transformer: Despite of its success, transformer is actually a recent model; to the best of our knowledge, the community just started applying orthogonal constraint to Transformer for NLP tasks (Zhang et al., 2021) . Our work is the first to apply it to Vision Transformer (Dosovitskiy et al., 2020) . Optimization algorithms Manifold optimization is a profound field. General approaches for 1st-order (i.e. gradient based) optimization on (Riemannian) manifold typically involve retraction



It helped computer vision applications prior to the deep learning era as well (e.g.,Liu et al., 2003).



3 and note this difference from previous milestones (e.g., Li et al. (2020)) significantly reduces tuning efforts. 7) Our algorithm works for a range of Riemannian metrics, allowing extra flexibility in choosing suitable geometry to optimize the performance for a specific problem. Selected (due to space) experimental tests of our optimizer are: (1) We consider the simple problem of leading eigenvalues, which is yet practically important in data sciences. It systematically investigates algorithmic performances under different parameters. (2) We show the elegant idea of approximating optimal transport distance in high-dim. via a good low-dim. projection (Paty and Cuturi, 2019; Lin et al., 2020) can be made even more efficacious by our optimizer. (3) We note that Vision Transformer (ViT) can be further improved by imposing attention heads to be orthogonal; more precisely: Consider training ViT from scratch. We discover that 1) requiring each head to be orthogonal in itself improves both training and testing accuracies the most. An important recent work by Zhang et al. (2021) applied orthogonality to transformers and demonstrated improved performance in NLP tasks. It concatenates each of the

MORE ON RELATED WORK Orthogonality in deep learning Initialization: Orthogonal initialization is both theoretically and experimentally shown good for deep neural networks (Saxe et al., 2013). CNN: Cisse et al. (2017) used regularizer to make weight matrices orthogonal when training MLP and CNN, and showed both improved accuracy and robustness. Rodríguez et al. (2017) showed that posing orthogonality to CNN reduces overfitting. Bansal et al. (2018) experimentally compared several orthogonal regularizers on ResNet. RNN: Arjovsky et al. (2016); Wisdom et al. (2016); Helfrich et al. (2018); Vorontsov et al.

we consider, under gradient oracle, the smooth optimization problem min

availability

https://github.

