AUGMENTED SLICED WASSERSTEIN DISTANCES

Abstract

While theoretically appealing, the application of the Wasserstein distance to large-scale machine learning problems has been hampered by its prohibitive computational cost. The sliced Wasserstein distance and its variants improve the computational efficiency through random projection, yet they suffer from low projection efficiency because the majority of projections result in trivially small values. In this work, we propose a new family of distance metrics, called augmented sliced Wasserstein distances (ASWDs), constructed by first mapping samples to higher-dimensional hypersurfaces parameterized by neural networks. It is derived from a key observation that (random) linear projections of samples residing on these hypersurfaces would translate to much more flexible nonlinear projections in the original sample space, so they can capture complex structures of the data distribution. We show that the hypersurfaces can be optimized by gradient ascent efficiently. We provide the condition under which the ASWD is a valid metric and show that this can be obtained by an injective neural network architecture. Numerical results demonstrate that the ASWD significantly outperforms other Wasserstein variants for both synthetic and real-world problems.

1. INTRODUCTION

Comparing samples from two probability distributions is a fundamental problem in statistics and machine learning. The optimal transport (OT) theory (Villani, 2008) provides a powerful and flexible theoretical tool to compare degenerative distributions by accounting for the metric in the underlying spaces. The Wasserstein distance, which arises from the optimal transport theory, has become an increasingly popular choice in various machine learning domains ranging from generative models to transfer learning (Gulrajani et al., 2017; Arjovsky et al., 2017; Kolouri et al., 2019b; Lee et al., 2019; Cuturi & Doucet, 2014; Claici et al., 2018; Courty et al., 2016; Shen et al., 2018; Patrini et al., 2018) . Despite its favorable properties, such as robustness to disjoint supports and numerical stability (Arjovsky et al., 2017) , the Wasserstein distance suffers from high computational complexity especially when the sample size is large. Besides, the Wasserstein distance itself is the result of an optimization problem -it is non-trivial to be integrated into an end-to-end training pipeline of deep neural networks, unless one can make the solver for the optimization problem differentiable. Recent advances in computational optimal transport methods focus on alternative OT-based metrics that are computationally efficient and differentiably solvable (Peyré & Cuturi, 2019) . Entropy regularization is introduced in the Sinkhorn distance (Cuturi, 2013) and its variants (Altschuler et al., 2017; Dessein et al., 2018; Lin et al., 2019) to smooth the optimal transport problem; as a result, iterative matrix scaling algorithms can be applied to provide significantly faster solutions with improved sample complexity (Genevay et al., 2019) . An alternative approach is to approximate the Wasserstein distance through slicing, i.e. linearly projecting, the distributions to be compared. The sliced Wasserstein distance (SWD) (Bonneel et al., 2015) is defined as the expected value of Wasserstein distances between one-dimensional random projections of high-dimensional distributions. The SWD shares similar theoretical properties with the Wasserstein distance (Bonnotte, 2013) and is computationally efficient since the Wasserstein distance in one-dimensional space has a closed form solution based on sorting. Deshpande et al. ( 2019) extends the sliced Wasserstein distance to the max-sliced Wasserstein distance (max-SWD), by finding a single projection direction with the maximal distance between projected samples. In Nguyen et al. ( 2020), the distributional sliced Wasserstein distance (DSWD) finds a distribution of projections that maximizes the expected distances over these projections. The subspace robust Wasserstein distance extends the idea of slicing to projecting distributions on linear subspaces (Paty & Cuturi, 2019) . However, the linear nature c) and (d) are distance histograms for the ASWD and the SWD between two 100-dimensional Gaussians. Figure 1(a) shows that the injective neural network embedded in the ASWD learns data patterns (in the X-Y plane) and produces well-separate projected values (Z-axis) between distributions in a random projection direction. The high projection efficiency of the ASWD is evident in Figure 1(c) , as almost all random projection directions in a 100-dimensional space lead to significant distances between 1-dimensional projections. In contrast, random linear mappings in the SWD often produce closer 1-d projections (Z-axis) (Figure 1(b) ); as a result, a large percentage of random projection directions in the 100-d space result in trivially small distances (Figure 1(d) ), leading to a low projection efficiency in high-dimensional spaces. of these projections usually leads to low projection efficiency of the resulted metrics in high-dimensional spaces (Deshpande et al., 2019; Liutkus et al., 2019; Kolouri et al., 2019a) . More recently, there are growing interests and evidences that slice-based Wasserstein distances with nonlinear projections can improve the projection efficiency, leading to a reduced number of projections needed to capture the structure of the data distribution (Kolouri et al., 2019a; Nguyen et al., 2020) . (Kolouri et al., 2019a) extends the connection between the sliced Wasserstein distance and the Radon transform (Abraham et al., 2017) to define generalized sliced Wasserstein distances (GSWDs) by utilizing generalized Radon transforms (GRTs). It is shown in (Kolouri et al., 2019a ) that the GSWD is indeed a metric if and only if the adopted GRT is injective. Injective GRTs are also used to extend the DSWD to the distributional generalized sliced Wasserstein distance (DGSWD) (Nguyen et al., 2020) . However, both the GSWD and the DGSWD are restricted by the limited class of injective GRTs, which utilize the circular functions and a finite number of harmonic polynomial functions with odd degrees as their defining function (Kuchment, 2006; Ehrenpreis, 2003) . The results reported in (Kolouri et al., 2019a; Nguyen et al., 2020) show impressive performance from the GSWD and the DGSWD, yet they require one to specify a particular form of defining function from the aforementioned limited class of candidates. However, the selection of defining function is usually task-dependent and needs domain knowledge. In addition, the impact on performance from different defining functions is still unclear. One variant of the GSWD (Kolouri et al., 2019a) is the GSWD-NN, which generates projections directly with neural network outputs to remove the limitations of slicing distributions with predefined GRTs. In the GSWD-NN, the number of projections, which equals the number of nodes in the neural network's output layer, is fixed. Hence different neural networks are needed if one wants to change the number of projections. There is also no random projections involved in the resulted GSWD-NN, since the projection results are determined by the neural network's weights. Besides, the GSWD-NN is a pseudo-metric since it uses a vanilla neural network, rather than the Radon transform or GRTs, as its push-forward operator. Therefore, the GSWD-NN does not fit into the theoretical framework of the GSWD and does not inherit its geometric properties. In this paper, we present the augmented sliced Wasserstein distance (ASWD), a distance metric constructed by first mapping samples to hypersurfaces in an augmented space, which enables flexible nonlinear slicing of data distributions for improved projection efficiency (See Figure 1 ). Our main contributions include: (i) We exploit the capacity of nonlinear projections employed in the ASWD by constructing injective mapping with arbitrary neural networks; (ii) We prove that the ASWD is a valid distance metric; (iii) We provide a mechanism in which the hypersurface where high-dimensional distributions are projected onto can be



Figure 1: (a) and (b) are visualizations of projections for the ASWD and the SWD between two 2-dimensional Gaussians. (c) and (d) are distance histograms for the ASWD and the SWD between two 100-dimensional Gaussians. Figure1(a)shows that the injective neural network embedded in the ASWD learns data patterns (in the X-Y plane) and produces well-separate projected values (Z-axis) between distributions in a random projection direction. The high projection efficiency of the ASWD is evident in Figure1(c), as almost all random projection directions in a 100-dimensional space lead to significant distances between 1-dimensional projections. In contrast, random linear mappings in the SWD often produce closer 1-d projections (Z-axis) (Figure1(b)); as a result, a large percentage of random projection directions in the 100-d space result in trivially small distances (Figure1(d)), leading to a low projection efficiency in high-dimensional spaces.

