EFFICIENT, STABLE, AND ANALYTIC DIFFERENTIA-TION OF THE SINKHORN LOSS

Abstract

Optimal transport and the Wasserstein distance have become indispensable building blocks of modern deep generative models, but their computational costs greatly prohibit their applications in statistical machine learning models. Recently, the Sinkhorn loss, as an approximation to the Wasserstein distance, has gained massive popularity, and much work has been done for its theoretical properties. To embed the Sinkhorn loss into gradient-based learning frameworks, efficient algorithms for both the forward and backward passes of the Sinkhorn loss are required. In this article, we first demonstrate issues of the widely-used Sinkhorn's algorithm, and show that the L-BFGS algorithm is a potentially better candidate for the forward pass. Then we derive an analytic form of the derivative of the Sinkhorn loss with respect to the input cost matrix, which results in an efficient backward algorithm. We rigorously analyze the convergence and stability properties of the advocated algorithms, and use various numerical experiments to validate the performance of the proposed methods.

1. INTRODUCTION

Optimal transport (OT, Villani, 2009 ) is a powerful tool to characterize the transformation of probability distributions, and has become an indispensable building block of generative modeling. At the core of OT is the Wasserstein distance, which measures the difference between two distributions. For example, the Wasserstein generative adversarial network (WGAN, Arjovsky et al., 2017) uses the 1-Wasserstein distance as the loss function to minimize the difference between the data distribution and the model distribution, and a huge number of related works emerge afterwards. Despite the various appealing theoretical properties, one major barrier for the wide applications of OT is the difficulty in computing the Wasserstein distance. For two discrete distributions, OT solves a linear programming problem of nm variables, where n and m are the number of Diracs that define the two distributions. Assuming n = m, standard linear programming solvers for OT have a complexity of O(n 3 log n) (Pele & Werman, 2009) , which quickly becomes formidable as n gets large, except for some special cases (Peyré et al., 2019) . To resolve this issue, many approximate solutions to OT have been proposed, among which the Sinkhorn loss has gained massive popularity (Cuturi, 2013) . The Sinkhorn loss can be viewed as an entropic-regularized Wasserstein distance, which adds a smooth penalty term to the original objective function of OT. The Sinkhorn loss is attractive as its optimization problem can be efficiently solved, at least in exact arithmetics, via Sinkhorn's algorithm (Sinkhorn, 1964; Sinkhorn & Knopp, 1967) , which merely involves matrix-vector multiplications and some minor operations. Therefore, it is especially suited to modern computing hardware such as the graphics processing units (GPUs). Recent theoretical results show that Sinkhorn's algorithm has a computational complexity of O(n 2 ε -2 ) to output an ε-approximation to the unregularized OT (Dvurechensky et al., 2018) . Many existing works on the Sinkhorn loss focus on its theoretical properties, for example Mena & Niles-Weed (2019) and Genevay et al. (2019) . In this article, we are mostly concerned with the computational aspect. Since modern deep generative models mostly rely on the gradient-based learning framework, it is crucial to use the Sinkhorn loss with differentiation support. One simple and natural method to enable Sinkhorn loss in back-propagation is to unroll Sinkhorn's algorithm, adding every iteration to the auto-differentiation computing graph (Genevay et al., 2018; Cuturi et al., 2019) . However, this approach is typically costly when the number of iterations are large. Instead, in this article we have derived an analytic expression for the derivative of Sinkhorn loss based on quantities computed from the forward pass, which greatly simplifies the back-propagation of the Sinkhorn loss. More importantly, one critical pain point of the Sinkhorn loss, though typically ignored in theoretical studies, is that Sinkhorn's algorithm is numerically unstable (Peyré et al., 2019) . We show in numerical experiments that even for very simple settings, Sinkhorn's algorithm can quickly lose precision. Various stabilized versions of Sinkhorn's algorithm, though showing better stability, still suffer from slow convergence in these cases. In this article, we have rigorously analyzed the solution to the Sinkhorn optimization problem, and have designed both forward and backward algorithms that are provably efficient and stable. The main contribution of this article is as follows: • We have derived an analytic expression for the derivative of the Sinkhorn loss, which can be efficiently computed in back-propagation. • We have rigorously analyzed the advocated forward and backward algorithms for the Sinkhorn loss, and show that they have desirable efficiency and stability properties. • We have implemented the Sinkhorn loss as an auto-differentiable function in the PyTorch and JAX frameworks, using the analytic derivative obtained in this article. The code to reproduce the results in this article is available at https://1drv.ms/u/s! ArsORq8a24WmoFjNQtZYE_BERzDQ.

2. THE (SHARP) SINKHORN LOSS AS APPROXIMATE OT

Throughout this article we focus on discrete OT problems. Denote by  ∆ n = {w ∈ R n + : w T 1 n = 1} the n-dimensional ⟨P, M ⟩, where ⟨A, B⟩ = tr(A T B). An optimal solution to (1), denoted as P * , is typically called an optimal transport plan, and can be viewed as a joint distribution whose marginals coincide with µ and ν. The optimal value W (M, a, b) = ⟨P * , M ⟩ is then called the Wasserstein distance between µ and ν if the cost matrix M satisfies some suitable conditions (Proposition 2.2 of Peyré et al., 2019) . As is introduced in Section 1, solving the optimization problem (1) can be difficult even for moderate n and m. One approach to regularizing the optimization problem is to add an entropic penalty term to the objective function, leading to the entropic-regularized OT problem (Cuturi, 2013) : Sλ (M, a, b) = min T ∈Π(a,b) S λ (T ) := min T ∈Π(a,b) ⟨T, M ⟩ -λ -1 h(T ), where  h(T ) = n i=1 m j=1 T ij (1 -log T ij ) is the entropy term. The new objective function S λ (T ) is λ -1 -strongly convex on Π(a, b), so



probability simplex, and let µ = n i=1 a i δ xi and ν = m j=1 b j δ yj be two discrete probability measures supported on data points {x i } n i=1 and {y j } m j=1 , respectively, where a = (a 1 , . . . , a n ) T ∈ ∆ n , b = (b 1 , . . . , b m ) T ∈ ∆ m , and δ x is the Dirac at position x. Define Π(a, b) = {T ∈ R n×m + : T 1 m = a, T T 1 n = b}, and let M ∈ R n×m be a cost matrix with entries M ij , i = 1, . . . , n, j = 1, . . . , m. Without loss of generality we assume that n ≥ m, as their roles can be exchanged. Then OT can be characterized by the following optimization problem, W (M, a, b) = min P ∈Π(a,b)

2) has a unique global solution, denoted as T * λ . In this article, T * λ is referred to as the Sinkhorn transport plan. The entropic-regularized Wasserstein distance, also known as the Sinkhorn distance or Sinkhorn loss in the literature (Cuturi, 2013), is then defined as S λ (M, a, b) = ⟨T * λ , M ⟩. To simplify the notation, we omit the subscript λ in T * λ hereafter when no confusion is caused. It is worth noting that in the literature, S λ and Sλ are sometimes referred to as the sharp and regularized Sinkhorn loss, respectively. The following proposition from Luise et al. (2018) suggests that S λ achieves a faster rate at approximating the Wasserstein distance than Sλ . Due to this reason, in this article we focus on the sharp version, and simply call S λ the Sinkhorn loss for brevity. Proposition 1 (Luise et al., 2018). There exist constants C 1 , C 2 > 0 such that for any λ > 0, |S λ (M, a, b) -W (M, a, b)| ≤ C 1 e -λ and | Sλ (M, a, b) -W (M, a, b)| ≤ C 2 /λ. The constants C 1 and C 2 are independent of λ, and depend on µ and ν.

