SPARSITY-CONSTRAINED OPTIMAL TRANSPORT

Abstract

Regularized optimal transport (OT) is now increasingly used as a loss or as a matching layer in neural networks. Entropy-regularized OT can be computed using the Sinkhorn algorithm but it leads to fully-dense transportation plans, meaning that all sources are (fractionally) matched with all targets. To address this issue, several works have investigated quadratic regularization instead. This regularization preserves sparsity and leads to unconstrained and smooth (semi) dual objectives, that can be solved with off-the-shelf gradient methods. Unfortunately, quadratic regularization does not give direct control over the cardinality (number of nonzeros) of the transportation plan. We propose in this paper a new approach for OT with explicit cardinality constraints on the transportation plan. Our work is motivated by an application to sparse mixture of experts, where OT can be used to match input tokens such as image patches with expert models such as neural networks. Cardinality constraints ensure that at most k tokens are matched with an expert, which is crucial for computational performance reasons. Despite the nonconvexity of cardinality constraints, we show that the corresponding (semi) dual problems are tractable and can be solved with first-order gradient methods. Our method can be thought as a middle ground between unregularized OT (recovered when k is small enough) and quadratically-regularized OT (recovered when k is large enough). The smoothness of the objectives increases as k increases, giving rise to a trade-off between convergence speed and sparsity of the optimal plan.

1. INTRODUCTION

Optimal transport (OT) distances (a.k.a. Wasserstein or earth mover's distances) are a powerful computational tool to compare probability distributions and have found widespread use in machine learning (Solomon et al., 2014; Kusner et al., 2015; Arjovsky et al., 2017) . While OT distances exhibit a unique ability to capture the geometry of the data, their applicability has been largely hampered by their high computational cost. Indeed, computing OT distances involves a linear program, which takes super-cubic time to solve using state-of-the-art network-flow algorithms (Kennington & Helgason, 1980; Ahuja et al., 1988) . In addition, these algorithms are challenging to implement and are not GPU or TPU friendly. An alternative approach consists instead in solving the so-called semi-dual using (stochastic) subgradient methods (Carlier et al., 2015) or quasi-Newton methods (Mérigot, 2011; Kitagawa et al., 2019) . However, the semi-dual is a nonsmooth, piecewise-linear function, which can lead to slow convergence in practice. For all these reasons, the machine learning community has now largely switched to regularized OT. Popularized by Cuturi (2013), entropy-regularized OT can be computed using the Sinkhorn algorithm (1967) and is differentiable w.r.t. its inputs, enabling OT as a differentiable loss (Cuturi, 2013; Feydy et al., 2019) or as a layer in a neural network (Genevay et al., 2019; Sarlin et al., 2020; Sander et al., 2022) . A disadvantage of entropic regularization, however, is that it leads to fully-dense transportation plans. This is problematic in applications where it is undesirable to (fractionally) match all sources with all targets, e.g., for interpretability or for computational cost reasons. To address this issue, several works have investigated quadratic regularization instead (Dessein et al., 2018; Blondel et al., 2018; Lorenz et al., 2021) . This regularization preserves sparsity and leads to unconstrained and smooth (semi) dual objectives, solvable with off-the-shelf algorithms. Unfortunately, it does not give direct control over the cardinality (number of nonzeros) of the transportation plan.

