MAX-SLICED BURES DISTANCE FOR INTERPRETING DISCREPANCIES

Abstract

We propose the max-sliced Bures distance, a lower bound on the max-sliced Wasserstein-2 distance, to identify the instances associated with the maximum discrepancy between two samples. The max-slicing can be decomposed into two asymmetric divergences each expressed in terms of an optimal slice or equivalently a 'witness' function that has large magnitude evaluations on a localized subset of instances in one distribution versus the other. We show how witness functions can be used to detect and correct for covariate shift through reweighting and to evaluate generative adversarial networks. Unlike heuristic algorithms for the max-sliced Wasserstein-2 distance that may fail to find the optimal slice, we detail a tractable algorithm that finds the global optimal slice and scales to large sample sizes. As the Bures distance quantifies differences in covariance, we generalize the max-sliced Bures distance by using non-linear mappings, enabling it to capture changes in higher-order statistics. We explore two types of non-linear mappings: positive semidefinite kernels where the witness functions belong to a reproducing kernel Hilbert space, and task-relevant mappings corresponding to a neural network. In the context of samples of natural images, our approach provides an interpretation of the Fréchet Inception distance by identifying the synthetic and natural instances that are either over-represented or under-represented with respect to the other sample. We apply the proposed measure to detect imbalances in class distributions in various data sets and to critique generative models.

1. INTRODUCTION

Divergence measures quantify the dissimilarity between probability distributions. They are fundamental to hypothesis testing and the estimation and criticism of statistical models, and serve as cost functions for optimizing generative adversarial neural networks (GANs). Although a multitude of divergences exists, not all of them are interpretable. A divergence is interpretable if can be expressed in terms of a real-valued witness function ω(•) whose level-sets identify the specific subsets that are not well matched between the distributions, specifically, subsets which have much higher or much lower probability under one distribution versus the other. Localizing these discrepancies is useful for understanding and compensating for differences between two samples or distributions, to detect covariate shift (Shimodaira, 2000; Quionero-Candela et al., 2009; Lipton et al., 2018) or to evaluate generative models (Heusel et al., 2017) . While many divergences can be posed in terms of witness functions, not all witness functions are readily obtained or interpreted. From an information-theoretic perspective, the most natural witness function is the logarithm of the ratio of the densities (Kullback & Leibler, 1951) as in the Kullback-Leibler divergence. Applying other convex functions to the density ratio constitutes the family of f -divergences (Ali & Silvey, 1966; Rényi, 1961) , which include the Hellinger, Jensen-Shannon, and others. However, without a parametric model estimating the densities from samples is challenging (Vapnik, 2013) . Following Vapnik's advice to "try to avoid solving a more general problem as an intermediate step," previous work has sought to directly model the density ratio via kernel learning (Nguyen et al., 2008; Kanamori et al., 2009; Yamada et al., 2011; 2013; Saito et al., 2018; Lee et al., 2019) or to estimate an f -divergence by optimizing a function from a suitable family (Nguyen et al., 2010) such as a neural network Nowozin et al. (2016) . Witness functions need not rely on the density ratio. A wide class of divergences called integral probability metrics (IPMs) (Müller, 1997) , which include total variation, the Wasserstein-1 distance, maximum mean discrepancy (MMD) (Gretton et al., 2007), and others (Mroueh et al., 2017) , seek a witness function that maximizes the distance between the first moments of the witness function evaluations. In these cases the optimal witness function ω (•) has a greater expectation in one distribution compared to the other distribution. An IPM between two measures µ and ν is expressed as sup ω∈F |E X∼µ [ω(X)] -E Y ∼ν [ω(Y )]| for a given family of functions F. A class of related divergences are the max-sliced Wasserstein-p distances, which seek a linear (Deshpande et al., 2019) or non-linear slicing function (Kolouri et al., 2019) that maximizes the Wasserstein-p distance between the witness function evaluations for the two distributions. However, there are two difficulties with computing the max-sliced Wasserstein distance for two samples. The first is that it is a saddlepoint optimization problem, whose objective evaluation requires sorting the samples. Previous work has sought to approximate it using a first moment approximation (Deshpande et al., 2019) or to use a finite number of steps of a local optimizer (Kolouri et al., 2019) , without any guarantee of obtaining an optimal witness function. Another difficulty is in the interpretation of the obtained witness function. Unlike the density ratio, there is no notion of whether the witness function will take higher values for points associated to one distribution versus the other. To address both of these issues we propose a max-sliced distance that replaces the Wasserstein-2 distance with a second-moment approximation based on the Bures distance (Dowson & Landau, 1982; Gelbrich, 1990) . The Bures distance (Bures, 1969; Uhlmann, 1976) is a distance metric between positive semidefinite operators. It is well-known in quantum information theory (Nielsen & Chuang, 2000; Koltchinskii & Xia, 2015) and machine learning (Brockmeier et al., 2017; Muzellec & Cuturi, 2018; Zhang et al., 2020; Oh et al., 2020; De Meulemeester et al., 2020) .

1.1. CONTRIBUTION

We propose a novel IPM-like divergence measure, the "max-sliced Bures distance", to identify localized regions and instances associated with the maximum discrepancy between two samples. The distance is expressed as the maximal difference between the root mean square (RMS) of the witness function evaluations sup ω∈S E X∼µ [ω 2 (X)] -E Y ∼ν [ω 2 (Y )] , where S is an appropriate family of functions. As |∆| = max{∆, -∆}, the max-sliced Bures can be expressed as the maximum of one-sided max-sliced divergences with optimal witness functions, ω µ>ν = arg max ω∈S E µ [ω 2 (X)]-E ν [ω 2 (Y )], and ω µ<ν = arg max ω∈S E ν [ω 2 (Y )]-E µ [ω 2 (X)]. If the distributions are not well-matched, then ω µ>ν has large magnitude function evaluations under a 'localized' subset of µ and smaller magnitude values for ν, and the opposite for ω µ<ν . The two samples {x i } m i=1 , {y i } n i=1 can be sorted by the magnitude of the witness function evaluations. 1 Crucially, we detail a tractable optimization procedure that is guaranteed to yield a global optimum witness function for the one-sided max-sliced Bures divergence. When X = R d and the first or second moments distinguish the distributions, linear witness functions can be used S = {ω(•) = •, w : w ∈ S d-1 }, where S d-1 denotes the unit sphere in R d . The optimal witness function for the one-sided max-sliced Bures divergence ω µ>ν (•) = •, w µ>ν coincides with the subspace with the greatest difference in RMS, w µ>ν = arg max w∈S d-1 w E[XX ]w -w E[Y Y ]w. This optimization problem depends on the dimension d; after computation of the covariance matrices, it is independent of the sample sizes m ≥ n. In comparison, the optimal slice for the max-sliced Wasserstein may not be obtained, and even gradient ascent to a local optimum requires O(m log m) at each function/gradient evaluation. Furthermore, the slice that maximizes the max-sliced Wasser-1 Four groups of 'witness points' (top-K instances) can be inspected to identify any discrepancies: ω 2 µ>ν (x π(1) ) ≥ • • • ≥ ω 2 µ>ν (x π(K) ) π sorts {x  where π, π, σ, σ denote permutations and and denote expected inequalities with a large difference.



i } m i=1 to reveal examples from μ with large ω 2 µ>ν ω 2 µ>ν (y σ(1) ) ≥ • • • ≥ ω 2 µ>ν (y σ(K) ) σ sorts {y i } n i=1 to find the examples from ν with large ω 2 µ>ν , (x π(1) ) ≥ • • • ≥ ω 2 µ<ν (x π(K) ) π sorts {x i } m i=1 to find examples from μ with large ω 2 µ<ν ω 2 µ<ν (y σ(1) ) ≥ • • • ≥ ω 2 µ<ν (y σ(K) )σ sorts {y i } n i=1 to find the examples from ν with large ω 2 µ<ν ,

