EFFICIENT DISCRETE MULTI-MARGINAL OPTIMAL TRANSPORT REGULARIZATION

Abstract

Optimal transport has emerged as a powerful tool for a variety of problems in machine learning, and it is frequently used to enforce distributional constraints. In this context, existing methods often use either a Wasserstein metric, or else they apply concurrent barycenter approaches when more than two distributions are considered. In this paper, we leverage multi-marginal optimal transport (MMOT), where we take advantage of a procedure that computes a generalized earth mover's distance as a sub-routine. We show that not only is our algorithm computationally more efficient compared to other barycentric-based distance methods, but it has the additional advantage that gradients used for backpropagation can be efficiently computed during the forward pass computation itself, which leads to substantially faster model training. We provide technical details about this new regularization term and its properties, and we present experimental demonstrations of faster runtimes when compared to standard Wasserstein-style methods. Finally, on a range of experiments designed to assess effectiveness at enforcing fairness, we demonstrate our method compares well with alternatives.

1. INTRODUCTION

The use of Optimal transport (OT) is now prevalent in many problem settings including information retrieval (Balikas et al., 2018; Yurochkin et al., 2019) , image processing (Bonneel et al., 2014) , statistical machine learning, and more recently, for ethics and fairness research (Kwegyir-Aggrey et al., 2021; Lokhande et al., 2020a) . OT is well-suited for tasks where dissimilarity between two or more probability distributions must be quantified; its success was made possible through dramatic improvements in algorithms (Cuturi, 2013; Solomon et al., 2015) that allow one to efficiently optimize commonly used functionals. In practice, OT is often used to estimate and minimize the distance between certain (data-derived) distributions, using an appropriately defined loss functional. When one seeks to operate on more than two distributions, however, newer constructions are necessary to effectively estimate distances and transports. To this end, a well studied idea in the literature is the "barycenter," identified by minimizing the pairwise distance between itself and all other distributions given. The d-dimensional proxy distance is then defined as the sum of the distances to the barycenter. Computing barycenters. Assuming that a suitably regularized form of the optimal transport loss is utilized, the pairwise distance calculation, by itself, can be efficient -in fact, in some cases, Sinkhorn iterations can be used (Cuturi, 2013) . On the other hand, to minimize distances to the mean, most algorithms typically operate by repeatedly estimating the barycenter and those pairwise distances, and using a "coupling" strategy to push points toward the barycenter, or in other cases, summing over all pairwise distances. As the number of distributions grows, robustness issues can exacerbate (Alvarez-Esteban et al., 2008) and the procedure is expensive (e.g., for 50 distributions, 50 bins). A potential alternative. Multi-marginal optimal transport (MMOT) is a related problem to the aforementioned task but to some extent, the literature has developed in parallel. In particular, MMOT focuses on identifying a joint distribution such that the marginals are defined by the input distributions over which we wish to measure the dissimilarity. The definition naturally extends the two-dimensional formulation, and recent work has explored a number of applications (Pass, 2015) . But the MMOT computation can be quite difficult, and only very recently have practical algorithms been identified

