MAX-SLICED BURES DISTANCE FOR INTERPRETING DISCREPANCIES

Abstract

We propose the max-sliced Bures distance, a lower bound on the max-sliced Wasserstein-2 distance, to identify the instances associated with the maximum discrepancy between two samples. The max-slicing can be decomposed into two asymmetric divergences each expressed in terms of an optimal slice or equivalently a 'witness' function that has large magnitude evaluations on a localized subset of instances in one distribution versus the other. We show how witness functions can be used to detect and correct for covariate shift through reweighting and to evaluate generative adversarial networks. Unlike heuristic algorithms for the max-sliced Wasserstein-2 distance that may fail to find the optimal slice, we detail a tractable algorithm that finds the global optimal slice and scales to large sample sizes. As the Bures distance quantifies differences in covariance, we generalize the max-sliced Bures distance by using non-linear mappings, enabling it to capture changes in higher-order statistics. We explore two types of non-linear mappings: positive semidefinite kernels where the witness functions belong to a reproducing kernel Hilbert space, and task-relevant mappings corresponding to a neural network. In the context of samples of natural images, our approach provides an interpretation of the Fréchet Inception distance by identifying the synthetic and natural instances that are either over-represented or under-represented with respect to the other sample. We apply the proposed measure to detect imbalances in class distributions in various data sets and to critique generative models.

1. INTRODUCTION

Divergence measures quantify the dissimilarity between probability distributions. They are fundamental to hypothesis testing and the estimation and criticism of statistical models, and serve as cost functions for optimizing generative adversarial neural networks (GANs). Although a multitude of divergences exists, not all of them are interpretable. A divergence is interpretable if can be expressed in terms of a real-valued witness function ω(•) whose level-sets identify the specific subsets that are not well matched between the distributions, specifically, subsets which have much higher or much lower probability under one distribution versus the other. Localizing these discrepancies is useful for understanding and compensating for differences between two samples or distributions, to detect covariate shift (Shimodaira, 2000; Quionero-Candela et al., 2009; Lipton et al., 2018) or to evaluate generative models (Heusel et al., 2017) . While many divergences can be posed in terms of witness functions, not all witness functions are readily obtained or interpreted. From an information-theoretic perspective, the most natural witness function is the logarithm of the ratio of the densities (Kullback & Leibler, 1951) as in the Kullback-Leibler divergence. Applying other convex functions to the density ratio constitutes the family of f -divergences (Ali & Silvey, 1966; Rényi, 1961) , which include the Hellinger, Jensen-Shannon, and others. However, without a parametric model estimating the densities from samples is challenging (Vapnik, 2013) . Following Vapnik's advice to "try to avoid solving a more general problem as an intermediate step," previous work has sought to directly model the density ratio via kernel learning (Nguyen et al., 2008; Kanamori et al., 2009; Yamada et al., 2011; 2013; Saito et al., 2018; Lee et al., 2019) or to estimate an f -divergence by optimizing a function from a suitable family (Nguyen et al., 2010) such as a neural network Nowozin et al. (2016) .

