DEMI: DISCRIMINATIVE ESTIMATOR OF MUTUAL INFORMATION

Abstract

Estimating mutual information between continuous random variables is often intractable and extremely challenging for high-dimensional data. Recent progress has leveraged neural networks to optimize variational lower bounds on mutual information. Although showing promise for this difficult problem, the variational methods have been theoretically and empirically proven to have serious statistical limitations: 1) many methods struggle to produce accurate estimates when the underlying mutual information is either low or high; 2) the resulting estimators may suffer from high variance. Our approach is based on training a classifier that provides the probability that a data sample pair is drawn from the joint distribution rather than from the product of its marginal distributions. Moreover, we establish a direct connection between mutual information and the average log odds estimate produced by the classifier on a test set, leading to a simple and accurate estimator of mutual information. We show theoretically that our method and other variational approaches are equivalent when they achieve their optimum, while our method sidesteps the variational bound. Empirical results demonstrate high accuracy of our approach and the advantages of our estimator in the context of representation learning.

1. INTRODUCTION

Mutual information (MI) measures the information that two random variables share. MI quantifies the statistical dependency -linear and non-linear -between two variables. This property has made MI a crucial measure in machine learning. In particular, recent work in unsupervised representation learning has built on optimizing MI between latent representations and observations (Chen et al., 2016; Zhao et al., 2018; Oord et al., 2018; Hjelm et al., 2018; Tishby & Zaslavsky, 2015; Alemi et al., 2018; Ver Steeg & Galstyan, 2014) . Maximization of MI has long been a default method for multi-modality image registration (Maes et al., 1997) , especially in medical applications (Wells III et al., 1996) , though in most work the dimensionality of the random variables is very low. Here, coordinate transformations on images are varied to maximize their MI. Estimating MI from finite data samples has been challenging and is intractable for most continuous probabilistic distributions. Traditional MI estimators (Suzuki et al., 2008; Darbellay & Vajda, 1999; Kraskov et al., 2004; Gao et al., 2015) do not scale well to modern machine learning problems with high-dimensional data. This impediment has motivated the construction of variational bounds for MI (Nguyen et al., 2010; Barber & Agakov, 2003) ; in recent years this has led to maximization procedures that use deep learning architectures to parameterize the space of functions, exploiting the expressive power of neural networks (Song & Ermon, 2019; Belghazi et al., 2018; Oord et al., 2018; Mukherjee et al., 2020) . Unfortunately, optimizing lower bounds on MI has serious statistical limitations. Specifically, McAllester & Stratos (2020) showed that any high-confidence distribution-free lower bound cannot exceed O(logN ), where N is the number of samples. This implies that if the underlying MI is high, it cannot be accurately and reliably estimated by variational methods like MINE (Belghazi et al., 2018) . Song & Ermon (2019) further categorized the state-of-the-art variational methods into "generative" and "discriminative" approaches, depending on whether they estimate the probability densities or the density ratios. They showed that the "generative" approaches perform poorly when the underlying MI is small and "discriminative" approaches perform poorly when MI is large; moreover, certain approaches like MINE (Belghazi et al., 2018) are prone to high variances. We propose a simple discriminative approach that avoids the limitations of previous discriminative methods that are based on variational bounds. Instead of estimating density or attempting to predict one data variable from another, our method estimates the likelihood that a sample is drawn from the joint distribution versus the product of marginal distributions. A similar classifier-based approach was used by Lopez-Paz & Oquab (2017) for "two sample testing" -hypothesis tests about whether two samples are from the same distribution or not. If the two distributions are the joint and product of the marginals, then the test is for independence. A generalization of this work was used by Sen et al. (2017) to test for conditional independence. We show that accurate performance on this classification task provides an estimate of the log odds. This can greatly simplify the MI estimation task in comparison with generative approaches: estimating a single likelihood ratio may be easier than estimating three distributions (the joint and the two marginals). Moreover, classification tasks are generally amicable to deep learning, while density estimation remains challenging in many cases. Our approach avoids the estimation of the partition function, which induces large variance in most discriminative methods (Song & Ermon, 2019) . Our empirical results bear out these conceptual advantages. Our approach, as well as other sampling-based methods such as MINE, uses the given joint/paired data with derived "unpaired" data that captures the product of the marginal distributions p(x)p(y). The unpaired data can be synthesized via permutations or resampling of the paired data. This construction, which synthesizes unpaired data and then defines a metric to encourage paired data points to map closer than the unpaired data in the latent space, has previously been used in other machine learning applications, such as audio-video and image-text joint representation learning (Harwath et al., 2016; Chauhan et al., 2020) . Recent contrastive learning approaches (Tian et al., 2019; Hénaff et al., 2019; Chen et al., 2020; He et al., 2020) further leverage a machine learning model to differentiate paired and unpaired data mostly in the context of unsupervised representation learning. Simonovsky et al. ( 2016) used paired and unpaired data in conjunction with a classifier-based loss function for patch-based image registration. This paper is organized as follows. In Section 2, we derive our approach to estimating MI. Section 2.4 discusses connections to related approaches, including MINE. This is followed by empirical evaluation in Section 3. Our experimental results on synthetic and real image data demonstrate the advantages of the proposed discriminative classification-based MI estimator, which has higher accuracy than the state-of-the-art variational approaches and a good bias/variance tradeoff.

2. METHODS

Let x ∈ X and y ∈ Y be two random variables generated by joint distribution p : X × Y → R + . Mutual Information (MI) I(x; y) ∆ = E p(x,y) log p(x, y) p(x)p(y) is a measure of dependence between x and y. Let D = {(x i , y i ) n i=1 } be a set of n independent identically distributed (i.i.d.) samples from p(x, y). The law of large numbers implies Îp (D) ∆ = 1 n n i=1 log p(x i , y i ) p(x i )p(y i ) → I(x; y) as n → ∞, which suggests a simple estimation strategy via sampling. Unfortunately, the joint distribution p(x, y) is often unknown and therefore the estimate in Eq. ( 2) cannot be explicitly computed. Here we develop an approach to accurately approximating the estimate Îp (D) based on discriminative learning. In our development, we will find it convenient to define a Bernoulli random variable z ∈ {0, 1} and to "lift" the distribution p(x, y) to the product space X × Y × {0, 1}. We thus define a family of

