SPARSITY-CONSTRAINED OPTIMAL TRANSPORT

Abstract

Regularized optimal transport (OT) is now increasingly used as a loss or as a matching layer in neural networks. Entropy-regularized OT can be computed using the Sinkhorn algorithm but it leads to fully-dense transportation plans, meaning that all sources are (fractionally) matched with all targets. To address this issue, several works have investigated quadratic regularization instead. This regularization preserves sparsity and leads to unconstrained and smooth (semi) dual objectives, that can be solved with off-the-shelf gradient methods. Unfortunately, quadratic regularization does not give direct control over the cardinality (number of nonzeros) of the transportation plan. We propose in this paper a new approach for OT with explicit cardinality constraints on the transportation plan. Our work is motivated by an application to sparse mixture of experts, where OT can be used to match input tokens such as image patches with expert models such as neural networks. Cardinality constraints ensure that at most k tokens are matched with an expert, which is crucial for computational performance reasons. Despite the nonconvexity of cardinality constraints, we show that the corresponding (semi) dual problems are tractable and can be solved with first-order gradient methods. Our method can be thought as a middle ground between unregularized OT (recovered when k is small enough) and quadratically-regularized OT (recovered when k is large enough). The smoothness of the objectives increases as k increases, giving rise to a trade-off between convergence speed and sparsity of the optimal plan.

1. INTRODUCTION

Optimal transport (OT) distances (a.k.a. Wasserstein or earth mover's distances) are a powerful computational tool to compare probability distributions and have found widespread use in machine learning (Solomon et al., 2014; Kusner et al., 2015; Arjovsky et al., 2017) . While OT distances exhibit a unique ability to capture the geometry of the data, their applicability has been largely hampered by their high computational cost. Indeed, computing OT distances involves a linear program, which takes super-cubic time to solve using state-of-the-art network-flow algorithms (Kennington & Helgason, 1980; Ahuja et al., 1988) . In addition, these algorithms are challenging to implement and are not GPU or TPU friendly. An alternative approach consists instead in solving the so-called semi-dual using (stochastic) subgradient methods (Carlier et al., 2015) or quasi-Newton methods (Mérigot, 2011; Kitagawa et al., 2019) . However, the semi-dual is a nonsmooth, piecewise-linear function, which can lead to slow convergence in practice. For all these reasons, the machine learning community has now largely switched to regularized OT. Popularized by Cuturi (2013), entropy-regularized OT can be computed using the Sinkhorn algorithm (1967) and is differentiable w.r.t. its inputs, enabling OT as a differentiable loss (Cuturi, 2013; Feydy et al., 2019) or as a layer in a neural network (Genevay et al., 2019; Sarlin et al., 2020; Sander et al., 2022) . A disadvantage of entropic regularization, however, is that it leads to fully-dense transportation plans. This is problematic in applications where it is undesirable to (fractionally) match all sources with all targets, e.g., for interpretability or for computational cost reasons. To address this issue, several works have investigated quadratic regularization instead (Dessein et al., 2018; Blondel et al., 2018; Lorenz et al., 2021) . This regularization preserves sparsity and leads to unconstrained and smooth (semi) dual objectives, solvable with off-the-shelf algorithms. Unfortunately, it does not give direct control over the cardinality (number of nonzeros) of the transportation plan. In this paper, we propose a new approach for OT with explicit cardinality constraints on the columns of the transportation plan. Our work is motivated by an application to sparse mixtures of experts, in which we want each token (e.g. a word or an image patch) to be matched with at most k experts (e.g., multilayer perceptrons). This is critical for computational performance reasons, since the cost of processing a token is proportional to the number of experts that have been selected for it. Despite the nonconvexity of cardinality constraints, we show that the corresponding dual and semidual problems are tractable and can be solved with first-order gradient methods. Our method can be thought as a middle ground between unregularized OT (recovered when k is small enough) and quadratically-regularized OT (recovered when k is large enough). We empirically show that the dual and semi-dual are increasingly smooth as k increases, giving rise to a trade-off between convergence speed and sparsity. The rest of the paper is organized as follows. • We review related work in §2 and existing work on OT with convex regularization in §3. • We propose in §4 a framework for OT with nonconvex regularization, based on the dual and semidual formulations. We study the weak duality and the primal interpretation of these formulations. • We apply our framework in §5 to OT with cardinality constraints. We show that the dual and semidual formulations are tractable and that smoothness of the objective increases as k increases. We show that our approach is equivalent to using squared k-support norm regularization in the primal. • We validate our framework in §6 and in Appendix A through a variety of experiments. Notation and convex analysis tools. Given a matrix T ∈ R m×n , we denote its columns by t j ∈ R m for j ∈ [n]. We denote the non-negative orthant by R m + and the non-positive orthant by R m -. We denote the probability simplex by m := {p ∈ R m + : p, 1 = 1}. We will also use b m to denote the set {t ∈ R m + : t, 1 = b}. The convex conjugate of a function f : R m → R ∪ {∞} is defined by f * (s) := sup t∈dom(f ) s, t -f (t). It is well-known that f * is convex (even if f is not). If the solution is unique, then its gradient is ∇f * (s) = argmax t∈dom(f ) s, t -f (t). If the solution is not unique, then we obtain a subgradient. We denote the indicator function of a set C by δ C , i.e., δ C (t) = 0 if t ∈ C and δ C (t) = ∞ otherwise. We denote the Euclidean projection onto the set C by proj C (s) = argmin t∈C st 2 2 . The projection is unique when C is convex, while it may not be when C is nonconvex. We use [•] + to denote the non-negative part, evaluated element-wise. Given a vector s ∈ R m , we use s [i] to denote its i-th largest value, i.e., s [1] ≥ • • • ≥ s [m] .



Figure1: OT formulation comparison (m = n = 20 points), with squared Euclidean distance cost, and with uniform source and target distributions. The unregularized OT plan is maximally sparse and contains at most m+n-1 nonzero elements. On the contrary, with entropy-regularized OT, plans are always fully dense, meaning that all points are fractionally matched with one another (nonzeros of a transportation plan are indicated by small squares). Squared 2-norm (quadratically) regularized OT preserves sparsity but the number of nonzero elements cannot be directly controlled. Our proposed sparsity-constrained OT allows us to set a maximum number of nonzeros k per column. It recovers unregularized OT in the limit case k = 1 (Proposition 4) and quadratically-regularized OT when k is large enough. It can be computed using solvers such as LBFGS or ADAM.

