MOMENTUM STIEFEL OPTIMIZER, WITH APPLICA-TIONS TO SUITABLY-ORTHOGONAL ATTENTION, AND OPTIMAL TRANSPORT

Abstract

The problem of optimization on Stiefel manifold, i.e., minimizing functions of (not necessarily square) matrices that satisfy orthogonality constraints, has been extensively studied. Yet, a new approach is proposed based on, for the first time, an interplay between thoughtfully designed continuous and discrete dynamics. It leads to a gradient-based optimizer with intrinsically added momentum. This method exactly preserves the manifold structure but does not require additional operation to keep momentum in the changing (co)tangent space, and thus has low computational cost and pleasant accuracy. Its generalization to adaptive learning rates is also demonstrated. Notable performances are observed in practical tasks. For instance, we found that placing orthogonal constraints on attention heads of trained-from-scratch Vision Transformer (Dosovitskiy et al., 2020) could markedly improve its performance, when our optimizer is used, and it is better that each head is made orthogonal within itself but not necessarily to other heads. This optimizer also makes the useful notion of Projection Robust Wasserstein Distance (Paty and Cuturi, 2019; Lin et al., 2020) for high-dim. optimal transport even more effective.

1. INTRODUCTION

Matrices that satisfy orthogonal constraints play important roles in various areas including machine learning. A range of studies showed this both theoretically (Saxe et al., 2013; Xiao et al., 2018 ) and experimentally -for example, orthogonality can boost the performances of many architecturesfoot_0 , e.g., MLP (Cisse et al., 2017) , CNN and ResNet (Li et al., 2020; Wang et al., 2020) , RNN (Arjovsky et al., 2016; Wisdom et al., 2016; Helfrich et al., 2018) , Transformer (Zhang et al., 2021) . The training of such models amounts to optimization under orthogonal constraints, whose applications to machine learning, however, are not restricted to just improving deep learning models. For example, such optimization helps construct optimal low-dim. projection of high-dim. data, which can be used for, e.g., robust and efficient approximation of Wasserstein distance (Paty and Cuturi, 2019; Lin et al., 2020) . Therefore, this article will focus on this non-Euclidean / Riemannian optimization problem. In fact, matrices that contain orthogonal columns constitute a Riemannian manifold known as the Stiefel manifold. Given integers n ≥ m > 0, it is defined as St(n, m) := {X ∈ R n×m : X ⊤ X = I m×m } (see Apdx. B.3 for the n = m special case which is almost SO(n)). Then, given a C 1 function f : St(n, m) → R, we consider, under gradient oracle, the smooth optimization problem min X∈St(n,m) f (X), which is equivalent to min X∈R n×m ,X ⊤ X=Im×m f (X). (1) Due to the importance of this problem, a wide range of approaches have been proposed. Methods based on regularizers, which approximate the constraints (and hence the manifold structure), for example, have been popular in the machine learning literature (Cisse et al., 2017; Bansal et al., 2018; Wang et al., 2020; Zhang et al., 2021) . Meanwhile, efforts have been continuously made to construct algorithms that truly preserve the constraints (i.e. the manifold structure), although challenges such as computational scalability and how to add momentum still remain; see the 2nd next paragraph for how our method addresses them, and more discussions on existing milestones in Sec. 1.1. Our strategy toward a good Stiefel optimizer is the following: first, formulate a variational principle and use that to construct an optimizing dynamics in continuous time, which is described by a system of ODEs corresponding to damped mechanical systems on a constrained manifold; then, design a delicate time discretization of these ODEs, which yields optimization algorithms that precisely preserve the constraints and mimic the continuous dynamics. Our optimizer has several pleasant properties: 1) It is exactly preserving the manifold structure, not only of the Stiefel manifold, but in fact of its tangent bundle. In other words, throughout the course of optimization, the position variable remains exactly on the Stiefel manifold, and the momentum variable remains exactly in the (co)tangent space. 2) Typically, in order to maintain the manifold structure, some kind of projection/retraction/exponential-map operation is needed, and since we have both position and momentum, such operation is needed for both variables (i.e. to maintain the cotangent bundle structure). However, our carefully designed ODE and its discretization make the structure preservation of momentum automatic, meaning that no extra operation (projection, retraction, parallel transport, etc.) is needed for the momentum variable. This not only leads to improved computational efficiency, but also serves an indirect evidence of having a reduced overall (i.e. both position and momentum) local error. 3) We used a quadratic-convergent iterative solver for our specific position retraction operation, which makes it fast. 4) Due to 2)+3), our per iteration computational complexity, O(nm 2 ), has a small constant factor (see Sec. C for details). 5) Our discretization is also numerically stable so that it well preserves the structure even under low machine precision and numerous iterations, which are beneficial in machine learning contexts. 6) Because our algorithm is derived from a variational framework that unify both Euclidean and Stiefel variables, the same hyperparameters can be used for both these parameters; see Sec.3 and note this difference from previous milestones (e.g., Li et al. (2020) ) significantly reduces tuning efforts. 7) Our algorithm works for a range of Riemannian metrics, allowing extra flexibility in choosing suitable geometry to optimize the performance for a specific problem. Selected (due to space) experimental tests of our optimizer are: (1) We consider the simple problem of leading eigenvalues, which is yet practically important in data sciences. It systematically investigates algorithmic performances under different parameters. (2) We show the elegant idea of approximating optimal transport distance in high-dim. via a good low-dim. projection (Paty and Cuturi, 2019; Lin et al., 2020) can be made even more efficacious by our optimizer. (3) We note that Vision Transformer (ViT) can be further improved by imposing attention heads to be orthogonal; more precisely: Consider training ViT from scratch. We discover that 1) requiring each head to be orthogonal in itself improves both training and testing accuracies the most. An important recent work by Zhang et al. (2021) applied orthogonality to transformers and demonstrated improved performance in NLP tasks. It concatenates each of the W Q i , W K i , W V i matrices in attention across all heads, and applies orthogonal constraint to each of the three, via regularizer. This makes each head (approximately) orthogonal, not only within itself, but also to other heads. Orthogonal constraint is also applied, via regularizer, to each weight matrix of feed-forward layers in their case. With our Stiefel optimizer which are not restricted to square matrices, we can now make each head exactly and only orthogonal within itself, which leads to further improvements at least in CV tasks. Meanwhile, 2) having orthogonality both in and across heads is found less effective than 1), but it is still better than requiring no orthogonality (i.e. vanilla ViT). No orthogonality on feed-forward layers was used in either 1) or 2). In addition, 3) to achieve these improvements, our Stiefel optimizer needs to be used; methods that do not have momentum or not exactly preserve structure (e.g., regularizer-based) are seen not fully exploiting the benefit of orthogonality.

1.1. MORE ON RELATED WORK

Orthogonality in deep learning Initialization: Orthogonal initialization is both theoretically and experimentally shown good for deep neural networks (Saxe et al., 2013) . CNN: Cisse et al. (2017) used regularizer to make weight matrices orthogonal when training MLP and CNN, and showed both improved accuracy and robustness. Rodríguez et al. (2017) showed that posing orthogonality to CNN reduces overfitting. Bansal et al. (2018) experimentally compared several orthogonal regularizers on ResNet. RNN: Arjovsky et al. (2016) ; Wisdom et al. (2016) ; Helfrich et al. (2018) ; Vorontsov et al. (2017) ; Lezcano-Casado and Martınez-Rubio (2019) showed orthogonalilty helps plain RNN achieve long term memory and even out-perform advanced models such as LSTM. Transformer: Despite of its success, transformer is actually a recent model; to the best of our knowledge, the community just started applying orthogonal constraint to Transformer for NLP tasks (Zhang et al., 2021) . Our work is the first to apply it to Vision Transformer (Dosovitskiy et al., 2020) . Optimization algorithms Manifold optimization is a profound field. General approaches for 1st-order (i.e. gradient based) optimization on (Riemannian) manifold typically involve retraction (see Apdx.D) which is not present in Euclidean cases. Thus we categorize 1st-order Stiefel optimizers into 2 types: O1) retraction-based, and O2) non-retraction-based. O1) Retraction-based. Given a point and a tangent vector, exponential map gives how the point moves along the tangent vector on the manifold. In practice, an approximation of the exponential map, known as a retraction, is often used for reducing the computational cost. The exponential map and various retractions on Stiefel manifold are well-studied. For Stiefel, the exponential map can be computed with complexity O(nmfoot_2 ) (Edelman et al., 1998) , whose constant prefactor is however large as the map is essentially matrix exponentiation. A well-known retraction is Cayley map (Eq. ( 32)), and Wen and Yin (2013) proposed a gradient descent method 2 with smartly lowered computational cost of Cayley map (will be referred to as Momentumless Stiefel (S)GD). SVD can also help construct a retraction; Li et al. (2019) did so by first following the gradient out of the manifold and then projecting back via modifying the singular values, which is interesting but expensive. Meanwhile, a remaining challenge is to add momentum, which is nontrivial in a manifold setting but helpful for improved speed of convergence (Zhang and Sra, 2018; Ahn and Sra, 2020; Alimisis et al., 2021) . In this case, the geometry that needs to be maintained is that of the augmented state-space, known as the tangent bundle; i.e., position has to be in Stiefel manifold and momentum in its tangent space. One could follow a general Riemannian retraction idea (e.g., Absil and Malick (2012) ) and use retraction that projects both position and momentum to the tangent bundle to ensure manifold preservation. This was done for example in Li et al. (2020) (will be referred to as Projected Stiefel SGD/Adam). Another cool technique to maintain the structure of momentum is parallel transport (e.g., Bécigneul and Ganea (2019) ), which is however computationally expensive and possibly inexact (Edelman et al., 1998) . Our work could be viewed as a retraction-based approach, because our position variable is projected back to St at each step; however, it doesn't require any retraction or projection on the momentum variable -thanks to our variational approach and delicate discretization, position and momentum intrinsically move together, and once the position X is placed on St, the momentum Q automatically finds itself in T X St. This leads to better computational efficiency, as well as improved accuracy because variables stay closer to the tangent bundle. O2) Non-retraction-based. In some sense one can always call an exact manifold preserving method retraction-based, but approximate structure preservation is also a possibility. Regularizers for instance could be used to approach X ⊤ X = I. Between exact and approximate structure preserving, which one works better is theoretically still an open question, and task-dependent but insightful empirical results exist. For example, regularizer-based methods usually need more hyperparameter tuning, and the final result is likely sensitive to the regularization strength (Wang et al., 2020) . Lai and Osher (2014) also discussed their slow convergence and possible improvements. Regularizer-based methods typically have low computational costs, but due to trajectory not exactly on the manifold, they may converge to (the neighborhood of) a local min. different from the one an exact algorithm converges to for multimodal problems. Similar observation was made for non-regularizer-based approximate method too; e.g., Vorontsov et al. (2017) suggested that at least in certain cases 'hard' constraint is better than 'soft'. The aforementioned Li et al. (2020) is actually inexact too, because in their practical implementation the retraction was inexactly performed for the sake of speed. Meanwhile, we also note an interesting, non-regularizer-based recent result (Ablin and Peyré, 2022 ) that converges to the manifold despite of being approximate prior to convergence. Additional. We also mention the existence of useful algorithms for electronic structure calculations that can be viewed as Stiefel optimizers (e.g., Zhang et al. (2014) ; Dai et al. (2017) ; Gao et al. (2018; 2019) ; Hu et al. (2019) ; Dai et al. (2020) ). They did not use momentum. Another inexact-manifoldpreserving but interesting result is Bu and Chang (2022) . The problem can also be approached via a constrained optimization formulation (e.g., (Leimkuhler et al., 2021) ). In addition, there are other interesting optimizers also based on discretizing damped mechanical systems but not via the splitting approach used here, e.g., Lee et al. (2021) ; Duruisseaux and Leok (2021) ; however they are mostly implicit and not always scalable to deep learning applications. Finally, it is possible to represent a matrix by other matrices and optimize those matrices instead. A special case of SO(n) was considered by Lezcano-Casado and Martınez-Rubio (2019) ; Helfrich et al. (2018) where X = expm(A) or X = Cayley(A) was used and then one optimized skew symmetric A instead. Although these authors did not consider a Stiefel generalization, this could be possible, but even without the generalization the computational cost is already high. This paper considers exact structure preservation. For applications included here, this is either empirically more beneficial (Sec.3.2) or a necessity (Sec.3.1 & Apdx.Q). See comparisons of most relevant approaches in Tab. 2 and more on complexity in Apdx. C.

2. DERIVATION OF THE OPTIMIZERS AND THEIR PROPERTIES

We represent Stiefel manifold as St(n, m) := {X ∈ R n×m : X ⊤ X = I m×m } where n ≥ mfoot_3 , i.e. matrices with orthonormal columns. This is a natural embedding in Euclidean space. Based on identity isomorphism between cotangent & tangent spaces (see Apdx. B.2), it also gives: Proposition 1. The (co)tangent space of St at X ∈ St is T X St := {∆ ∈ R n×m : X ⊤ ∆ + ∆ ⊤ X = 0} ; its (co)tangent bundle is T St := {(X, ∆) : X ∈ St, ∆ ∈ T X St}. Throughout this paper, we focus on a family of Riemannian metrics on St defined in the following. Definition 1 (canonical-type metric). For a fixed constant a < 1, the canonical-type metric g X : T X St × T X St → R is defined as g X (∆ 1 , ∆ 2 ) = Tr(∆ ⊤ 1 (I -aXX ⊤ )∆ 2 ), ∀∆ 1 , ∆ 2 ∈ T X St. (2) The following are two commonly used examples of the canonical-type metric (Tagare, 2011): • When a = 0, g X (∆ 1 , ∆ 2 ) = Tr(∆ ⊤ 1 ∆ 2 ), which corresponds to the Euclidean metric. • When a = 1/2, g X (∆ 1 , ∆ 2 ) = Tr(∆ ⊤ 1 (I -1 2 XX ⊤ )∆ 2 ) , which is known as the canonical metric. Notation. Denote Euclidean gradient in ambient space by ∂f ∂X := ∂f ∂Xij (X) ij . Denote by ⊙, ⊘, and A •c element-wise product, division, and cth power. Denote expm to be the matrix exponential.

2.1. OPTIMIZATION DYNAMICS IN CONTINUOUS TIME

In this section, we study the Stiefel optimization problem (1) from a variational perspective. The manifold setup yields a constrained variational problem, which is hard to handle due to nonlinear geometry of the function space. Thus we introduce an alternative variational formulation based on functional Lagrange multiplier to bypass this difficulty. It generates an ODE that never escapes from the manifold T St and is guaranteed to converge to a local minimizer of the function f on St. In Euclidean space, momentum has long been used to accelerate gradient descent (Nesterov, 1983 ). More recently, continuous time limit brought useful tools to complement the classical analyses for such optimizers (e.g., Su et al. (2014) ), and a variational formulation was then established to unify a large family of momentum-based Euclidean optimizers (Wibisono et al., 2016) . This formulation can also provide a powerful and intrinsic way to generalize momentum-based optimization to non-Euclidean settings. One demonstration is, for example, accelerated optimization on Lie groups (Tao and Ohsawa, 2020) . Following the same line of thoughts, we consider Stiefel optimization by first defining a time-dependent Lagrangian L : T St × R + → R as L(X, Ẋ, t) := r(t) 1 2 g X ( Ẋ, Ẋ) -f (X) = r(t) 1 2 Tr Ẋ⊤ (I -aXX ⊤ ) Ẋ -f (X) . (3) Without r(t), this would be a standard Lagrangian for mechanics on manifold, and it would generate dynamics in which the kinetic energy 1 2 g X ( Ẋ, Ẋ) and the potential energy f (X) continuously exchange with each other. Choosing r(t) to be an increasing function, however, will break time translational symmetry and introduce dissipation in the system; consequently, the total energy 1 2 g X ( Ẋ, Ẋ)+f (X) will monotonically decrease in time, leading to the minimization of f . Given this Lagrangian, we can define a variational problem (VP) on manifold whose corresponding stationarity condition (known as Euler-Lagrange equation) yields an ODE that optimizes f in continuous time Constrained VP: δ T 0 L(X(t), Ẋ(t), t)dt = 0, s.t. X(t), Ẋ(t) ∈ T St, ∀ 0 ≤ t ≤ T. (4) However, to obtain concrete algorithms, this constrained VP is very difficult to handle because the variation of X(•) has to keep it in a nonlinearly constrained function space. Tao and Ohsawa (2020) applied a tool of reduction via Lie algebra to solve the case of Lie group manifold, but for Stiefel manifold St(n, m), unless n = m, there is no group structure and this technique no longer applies. Therefore, we transfer (4) into an unconstrained variational problem, using tools described in, e.g., Chen et al. (2021) . By adding a Lagrange multiplierfoot_4 Λ(t) ∈ R m×m for the constraints X(t) ⊤ X(t) -I = 0, we augment the Lagrangian to be L(X, Ẋ, Λ, t) = r(t) 1 2 Tr Ẋ⊤ (I -aXX ⊤ ) Ẋ -f (X) - 1 2 Tr Λ ⊤ (X ⊤ X -I) , where Λ(t) ∈ R m×m is a symmetric matrix. The benefit is that now X can be varied in an unconstrained, flat function space, corresponding to Unconstrained VP: δ T 0 L(X(t), Ẋ(t), Λ(t), t)dt = 0, ∀ 0 ≤ t ≤ T. To sum up the above discussion, problem (1) can be resolved through (un)constrained variational problem (4),(6), i.e., (1) ⇐= (4) ⇐⇒ (6), in the sense that the solutions of ( 4),(6) solve problem (1); meanwhile, (6) can be explicitly solved, as detailed in the following (proof is in Apdx. G): Theorem 1 (Optimization dynamics on T St). To use problem ( 6) for solving problem (1), we have: 1. The solution of the unconstrained variational problem ( 6) is the following ODE Ẋ = Q Q = -γQ -XQ ⊤ Q -3a 2 (I -XX ⊤ )QQ ⊤ X -∂f ∂X + 1+b 2 XX ⊤ ∂f ∂X + 1-b 2 X ∂f ∂X ⊤ X (7) where (X(t), Q(t)) = (X(t), Ẋ(t)) ∈ R n×m × R n×m is the tuple of state variable and its momentum, γ(t) := ṙ(t)/r(t) is the scalar friction coefficient, and b := a a-1 is a constant depending on the canonical-type metric (2). 2. For any isolated local minimum X * ∈ St of f , there exists a neighbourhood U ⊂ T St of (X * , 0) s.t., for any initial condition (X 0 , Q 0 ) ∈ U , the solution (X(t), Q(t)) of the system (7) converges to (X * , 0) as t → ∞. One feature of ODE ( 7) is the preservation of constraints: even though the ODE is defined in the Euclidean space, as long as it starts on the manifold, it will never leave the manifold: Theorem 2 (Constrained optimization with unconstrained dynamics). If the initial condition of ( 7) is on T St, the cotangent bundle of Stiefel manifold, then dynamics (7) automatically stays on T St, i.e., if X(0 ) ⊤ X(0) = I m×m , X(0) ⊤ Q(0) + Q(0) ⊤ X(0) = 0 m×m ,, then for all t ≥ 0 X(t) ⊤ X(t) = I m×m , X(t) ⊤ Q(t) + Q(t) ⊤ X(t) = 0 m×m . Proof is in Apdx. H.

2.2. STRUCTURE-PRESERVING DISCRETIZATION VIA VARIABLE DECOMPOSITION AND OPERATOR SPLITTING

Although the continuous dynamics ( 7) preserves the constraints, such preservation is in general not guaranteed when time is discretized. The construction of our discretization briefly follows four steps: a geometric decomposition of momentum, a carefully designed operator splitting scheme to approximate the ODEs, structure-preserving approximations of the split system, and a structurepreserving relaxation that further reduces the computational cost. Details now follow. Preparation: from a static decomposition of the tangent space to a decomposition of Q dynamics. To retain the preservation of geometric structures through a time discretization, we first decompose the tangent space T X St into X and X ⊥ components (see more details in Tagare ( 2011)); more precisely, given a tangent vector Q represented by an n × m matrix, we rewrite it as Q = XY + V for Y ∈ R m×m , V ∈ R n×m , and use Y, V to replace Q. This transformation changes the constraint X ⊤ Q + Q ⊤ X = 0 to {Y ⊤ + Y = 0, X ⊤ V = 0} instead (see Apdx E.1). Another advantage of this decomposition is that Y, V naturally split the canonical-type metric (2) (see Apdx E.2). Although a fixed tangent Q can be uniquely decomposed into Y and V , when we start to consider a dynamical Q(t) and make the decomposition at each t, we need to understand how the corresponding Y (t) and V (t) evolve. This is not a trivial question, but it can be proved that the new dynamics is Ẋ = XY + V, Ẏ = -γY - 1 -b 2 X ⊤ ∂f ∂X - ∂f ∂X ⊤ X , V = -γV + 3a -2 2 V Y -XV ⊤ V -I -XX ⊤ ∂f ∂X , with initial condition satisfying X(0 ) ⊤ X(0) = I, Y (0) ⊤ +Y (0) = 0, X(0) ⊤ V (0) = 0. This system is equivalent to (7) via Q(t) = X(t)Y (t) + V (t) and preserving X(t) ⊤ X(t) = I, Y (t) ⊤ + Y (t) = 0, X(t) ⊤ V (t) = 0 for all t (Thm.8). Moreover, it will be amenable to a good discretization and is thus the base for constructing our optimizer. With this decomposition defined, we can finally make the phrase 'structure preservation' precise: Definition 2. Structure preservation means variables exactly satisfying all constraints for all time. That is, for 'XQ'-system, X(t ) ⊤ X(t) = I, X(t) ⊤ Q(t) = 0, ∀t; for 'XY V '-system, X(t) ⊤ X(t) = I, Y (t) ⊤ + Y (t) = 0, X(t) ⊤ V (t) = 0, ∀t. In comparison, staying exactly on Stiefel, i.e. X(t) ⊤ X(t) = I is termed 'feasible' (Wen and Yin, 2013; Ablin and Peyré, 2022 ), but we have additional constraints due to momentum. Step I: operator splitting of the ODEs. To handle the high nonlinearity of ODE ( 8) and maintain the preservation of constraints after discretization, we adopt an operator splitting method, based on a general fact that a numerical discretization of an ODE can be obtained by composing the (approximate) flow maps of split ODEs (McLachlan and Quispel, 2002) . More precisely, we split the vector field of (8) as a sum of three vector fields, each associated with one of the following ODEs:          Ẋ = XY Ẏ = -γY -1-b 2 X ⊤ ∂f ∂X -∂f ∂X ⊤ X V = 0 (9)          Ẋ = 0 Ẏ = 0 V = -γV + 3a-2 2 V Y -(I -XX ⊤ ) ∂f ∂X (10)      Ẋ = V Ẏ = 0 V = -XV ⊤ V (11). Define the corresponding time-h evolution maps ϕ 1 , ϕ 2 , ϕ 3 of Eq.( 9)(10)(11) to be ϕ j : [X(t), Y (t), V (t)] → [X(t + h), Y (t + h), V (t + h)] for system j = 1, 2, 3. Note ϕ 1 , ϕ 2 , ϕ 3 give the exact solutions of these split ODEs. Then we see our specific split honors all constraints: Theorem 3. ϕ 1 , ϕ 2 , ϕ 3 are all structure preserving. Proof is in Apdx. J. Step II: structure-preserving approximation of exact flow maps. Due to the nonlinearity, ϕ 1 and ϕ 3 do not admit analytical expressions; ϕ 2 does have an explicit expression (Eq. ( 30)), but an approximation will still reduce computational costs while maintaining certain accuracy (see Fig. 4 ). Therefore we first denote the 1st-order approximation of the exact flow maps ϕ 1 , ϕ 2 , ϕ 3 to be φ1 , φ2 , φ3 , where φj : [X 0 , Y 0 , V 0 ] → [X h , Y h , V h ] = ϕ j ([X 0 , Y 0 , V 0 ]) + O(h 2 ), j = 1, 2, 3. Then φ1 :          X h = X 0 expm(hY h ) Y h = exp(-γh)Y 0 -(1-b)(1-exp(-γh)) 2γ X ⊤ 0 ∂f ∂X0 -∂f ∂X0 ⊤ X 0 V h = V 0 (12) φ3 :          X † = X 0 + hV 0 X ⊤ 0 X 0 X h = X † (X ⊤ † X † ) -1/2 Y h = Y 0 V h = V 0 -hX 0 V ⊤ 0 V 0 . (13) φ2 : X h = X 0 , V h = (1 -γh)V 0 + 3a -2 2 hV 0 Y 0 -h I -X 0 X ⊤ 0 ∂f ∂X (X 0 ), Y h = Y 0 . ( ) Theorem 4. φ1 , φ2 , φ3 are all structure preserving. Proof is in Apdx. K. This shows not only the split flow maps but their decently designed discretizations maintain the constraints of X, Y, V . There are several specially designed features to enable the numerical stability of the integrators: 1) 'X † (X ⊤ † X † ) -1/2 ' in φ3 is part of a nontrivial discretization scheme and the same as polar retraction (Absil et al., 2009) . It directly leads to the preservation of the geometry of the position variable by φ3 , i.e., even the input of φ3 has error that X ⊤ 0 X 0 ̸ = I, the output always satisfies X ⊤ h X h = I, but it will not impair the order of the O(h 2 ) local discretization error. See more about its connection to numerically stability to arithmetic errors in Apdx. M.1. 2) In the first equation of φ3 , X 0 + hV 0 X ⊤ 0 X 0 is used instead of X 0 + hV 0 , even though most of the time X ⊤ 0 X 0 = I. This guarantees that even if X ⊤ 0 X 0 ̸ = I, the constraint X ⊤ 0 V 0 = 0 itself still leads to X ⊤ h V h = 0, which improves numerical stability. What's more, when combined with 1), this property will also enable us to use a cheaper φ1 in the following Step III. 3) The 'forward Euler'-like discretization for ϕ 3 is carefully chosen,and updating V h using X h instead of X 0 , for example, will destroy its structure preservation. 4) No extra handling except forward Euler is applied to the momentum variables Y, V ; i.e. this discretization itself is beneficial enough to guarantee momentum structure preservation. Step III: relaxation and composition of the operators. Our goal is to obtain a structure preserving scheme for the original ODE (8) instead of requiring structure preservation of each operator for the split ODEs (9)(10)(11). The latter will ensure the former as we have: Theorem 5. The composition of any ordering of φ1 , φ2 , φ3 , e.g., φ1 • φ2 • φ3 , is a structure-preserving, 1st-order (in h) numerical integrator of (8). Proof is in Apdx. L. However, the latter (all operators preserving structure) is not needed for the former (their composition preserves structure). We can still eventually obtain a structure-preserving integrator, however without the costly 'expm', by relaxing some of the structure preservation requirements for φ1 . Theorem 6. Consider a consistent approximation of φ1 by φ1 (i.e. φ1 = φ1 + O(h 2 )). Assume φ1 satisfies that if initially X ⊤ 0 X 0 = I, Y ⊤ 0 + Y 0 = 0, X ⊤ 0 V 0 = 0, then Y ⊤ h + Y h = 0, X ⊤ h V h = 0. Then the specific composition φ3 • φ1 • φ2 is structure preserving. Moreover, the composition of any ordering of φ1 , φ2 , φ3 is a 1st-order (in h) integrator of (8). Proof is in Apdx. M.3 Therefore, our default recommendation is to use φ3 . See Apdx M.4 for more information. A few side remarks are listed in the following. If m is small, then φ1 • ϕ 2 • φ3 (or any permutation) can also be used, although experimentally no clear advantage was observed (Fig. 4 ); otherwise, the computational cost of expm can become prohibitive. Moreover, the computational cost of matrix inverse square root in the polar retraction (X ⊤ † X † ) -1/2 should not be a concern since it is only for m × m matrices instead of n × n (n ≥ m). Meanwhile, it can be computed to machine precision rapidly using a quadratically convergent iterative method (Higham, 1997); see Algo. 3 in Apdx. M.1 for details, and Apdx.C for computational complexity. • φ1 • φ2 , where φ1 is X h = X 0 + hX 0 Y h , Y h = (1 -γh)Y 0 -1-b 2 h X ⊤ 0 ∂f ∂X0 -∂f ∂X0 ⊤ X 0 , V h = V 0 , Algorithm 1: Momentum (S)GD on St(n, m) (SGD: ∂f /∂X replaced by a stochastic estimator) Hyperparameter :η ∈ (0, +∞), µ ∈ [0, 1), maximum number of iterations N Initialization :X 0 , U 0 , Z 0 s.t. X ⊤ 0 X 0 = I, X ⊤ 0 U 0 = 0, Z 0 + Z ⊤ 0 = 0 for i = 0, • • • , N -1 do Compute 'gradients': f i = 1-b 2 X ⊤ i ∂f ∂X (X i ) -∂f ∂X ⊤ (X i )X i ; g i = (I -X i X ⊤ i ) ∂f ∂X (X i ) Update φ2 : U i+ 1 2 = µU i -3a-2 2 ηU i Z i -g i Update φ1 : Z i+1 = µZ i -f i ; X i+ 1 2 = X i + ηX i Z i Update φ3 : X † = X i+ 1 2 + ηU i+ 1 2 X ⊤ i+ 1 2 X i+ 1 2 ; Compute (X ⊤ † X † ) -1 2 using Algo. 3; X i+1 = X † (X ⊤ † X † ) -1 2 ; U i+1 = U i+ 1 2 -ηX i+ 1 2 U ⊤ i+ 1 2 U i+ 1 2 end return X N

2.3. AN ADAPTIVE LEARNING RATE VERSION

Tuning for the best SGD learning rate can be labor intensive and computationally unafforable, and sometimes SGD even performs worse than adaptive methods (Zhang et al., 2020) . Thus in this section, we propose an adaptive version of our Stiefel optimizer. More precisely, we will establish, as an example, a Stiefel version of Adam (Kingma and Ba, 2015), which estimates the 1st and 2nd moments of gradients to obtain element-wise adaptive step sizes. The algorithm is established via the following ideas. The 'gradient' in this Adam-version method is constructed from our Stiefel SGD with momentum (Alg.1) where the 'gradients' in Y /Z and V /U direction can be interpreted as 1-b 2 X ⊤ ∂f ∂X -∂f ∂X ⊤ X and (I -X(X ⊤ X) -1 X ⊤ ) ∂f ∂X respectively. The main difficulty of extending Stiefel SGD to Stiefel Adam that does not appear in Euclidean case is that element-wise operation on momentum loses tangent vector structure. We solve this respectively: (1) For Y /Z-direction, the skew-symmetry is preserved after a symmetric element-wise operation Z ⊘ (p • 1 2 + ϵ). (2) For V /U -direction, we apply a projection I -X(X ⊤ X) -1 X ⊤ to the element-wisely rescaled momentum U ⊘ (q • 1 2 + ϵ), making sure 'X ⊤ V = 0'. Combining all the above, we obtain the Adam-Stiefel optimizer. Denote φ1 , φ2 , φ3 to be the modification of ϕ 1 , ϕ 2 , ϕ 3 in the Adam version (see detailed expressions in Apdx. M.5). Then the integrator is defined as φ3 • φ1 • φ2 . The overall method is shown in Algo. 2. Theorem 7. The Adam-Stiefel φ3 • φ1 • φ2 is structure-preserving. Proof is in Apdx. N.

3. EXPERIMENTS

This section demonstrates our Stiefel optimizers on Projection Robust Wasserstein Distance (Paty and Cuturi, 2019; Lin et al., 2020) and trained-from-scratch Vision Transformer (Dosovitskiy et al., 2020) . They will also be compared with other popular Stiefel optimizers summarized in Tab. 2. Canonical metric (i.e. a = 1/2 in Eq.2) is used in both examples to show that the gained performance is due to algorithmic innovation but not an extra tuned knob. An additional experiment on leading eigenvalue problem is deferred to Apdx. Q, where we also compare the convergence rates and time consumptions of different algorithms, test various choices of metrics, and study why our retraction is tailored to our algorithm. Good performance is observed in all three examples. Algorithm 2: Adam on Stiefel manifold St(n, m) Hyperparameter :η ∈ (0, +∞), β 1 ∈ [0, 1), β 2 ∈ [0, 1), 0 < ϵ ≪ 1, number of iterations N Initialization :X 0 , V 0 , Y 0 , p 0 , q 0 s.t. X ⊤ 0 X 0 = I, X ⊤ 0 U 0 = 0, Z 0 + Z ⊤ 0 = 0, p 0 = p ⊤ 0 for i = 0, • • • , N -1 do Compute 'gradients': More experimental details are in Apdx. P. f i = 1-b 2 X ⊤ i ∂f ∂X (X i ) -∂f ∂X (X i ) ⊤ X i ; g i = (I -X i X ⊤ i ) ∂f ∂X (X i ) 2nd-moment estimation: p i+1 = β 2 p i + (1 -β 2 )f •2 i ; q i+1 = β 2 q i + (1 -β 2 )g •2 i Update φ2 : U i+ 1 2 = β 1 U i -3a-2 2 ηU i Z i -(1 -β 1 )g i Update φ1 : Z i+1 = β 1 Z i -(1 -β 1 )f i ; X i+ 1 2 = X i + η 1 -β i+1 2 X i Z i+1 ⊘ (p • 1 2 i+1 + ϵ) Update φ3 : Ũ = 1 -β i+1 2 (I -X i+ 1 2 (X ⊤ i+ 1 2 X i+ 1 2 ) -1 X ⊤ i+ 1 2 )(U i+ 1 2 ⊘ (q • 1 2 i+1 + ϵ)); X † = X i+ 1 2 + η Ũ X ⊤ i+ 1 2 X i+ 1 2 ; X i+1 = X † (X ⊤ † X † ) -1 2 ; U i+1 = U i+ 1 2 -ηX i+ 1 2 Ũ ⊤ U i+ 1 2 end return X N

3.1. PROJECTION ROBUST WASSERSTEIN DISTANCE

Projection Robust Wasserstein Distance (PRW) (Paty and Cuturi, 2019; Lin et al., 2020) is a notion recently proposed to improve the robustness of standard Wasserstein metric, especially in highdim settings. The idea is to simultaneously look for a best projection from high-dim data (x i , y j , respectively with weights r i , c j ) to low-dim ones, and an entropic regularized optimal transport plan between projected data, i.e. max U ∈St(d,k) min π∈R n×n + , j πij =ri, i πij =cj n i=1 n j=1 π i,j ∥U ⊤ x i -U ⊤ y j ∥ 2 + η⟨π, log(π) -1 n 1 ⊤ n ⟩ This problem is geodesically-nonconvex w.r. We test these two types of constraints on vanilla ViT (Dosovitskiy et al., 2020) , trained from scratch. Results summarized in Tab. 1 (and Fig. 2 in Apdx. P.3) show: (1) requiring each head to be orthogonal just within itself leads to the most significant improvement of performance; (2) additionally requiring heads to be orthogonal to each other will actually be less effective, although that is still better than requiring no orthogonality; (3) the choice of optimizer matters: our methods give the best results and out-perform non-constrained baselines in all cases, but not all existing Stiefel optimizers can beat non-constrained baselines in all cases; (4) momentum significantly helps train orthogonal ViT; (5) simply imposing orthogonal constraint (with our optimizer) makes ViT outperform some carefully designed models of the same size that are trained from scratch (Tab. A3 in Zhang et al. (2022) ). Note models trained from scratch are very different from pre-trained ViT. To have some realistic expectation of the performance of trained-from-scratch ViT, recall, for example, an improved ViT, PVT-T (Wang et al., 2021) , has 9.49% and 30.38% error on CIFAR 10 and CIFAR 100, using 12.8M parameters, and another improvement, DeiT-T (Touvron et al., 2021) , achieves 11.62% and 32.48% respectively with 5.3M parameters. Our model in Tab.1 (and Fig. 2 in Apdx. P.3) uses 6.3M parameters. More details about this experiment can be found in Apdx. P.3. Remark 1 (No additional hyperparameters needed for the Stiefel part). Our Stiefel-ViT model has both non-Euclidean and Euclidean parameters, the latter due to the reason that we don't impose orthogonality on feedforward layers and W V i . Our optimizer can use the same learning rate for both sets of parameters. This is because the optimizer is the time discretization of ODEs, and synchronous discretization is sufficient for convergence. This contrasts with other approaches such as projected Steifel SGD/Adam (Li et al., 2020) , where learning rates for these two sets of parameters need to be adjusted separately and could differ by 40 times in some tasks. See Apdx. P.1 for more discussion. head i = Attention(QW Q i , KW K i , V W V i ), Attention( Q, K, Ṽ ) = softmax( Q K⊤ √ d k ) Ṽ , matrices W Q i ∈ R dmodel×d k , W K i ∈ R dmodel×d k , W V i ∈ R dmodel×dv and W O ∈ R nheaddv×dmodel correspond

Published as a conference paper at ICLR 2023

Notation For simplicity, we denote G := ∂f ∂X when there is no confusion. A If we require exact Stiefel manifold preservation for all time, i.e., X(t) ⊤ X(t) = I, a simple time differentiation gives a 2nd-order constraint X(t) ⊤ Q(t) + Q(t) ⊤ X(t) = 0 for Q = Ẋ. This is why momentum Q has the additional structure Q(t) ∈ T * X(t) St.

B.2 WHAT EXACTLY IS MOMENTUM?

Through this paper, we abused notation and called Q := Ẋ momentum, following the common practice of the community. It really should be called velocity instead, because although in canonical Euclidean spaces this doesn't make much a difference, for mechanical systems on manifolds they are not the same thing. For example, for our case, the geometrically intrinsic momentum variable should be given by Legendre transform P := ∂L ∂ Ẋ = r(I -aXX ⊤ ) Ẋ instead, and velocity Q is in the tangent space while momentum P is in the cotangent space. In the paper we used the identity map for an isomorphism between the tangent and cotangent spaces, but if we'd like to make our variational formulation more elegant by viewing the kinetic energy as a natural pairing between the velocity and momentum variables, then the isomorphism should be given by the metric as in (15). None of these affect the correctness or the efficacy of results in this paper. This discussion is just about terminology. For the inner loop of the square root of matrix inversion (Algo. 3), due to its quadratic convergence (Higham, 1997) , it takes O(log log(1/u)) number of steps to achieve the machine precision u. Note early stopping can be applied to this loop so that the complexity can be further reduced while the order of the overall method (1st-order) is still maintained. The combination of all the above gives complexity O(nm 2 ) + O(m 3 log log(1/u))foot_5 . Commonly used optimizers on Stiefel manifold (e.g. Tab. 2) also have per-iteration complexity of O(nm 2 ). Although it is difficult to compare them with our method in terms of the constant prefactors in respective complexity bounds, we can do some heuristics to estimate these prefactors, which will demonstrate the low-cost advantage of our method. For our Stiefel SGD (Algo. 1), we count the numbers of matrix multiplications needed per iteration, each of which costs nm 2 . Namely, step 2: 2 multiplications; step 3: 1; step 4: 1; step 5: 6. In step 6 we need to call Algo. 3 for computing matrix root inversion, and a closer look gives 3 m 3 -cost matrix multiplications in each iteration of Algo. 3. Since Fig. 5 (c) shows that 8 inner iteration is enough for this inner loop to converge under double precision, in total, 10 nm 2 -cost matrix multiplications and 24 m 3 -cost matrix multiplications are needed in each (outer) iteration of our optimizer (Algo. 1). We also count the number of matrix multiplications needed for other aforementioned algorithms. For Momentumless Stiefel (S)GD (Wen and Yin, 2013) , the smartly designed Cayley map retraction takes about 10 times of nm 2 -cost matrix multiplications and one inversion for a (2m) × (2m)-sized matrix in each iteration. In comparison, Projected Stiefel SGD Li et al. (2020) needs 9 nm 2 -cost matrix multiplications for projecting momentum and 6 nm 2 -cost matrix multiplications each iteration of Cayley loop. Since 8 iterations are needed to compute the Cayley transform to single precision (see fig. 5 ; note it is just single precision instead of double precision by our inner loop), a total of 57 nm 2 -cost matrix multiplications are needed per outer iteration. The above estimation is just for the part on maintaining the structure of the position variable. Let's now also discuss the momentum variable. Our carefully designed way of introducing momentum bypasses the costly moving momentum while keeping it in tangent spaces. Note Projected Stiefel SGD first moves momentum in the Euclidean space and then uses a cleverly designed projection, which markedly improves computational efficiency, but it still devotes 9 matrix multiplications to just the momentum projections, let alone the fact that needing projection means less accuracy. To add to this discussion, parallel transport is another existing way of moving momentum, but it cannot even be solved with O(nm 2 ) complexity. Altogether, the above estimations heuristically show that our prefactor is close to that in Momentumless Stiefel (S)GD but much smaller than those in existing methods that have momentum. This matches the experimental evidence in Fig. 3 .

D TERMINOLOGIES OF RIEMANNIAN OPTIMIZATION

In each iteration of optimization, gradient needs to be a vector in tangent space, and the operation that maps nonintrinsic gradient to the tangent space is called projection. Then, the variable goes one step on manifold given the tangent vector, ideally via the exponential map, which is, however, expensive or even not admitting a closed form computation, and approximation known as retraction is used. When momentum, a tangent vector, is involved, at the same time we update the position, tangent space is also changed. An intrinsic way to move tangent vectors is by parallel transport, which is an isomorphism of the tangent spaces of two points connected by a geodesic. E ABOUT THE Y, V DECOMPOSITION E.1 THE GEOMETRIC VERSION AND THE DYNAMICAL VERSION OF THEIR STRUCTURAL

CONSTRAINTS

A tangent vector Q ∈ T X St can be decomposed, as studied in Edelman et al. (1998) ; Absil et al. (2009) , into Q = X X ⊥ Y Z = XY + X ⊥ Z, ( ) where X ∈ St(n, m), Y ∈ R m×m , Z ∈ R (n-m)×m ; X ⊥ ∈ R n×(n-m) is a matrix in the orthogonal complement of X, i.e., X ⊤ • X ⊥ = 0 m×(n-m) . Let V = X ⊥ Z. We have X ⊥ V and Y =X ⊤ Q V =Q -XY = (I -XX ⊤ )Q. ( ) The above representation immediately implies Y is a skew-symmetric matrix since X ⊤ Q + Q ⊤ X = 0 ⇒ Y ⊤ = -Y. In our derivation, however, X and Q are variables (they really should be X(t) ∈ St, Q(t) ∈ T X(t) St) that change with time, and this poses a new challenge. More precisely, given the Q dynamics, i.e. Q, there could be infinitely many ways of assigning Ẏ and V so that they together produce the given Q; however, in general they do not maintain the tangent decomposition structures. More precisely, starting with Y (0) ⊤ + Y (0) = 0 and X(0) ⊤ V (0) = 0 (and hence X(0) ⊤ Q(0) + Q(0) ⊤ X(0) = 0), one may not have Y (t) ⊤ + Y (t) = 0 and X(t) ⊤ V (t) = 0 for t > 0 despite that X(t) ⊤ Q(t) + Q(t) ⊤ X(t) = 0 will still be guaranteed. However, we have found a nontrivial choice of Ẏ and V (Eq.8), so that the (static) geometric constraint becomes dynamically true, i.e. Y (t) ⊤ + Y (t) = 0 and X(t) ⊤ V (t) = 0 for all t. Therefore we can simply maintain the decomposition Y (t) =X(t) ⊤ Q(t) V (t) =Q -XY = (I -X(t)X(t) ⊤ )Q(t).

E.2 THE ADVANTAGE OF THE Y, V DECOMPOSITION

The Y, V decomposition is near 'orthogonal' in the sense that XY ⊥ V and it makes the metric separable. More precisely, consider our Y, V representation. ∀∆ i ∈ T X St, i = 1, 2, if we denote Y ∆,i = X ⊤ ∆ i and V ∆,2 = (I -XX ⊤ )∆ i , which means ∆ i = XY ∆,i + V ∆,i , then the metric (2) can be rewritten as g X (∆ 1 , ∆ 2 ) =Tr((XY ∆,1 + V ∆,1 ) ⊤ (I -aXX ⊤ )(XY ∆,2 + V ∆,2 )) =(1 -a)Tr(Y ⊤ ∆,1 Y ∆,2 ) + Tr(V ⊤ ∆,1 V ∆,2 ). This means the metric of the tangent space is a linear combination of the metrics of Y and V -directions. Therefore, this gives an intuition why our continuous dynamics and numerical discretization in the following can handle the family of canonical-type metrics.

F CHOICE OF γ

Although monotonicity of r suffices for the convergence, choices of γ = ṙ/r affect the convergence speed just like in Euclidean cases (e.g., Wibisono et al. (2016) ; Wilson et al. ( 2021)). Popular choices include constant γ and γ(t) = 3/t, but methods proposed here work for all γ's (e.g., γ = 3/t + ct p for variance reduction Tao and Ohsawa (2020) ). However, for simplicity we will use constant γ from now on.

G PROOF OF THEOREM 1

Proof of Theorem 1, Part 1. Given the Lagrangian L(X, Ẋ, Λ, t) = r(t) 1 2 Tr Ẋ⊤ (I -aXX ⊤ ) Ẋ -f (X) - 1 2 Tr Λ ⊤ (X ⊤ X -I) , Legendre transform gives momentum P as P := ∂L ∂ Ẋ = r(I -aXX ⊤ ) Ẋ. One can equivalently switch the Hamiltonian picture, and the corresponding Hamiltonian is H : T * St → R with H(X, P ) := T r(P ⊤ Ẋ) -L H(X, P ) = 1 2r Tr P ⊤ (I -bXX ⊤ )P + rf (X) + 1 2 Tr Λ ⊤ (X ⊤ X -I m×m ) where b := a a-1 solves (I -aXX ⊤ )(I -bXX ⊤ ) = I. Hence we get the following Hamilton's equations      Ẋ = ∂H ∂P = 1 r (I -bXX ⊤ )P Ṗ = - ∂H ∂X = b 2r P P ⊤ X + XP ⊤ P -r ∂f ∂X -XΛ Since X ⊤ X ≡ I, which gives Ẋ⊤ X + X ⊤ Ẋ = 0. So we have X ⊤ P + P ⊤ X = 0. Take derivative to get Ẋ⊤ P + X ⊤ Ṗ + Ṗ ⊤ X + P ⊤ Ẋ = 0, and use the condition that Λ is symmetric, we can solve Λ is Λ = b + 2 2r P ⊤ P - b 2r X ⊤ P P ⊤ X - r 2 (X ⊤ G + G ⊤ X) And (X, P ) system becomes      Ẋ = ∂H ∂P = 1 r (I -bXX ⊤ )P Ṗ = - ∂H ∂X = b 2r P P ⊤ X - 1 r XP ⊤ P + b 2r XX ⊤ P P ⊤ X -rG + r 2 (XX ⊤ G + XG ⊤ X) (23) By a coordinate change Q := 1 r (I -bXX ⊤ )P ∈ T X St and define friction parameter γ := ṙ/r, Eq. ( 23) becomes Q = - ṙ r 2 (I -bXX ⊤ )P + 1 r (I -bXX ⊤ ) Ṗ - b r ẊX ⊤ P - b r X Ẋ⊤ P = -γQ -XQ ⊤ Q - 3a 2 (I -XX ⊤ )QQ ⊤ X -G + 1 + b 2 XX ⊤ G + 1 -b 2 XG ⊤ X    Ẋ =Q Q = -γQ -XQ ⊤ Q - 3a 2 (I -XX ⊤ )QQ ⊤ X -G + 1 + b 2 XX ⊤ G + 1 -b 2 XG ⊤ X In order to prove the second part of Thm. 1, we need the following lemma. Lemma 1 (First-order stationary point). If X ∈ St s.t. G -1+b 2 XX ⊤ G -1-b 2 XG ⊤ X = 0, then ∀∆ ∈ T X St, we have Tr(G ⊤ ∆) = 0, which means X is a first-order stationary point of f . Proof. Left multiply both side of G -1+b 2 XX ⊤ G -1-b 2 XG ⊤ X = 0 by XX ⊤ , we have XX ⊤ G = XG ⊤ X, further we also have G = XG ⊤ X.

So we have

Tr(G ⊤ ∆) =Tr(X ⊤ GX ⊤ ∆) = -Tr(X ⊤ G∆ ⊤ X) = -Tr(∆ ⊤ XX ⊤ G) = -Tr(∆ ⊤ G) Thus we have Tr(∆ ⊤ G) = 0, ∀∆ ∈ T X St, which means X is a 1-order stationary point of f . Proof of Theorem 1, Part 2. Let t → (X(t), Q(t) ) be a solution of Eq. ( 7). Define the 'energy' function E : T St → R as E(X, Q) := 1 2 Tr(Q ⊤ (I -aXX ⊤ )Q) + f (X) This gives a Lyapunov function. So we have a neighbourhood U of (X * , 0) such that E(X, Q) ≥ f (X) ≥ f (X * ) for any (X, Q) ∈ U . More over, since X ⊤ Q + Q ⊤ X ≡ 0, we have dE dt (X(t), Q(t)) =Tr(Q ⊤ (I -aXX ⊤ ) Q) -aTr(Q ⊤ ( ẊX ⊤ )Q) + Tr ∂f ∂X ⊤ Ẋ Using the fact that X ⊤ Q + Q ⊤ X ≡ 0, we have Tr(Q ⊤ XQ ⊤ Q) = 0 and Tr(Q ⊤ XX ⊤ G) + Tr(Q ⊤ XG ⊤ X) = 0, which gives dE dt (X(t), Q(t)) = -γTr(Q ⊤ (I -aXX ⊤ )Q) ≤ 0 Since we have r monotonely increasing, which means γ = r ′ /r > 0, ∀t. Then we have the energy is decreasing monotonically, implying that Q = 0 when converged. By lemma 1 that the limiting point for X(t) is a first order stationary point, which is X * since it is an isolated local minimum. Remark 2. For the sake of length, rate of convergence in specific situations (e.g., under geodesic convexity) will not be quantified, but it should be obtainable via tools in Wilson et al. ( 2021); Duruisseaux and Leok (2021) .

H PROOF OF THEOREM 2

Proof. We can derive a new system of ODEs as              d dt (X ⊤ X) =X ⊤ Q + Q ⊤ X d dt (X ⊤ Q) = -γX ⊤ Q -(I -X ⊤ X)Q ⊤ Q - 3a 2 X ⊤ (I -XX ⊤ )QQ ⊤ X -X ⊤ G + 1 + b 2 X ⊤ XX ⊤ G + 1 -b 2 X ⊤ XG ⊤ X with viewing X ⊤ G as a matrix function of t following Eq. ( 7). By the uniqueness and existence of ODE, we have that this system of ODE with variable (X ⊤ X, X ⊤ Q) and initial condition (I, 0) has the unique solution X ⊤ X ≡ I, X ⊤ Q ≡ 0.

L PROOF OF THEOREM 5

Proof. The composition of structure-preserving maps is structure preserving. ϕ 1 •ϕ 2 •ϕ 3 is a 1st-order integrator due to operator splitting theory McLachlan and Quispel (2002) , and its convergence order is kept after any ϕ j gets replaced by φj as long as the difference between φj and ϕ j is higher-order Tao (2016) . M DISCUSSIONS ON OUR NUMERICAL INTEGRATOR M.1 BENEFITS OF THE STEP X † (X ⊤ † X † ) -1 2 To remain the position X on the manifold, certain techniques that pull point back to manifold are used. For example, in Li et al. (2020) ; Lin et al. (2020) for an n-by-m matrix X that is not on the manifold, they perform QR decomposition X = QR and update the new X to be the first m columns of Q. However, the QR decomposition is not unique and such step does not have a closed form expression. More importantly, it cannot help design a structure-preserving scheme. Instead, our algorithm has a similarly functioned step via the square root of the inverse matrix X h = X † (X ⊤ † X † ) -1/2 . In this case, it follows that X ⊤ h X h = I. This step is carefully designed from discretization of ODE ensuring the structure of X, Y , V is preserved at the same time. Through this update, we are able to obtain a structure-preserving method with a closed form expression (see the proof of Theorem 4 in Apdx. K). Note that the square root of the inverse matrix can be iteratively solved with low computational cost (see Algo. 3; Higham (1997) ). It was proved to be quadratically convergent, which means only a couple of iterations are needed to reach machine precision, and only matrix multiplications are needed. This lead to our inner loop is more efficient in both cost per iteration and convergence speed comparing to Li et al. (2020) . See Sec. Q for more details.  Y k+1 = 1 2 Y k (3I -Z k Y k ) Z k+1 = 1 2 (3I -Z k Y k )Z k k ← k + 1 end return Y k ≈ A 1 2 and Z k ≈ A -1 2 Another benefit of this step is that it is more stable to the truncation error produced by finite machine precision (which can become a significant issue if the model is trained in single accuracy on consumer graphic cards or even quantized). Although X ⊤ † X † follows the expression X ⊤ † X † =(X 0 + hV 0 ) ⊤ (X 0 + hV 0 ) =(X ⊤ 0 X 0 + hX ⊤ 0 V 0 + hV ⊤ 0 X 0 + h 2 V ⊤ 0 V 0 ) =I + h 2 V ⊤ 0 V 0 , the stability of X † (X ⊤ † X † ) -1/2 is much better than X † (I + h 2 V ⊤ 0 V 0 ) -1/2 . Particularly, when n = m, it is an identical map if there is no machine error.

M.2 SIMPLIFICATION OF MATRIX EXPONENTIAL

There are two steps in the discretizations (30) and ( 12) that use the matrix exponential of m × m matrices. Similar matrix exponential is also shown in Li et al. (2020) ; Wen and Yin (2013) but with matrices of much larger size n × n (note n > m). We introduce two ways of simplifying the computation of the matrix exponential. First, note that Cayley transform is a 2nd-order structurepreserving approximation of the matrix exponential, defined as follows Cayley(hY ) = (I -hY /2) -1 (I + hY /2) = exp(hY ) + O(h 3 ). (32) By applying the Cayley transform, the computation is reduced to an inversion of matrix and matrix multiplication while it still keeps the variable on the manifold. Additionally, we can use the first-order forward Euler to discretize the ODEs for the two steps, which is also the first-order truncation of the matrix exponential. In this case, the X update in scheme ( 12) is just X h = X 0 + hY h and is no longer structure-preserving but structure-preserving property of the overall algorithm is not affected (see Thm. 6). For scheme (30), its Cayley map approximation is structure preserving and defined as follows V h = V 0 Cayley(-γhI + 3a -2 2 hV 0 Y 0 ) - 1 -exp(-γh) γ (I -X 0 X ⊤ 0 ) ∂f ∂X 0 . ( ) Note the variable V will still satisfy the constraint (28) even if we change the update of V to forward Euler V h = V 0 -γhV 0 + 3a -2 2 hV 0 Y 0 - 1 -exp(-γh) γ (I -X 0 X ⊤ 0 ) ∂f ∂X 0 . ( ) In fact, the above two ways and the original matrix exponential share the same complexity per iteration. Also, we do not find any significant difference in the convergence in numerical experiments (Apdx. Q). Hence we will always use forward Euler as default, which has the lowest computational cost.

M.3 PROOF OF THEOREM 6

Proof. The composition of φ1 , φ2 , φ3 in any order will give a structure preserving scheme, because in Sec. K we proved that each of them is structure preserving. When we apply it in specific order φ3 • φ1 • φ2 , we prove when X ⊤ X = I is no longer preserved by φ2 , the assembled scheme is still structure preserving. We only need to prove that [X h , Y h , V h ] = φ3 (x 0 , Y 0 , V 0 ) satisfies the following: when initial condition satisfies Tao and Ohsawa (2020) ) such that the momentum Y can be recognized as an element in T I SO(n), the tangent space at the identity. In the latter interpretation, the momentum will always stay in the same tangent space T I SO(n) which passes on a message that trivialization almost reproduce the convenience in Euclidean space for SO(n). Hence unlike the general Adam-Stiefel optimizer (Algo. 2), no extra technique, for example projection, is needed for Adam-SO(n). Detailed iterations are shown in Algo. 5. Note this is also a structure preserving scheme (see Thm. 7). Y ⊤ 0 + Y 0 = 0 and X ⊤ 0 V 0 = 0, we have X ⊤ h X h = I, Y ⊤ h + Y h = 0 and X ⊤ h V h = 0. Since we have the step X h = X † (X ⊤ † X † ) -1/2 , we have X ⊤ h X h = I. Y is not updating so Y h + Y ⊤ h = 0. X ⊤ † V h = X ⊤ 0 V 0 -hX ⊤ 0 X 0 V ⊤ 0 V 0 + hX ⊤ 0 X 0 V ⊤ 0 V 0 -h 2 X ⊤ 0 X 0 V ⊤ 0 X 0 V ⊤ 0 V 0 = 0 As a result, X ⊤ h V h = (X ⊤ † X † ) -1/2 X ⊤ † V h = 0 M.4 THE INTEGRATOR IN RESCALED COORDINATES USED IN ALGORITHM 1 φ1 :            X η =X 0 + ηZ h Z η =µZ 0 - 1 -b 2 X ⊤ 0 ∂f ∂X 0 - ∂f ∂X 0 ⊤ X 0 U η =U 0 φ2 :                X η =X 0 Z η =Z 0 U η =µU 0 + 3a -2 2 ηU 0 Z 0 -I -X 0 X ⊤ 0 ∂f ∂X 0 , φ3 :            X † =X 0 + ηU 0 X ⊤ 0 X 0 X η =X † (X ⊤ † X † ) -1/2 Z η =Z 0 U η =U 0 -ηX 0 U ⊤ 0 U 0 (see Algorithm 5: Adam on SO(n) Hyperparameter :η ∈ (0, +∞), β 1 ∈ [0, 1), β 2 ∈ [0, 1), 0 < ϵ ≪ 1, maximum number of iteration N Initialization :X 0 , V 0 , Y 0 s.t. X ⊤ 0 X 0 = I, X ⊤ 0 V 0 = 0, Y 0 + Y ⊤ 0 = 0, p 0 = 0, q 0 = 0 for i = 0, ..., N -1 do f i = 1-b 2 X ⊤ i ∂f ∂X (X i ) -∂f ∂X (X i ) ⊤ X i p i+1 = β 2 p i + (1 -β 2 )f •2 i Y i+1 = β 1 Y i -(1 -β 1 )f i X † = X i expm η 1 -β i+1 2 Y i+1 ⊘ (p •-1 2 i+1 + ϵ) X i+1 = X † (X ⊤ † X † ) -1 2 . end return X N P EXPERIMENTAL DETAILS Note: codes are provided in supplementary materials. Experiments are conducted on a high-performance computing cluster whose name shall be revealed post anonymous period. Single v100 GPU was used.

P.1 GENERAL DISCUSSION

Initialization For variable X with the initial X 0 , in fact, any initialization with full rank is valid, including the non-orthogonal X 0 such that X 0 does not start on the manifold. This is due to the structure-preserving property of our algorithm (both SGD and Adam versions, see Apdx. M). After just one iteration, X will stay on the manifold. Here we suggest one way of obtaining orthogonal X 0 which is performed in Saxe et al. (2013) . After randomly generating an entry-wise i.i.d. normal distributed n-by-m matrix, we can perform the QR decomposition. Denote the orthogonal matrix from QR decomposition as Q and denote the diagonal matrix which is the sign of the diagonal elements of R as R. Then we can take X 0 = Q R and thus have X ⊤ 0 X 0 = I. For the variables U , Z in Stiefel SGD and U , Z, p, q in Stiefel Adam, we would use zero matrices as their initialization since zero matrices always satisfy these constraints (28). No need to tune learning rate separately for constrained and unconstrained parameters For problems with both constrained and unconstrained parameters (Stiefel and Euclidean spaces), for example, the ViT test in Sec. 3, existing literatures using projection, retraction, and intermediate matrices (see Sec. 1.1) requires separate adjustments of learning rates for the two types of parameters. As an illustration, in projected SGD and Adam on Stiefel manifold Li et al. (2020) , learning rate may vary 10s of times for constrained and unconstrained parameters. This can be intuitively understood as the result of applying different optimization methods for different parameters. However, using various learning rates may lead to the divergence of the algorithms. In contrast, our Stiefel method can be established for both the constrained and unconstrained parameters with the same learning rate. Due to the same derivation of numerical methods for these two types, the effort in tuning hyperparameters is reduced while there is still good performance. In practice, the learning rate need to be adjusted in Stiefel SGD and momentum can be chosen as 0.9 for most machine learning tasks. For Stiefel Adam, we recommend to use 10 -3 to be the learning rate and (β 1 , β 2 ) = (0.9, 0.999). Training All the training uses the same scheduler, which is to let the learning rate linearly increase from 0 to the target learning rate in 5 epochs, followed by cosine annealing learning rate Loshchilov and Hutter (2016) with no restart, and minimum learning rate as 0.01 times max learning rate. Label smoothing with parameter 0.1 is used. weight decay = 5e -5 unless specified. All the hyperparameters used for optimizers are listed in Tab. 4. 6 Adam-type methods are not included since they are not suitable for this problem. η orth = η non-orth = 0.15, µ = 0.9 Projected Stiefel SGD (Li et al., 2020) η orth = 1, η non-orth = 0.1, µ = 0.9 Regularizer SGD η = 0.1, µ = 0.9, penalty weight = 1e-6 Momentumless Stiefel SGD (Wen and Yin, 2013) η orth = 0.1. For non-orth parameters, η non-orth = 0.1, µ = 0.9 Stiefel Adam (Algo. 2) η orth = η non-orth = 0.001, β 1 = 0.9, β 2 = 0.999 Projected Stiefel Adam (Li et al., 2020) η orth = 0.001, η non-orth = 0.01, β 1 = 0.9, β 2 = 0.999 Regularizer Adam η = 0.001, β 1 = 0.9, β 2 = 0.999, penalty weight = 1e-6 SGD η orth = 0.1, µ = 0.9 Adam (Kingma and Ba, 2015) η = 0.001, β 1 = 0.9, β 2 = 0.999 AdamW (Loshchilov and Hutter, 2017) η = 0.001, β 1 = 0.9, β 2 = 0.999, weight decay = 0.2 The algorithms are performed on CPU with learning rate of each one adjusted to its best performance. The upper two figures show that our Stiefel SGD optimizer has the fastest convergence and best manifold perseverance. The lower two figures experimentally prove that we have the lowest O(nm 2 ) complexity per iteration, which matches our complexity analysis in Apdx. C and table 2. Notice that our algorithm and Momentumless Stiefel SGD have almost coinciding curves when m → ∞ in (c) and when n → ∞ in (d), while both algorithms have much lower curves than Projected Steifel SGD when m → ∞. This means that our algorithm has a similar constant factor as the complexity of Momentumless Stiefel (S)GD (Wen and Yin, 2013) , indicating that the introduction of momentum in our algorithm is almost 'free', and meanwhile, Projected Steifel SGD (Li et al., 2020) Though our special retraction on tangent bundle (polar retraction for position and no extra handling for momentum) is cheap, it must be used with our specially designed algorithm and cannot be applied to other algorithms even if the position will still stay on Stiefel manifold. The reason is that the loss of structure of momentum will lead to slow convergence. Fig 6 shows that projected Stiefel SGD with our retraction for tangent bundle convergences slower than not only our algorithm but also the original projected Stiefel SGD. (Li et al., 2020) . Though our design also helps the projected Stiefel SGD to preserve the manifold structure with only machine precision error, the convergence to minimum is slower than both our algorithm and original projected Stiefel SGD.



It helped computer vision applications prior to the deep learning era as well (e.g.,Liu et al., 2003). Their setting was deterministic, but an extension to use stochastic gradient descents is straightforward. The problem is harder when n > m, which thus will be our focus. For the n = m case, see Sec. O. It differs from the seminal work(Wen and Yin, 2013) as our setup (and the Lagrange multiplier) is dynamical. u stands for machine precision, which is a very small number The time recorded excludes the corresponding part of computing gradients. Both the algorithm in paperLi et al. (2020) and their code has O(n 2 m) computational complexity per iteration. However, this can be improved by changing the order of computing matrix production by associative law. Also, this O(n 2 m) is the complexity under a fixed number of iterations for approximating Cayley map according to their setup. It will be changed to O(nm 2 (1 + log(1/u))) if Cayley map is computed to machine precision.



mainly approximating the expm in φ1 by forward Euler. The result is concretely summarized in Algo. 1. In order to match commonly used notations in machine learning, we rescale the parameters: learning rate η := 1-exp(-γh) γ h, momentum parameter µ := exp(-γh), Z := Y / 1-exp(-γh) γ , and U := V / 1-exp(-γh)   γ

Figure1: Projection Robust Wasserstein Distance (PRW) tested on MNIST and Shakespeare plays. Data points are features extracted by a pre-trained model. The mean optimal transport value is taken among all digits or movie pairs; larger mean optimal transport value means more effective orthogonal projection. Our method makes PRW more effective by getting the best local minimum (largest optimal transport value) and fast convergence.

THE SPECIAL CASE OF St(n, n) AND ITS RELATION WITH O(n) AND SO(n) When n ≥ m, St(n, m) ∼ = O(n)/O(n -m); when n > m strictly, we also have St(n, m) ∼ = SO(n)/SO(n -m). However, in the special case of n = m, St(n, n) ∼ = O(n) is not connectedunlike the n > m cases. Since our optimizer is based on the discretization of continuous dynamics, it cannot make jumps and thus just optimizes on a connected component of St(n, m), which means for the special case of n = m, it is, to be precise, optimizing on SO(n) but not O(n), although a similar complication is nonexistent when n > m.C DETAILS ABOUT THE PER-ITERATION COMPLEXITY AND COMPUTATIONAL COST FOR OUR ALGORITHMSThe most costly operation in our algorithms is the n × m matrix multiplication (note n > m). The computation for the matrix exponential and square root of matrix inversion, is cautiously designed to only deal with matrices of dimension m × m (see Apdx. M.1M.2), and thus admits at most O(m 3 ) at each step (particularly, forward Euler only has the complexity of O(m 2 ) while Cayley map is O(m 3 )).

Algorithm for matrix root and matrix root inversion (Eq. (2.6) in Higham (1997)) Input: Symmetric m-by-m matrix A, tol Initialization :Y 0 = A, Z 0 = I m×m , k = 0 while ∥Y 2 k -A∥ ≥ tol do

Figure2: Test errors of Stiefel-ViT on CIFAR. We can see that our optimizer has the best test accuracy in both cases. Meanwhile, when our optimizer is used, orthogonality only within each head performs better than orthogonality across heads.

Figure3: Comparison of exact manifold preserving methods on leading eigenvalue problem. The algorithms are performed on CPU with learning rate of each one adjusted to its best performance. The upper two figures show that our Stiefel SGD optimizer has the fastest convergence and best manifold perseverance. The lower two figures experimentally prove that we have the lowest O(nm 2 ) complexity per iteration, which matches our complexity analysis in Apdx. C and table 2. Notice that our algorithm and Momentumless Stiefel SGD have almost coinciding curves when m → ∞ in (c) and when n → ∞ in (d), while both algorithms have much lower curves than Projected Steifel SGD when m → ∞. This means that our algorithm has a similar constant factor as the complexity of Momentumless Stiefel (S)GD(Wen and Yin, 2013), indicating that the introduction of momentum in our algorithm is almost 'free', and meanwhile, Projected Steifel SGD(Li et al., 2020) has a much larger prefactor. Please see Apdx. C for more discussions.7

Figure 4: Comparision of different canonical-type metrics (Euclidean and canonical metrics are tested) and different ways to approximate matrix exponential in Apdx. M.2.

Figure6: Convergence and manifold preservation when our low-cost retraction (polar retraction for position and no extra handling for momentum) is applied to projected Stiefel SGD(Li et al., 2020). Though our design also helps the projected Stiefel SGD to preserve the manifold structure with only machine precision error, the convergence to minimum is slower than both our algorithm and original projected Stiefel SGD.

t. the Stiefel variable U(Jordan et al., 2022), thus computationally extra-challenging to Stiefel optimizers.Lin et al. (2020) proposed an effective method based on alternations between a full Sinkhorn step(Cuturi, 2013), given the current projection, and an optimization step that improves the projection. In particular, they developed Riemmanian optimizer with projection and retraction (RGAS) and its adaptive learning rate version (RAGAS). We replace RGAS and RAGAS by our optimizer and others, and test on the hardest experiments inLin et al. (2020), namely MNIST and Shakespeare. Results are in Fig.1. Our method is observed to find the largest value of PRW and thus the best projection among the tested, which implies the best performance. Details of setup are in Apdx. P.2.3.2 HOW COULD ORTHOGONALITY IMPROVE VANILLA VISION TRANSFORMER (VIT)?This section explores the possibility of making self-attention in Transformer models(Vaswani et al., 2017) orthogonal. The intuition is, if we interpret each attention head as a characterization of interac-

to trainable parameters. The three input matrices Q, K and V all have dimension sequence_length × d model . d k and d v are usually smaller than d model .For orthogonality only within head, we require thatW Q i , W K i are in St(d model , d k ). This needs d model ≥ d k , which holds in most cases. For orthogonality across heads, we need d model ≥ n head d k , which is satisfied in many popular models, and require Concat(W Q = 1..., n head ) to be in St(d model , n head d k ), which means it contains not only 'orthogonality only within head', but also extra cross-head orthogonality.

A QUICK SUMMARY OF OPTIMIZERS WE EXPERIMENTALLY COMPARED TO A summary of pros and cons of existing optimizers.

Details for hyperparameters of optimizers in PRW test. η stands for learning rate, µ stands for momentum parameter and β stands for the parameter that adjusts stepsize automatically.How to choose between Adam-Stiefel and momentum SGD-Stiefel Generally, we recommend choosing momentum SGD-Stiefel for traditional optimization and small scale neural networks that hyperparameters can be carefully adjusted while for large scale neural networks, use Adam-Stiefel as default.P.2 DETAILS OF PRWFor MNIST experiment, we use a pretrained CNN to extract 128-dim. features of figures in MNIST. For Shakespeare's plays experiment, we compute the PRW distances between all pairs of items in a corpus of eight Shakespeare's operas embedded into 200-dim. space using word2vec. For each digit or play pair, we use the extracted features as point sets {x i } and {y i }. We choose the dimension of the target space of the projection to be k = 2 in both experiments. The mean optimal transport values are taken among all digits or movie pairs at the specific iteration. All the setting are same as the original paperLin et al. (2020) except that early termination is removed.6 The hyperparameters for each optimizer are listed in Tab. 3. The hyperparameters are carefully adjusted for each optimizer to obtain the best performance. For projected Stiefel SGD/AdamLi et al. (2020), we tune the learning rates for constrained and unconstrained parameters separately. For regularizer SGD/Adam, the regularizer scaling parameter is also carefully chosen. The largest learning rate for momentumless Stiefel SGDWen and Yin (2013) is applied but momentum methods still admit faster convergence. Additional experimental results can be seen in Fig.2.Model structure The images are cut into 4×4 patches.d model = d f eedf orward = 384, n head = 12. d q = d k = d v = 32 (which also equals d model /n head , so 'orthogonality across heads' constraint can be applied; see Section 3.2). The 'classification token' inDosovitskiy et al. (2020) is used. Total number of layer of our ViT model is 7.

Details for hyperparameters of optimizers in ViT training. η stands for learning rate, µ stands for momentum parameter and β1, β2 are parameters in Adam.

has a much larger prefactor. Please see Apdx. C for more discussions.7

(a)  shows that on Projected Stiefel SGD, performing more inner iterations that leads to a more accurate Cayley transform gives a better optimized function value; (b) shows that more frequently applying QR retractions, i.e., projecting the variable, provides better accuracy; (c) shows that more accurate matrix root inversion for polar retraction in our algorithm 4 leads to more accurate solution. We can see performing 8 iterations (note this curve coincides with 100 iterations) of matrix root inversion gives perfect accuracy for our method, compared to (a) where 16 iterations of Cayley transform only give an error ∼ 10 6 times bigger.

ACKNOWLEDGMENTS

We thank Tomoki Ohsawa for insightful discussion. We are grateful for partial support by NSF DMS-1847802 (LK, YW and MT), NSF ECCS-1936776 (MT), and Cullen-Peck Scholar Award (LK, YW and MT).

availability

https://github.

I XYV SYSTEM IS STRUCTURE PRESERVING

Theorem 8 (Constrained optimization with unconstrained dynamics). As long as the initial condition of (8) satisfies X(0) ⊤ X(0) = I m×m , Y (0) + Y (0) ⊤ = 0 m×m , X(0) ⊤ V (0) = 0 m×m , then the dynamics automatically satisfies the same constraint, i.e., X(t) ⊤ X(t) = I m×m , Y (t) + Y (t) ⊤ = 0 m×m , X(t) ⊤ V (t) = 0 m×m ,Proof. By the uniqueness of solution of ODE, we know that the following ODE for (X ⊤ X, Y ⊤ + Y, X ⊤ V ), derived from Eq. ( 8), has a unique solution if view V ⊤ V as an independent variable.We can see that given initial condition (X ⊤ 0 X 0 , Y ⊤ 0 + Y 0 , X ⊤ 0 V 0 ) = (I, 0, 0), the unique solution is (X ⊤ X, Y ⊤ + Y, X ⊤ V ) ≡ (I, 0, 0). So we know that the constraint in Eq. ( 28) are preserved by continuous dynamics Eq. ( 8).

J PROOF OF THEOREM 3

Proof. We assume the initial condition (X 0 , Y 0 , V 0 ) satisfies constraint (28).For Eq. ( 9), we check X ⊤ (t)X(t) ≡ 0, Y (t) + Y ⊤ (t) ≡ 0 and X ⊤ (t)V (t) ≡ 0.Using the conclusion Y + Y ⊤ ≡ 0 that we have just proved and initial condition X ⊤ 0 X 0 = I, we have X ⊤ X ≡ 0.For Eq. ( 10), we have the exact solution of ϕ 2 is given by X(t) = X(0), Y (t) = Y (0), and withFor Eq. ( 11), we need to check d dt X ⊤ (t)X(t) = 0 and d dt (X ⊤ (t)V (t)) = 0. By the uniqueness of solution of ODE, we know that the following ODE for (X ⊤ X, X ⊤ V, V ⊤ V ), derived from Eq. ( 11), has a unique solutionWe can see that given initial condition. So we know that the constraint X ⊤ X = I, X ⊤ V = 0 are preserved by continuous dynamics Eq. ( 11).

K PROOF OF THEOREM 4

Proof. Eq. ( 12) is structure preserving. We use the idea from Tao and Ohsawa (2020) . Assume the initial value X 0 , Y 0 and V 0 satisfies the constraint, i.e., X ⊤ 0 X 0 = I, Y 0 + Y ⊤ 0 = 0 and X ⊤ 0 V 0 = 0. For the first step updating Y , due to the special form of the derivative of Y that is always skewsymmetric, we can tell that Y h is also skew-symmetric and Y h + Y ⊤ h = 0. And the skew-symmetricity of Y h gives us thatso all 3 conditions in Eq. ( 28) are satisfied, meaning Eq. ( 12) is structure preserving.Eq. ( 14) is structure preserving.Given initial condition X ⊤ 0 X 0 = I andEq. ( 13) is structure preserving. First, we show that the discretization is a one order approximation of the exact solution. We get V h and X † by forward Euler. And since X 0 and V 0 satisfies the constraint 28, we have), indeed a one order approximation of exact solution.Next we show the numerical discretization is structure preserving. For constraint X ⊤ V = 0, we can check that'gradient' and 2-order momentum :Proof. We will use the same notation as in algorithm 2. AssumeOur initialization satisfies these conditions. We prove that these conditions are satisfied after mapped by φ3 • φ1 • φ2 in the following.Step 1: Due to the skew-symmetricity of f i , we have p i+1 is symmetric.Step 2:i U i = 0 and X ⊤ i g i = 0, we can checkStep 3:2 = I no longer stands now but it satisfied until the algorithm finishes.Step 4X i+ 1 2 is full rank, which is always true since X i is full rank and η is small.

Since we have X

So we proved all the constrain are satisfied again, which meansOur method for Stiefel optimization can naturally be applied on the special orthogonal group SO(n) (Apdx. B.3) and is defined as followsTao and Ohsawa (2020) proposed an efficient algorithm based on Lie group structure for the optimization on SO(n) although it cannot be generalized to Stiefel case. Our method restores the same integrator in Tao and Ohsawa (2020) while using different approach applied to a family of metrics.In greater detail, for SO(n), the canonical-type metric (2) degenerates to Euclidean metric up to constant scaling, i.e., g(∆which results in the following position-momentum (X, Q) dynamicswith the same initialization as ( 7). Next, apply the same Y, V decomposition of Q. Since I -XX ⊤ = 0, we have V = 0, i.e., the momentum Q is purely in Y -direction. Hence the equivalent dynamics to (37) is the followingNote when we take b = -1 (namely, a = 1/2), (38) is identical to the continuous dynamics in Tao and Ohsawa (2020) , although Tao and Ohsawa (2020) uses an intrinsic form which is coordinate-free but not explicit enough.The numerical integrator corresponding to (S)GD is defined the same as Section M.4 and is summarized in Alg.4. Note Step 5 is optional in the n = m case. It gives the identity map if no arithmetic error due to machine precision is incurred. However, this step is beneficial under low precision (see Apdx. M.1).Algorithm 4: Momentum SGD on SO(n) (generalization to Tao and Ohsawa (2020) )In addition, we also have an adaptive version which was absent in Tao and Ohsawa (2020):O.1 ADAM ON SO(n)In this section, we extend the above optimizer to an Adam version. Note it can be derived from either the special case of the Stiefel optimizer, or the structure of Lie algebra so(n), i.e., left-trivialization Method Hyperparameters Stiefel SGD (ours) η = 0.1, µ = 0.9 Stiefel Adam (ours) η = 0.001, β 1 = 0.9, β 2 = 0.999 Projected Stiefel SGDLi et al. 

Q ADDITIONAL EXPERIMENT: SYSTEMATIC TESTS ON THE LEADING EIGENVALUE PROBLEM

In this section, we consider the leading eigenvalue problem: given an n-by-n matrix A, find the m largest eigenvalues of A. This problem is at the core of many data sciences tasks, where n can be very large but m remains small. Due to its relative simplicity and the existence of exact solution, it's possible to implement a large amount of experiments in order to systematically investigate the convergence, manifold perseverance, and time consumption of various approaches. Particularly, this problem can be formulated as an optimization problem on St(n, m) as followsThe idea is one seeks an optimal m-dimensional subspace, represented by X via m orthonormal bases in R n corresponding to its m columns, so that eigenvalues projected onto this subspace sums up to a maximum value.Such formulation is not unique. For example, Tao and Ohsawa (2020) consider applying their momentum-accelerated SO(n) optimizer to arg min R∈SO(n) Tr(E ⊤ R ⊤ ARE), where the constant padding matrix E := [I m×m ; 0 (n-m)×m ]. However, their setting is computationally less efficient than our Stiefel simplification. In fact, our formulation will be particularly suitable to cases when m is small but n is very large.Our test uses a deterministic matrix A generated by A = (Ξ + Ξ ⊤ )/2/ √ n, where Ξ is an instance of n-by-n matrix with i.i.d. random normal elements. The exact solution is obtainable from matrix decomposition so that the error can be quantified. As is mentioned above, several aspects of our algorithm are examined, including convergence speed, manifold perseverance, and computational complexity in Fig. 3 , as well as other exact manifold-preserving methods. We can conclude that our (S)GD-Stiefel method has the fastest convergence rate and exact manifold perseverance with the lowest computational complexity.Additionally, in Fig. 4 , we test the influence of different canonical-type metric, i.e., different a in Def. 1, and the three different ways to 'compute' matrix exponential in Apdx. M.2, i.e., the matrix exponential itself, Cayley map, and forward Euler. No significant difference is found in leading eigenvalue test and as a result, the most commonly used canonical metric and the cheapest forward Euler are chosen as default in all the other experiments. The hyperparameter used for each algorithm are listed in Tab. 5We also experimentally prove that better manifold preservation leads to better convergence, for both our algorithm 1 and Projected Stiefel SGD (Li et al., 2020) in Fig 5 . What's more, this experiment shows our (inner) iterative solver is quadratically convergent (as opposed to the linearly convergent Cayley approximation in Li et al. (2020) ), and therefore a smaller number of iteration is needed to reach machine precision.

