EFFICIENT DISCRETE MULTI-MARGINAL OPTIMAL TRANSPORT REGULARIZATION

Abstract

Optimal transport has emerged as a powerful tool for a variety of problems in machine learning, and it is frequently used to enforce distributional constraints. In this context, existing methods often use either a Wasserstein metric, or else they apply concurrent barycenter approaches when more than two distributions are considered. In this paper, we leverage multi-marginal optimal transport (MMOT), where we take advantage of a procedure that computes a generalized earth mover's distance as a sub-routine. We show that not only is our algorithm computationally more efficient compared to other barycentric-based distance methods, but it has the additional advantage that gradients used for backpropagation can be efficiently computed during the forward pass computation itself, which leads to substantially faster model training. We provide technical details about this new regularization term and its properties, and we present experimental demonstrations of faster runtimes when compared to standard Wasserstein-style methods. Finally, on a range of experiments designed to assess effectiveness at enforcing fairness, we demonstrate our method compares well with alternatives.

1. INTRODUCTION

The use of Optimal transport (OT) is now prevalent in many problem settings including information retrieval (Balikas et al., 2018; Yurochkin et al., 2019) , image processing (Bonneel et al., 2014) , statistical machine learning, and more recently, for ethics and fairness research (Kwegyir-Aggrey et al., 2021; Lokhande et al., 2020a) . OT is well-suited for tasks where dissimilarity between two or more probability distributions must be quantified; its success was made possible through dramatic improvements in algorithms (Cuturi, 2013; Solomon et al., 2015) that allow one to efficiently optimize commonly used functionals. In practice, OT is often used to estimate and minimize the distance between certain (data-derived) distributions, using an appropriately defined loss functional. When one seeks to operate on more than two distributions, however, newer constructions are necessary to effectively estimate distances and transports. To this end, a well studied idea in the literature is the "barycenter," identified by minimizing the pairwise distance between itself and all other distributions given. The d-dimensional proxy distance is then defined as the sum of the distances to the barycenter. Computing barycenters. Assuming that a suitably regularized form of the optimal transport loss is utilized, the pairwise distance calculation, by itself, can be efficient -in fact, in some cases, Sinkhorn iterations can be used (Cuturi, 2013) . On the other hand, to minimize distances to the mean, most algorithms typically operate by repeatedly estimating the barycenter and those pairwise distances, and using a "coupling" strategy to push points toward the barycenter, or in other cases, summing over all pairwise distances. As the number of distributions grows, robustness issues can exacerbate (Alvarez-Esteban et al., 2008) and the procedure is expensive (e.g., for 50 distributions, 50 bins). A potential alternative. Multi-marginal optimal transport (MMOT) is a related problem to the aforementioned task but to some extent, the literature has developed in parallel. In particular, MMOT focuses on identifying a joint distribution such that the marginals are defined by the input distributions over which we wish to measure the dissimilarity. The definition naturally extends the two-dimensional formulation, and recent work has explored a number of applications (Pass, 2015) . But the MMOT computation can be quite difficult, and only very recently have practical algorithms been identified ( Lin et al., 2022) . Additionally, even if a suitable method for computing an analogous measure of distance were available, minimizing this distance to reduce dissimilarity (push distributions closer to each other) is practically hard if standard interior point solvers are needed just to compute the distance itself. Why and where is dissimilarity important? Enforcing distributions to be similar is a generic goal whenever one wishes some outcome of interest to be agnostic about particular groups within the input data. In applications where training deep neural network models is needed, it is often a goal to enforce distribution similarity on model outputs. For example, in Jiang et al. ( 2020), the authors define fairness measures over the probability of the prediction, given ground truth labels. However, these methods are rarely extended to continuous measures among internal neural network activations, mainly due to the strong distributional assumptions needed (product of Gaussians) and the added algorithmic complexity of estimating the barycenter. These issues limit application of these ideas to only the final outputs of neural network models, where the distribution is typically binomial or multinomial. MMOT solutions might be employed here, but suffer similar computational limitations.

Contributions. (1)

We identify a particular form of the discrete multi-marginal optimal transport problem which admits an extremely fast and numerically robust solution. Exploiting a recent extension of the classical Earth Movers Distance (EMD) to a higher-dimensional Earth Mover's objective, we show that such a construction is equivalent to the discrete MMOT problem with Monge costs. (2) We show that minimization of this global distributional measure leads to the harmonization of input distributions very similar in spirit to the minimization of distributions to barycenters (see Figure 1 ). (3) We prove theoretical properties of our scheme, and show that the gradient can be read directly off from a primal/dual algorithm, alleviating the need for computationally intense pairwise couplings needed for barycenter approaches. (4) The direct availability of the gradient enables a specific neural network instantiation, and with a particular scaffolding provided by differentiable histograms, we can operate directly on network activations (anywhere in the network) to compute/minimize the d-MMOT. We establish via experiments that computing gradients used in backpropagation is fast, due to rapid access to solutions of the dual linear program. We compare with barycenter-like approaches in several settings, including common fairness applications.

2. RELATED WORK

Despite originating with (Monge, 1781), optimal transport continues to be an active area of research (Villani, 2009) . The literature is vast, but we list a few key developments. Early applications. Starting in (Peleg et al., 1989) , the idea of shifting "mass" around within an image was used for comparing images to each other and applied to image retrieval (Rubner et al., 2000) , where the term "Earth Mover's Distance" (EMD) was introduced. EMD has since been widely used in computer vision: e.g., for image warping (Zhang et al., 2011) , in supervised settings (Wang & Guibas, 2012), matching point sets (Cabello et al., 2008) and in scenarios involving histogram comparisons (Ling & Okada, 2007; Wang & Guibas, 2012; Haker et al., 2004) . Modern machine learning. The continuous optimal transport problem (Monge-Kantorovich problem), was originally presented in (Kantorovich, 1942; Kantorovitch, 1958) . While the continuous problem has been studied intensively (Villani, 2021) , uses of optimal transport within machine learning were possible due to (Cuturi, 2013) , which showed that entropic regularization enables fast algorithms for EMD (two distributions with discrete support), and contributed to the success of Wasserstein distances in applications. Consequently, problems including autoencoders (Tolstikhin et al., 2018 ), GANs (Arjovsky et al., 2017 ), domain adaptation (Courty et al., 2016 ), word embeddings (Huang et al., 2016) and classification tasks (Frogner et al., 2015) have benefited via the use of optimal transport.



Figure 1: Starting and ending state of minimizing a multi-marginal OT distance. Each iteration minimizes the generalized Earth Mover's objective, and then updates each histogram in the direction provided by the gradient.

