MOMENTUM STIEFEL OPTIMIZER, WITH APPLICA-TIONS TO SUITABLY-ORTHOGONAL ATTENTION, AND OPTIMAL TRANSPORT

Abstract

The problem of optimization on Stiefel manifold, i.e., minimizing functions of (not necessarily square) matrices that satisfy orthogonality constraints, has been extensively studied. Yet, a new approach is proposed based on, for the first time, an interplay between thoughtfully designed continuous and discrete dynamics. It leads to a gradient-based optimizer with intrinsically added momentum. This method exactly preserves the manifold structure but does not require additional operation to keep momentum in the changing (co)tangent space, and thus has low computational cost and pleasant accuracy. Its generalization to adaptive learning rates is also demonstrated. Notable performances are observed in practical tasks. For instance, we found that placing orthogonal constraints on attention heads of trained-from-scratch Vision Transformer (Dosovitskiy et al., 2020) could markedly improve its performance, when our optimizer is used, and it is better that each head is made orthogonal within itself but not necessarily to other heads. This optimizer also makes the useful notion of Projection Robust Wasserstein Distance (Paty and Cuturi, 2019; Lin et al., 2020) for high-dim. optimal transport even more effective.

1. INTRODUCTION

Matrices that satisfy orthogonal constraints play important roles in various areas including machine learning. A range of studies showed this both theoretically (Saxe et al., 2013; Xiao et al., 2018 ) and experimentally -for example, orthogonality can boost the performances of many architecturesfoot_0 , e.g., MLP (Cisse et al., 2017) , CNN and ResNet (Li et al., 2020; Wang et al., 2020) , RNN (Arjovsky et al., 2016; Wisdom et al., 2016; Helfrich et al., 2018 ), Transformer (Zhang et al., 2021) . The training of such models amounts to optimization under orthogonal constraints, whose applications to machine learning, however, are not restricted to just improving deep learning models. For example, such optimization helps construct optimal low-dim. projection of high-dim. data, which can be used for, e.g., robust and efficient approximation of Wasserstein distance (Paty and Cuturi, 2019; Lin et al., 2020) . Therefore, this article will focus on this non-Euclidean / Riemannian optimization problem. In fact, matrices that contain orthogonal columns constitute a Riemannian manifold known as the Stiefel manifold. Given integers n ≥ m > 0, it is defined as St(n, m) := {X ∈ R n×m : X ⊤ X = I m×m } (see Apdx. B.3 for the n = m special case which is almost SO(n)). Then, given a C 1 function f : St(n, m) → R, we consider, under gradient oracle, the smooth optimization problem min X∈St(n,m) f (X), which is equivalent to min X∈R n×m ,X ⊤ X=Im×m f (X). (1) Due to the importance of this problem, a wide range of approaches have been proposed. Methods based on regularizers, which approximate the constraints (and hence the manifold structure), for example, have been popular in the machine learning literature (Cisse et al., 2017; Bansal et al., 2018; Wang et al., 2020; Zhang et al., 2021) . Meanwhile, efforts have been continuously made to construct algorithms that truly preserve the constraints (i.e. the manifold structure), although challenges such as computational scalability and how to add momentum still remain; see the 2nd next paragraph for how our method addresses them, and more discussions on existing milestones in Sec. 1.1. Our strategy toward a good Stiefel optimizer is the following: first, formulate a variational principle and use that to construct an optimizing dynamics in continuous time, which is described by a system



It helped computer vision applications prior to the deep learning era as well (e.g.,Liu et al., 2003).1

availability

https://github.

