SPARSITY-CONSTRAINED OPTIMAL TRANSPORT

Abstract

Regularized optimal transport (OT) is now increasingly used as a loss or as a matching layer in neural networks. Entropy-regularized OT can be computed using the Sinkhorn algorithm but it leads to fully-dense transportation plans, meaning that all sources are (fractionally) matched with all targets. To address this issue, several works have investigated quadratic regularization instead. This regularization preserves sparsity and leads to unconstrained and smooth (semi) dual objectives, that can be solved with off-the-shelf gradient methods. Unfortunately, quadratic regularization does not give direct control over the cardinality (number of nonzeros) of the transportation plan. We propose in this paper a new approach for OT with explicit cardinality constraints on the transportation plan. Our work is motivated by an application to sparse mixture of experts, where OT can be used to match input tokens such as image patches with expert models such as neural networks. Cardinality constraints ensure that at most k tokens are matched with an expert, which is crucial for computational performance reasons. Despite the nonconvexity of cardinality constraints, we show that the corresponding (semi) dual problems are tractable and can be solved with first-order gradient methods. Our method can be thought as a middle ground between unregularized OT (recovered when k is small enough) and quadratically-regularized OT (recovered when k is large enough). The smoothness of the objectives increases as k increases, giving rise to a trade-off between convergence speed and sparsity of the optimal plan.

1. INTRODUCTION

Optimal transport (OT) distances (a.k.a. Wasserstein or earth mover's distances) are a powerful computational tool to compare probability distributions and have found widespread use in machine learning (Solomon et al., 2014; Kusner et al., 2015; Arjovsky et al., 2017) . While OT distances exhibit a unique ability to capture the geometry of the data, their applicability has been largely hampered by their high computational cost. Indeed, computing OT distances involves a linear program, which takes super-cubic time to solve using state-of-the-art network-flow algorithms (Kennington & Helgason, 1980; Ahuja et al., 1988) . In addition, these algorithms are challenging to implement and are not GPU or TPU friendly. An alternative approach consists instead in solving the so-called semi-dual using (stochastic) subgradient methods (Carlier et al., 2015) or quasi-Newton methods (Mérigot, 2011; Kitagawa et al., 2019) . However, the semi-dual is a nonsmooth, piecewise-linear function, which can lead to slow convergence in practice. For all these reasons, the machine learning community has now largely switched to regularized OT. Popularized by Cuturi (2013) , entropy-regularized OT can be computed using the Sinkhorn algorithm (1967) and is differentiable w.r.t. its inputs, enabling OT as a differentiable loss (Cuturi, 2013; Feydy et al., 2019) or as a layer in a neural network (Genevay et al., 2019; Sarlin et al., 2020; Sander et al., 2022) . A disadvantage of entropic regularization, however, is that it leads to fully-dense transportation plans. This is problematic in applications where it is undesirable to (fractionally) match all sources with all targets, e.g., for interpretability or for computational cost reasons. To address this issue, several works have investigated quadratic regularization instead (Dessein et al., 2018; Blondel et al., 2018; Lorenz et al., 2021) . This regularization preserves sparsity and leads to unconstrained and smooth (semi) dual objectives, solvable with off-the-shelf algorithms. Unfortunately, it does not give direct control over the cardinality (number of nonzeros) of the transportation plan. and with uniform source and target distributions. The unregularized OT plan is maximally sparse and contains at most m+n-1 nonzero elements. On the contrary, with entropy-regularized OT, plans are always fully dense, meaning that all points are fractionally matched with one another (nonzeros of a transportation plan are indicated by small squares). Squared 2-norm (quadratically) regularized OT preserves sparsity but the number of nonzero elements cannot be directly controlled. Our proposed sparsity-constrained OT allows us to set a maximum number of nonzeros k per column. It recovers unregularized OT in the limit case k = 1 (Proposition 4) and quadratically-regularized OT when k is large enough. It can be computed using solvers such as LBFGS or ADAM. In this paper, we propose a new approach for OT with explicit cardinality constraints on the columns of the transportation plan. Our work is motivated by an application to sparse mixtures of experts, in which we want each token (e.g. a word or an image patch) to be matched with at most k experts (e.g., multilayer perceptrons). This is critical for computational performance reasons, since the cost of processing a token is proportional to the number of experts that have been selected for it. Despite the nonconvexity of cardinality constraints, we show that the corresponding dual and semidual problems are tractable and can be solved with first-order gradient methods. Our method can be thought as a middle ground between unregularized OT (recovered when k is small enough) and quadratically-regularized OT (recovered when k is large enough). We empirically show that the dual and semi-dual are increasingly smooth as k increases, giving rise to a trade-off between convergence speed and sparsity. The rest of the paper is organized as follows. • We review related work in §2 and existing work on OT with convex regularization in §3. • We propose in §4 a framework for OT with nonconvex regularization, based on the dual and semidual formulations. We study the weak duality and the primal interpretation of these formulations. • We apply our framework in §5 to OT with cardinality constraints. We show that the dual and semidual formulations are tractable and that smoothness of the objective increases as k increases. We show that our approach is equivalent to using squared k-support norm regularization in the primal. • We validate our framework in §6 and in Appendix A through a variety of experiments. Notation and convex analysis tools. Given a matrix T ∈ R m×n , we denote its columns by t j ∈ R m for j ∈ [n]. We denote the non-negative orthant by R m + and the non-positive orthant by R m -. We denote the probability simplex by m := {p ∈ R m + : p, 1 = 1}. We will also use b m to denote the set {t ∈ R m + : t, 1 = b}. The convex conjugate of a function f : R m → R ∪ {∞} is defined by f * (s) := sup t∈dom(f ) s, t -f (t). It is well-known that f * is convex (even if f is not). If the solution is unique, then its gradient is ∇f * (s) = argmax t∈dom(f ) s, t -f (t). If the solution is not unique, then we obtain a subgradient. We denote the indicator function of a set C by δ C , i.e., δ C (t) = 0 if t ∈ C and δ C (t) = ∞ otherwise. We denote the Euclidean projection onto the set C by proj C (s) = argmin t∈C st 2 2 . The projection is unique when C is convex, while it may not be when C is nonconvex. We use [•] + to denote the non-negative part, evaluated element-wise. Given a vector s ∈ R m , we use s [i] to denote its i-th largest value, i.e., s [1] ≥ • • • ≥ s [m] .

2. RELATED WORK

Sparse optimal transport. OT with arbitrary strongly convex regularization is studied by Dessein et al. (2018) and Blondel et al. (2018) . More specifically, quadratic regularization was studied in the discrete (Blondel et al., 2018; Roberts et al., 2017) and continuous settings (Lorenz et al., 2021) . Although it is known that quadratic regularization leads to sparse transportation plans, it does not enable explicit control of the cardinality (maximum number of nonzero elements), as we do. In this work, we study the nonconvex regularization case and apply it to cardinality-constrained OT. Sparse projections. In this paper, we use k-sparse projections as a core building block of our framework. Sparse projections on the simplex and on the non-negative orthant were studied by Kyrillidis et al. (2013) and Bolte et al. (2014) , respectively. These studies were later extended to more general sets (Beck & Hallak, 2016) . On the application side, sparse projections on the simplex were used for structured prediction (Pillutla et al., 2018; Blondel et al., 2020) , for marginalizing over discrete variables (Correia et al., 2020) and for Wasserstein K-means (Fukunaga & Kasai, 2021) . Sparse mixture of experts (MoE). In contrast to usual deep learning models where all parameters interact with all inputs, a sparse MoE model activates only a small part of the model ("experts") in an input-dependent manner, thus reducing the overall computational cost of the model. Sparse MoEs have been tremendously successful in scaling up deep learning architectures in tasks including computer vision (Riquelme et al., 2021) , natural language processing (Shazeer et al., 2017; Lewis et al., 2021; Lepikhin et al., 2021; Roller et al., 2021; Fedus et al., 2022b; Clark et al., 2022) , speech processing (You et al., 2022) , and multimodal learning (Mustafa et al., 2022) . In addition to reducing computational cost, sparse MoEs have also shown other benefits, such as an enhancement in adversarial robustness (Puigcerver et al., 2022) . See Fedus et al. (2022a) for a recent survey. Crucial to a sparse MoE model is its routing mechanism that decides which experts get which inputs. Transformer-based MoE models typically route individual tokens (embedded words or image patches). To balance the assignments of tokens to experts, recent works cast the assignment problem as entropy-regularized OT (Kool et al., 2021; Clark et al., 2022) . We go beyond entropy-regularized OT and show that sparsity-constrained OT yields a more natural and effective router.

3. OPTIMAL TRANSPORT WITH CONVEX REGULARIZATION

In this section, we review OT with convex regularization, which also includes the unregularized case. For a comprehensive survey on computational OT, see (Peyré & Cuturi, 2019) . Primal formulation. We focus throughout this paper on OT between discrete probability distributions a ∈ m and b ∈ n . Rather than performing a pointwise comparison of the distributions, OT distances compute the minimal effort, according to some ground cost, for moving the probability mass of one distribution to the other. Recent applications of OT in machine learning typically add regularization on the transportation plan T . In this section, we apply convex regularization Ω : R m + → R + ∪ {∞} separately on the columns t j ∈ R m + of T and consider the primal formulation P Ω (a, b, C) := min T ∈U (a,b) T, C + n j=1 Ω(t j ), where C ∈ R m×n Dual and semi-dual formulations. Let us denote Ω * + (s) := (Ω + δ R m + ) * (s) = max t∈R m + s, t -Ω(t) and Ω * b (s) := (Ω + δ b m ) * (s) = max t∈b m s, t -Ω(t). (3) The dual and semi-dual corresponding to (1) can then be written (Blondel et al., 2018) as Ω(t) Ω * + (s) Ω * b (s) Unregularized 0 δ R m -(s) b max i∈[m] s i Negentropy t, log t m i=1 e si-1 log m i=1 e si -b Squared 2-norm 1 2 t 2 2 1 2 m i=1 [s i ] 2 + 1 2 m i=1 1 si≥θ (s 2 i -θ 2 ) Sparsity-constrained (top-k) 1 2 t 2 2 + δ B k (t) 1 2 k i=1 [s [i] ] 2 + 1 2 k i=1 1 s [i] ≥τ (s 2 [i] -τ 2 ) Sparsity-constrained (top-1) 1 2 t 2 2 + δ B1 (t) 1 2 max i∈[m] [s i ] 2 + b max i∈[m] s i -γ 2 b 2 D Ω (a, b, C) := max α∈R m ,β∈R n α, a + β, b - n j=1 Ω * + (α + β j 1 m -c j ) and S Ω (a, b, C) := max α∈R m α, a -P * Ω (α, b, C) = max α∈R m α, a - n j=1 Ω * bj (α -c j ), where P * Ω denotes the conjugate in the first argument. When Ω is convex (which also includes the unregularized case Ω = 0), by strong duality, we have that P Ω (a, b, C) = D Ω (a, b, C) = S Ω (a, b, C) for all a ∈ m , b ∈ n and C ∈ R m×n + . Computation. With Ω = 0 (without regularization), then (2) becomes the indicator function of the non-positive orthant, leading to the constraint α i + β j ≤ c i,j . The dual (4) is then a constrained linear program and the most commonly used algorithm is the network flow solver. On the other hand, (3) becomes a max operator, leading to the so-called c-transform β j = min i∈[m] c i,j -α i for all j ∈ [n]. The semi-dual (5) is then unconstrained, but it is a nonsmooth piecewise linear function. The key advantage of introducing strongly convex regularization is that it makes the corresponding (semi) dual easier to solve. Indeed, (2) and (3) become "soft" constraints and max operators. In particular, when Ω is Shannon's negentropy Ω(t) = γ t, log t , where γ controls the regularization strength, then (2) and (3) rely on the exponential and log-sum-exp operations. It is well known that the primal (1) can then be solved using Sinkhorn's algorithm (Cuturi, 2013) , which amounts to using a block coordinate ascent scheme w.r.t. α ∈ R m and β ∈ R n in the dual (4). As pointed out in Blondel et al. (2018) , the semi-dual is smooth (i.e., with Lipschitz gradients) but the dual is not. When Ω is the quadratic regularization Ω(t) = γ 2 t 2 2 , then as shown in Blondel et al. (2018, Table 1 ), ( 2) and (3) rely on the squared relu and on the projection onto the simplex. However, it is shown empirically that a block coordinate ascent scheme in the dual (4) converges slowly. Instead, Blondel et al. (2018) propose to use LBFGS both to solve the dual and the semi-dual. Both the dual and the semi-dual are smooth (Blondel et al., 2018) , i.e., with Lipschitz gradients. For both types of regularization, when γ → 0, we recover unregularized OT.

Recovering a plan.

If Ω is strictly convex, the unique optimal solution T of (1) can be recovered from an optimal solution (α , β ) of the dual (4) by t j = ∇Ω * + (α + β j 1 m -c j ) ∀j ∈ [n] or from an optimal solution α of the semi-dual (5) by t j = ∇Ω * bj (α -c j ) ∀j ∈ [n]. ( ) If Ω is convex but not strictly so, recovering T is more involved; see Appendix B.2 for details.

4. OPTIMAL TRANSPORT WITH NONCONVEX REGULARIZATION

In this section, we again focus on the primal formulation (1), but now study the case when the regularization Ω : R m + → R + ∪ {∞} is nonconvex. Concavity. It is well-known that the conjugate function is always convex, even when the original function is not. As a result, even if the conjugate expressions (2) and (3) involve nonconcave maximization problems in the variable t, they induce convex functions in the variable s. We can therefore make the following elementary remark: the dual (4) and the semi-dual (5) are concave maximization problems, even if Ω is nonconvex. This means that we can solve them to arbitrary precision as long as we know how to compute the conjugate expressions (2) and (3). This is generally hard but we will see in §5 a setting in which these expressions can be computed exactly. We remark that the identity S Ω (a, b, C) = max α∈R m α, a -P * Ω (α, b, C) still holds even when Ω is nonconvex. The semi-dual upper-bounds the dual. Of course, if Ω is nonconvex, only weak duality holds, i.e., the dual (4) and semi-dual ( 5) are lower bounds of the primal (1). The next proposition clarifies that the semi-dual is actually an upper-bound for the dual (a proof is given in Appendix B.1).

Proposition 1. Weak duality

Let Ω : R m + → R + ∪ {∞} (potentially nonconvex). For all a ∈ m , b ∈ n and C ∈ R + m×n D Ω (a, b, C) ≤ S Ω (a, b, C) ≤ P Ω (a, b, C). Therefore, if the goal is to compute approximately P Ω (a, b, C), which involves an intractable nonconvex problem in general, it may be advantageous to use S Ω (a, b, C) as a proxy, rather than D Ω (a, b, C). However, for the specific choice of Ω in §5, we will see that D Ω (a, b, C) and S Ω (a, b, C) actually coincide, i.e., D Ω (a, b, C) = S Ω (a, b, C) ≤ P Ω (a, b, C). Recovering a plan. Many times, the goal is not to compute the quantity P Ω (a, b, C) itself, but rather the associated OT plan. If Ω is nonconvex, this is again intractable due to the nonconvex nature of the problem. As an approximation, given an optimal solution (α , β ) of the dual or an optimal solution α of the semi-dual, we propose to recover a transportation plan with ( 6) and ( 7), just like we would do in the convex Ω case. The following proposition clarifies that the optimal transportation plan T that we get corresponds to a convex relaxation of the original nonconvex problem (1). A proof is given in Appendix B.3.

Proposition 2. Primal interpretation

Let Ω : R m + → R + ∪ {∞} (potentially nonconvex). For all a ∈ m , b ∈ n and C ∈ R m×n + D Ω (a, b, C) = min T ∈R m×n T 1n=a T 1m=b T, C + n j=1 Ω * * + (t j ) = P Ω * * (a, b, C) S Ω (a, b, C) = min T ∈R m×n T 1n=a T, C + n j=1 Ω * * bj (t j ) = P Ω * * (a, b, C). In the above, f * * denotes the biconjugate of f , the tightest convex lower bound of f . When Ω is nonconvex, deriving an expression for Ω * * + and Ω * * bj could be challenging in general. Fortunately, for the choice of Ω in §5, we are able to do so. When a function is convex and closed, its biconjugate is itself. As a result, if Ω is a convex and closed function, we recover P Ω (a, b, C) = D Ω (a, b, C) = S Ω (a, b, C) for all a ∈ m , b ∈ n and C ∈ R m×n + . Summary: proposed method. To approximately solve the primal OT objective (1) when Ω is nonconvex, we proposed to solve the dual (4) or the semi-dual (5), which by Proposition 1 lower-bound the primal. We do so by solving the concave maximization problems in ( 4) and ( 5) by gradientbased algorithms, such as LBFGS (Liu & Nocedal, 1989) or ADAM (Kingma & Ba, 2015) . When a transportation plan is needed, we recover it from ( 6) and ( 7), as we would do with convex Ω. 

5. QUADRATICALLY-REGULARIZED OT WITH SPARSITY CONSTRAINTS

In this section, we build upon §4 to develop a regularized OT formulation with sparsity constraints. Formulation. Formally, given t ∈ R m , let us define the 0 pseudo norm by t 0 := |{t j = 0 : j ∈ [m]}|, i.e., the number of nonzero elements in t. For k ∈ {1, . . . , m}, we denote the 0 level sets by B k := {t ∈ R m : t 0 ≤ k}. Our goal in this section is then to approximately solve the following quadratically-regularized optimal transport problem with cardinality constraints on the columns of T min T ∈U (a,b) T ∈B k ×•••×B k T, C + γ 2 T 2 2 , where γ > 0 controls the regularization strength and where k is assumed large enough to make the problem feasible. Problem ( 8) is a special case of (1) with the nonconvex regularization Ω = γ 2 • 2 2 + δ B k . We can therefore apply the methodology outlined in §4. If the cardinality constraints need to be applied to the rows instead of to the columns, we simply transpose the problem. Computation. We recall that in order to solve the dual (4) or the semi-dual (5), the main quantities that we need to be able to compute are the conjugate expressions (2) and (3), as well as their gradients. While this is intractable for general nonconvex Ω, with the choice of Ω in (9), we obtain Ω * + (s) = max t∈R m + ∩B k s, t - 1 2 t 2 2 (10) Ω * b (s) = max t∈b m ∩B k s, t - 1 2 t 2 2 , where, without loss of generality, we assumed γ = 1. Indeed, when γ = 1, we can simply use the property (γf ) * = γf * (•/γ). From the envelope theorem of Rockafellar & Wets (2009, Theorem 10.31), the gradients are given by the corresponding argmax problems and we obtain ∇Ω * + (s) = argmax t∈R m + ∩B k s, t - 1 2 t 2 2 = proj R m + ∩B k (s) ∇Ω * b (s) = argmax t∈b m ∩B k s, t - 1 2 t 2 2 = proj b m ∩B k (s). Therefore, computing an optimal solution t reduces to the k-sparse projections of s onto the nonnegative orthant and onto the simplex (scaled by b > 0), respectively. When t is not unique (i.e., s contains ties), the argmax is set-valued. We discuss this situation in more details in Appendix B.2. Fortunately, despite the nonconvexity of the set B k , it turns out that both k-sparse projections can be computed exactly (Kyrillidis et al., 2013; Bolte et al., 2014; Beck & Hallak, 2016 ) by composing the unconstrained projection onto the original set with a top-k operation: proj R m + ∩B k (s) = proj R m + (topk(s)) = [topk(s)] + (11) proj b m ∩B k (s) = proj b m (topk(s)) = [topk(s) -τ 1 m ] + , for some normalization constant τ ∈ R, such that the solution sums to b. Here, topk(s ) is defined such that [topk(s)] i = s i if s i is in the top-k elements of s and [topk(s)] i = -∞ otherwise. The k-sparse projection on the simplex is also known as top-k sparsemax (Pillutla et al., 2018; Blondel et al., 2020; Correia et al., 2020) . Plugging these solutions back into s, t -1 2 t 2 2 , we obtain the expressions in Table 1 (a proof is given in Appendix B.4). Computing ( 11) or ( 12) requires a top-k sort and the projection of a vector of size at most k. A top-k sort can be computed in O(m log k) time, proj R m + simply amounts to the non-negative part [•] + and computing τ , as needed for proj b m , can be computed in O(k) time (Michelot, 1986; Duchi et al., 2008) , by reusing the top-k sort.We have thus obtained an efficient way to compute the conjugates (2) and (3). The total complexity per LBFGS or ADAM iteration is O(mn log k). Recovering a plan. Assuming no ties in α + β j 1 m -c j or in αc j , the corresponding column of the transportation plan is uniquely determined by ∇Ω * + (α + β j 1 m -c j ) or ∇Ω * bj (α -c j ), respectively. From ( 11) and ( 12), this column belongs to B k . In case of ties, ensuring that the plan belongs to U(a, b) requires to solve a system of linear equations, as detailed in Appendix B.2. Unfortunately, the columns may fail to belong to B k in this situation. Biconjugates and primal interpretation. As we discussed in §4 and Proposition 2, the biconjugates Ω * * + and Ω * * b allow us to formally define what primal objective the transportation plans obtained by ( 6) and ( 7) optimally solve when Ω is nonconvex. Fortunately, for the case of Ω defined in (9), we are able to derive an actual expression. Let us define the squared k-support norm by Ω * * (t) = Ψ(t) := 1 2 min λ∈R m m i=1 t 2 i λ i s.t. λ, 1 = k, 0 < λ i ≤ 1 ∀i ∈ [m]. The k-support norm is known to be the tightest convex relaxation of the 0 pseudo norm over the 2 unit ball and can be computed in O(m log m) time (Argyriou et al., 2012; McDonald et al., 2016) . We then have the following proposition, proved in Appendix B.6.

Proposition 3. Biconjugate and primal interpretation

With Ω defined in (9), we have Ω * * + (t) = Ψ(t) if t ∈ R m + , Ω * * + (t) = ∞ otherwise. Ω * * b (t) = Ψ(t) if t ∈ b m , Ω * * b (t) = ∞ otherwise. Therefore, with Ω defined in (9), we have for all a ∈ m , b ∈ n and C ∈ R m×n + D Ω (a, b, C) = S Ω (a, b, C) = P Ψ (a, b, C) ≤ P Ω (a, b, C). The last inequality is an equality if there are no ties in α + In other words, our dual and semi-dual approaches based on the nonconvex Ω are equivalent to using the convex relaxation Ψ as regularization in the primal! We believe that the biconjugate expressions in Proposition 3 are of independent interest and could be useful in other works. For instance, it shows that top-k sparsemax can be alternatively viewed as an argmax regularized with Ψ. β j 1 m -c j or in α -c j ∀j ∈ [n]. Limit cases and smoothness. Let T be the solution of the quadratically-regularized OT (without cardinality constraints). If k ≥ t j 0 for all j ∈ [n], then the constraint t j 0 ≤ k in ( 8) is vacuous, and therefore our formulation recovers the quadratically-regularized one. Since Ω is strongly convex in this case, both conjugates Ω * + and Ω * b are smooth (i.e., with Lipschitz gradients), thanks to the duality between strong convexity and smoothness (Hiriart-Urruty & Lemaréchal, 1993) . On the other extreme, when k = 1, we have the following (a proof is given in Appendix B.7).

Proposition 4. Limit cases

With Ω defined in (9) and k = 1, we have for all s ∈ R m Ω * + (s) = 1 2γ max i∈[m] [s i ] 2 + and Ω * b (s) = b max i∈[m] s i - γ 2 b 2 . We then have for all a ∈ m , b ∈ n and C ∈ R m×n + , D Ω (a, b, C) = S Ω (a, b, C) = S 0 (a, b, C) + γ 2 b 2 2 = P 0 (a, b, C) + γ 2 b 2 2 . When m < n, it is infeasible to satisfy both the marginal and the 1-sparsity constraints. Proposition 4 shows that our (semi) dual formulations reduce to unregularized OT in this "degenerate" case. As illustrated in Figure 2 , the conjugates Ω * + and Ω * b become increasingly smooth as k increases. We therefore interpolate between unregularized OT (k small enough) and quadratically-regularized OT (k large enough), with the dual and semi-dual being increasingly smooth as k increases.

6.1. SOLVER AND OBJECTIVE COMPARISON

We compared two solvers, LBFGS (Liu & Nocedal, 1989 ) and ADAM (Kingma & Ba, 2015) for maximizing the dual and semi-dual objectives of our sparsity-constrained OT. Results are provided in Figure 3 . Compared to ADAM, LBFGS is a more convenient option as it does not require the tuning of a learning rate hyperparameter. In addition, LBFGS empirically converges faster than ADAM in the number of iterations (first row of Figure 3 ). That being said, when a proper learning rate is chosen, we find that ADAM converges either as fast as or faster than LBFGS in wallclock time (second row of Figure 3 ). In addition, Figure 3 shows that dual and semi-dual objectives are very close to each other toward the end of the optimization process. This empirically confirms Proposition 3, which states that the dual and the semi-dual are equal at their optimum. We have seen that a greater k leads to a smoother objective landscape (Figure 2 ). It is known that a smoother objective theoretically allows faster convergence. We validate this empirically in Appendix A.2, where we see that a greater k leads to faster convergence.

6.2. SPARSE MIXTURES OF EXPERTS

We applied sparsity-constrained OT to vision sparse mixtures of experts (V-MoE) models for largescale image recognition (Riquelme et al., 2021) . A V-MoE model replaces a few dense feedforward layers MLP : (Dosovitskiy et al., 2021) with the sparsely-gated mixture-of-experts layers: x ∈ R d → MLP(x) ∈ R d in Vision Transformer (ViT) MoE(x) := n r=1 Gate r (x) • MLP r (x), where Gate r : R d → R + is a sparse gating function and feedforward layers {MLP r } n r=1 are experts. In practice, only those experts MLP r (•) corresponding to a nonzero gate value Gate r (x) will be computed -in this case, we say that the token x is routed to the expert r. Upon a minibatch of m tokens {x 1 , . . . , x m }, we apply our sparsity-constrained OT to match tokens with experts, so that the number of tokens routed to any expert is bounded. Following Clark et al. (2022) , we backprop the gradient only through the combining weights Gate r (x), but not through the OT algorithm (details in Appendix A.5), as this strategy accelerates the backward pass of V-MoEs. Using this routing strategy, we train the B/32 and B/16 variants of the V-MoE model: They refer to the "Base" variants of V-MoE with 32 × 32 patches and 16 × 16 patches, respectively. Hyperparameters of these architectures are provided in Riquelme et al. (2021, Appendix B.5 ). We train on the JFT-300M dataset (Sun et al., 2017) , which is a large scale dataset that contains more than 305 million images. We then perform 10-shot transfer learning on the ImageNet dataset (Deng et al., 2009) . Additional V-MoE and experimental details in Appendix A.5. Table 2 summarizes the validation accuracy on JFT-300M and 10-shot accuracy on ImageNet. Compared to baseline methods, our sparsity-constrained approach yields the highest accuracy with both architectures on both benchmarks.

7. CONCLUSION

We presented a dual and semi-dual framework for OT with general nonconvex regularization. We applied that framework to obtain a tractable lower bound to approximately solve an OT formulation with cardinality constraints on the columns of the transportation plan. We showed that this framework is formally equivalent to using squared k-support norm regularization in the primal. Moreover, it interpolates between unregularized OT (recovered when k is small enough) and quadraticallyregularized OT (recovered when k is large enough). The (semi) dual objectives were shown to be increasingly smooth as k increases, enabling the use of gradient-based algorithms such as LBFGS or ADAM. We illustrated our framework on a variety of tasks; see §6 and Appendix A. For training of mixture-of-experts models in large-scale computer vision tasks, we showed that a direct control of sparsity improves the accuracy, compared to top-k and Sinkhorn baselines. Beyond empirical performance, sparsity constraints may lead to more interpretable transportation plans and the integer-valued hyper-parameter k may be easier to tune than the real-valued parameter γ. 

A.2 SOLVER COMPARISON WITH AN INCREASED CARDINALITY

We have seen that an increased k increases the smoothness of the optimization problem (Figure 2 ). This suggests that solvers may converge faster with an increased k. We show this empirically in Figure 6 , where we measure the gradient norm at each iteration of the solver and compare the case k = 2 and k = 4. 

A.3 COLOR TRANSFER

We apply our sparsity-constrained formulation on the classical OT application of color transfer (Pitié et al., 2007) . We follow exactly the same experimental setup as in Blondel et al. (2018, Section 6) . Figure 7 shows the results obtained from our sparsity-constrained approach. Similar to well-studied alternatives, our yields visually pleasing results. Plan sparsity: 99.2% Reference Sq. l2 norm reg. γ = 0.1

Sparsity constrained k=2

Source Plan sparsity: 99.3% Plan sparsity: 99.2% Plan sparsity: 99.3% We follow Amos et al. (2022) to set up a synthetic transport problem between 100 supply locations and 10,000 demand locations worldwide. Transport costs are set to be the spherical distance between the demand and supply locations. This transportation problem can be solved via the entropy regularized optimal transport as in Amos et al. (2022) . We visualized this entropy-regularized transport plan in panel (a) of Figure 8 . Building upon the setting in Amos et al. (2022) , we additionally assume that each supplier has a limited supplying capacity. That is, each supplier can transport goods to as many locations as possible up to a certain prescribed limit. This constraint is conceivable, for instance, when suppliers operate with a limited workforce and cannot meet all requested orders. We incorporate this constraint into our formulation of sparsity-constrained optimal transport by specifying k as the capacity limit. The panel (b) of Figure 8 is the obtained transportation plan with a supplying capacity of k = 100 (each supplier can transport goods to at most 100 demand locations). Comparing panels (a) and (b) of Figure 8 , we recognize that derived plans are visibly different in a few ways. For instance, with the capacity constraint on suppliers, demand locations in Europe import goods from more supply locations in North America than without the capacity constraint. Similar observations go to demand locations in pacific islands: Without the capacity constraint, demand locations in Pacific islands mostly rely on suppliers in North America; with the capacity constraint, additional suppliers in South America are in use. Published as a conference paper at ICLR 2023 The top panel shows the transportation plan without a capacity limit on supply: Each supplying location can transport to as many demand location as possible. This is derived based on the entropy-regularized optimal transportation. The bottom panel shows the transportation plan with a capacity limit on supply: Each supplying location can meet demands up to a fixed capacity. The plan in this case is derived by sparsity-constrained optimal transport with k = 100.

A.5 V-MOE EXPERIMENT

Our experiment is based on the vision MoE (V-MoE) architecture (Riquelme et al., 2021) , which replaces a few MLP layers of the vision Transformer (Dosovitskiy et al., 2021) by MoE layers. In this subsection, we review the background of V-MoE, with a focus on the router, which decides which experts get which input tokens. We introduce a few notations that will be used throughout this subsection. Let {x 1 , . . . , x m } ⊂ R d be a minibatch of m tokens in R d , and let X ∈ R m×d be a corresponding matrix whose rows are tokens. Let W ∈ R d×n be a learnable matrix of expert weights, where each column of W is a learnable feature vector of an expert. Common to different routing mechanisms is an token-expert affinity matrix Π := XW ∈ R m×n : Its (i, j)-th entry is an inner-product similarity score between the i-th token and the j-th expert. The TopK router. To route tokens to experts, the TopK router in Riquelme et al. ( 2021) computes a sparse gating matrix Γ that has at most κ nonzeros per row, through a function top κ : R n → R n that sets all but largest κ values zero: Γ := top κ softmax (Π + σ ) ∈ R m×n with Π = XW. ( ) Note that the integer κ is not to be confused with k used in the main text -κ here refers to the number of selected expert for each token and it can differ from the cardinality-constraint k used in the main text in general. The vector ∼ N (0, I) in ( 15) is a noise injected to the token-expert affinity matrix XW with σ ∈ R controlling the strength of noise. In practice, σ is set to be 1/n during training and 0 in inference. To ensure that all experts are sufficiently trained, the gating matrix Γ in ( 15) is regularized by auxiliary losses that encourage experts to taken a similar amount of tokens in a minibatch. A detailed description of these auxiliary losses is presented in Riquelme et al. (2021, Section A) . For an efficient hardware utilization, Riquelme et al. ( 2021) allocate a buffer capacity of experts, which specifies the number of tokens each expert can at most process in a minibatch. With a specified buffer capacity and a computed gating matrix Γ, the TopK router goes over the rows of Γ and assign each token to its top-chosen expert as long as the chosen expert's capacity is not full. This procedure is described in Algorithm 1 of Riquelme et al. (2021, Section C.1) . Finally, the outcomes of experts are linearly combined using the gating matrix Γ as in ( 14). The S-BASE router. Clark et al. (2022) cast the token-expert matching problem as an entropyregularized OT problem, solved using the Sinkhorn algorithm. This approach, dubbed as the Sinkhorn-BASE (S-BASE) router, was originally designed for language MoEs that take text as input. In this work, we adapt it to vision MoEs. In direct parallel to the TopK gating matrix in (15), the gating matrix of entropy-regularized OT is set to be Γ ent := top κ Π ent ∈ R m×n , where Π ent := argmin T ∈U (a,b) T, -Π + T, log T , with a = 1 n and b = (m/n)1 n . The optimization plan Π ent in (17) can be obtained using the Sinkhorn algorithm (Sinkhorn & Knopp, 1967) . Note that while we formulated optimal transport problems with non-negative cost matrices in the main text, values in the cost matrix C = -Π in (17) can be both positive and negative, following Clark et al. (2022) . Since Π ent is a dense matrix, a heuristic is needed to select only κ experts to form the gating matrix Γ ent -this is achieved by using a top κ in ( 16). With a computed gating matrix Γ ent , the S-BASE router assigns each token to its top-chosen expert in the same way of the TopK router. This process allocates each expert an amount of tokens, up to a certain upper bound specified by the buffer capacity as in the case of TopK. As in Clark et al. (2022) , we linearly combine the output of experts using a softmax matrix softmax(Π). In this way, the backward pass of gradient-based training does not go through the Sinkhorn algorithm, can be faster and more numerically stablefoot_0 . The Sparsity-constrained router. We cast the token-expert matching problem as a sparsityconstrained OT problem. With a prescribed buffer capacity k, our goal is to upper-bound the number of tokens assigned to each expert by k. This amounts to adding a cardinality constraint to each column of the gating matrix: Γ sparse := argmin T ∈U (a,b) T ∈B k ×•••×B k T, -Π softmax + 1 2 T 2 2 , with a = 1 n and b = (m/n)1 n with Π softmax = softmax(XW ). The purpose of the softmax function here is to obtain a cost matrix containing values of the same sign. Otherwise, if a cost matrix contains both positive and negative values, then the obtained plan from sparsity-constrained optimal transport may contain zero at all entries corresponding to positive values in the cost matrix, so as to minimize to the objective. In that case, columns of this transportation may contain much fewer nonzeros than k -this is an undesirable situation as it under-uses the buffer capacity of experts. Note that, however, this was not an issue in the S-BASE router -a cost matrix there can contain both positive and negative values (Clark et al., 2022) -because values of the transportation plan yielded by the Sinkhorn's algorithm are strictly positive. The sparse transportation plan Γ sparse in (18), allocates each expert an amount of tokens up to k. As in the S-BASE router, we linearly combine the output of experts using the matrix Π softmax . To approximate Γ sparse , we optimize its semi-dual proxy as introduced in Section 4. We do so by using an ADAM optimizer with a learning rate of 10 -2 for 50 steps. V-MoE architecture. We use the S-BASE router and our proposed sparsity-constrained router as drop-in replacements of the TopK router in otherwise standard V-MoE architectures (Riquelme et al., 2021) . We focus on the V-MoE B/32 and B/16 architectures, which use 32 × 32 and 16 × 16 patches, respectively. We place MoEs on every other layer, which is the Every-2 variant in Riquelme et al. (2021) . We fix the total number of experts n = 32 for all experiments. In the TopK and S-BASE router, we assign 2 experts to each expert, that is, κ = 2 in (15) and ( 16). The buffer capacity is set to be n/κ = 32/2 = 16, that is, each expert can take 16 tokens at most. To match this setting, we use k = 16 in (18) for our sparsity-constrained router. Upstream training and evaluation. We follow the same training strategy of Riquelme et al. (2021) to train B/32 and B/16 models on JFT-300M, with hyperparameters reported in Riquelme et al. (2021, Table 8 ). JFT-300M has around 305M training and 50,000 validation images. Since labels of the JFT-300M are organized in a hierarchical way, an image may associate with multiple labels. We report the model performance by precision@1 by checking if the predicted class with the highest probability is one of the true labels of the image. Downstream transfer to ImageNet. For downstream evaluations, we perform 10-shot linear transfer on ImageNet (Deng et al., 2009) . Specifically, with a JFT-trained V-MoE model, we freeze the model up to its penultimate layer, re-initialize its last layer, and train the last layer on ImageNet. Comparing the speed of routers. We note that the sparsity-constrained router is slightly slower than baseline routers. One reason is that the topk function used for k-sparse projection steps. To further speedup the sparsity-constrained router, an option is to use the approximated version of topk (Chern et al., 2022) Sparsity-constrained Solve (4), (5) with Ω(t) = 1 2 t 2 2 + δ B k (t) Table 5 : Steps of the EM-like algorithm used to estimate the centers of the clusters with each method. In all cases, the cost matrix C (s) is the squared distance between each data point and the current estimate of the centers, at a given step s, i.e. [C (s) ] i,j = x i -µ (s) j 2 2 . The vector e l ∈ R n is a canonical basis vector with l-th entry being 1 and all other entries being 0. the sparsity-constrained router is to explore different optimizers. Currently, we run the ADAM optimizer for 50 steps using a learning rate 10 -2 . We suspect that with a more careful tuning of the optimizer, one can reduce the number of steps without harming the performance. Variants of accelerated gradient-based methods (An et al., 2022) may also be applicable.

A.6 SOFT BALANCED CLUSTERING

OT viewpoint. Suppose we want to cluster m data points x 1 , ..., x m ∈ R d into n clusters with centroids µ 1 , . . . , µ n ∈ R d . We let X ∈ R d×m be a matrix that contains data points x 1 , ..., x m as columns. Similarly, we let µ ∈ R d×n be a matrix of centroids. This moves the optimal plan away from the vertices of the polytope. This corresponds to a "soft" balanced K-Means, in which we replace "hard" cluster memberships with "soft" ones. We can again alternate between minimization w.r.t. T (solving a regularized OT problem) and minimization w.r.t. µ. In the case of the squared Euclidean distance, the closed form solution for the latter is µ i ∝ n j=1 t i,j x j for all i ∈ [m]. When Ω is nonconvex, we propose to solve the (semi) dual as discussed in the main text. Results on MNIST. MNIST contains grayscale images of handwritten digits, with a resolution of 28 × 28 pixels. The dataset is split in 60 000 training and 10 000 test images. As preprocessing, Using the inequality min u max v f (u, v) ≥ max v min u f (u, v) twice, we have P Ω (a, b, C) = min T ∈R m×n + max α∈R m max β∈R n L(T, α, β) ≥ max α∈R m min T ∈R m×n + max β∈R n L(T, α, β) ≥ max α∈R m max β∈R n min T ∈R m×n + L(T, α, β). For the first inequality, we have max α∈R m min T ∈R m×n + max β∈R n L(T, α, β) = max α∈R m α, a + min T ∈R m×n + max β∈R n β, b + n j=1 t j , c j -α -β j 1 m + Ω(t j ) = max α∈R m α, a + n j=1 min tj ∈bj m t j , c j -α + Ω(t j ) = max α∈R m α, a - n j=1 Ω * bj (α -c j ) = S Ω (a, b, C). For the second inequality, we have max α∈R m max β∈R n min T ∈R m×n + L(T, α, β) = max α∈R m max β∈R n α, a + β, b + n j=1 min tj ∈R m + t j , c j -α -β j 1 m + Ω(t j ) = max α∈R m max β∈R n α, a + β, b - n j=1 Ω * + (α + β j 1 m -c j ) = D Ω (a, b, C). To summarize, we showed P Ω (a, b, C) ≥ S Ω (a, b, C) ≥ D Ω (a, b, C).

B.2 DUAL-PRIMAL LINK

When the solution of the maximum below is unique, t j can be uniquely determined for j ∈ [n] from t j = ∇Ω * + (α + β j 1 m -c j ) = argmax tj ∈R M + α + β j 1 m -c j , t j -Ω(t j ) = ∇Ω * bj (α -c j ) = argmax tj ∈ m α -c j , t j -Ω(t j ). See Table 1 for examples. When the maximum is not unique, t j is jointly determined by t j ∈ ∂Ω * + (α + β j 1 m -c j ) = argmax tj ∈R M + α + β j 1 m -c j , t j -Ω(t j ) ∈ ∂Ω * bj (α -c j ) = argmax tj ∈ m α -c j , t j -Ω(t j ), where ∂ indicates the subdifferential, and by the primal feasability T ∈ U(a, b), or more explicitly T ∈ R m×n + , T 1 n = a and (T ) 1 m = b. This also implies T , 1 m 1 n = 1.

Unregularized case.

When Ω = 0, for the dual, we have ∂Ω * + (s j ) = argmax tj ∈R m + s j , t j . We note that the problem is coordinate-wise separable with argmax ti,j ∈R+ t i,j • s i,j =    ∅ if s i,j > 0 R ++ if s i,j = 0 {0} if s i,j < 0 . With s i,j = α i + β j -c i,j , we therefore obtain t i,j > 0 if α i + β j = c i,j t i,j = 0 if α i + β j < c i,j , since s i,j > 0 is dual infeasible. We can therefore use α and β to identify the support of T . The size of that support is at most m + n -1 (Peyré & Cuturi, 2019, Proposition 3.4) . Using the marginal constraints T 1 n = a and (T ) 1 m = b, we can therefore form a system of linear equations of size m + n to recover T . Likewise, for the semi-dual, with s j = αc j , we have t j ∈ ∂Ω * bj (s j ) = argmax tj ∈ m s j , t j = conv({v 1 , . . . , v |Sj | }) = conv(S j ), where S j := conv({e i : i ∈ argmax i∈[m] s i,j }). Let us gather v 1 , . . . , v |Sj | as a matrix V j ∈ R m×|Sj | . There exists w j ∈ |Sj | such that t j = V j w j . Using the primal feasability, we can solve with respect to w j for j ∈ [n]. This leads to a (potentially undertermined) system of linear equations with n j=1 |S j | unknowns and m + n equations. Squared k-support norm. We now discuss Ω = Ψ, as defined in (13). When the maximum is unique (no ties), t j is uniquely determined by (19). We now discuss the case of ties. For the dual, with s j = α + β j 1 m -c j , we have t j ∈ ∂Ω * + (s j ) = argmax tj ∈R m + s j , t j -Ω(t j ) = conv({v 1 , . . . , v |Sj | }) := conv({[u j ] + : u j ∈ S j }), where S j := topk(s j ) is a set containing all possible top-k vectors (the set is a singleton if there are no ties in s j , meaning that there is only one possible top-k vector). For the semi-dual, with s j = αc j , we have t j ∈ ∂Ω * bj (s j ) = argmax tj ∈ m s j , t j -Ω(t j ) = conv({v 1 , . . . , v |Sj | }) := conv({[u j -τ j ] + : u j ∈ S j }), where τ j is such that k i=1 [s [i],j -τ j ] + = b j . Again, we can combine these conditions with the primal feasability T ∈ U(a, b) to obtain a system of linear equations. Unfortunately, in case of ties, ensuring that T ∈ U(a, b) by solving this system may cause t j ∈ B k . Another situation causing t j ∈ B k is if k is set to a smaller value than the maximum number of nonzero elements in the columns of the primal LP solution.

B.3 PRIMAL INTERPRETATION (PROPOSITION 2)

For the semi-dual, we have S Ω (a, b, C) = max α∈R m α, a - n j=1 Ω * bj (α -c j ) = max α∈R m µj : µj =α-cj α, a - n j=1 Ω * bj (µ j ) = max α∈R m µj ∈R m min T ∈R m×n α, a - n j=1 Ω * bj (µ j ) + n j=1 t j , µ j -α + c j = min T ∈R m×n T, C + max α∈R m α, T 1 n -a + n j=1 max µj ∈R m t j , µ j -Ω * bj (µ j ) = min T ∈R m×n T 1n=a T, C + n j=1 Ω * * bj (t j ), where we used that strong duality holds, since the conjugate is always convex, even if Ω is not. Likewise, for the dual, we have D Ω (a, b, C) = max α∈R m ,β∈R n α, a + β, b - n j=1 Ω * + (α + β j 1 m -c j ) = max α∈R m ,β∈R n µj : µj =α+βj 1m-cj α, a + β, b - n j=1 Ω * + (µ j ) = max α∈R m ,β∈R n µj ∈R m min T ∈R m×n α, a + β, b - n j=1 Ω * + (µ j ) + n j=1 t j , µ j -α -β j 1 m + c j = min T ∈R m×n T, C + max α∈R m α, T 1 n -a + max β∈R n β, T 1 m -b + n j=1 max µj ∈R m t j , µ j -Ω * + (µ j ) = min T ∈R m×n T 1n=a T 1m=b T, C + n j=1 Ω * * + (t j ).

B.4 CLOSED-FORM EXPRESSIONS (TABLE 1)

The expressions for the unregularized, negentropy and quadratic cases are provided in (Blondel et al., 2018 , Table 1 ). We therefore focus on the top-k case. Plugging (11) back into s, t -1 2 t 2 2 , we obtain Ω * + (s) = k i=1 [s [i] ] + s [i] - 1 2 [s [i] ] 2 + = k i=1 [s [i] ] 2 + - 1 2 [s [i] ] 2 + = 1 2 k i=1 [s [i] ] 2 + . Plugging (12) back into s, t -1 2 t 2 2 , we obtain Its conjugate for all s ∈ R m is the squared k-support dual norm: Ω * bj (s) = k i=1 [s [i] -τ ] + s [i] - 1 2 k i=1 [s [i] -τ ] 2 + = k i=1 1 s [i] ≥τ (s [i] -τ )s [i] - 1 2 k i=1 1 s [i] ≥τ (s [i] -τ ) 2 = 1 2 k i=1 1 s [i] ≥τ (s 2 [i] -τ 2 ). Ψ * (s) = 1 2 k i=1 |s| 2 [i] . Proof. This result was proved in previous works (Argyriou et al., 2012; McDonald et al., 2016) . We include here an alternative proof for completeness. Using Ψ(t) = 1 2 ψ * (2t) gives the desired result. Lemma 2. For all s ∈ R m and b > 0 max t∈b m s, t - Note that ξ is not constrained to be non-negative because it is squared in the objective. 1 Derivation of Ω * * b . Recall that Ω * b (s) := max t∈b m ∩B k s, t - 1 2 t 2 2 . Using Lemma 1, Lemma 2 and (12), we obtain Proof of proposition. We recall that Ψ is defined in (13). From Proposition 2, we have D Ω (a, b, C) = S Ω (a, b, C) = P Ψ (a, b, C). From Proposition 1 (weak duality), we have P Ψ (a, b, C) ≤ P Ω (a, b, C). Ω * b (s) = min Assuming no ties in α + β j 1 m -c j or in αc j for all j ∈ [n], we know that t j ∈ B k for all j ∈ [n]. Furthermore, from (13), we have for all t j ∈ B k that Ω(t j ) = Ψ(t j ) = 1 2 t j 2 2 . Therefore, without any ties, we have P Ψ (a, b, C) = P Ω (a, b, C).

B.7 LIMIT CASES (PROPOSITION 4)

In the limit case k = 1 with Ω defined in (9), we have Likewise, we have Ω * b (s) = max Ω * + (s) = max t∈R m + ∩B1 s, t - γ 2 t 2 2 = max i∈[m] max t∈R+ s i t - γ 2 t 2 = 1 2γ max i∈[m] [s i ] 2 + . We therefore get [α i + β j -c i,j ] 2 + . From Proposition 3, we have D Ω = S Ω = S 0 + γ 2 b 2 2 = P 0 + γ 2 b 2 2 .



Personal communications with the authors ofClark et al. (2022).



Figure1: OT formulation comparison (m = n = 20 points), with squared Euclidean distance cost, and with uniform source and target distributions. The unregularized OT plan is maximally sparse and contains at most m+n-1 nonzero elements. On the contrary, with entropy-regularized OT, plans are always fully dense, meaning that all points are fractionally matched with one another (nonzeros of a transportation plan are indicated by small squares). Squared 2-norm (quadratically) regularized OT preserves sparsity but the number of nonzero elements cannot be directly controlled. Our proposed sparsity-constrained OT allows us to set a maximum number of nonzeros k per column. It recovers unregularized OT in the limit case k = 1 (Proposition 4) and quadratically-regularized OT when k is large enough. It can be computed using solvers such as LBFGS or ADAM.

is a cost matrix and U(a, b) := {T ∈ R m×n + : T 1 n = a, T 1 m = b} is the transportation polytope, which can be interpreted as the set of all joint probability distributions with marginals a and b. It includes the Birkhoff polytope as a special case when m = n and a = b = 1m m .

Figure 2: Increasing k increases smoothness. Let s = (s 1 , s 2 , 0, 1). We visualize Ω * + (s), defined in (10) and derived in Table 1, when varying s 1 and s 2 . It can be interpreted as a relaxation of the indicator function of the non-positive orthant. The conjugate Ω * b (s) (not shown) can be interpreted as a relaxed max operator, scaled by b. In both cases, the smoothness increases when k increases.

Figure 3: Solver comparison for the semi-dual and dual formulation of sparsity-constrained OT with k = 2 on different datasets. Columns correspond to datasets used in Figure 1, Figure 5, and Figure 7.

OT between Gaussian and Similar to Figure4, we show transportation plans between a Gaussian source marginal and a mixture of two Gaussians target marginal in Figure5. We set the source distribution as P(Y = y) := φ(y; 16, 5)/c Y , where c Y := 31 y=0 φ(y; 16, 5) is the normalizing constant; we set the target distribution as P(Z = z) := φ(z; 8, 5) + φ(z; 24, 5) /c Z , where c Z = 31 z=0 φ(z; 8, 5) + φ(z; 24, 5) is the normalizing constant. Apart from that, we use the same settings as Figure4,

Figure 5: OT between Gaussian and bi-Gaussian distributions.

Figure6: Solvers converge faster with an increased k. We measure the gradient norm at each iteration of LBFGS applied to the semi-dual formulations (top row) and the dual formulations (bottom row) of different datasets. Since the gradient norm should go to zero, we see that LBFGS solver converges faster with an increased k.

Figure 7: Result comparison on the color transfer task. The sparsity indicated below each image shows the percentage of nonzeros in the transportation plan. For a fair comparison, we use k = 2 for the sparsity-constrained formulation and the regularization weight γ = 0.1 for squared 2 formulation to produce comparably sparse transportation plan.

Figure8: Plans obtained from the supply-demand transportation task. Blue lines show the transportation plan from the supply locations (yellow dots) to demand locations (blue dots). The top panel shows the transportation plan without a capacity limit on supply: Each supplying location can transport to as many demand location as possible. This is derived based on the entropy-regularized optimal transportation. The bottom panel shows the transportation plan with a capacity limit on supply: Each supplying location can meet demands up to a fixed capacity. The plan in this case is derived by sparsity-constrained optimal transport with k = 100.

The K-Means algorithm can be viewed as an OT problem with only one marginal constraint, where [C] i,j = x i -µ j 2 2 and a = 1 m /m. Lloyd's algorithm corresponds to alternating minimization w.r.t. T (updating centroid memberships) and w.r.t. µ (updating centroid positions). This viewpoint suggests two generalizations. The first one consists in using two marginal constraints minT ∈R m×n + T 1n=a T 1m=b µ∈R d×n T, C = min T ∈U (a,b) µ∈R d×n T, C .This is useful in certain applications to impose a prescribed size to each cluster (e.g., b = 1 n /n) and is sometimes known as balanced or constrained K-Means(Ng, 2000).The second generalization consists in introducing convex regularization Ω min

USEFUL LEMMAS Lemma 1. Conjugate of the squared k-support norm Let us define the squared k-support norm for all t ∈ R m byΨ. λ, 1 = k, 0 < λ i ≤ 1, ∀i ∈ [m].

Lapin et al. (2015, Lemma 1), we have for alla ∈ R kv + ξ, 1 s.t. ξ i ≥ s 2 i -v ∀i ∈ [m].We introduce Lagrange multipliers λ ∈ R m + , for the inequality constraints but keep the non-, 1 + λ, s • s -v1 -ξ .Using strong duality, we haveψ(s) = max λ∈R m + m i=1 λ i s 2 i + min v∈R v(k -λ, 1 ) + min . λ, 1 = k, λ i ≤ 1 ∀i ∈ [m]. . λ, 1 = k, λ i ≤ 1 ∀i ∈ [m].

-Ψ * (ξ) -µ, sξ = min µ∈R m + max s∈R m s, tµ + max ξ∈R m µ, ξ -Ψ * (ξ

ξ∈R m Ψ * (ξ) + τ b s.t. ξ ≥ s -τ 1 = min τ ∈R,ξ∈R m max µ∈R m + Ψ * (ξ) + τ b + µ, s -τ 1 -ξ . 1 -b) + max s∈R m + s, tµ + max ξ∈R m ξ, µ -Ψ * (ξ) = Ψ(t) if t ∈ b m ∞, otherwise .

Ω (a, b, C) = max α∈R m ,β∈R n α, a + β, bn j=1 Ω * + (α + β j 1 m -c j ) = max α∈R m ,β∈R n α, a + β, b -

-τ ] + sum to b ( §5), where s [i] denotes the i-th largest entry of the vector s ∈ R m . The top-k and top-1 expressions above assume no ties in s.

Performance of the V-MoE B/32 and B/16 architectures with different routers. The fewshot experiments are averaged over 5 different seeds (Appendix A.5).

This newly initialized layer is trained on 10 examples per ImageNet class (10-shot learning).

, which we did not use in this study. This approximated topk may be especially useful on large models like B/16, where the number of tokens is large. Another way to accelerate

,ξ∈R m s, t -Ψ * (ξ) s.t. ξ ≥ s = max s∈R m ,ξ∈R m min

ACKNOWLEDGMENTS

We thank Carlos Riquelme, Antoine Rolet and Vlad Niculae for feedback on a draft of this paper, as well as Aidan Clark and Diego de Las Casas for discussions on the Sinkhorn-Base router. We are also grateful to Basil Mustafa, Rodolphe Jenatton, André Susano Pinto and Neil Houlsby for feedback throughout the project regarding MoE experiments. We thank Ryoma Sato for a fruitful email exchange regarding strong duality and ties.

A EXPERIMENTAL DETAILS AND ADDITIONAL EXPERIMENTS

A.1 ILLUSTRATIONS OT between 2D points. In Figure 1 , we visualize the transportation plans between 2D points. These transportation plans are obtained based on different formulations of optimal transport, whose properties we recall in Table 3 . The details of this experiment are as follows. We draw 20 samples from a Gaussian distribution N ([ 0 0 ], [ 1 0 0 1 ]) as source points; we draw 20 samples from a different Gaussian distribution N ([ 4 4 ], 1 -0.8 -0.6 1) as target points. The cost matrix in R 20×20 contains the Euclidean distances between source and target points. The source and target marginals a and b are both probability vectors filled with values 1/20. On the top row of Figure 1 , blue lines linking source points and target points indicate nonzero values in the transportation plan obtained from each optimal transport formulation. These transportation plans are shown in the second row. Figure 1 confirms that, by varying k in our sparsity-constrained formulation, we control the columnwise sparsity of the transportation plan. , where z, m, s are all real scalars with s = 0. The source distribution is set to be P(Y = y) := φ(y; 10, 4)/c Y with a normalizing constant c Y := 31 y=0 φ(y; 10, 4); the target distribution is set to be P(Z = z) := φ(z; 16, 5)/ c Z with a normalizing constant c Z := 31 z=0 φ(z; 16, 5). The cost matrix C contains normalized squared Euclidean distances between source and target locations:By setting k = 2 in our sparsity-constrained OT formulation, we obtain a transportation plan that contains at most two nonzeros per column (right-most panel of Figure 4 ). 6 : Clustering results on MNIST using different algorithms. We report the average cost on the test split (average distance between each test image and its cluster), and the Kullback-Leibler divergence between the empirical distribution of images per cluster and the expected one (a uniform distribution). We average the results over 20 runs and report confidence intervals at 95%. The algorithm proposed in §5 achieves the most balanced clustering, and a comparable cost to other OTbased solutions.we simply put the pixel values in the range [-1, 1] and "flatten" the images to obtain vectors of 784 elements.We use the training set to estimate the centers of the clusters using different algorithms. We use an EM-like algorithm to estimate the cluster centers in all cases, as described in Table 5 (we perform 50 update steps). In particular, notice that only the E-step changes across different algorithms, as described in Table 5 . Since there are 10 digits, we use 10 clusters.We evaluate the performance on the test set. Since some of the algorithms produce a "soft" clustering (all except K-Means), represented by the matrix T , for each test image i we assign it to the cluster j with the largest value in T = (t i,j ). We measure the average cost (i.e. average squared distance between each image and its selected cluster), and the KL divergence between the empirical distribution of images per cluster and the expected one (a uniform distribution). The centers are initialized from a normal distribution with a mean of 0 and a standard deviation of 10 -3 . Algorithms employing an OT-based approach perform 500 iterations to find T , using either the Sinkhorn algorithm (with the Negentropy method) or LBFGS (used by the rest of OT-based methods). We use a sparsity-constraint of k = 1.15 • m n (recall that k is the maximum number of nonzeros per column). Notice that using k = m n , and assuming that n is a divisor of m, would necessary require that the number of nonzeros per row is 1. Thus, our minimization problem would be equivalent to that of the unregularized OT. Thus, we slightly soften the regularization.Table 6 shows the results of the experiment, averaged over 20 different random seeds. The best cost is achieved by the Soft K-Means algorithm, but the resulting clustering is quite unbalanced, as reported by the KL divergence metric. On the other hand, all OT-based approaches achieve similar costs, but the algorithm based on §5 obtains a significantly better balanced clustering. 

