REGULARIZED MUTUAL INFORMATION NEURAL ES-TIMATION

Abstract

With the variational lower bound of mutual information (MI), the estimation of MI can be understood as an optimization task via stochastic gradient descent. In this work, we start by showing how Mutual Information Neural Estimator (MINE) searches for the optimal function T that maximizes the Donsker-Varadhan representation. With our synthetic dataset, we directly observe the neural network outputs during the optimization to investigate why MINE succeeds or fails: We discover the drifting phenomenon, where the constant term of T is shifting through the optimization process, and analyze the instability caused by the interaction between the logsumexp and the insufficient batch size. Next, through theoretical and experimental evidence, we propose a novel lower bound that effectively regularizes the neural network to alleviate the problems of MINE. We also introduce an averaging strategy that produces an unbiased estimate by utilizing multiple batches to mitigate the batch size limitation. Finally, we show that L 2 regularization achieves significant improvements in both discrete and continuous settings.

1. INTRODUCTION

Identifying a relationship between two variables of interest is one of the great linchpins in mathematics, statistics, and machine learning (Goodfellow et al., 2014; Ren et al., 2015; He et al., 2016; Vaswani et al., 2017) . Not surprisingly, this problem is closely tied to measuring the relationship between two variables. One of the fundamental approaches is information theory-based measurement, namely the estimation of mutual information (MI). Recently, Belghazi et al. (2018) proposed a neural network-based MI estimator, which is called Mutual Information Neural Estimator (MINE). Due to its differentiability and applicability, it motivated several researches such as various loss functions bridging the gap between latent variables and representations (Chen et al., 2016; Belghazi et al., 2018; Oord et al., 2018; Hjelm et al., 2019) , and methodologies identifying the relationship between input, output and hidden variables (Tishby & Zaslavsky, 2015; Shwartz-Ziv & Tishby, 2017; Saxe et al., 2019) . Although many have shown the computational tractability and its usefulness, many intriguing questions about the MI estimator itself remain unanswered. • How does the neural network inside MINE behave when estimating MI? • Why does MINE loss diverge in some cases? Where does the instability originate from? • Can we make a more stable estimate on small batch size settings? • Why does the value of each term in MINE loss are shifting even after the estimated MI converges? Are there any side effects of this phenomenon? This study attempts to answer these questions by designing a synthetic dataset to interpret network outputs. Through keen observation, we dissect the Donsker-Varadhan representation (DV) one by one and conclude that the instability and the drifting are caused by the interrelationship between the stochastic gradient descent based optimization and the theoretical properties of DV. Based on these insights, we extend DV to draw out a novel lower bound for MI estimation, which mitigates the aforementioned problems, and circumvents the batch size limitation by maintaining the history of network outputs. We furthermore look into the L 2 regularizer form of our bound in detail and analyze how various hyper-parameters impact the estimation of MI and its dynamics during the optimization process. Finally, we demonstrate that our method, called ReMINE, performs favorably against other existing estimators in multiple settings.

Definition of Mutual Information

The mutual information between two random variables X and Y is defined as I(X; Y ) = D KL (P XY ||P X ⊗ P Y ) = E P XY [log dP XY dP X⊗Y ] where P XY and P X ⊗ P Y are the joint and the marginal distribution, respectively.  E P XY [T ] -log(E P X ⊗P Y [e T ]). ( ) where both the expectations E P XY [T ] and E P X ⊗P Y [e T ] are finite. However, as the second term in Eq. ( 2) leads to biased gradient estimates with a limited number of samples, MINE uses exponential moving averages of mini-batches to alleviate this problem. To further improve the sampling efficiency of MINE, Lin et al. (2019) proposes DEMINE that partitions the samples into train and test sets.  E P XY [T ] -E P X ⊗P Y [e T -1 ], where the bound is tight when T = log(dP/dQ) + 1. Nevertheless, if MI is too big, estimators exhibit large bias or variation (McAllester & Stratos, 2018; Song & Ermon, 2020) 

3. HOW DOES MINE ESTIMATE?

Before going any further, we first observe the statistics network output in MINE during the optimization process using our novel synthetic dataset, and identify and analyze the following phenomena: • Drifting phenomenon (Fig. 1a ), where estimates of E P XY [T ] and log(E P X ⊗P Y [e T ]) drifts in parallel even after the MI estimate converges. • Exploding network outputs (Fig. 1d ), where smaller batch sizes cause the network outputs to explode, but training with larger batch size reduces the variance of MI estimates (Fig. 2a ). • Bimodal distribution of the outputs (Fig. 2b ), where the network not only classifies input samples but also clusters the network outputs as the MI estimate converges. Based on these observations, we analyze the inner workings of MINE, and understand how batch size affects MI estimation.



Other representations based on f-measures are also proposed by Nguyen et al. (2010); Nowozin et al. (2016), which produce unbiased estimates and hence eliminating the need for additional techniques. Lemma 2. (Nguyen, Wainwright, and Jordan representation (NWJ)) I(X; Y ) = sup T :Ω→R

D KL is the Kullback-Leibler (KL) divergence. Without loss of generality, we consider P XY and P X ⊗ P Y as being distributions on a compact domain Ω ⊂ R d .

. To balance in between, Poole et al. (2019) design a new estimator I α that interpolates Contrastive Predictive Coding (Oord et al., 2018) and NWJ.Yet, these methods concentrate on various stabilization techniques rather than revealing the dynamics inside the black box. In this paper, we focus on the DV representation and provide intuitive understandings of the inner mechanisms of neural network-based estimators. Based on the analysis, we introduce a new regularization term for MINE, which can effectively remedy its weaknesses theoretically and practically.

