REGULARIZED MUTUAL INFORMATION NEURAL ES-TIMATION

Abstract

With the variational lower bound of mutual information (MI), the estimation of MI can be understood as an optimization task via stochastic gradient descent. In this work, we start by showing how Mutual Information Neural Estimator (MINE) searches for the optimal function T that maximizes the Donsker-Varadhan representation. With our synthetic dataset, we directly observe the neural network outputs during the optimization to investigate why MINE succeeds or fails: We discover the drifting phenomenon, where the constant term of T is shifting through the optimization process, and analyze the instability caused by the interaction between the logsumexp and the insufficient batch size. Next, through theoretical and experimental evidence, we propose a novel lower bound that effectively regularizes the neural network to alleviate the problems of MINE. We also introduce an averaging strategy that produces an unbiased estimate by utilizing multiple batches to mitigate the batch size limitation. Finally, we show that L 2 regularization achieves significant improvements in both discrete and continuous settings.

1. INTRODUCTION

Identifying a relationship between two variables of interest is one of the great linchpins in mathematics, statistics, and machine learning (Goodfellow et al., 2014; Ren et al., 2015; He et al., 2016; Vaswani et al., 2017) . Not surprisingly, this problem is closely tied to measuring the relationship between two variables. One of the fundamental approaches is information theory-based measurement, namely the estimation of mutual information (MI). Recently, Belghazi et al. (2018) proposed a neural network-based MI estimator, which is called Mutual Information Neural Estimator (MINE). Due to its differentiability and applicability, it motivated several researches such as various loss functions bridging the gap between latent variables and representations (Chen et al., 2016; Belghazi et al., 2018; Oord et al., 2018; Hjelm et al., 2019) , and methodologies identifying the relationship between input, output and hidden variables (Tishby & Zaslavsky, 2015; Shwartz-Ziv & Tishby, 2017; Saxe et al., 2019) . Although many have shown the computational tractability and its usefulness, many intriguing questions about the MI estimator itself remain unanswered. • How does the neural network inside MINE behave when estimating MI? • Why does MINE loss diverge in some cases? Where does the instability originate from? • Can we make a more stable estimate on small batch size settings? • Why does the value of each term in MINE loss are shifting even after the estimated MI converges? Are there any side effects of this phenomenon? This study attempts to answer these questions by designing a synthetic dataset to interpret network outputs. Through keen observation, we dissect the Donsker-Varadhan representation (DV) one by one and conclude that the instability and the drifting are caused by the interrelationship between the stochastic gradient descent based optimization and the theoretical properties of DV. Based on these insights, we extend DV to draw out a novel lower bound for MI estimation, which mitigates the aforementioned problems, and circumvents the batch size limitation by maintaining the history of network outputs. We furthermore look into the L 2 regularizer form of our bound in detail and analyze how various hyper-parameters impact the estimation of MI and its dynamics during the optimization process. Finally, we demonstrate that our method, called ReMINE, performs favorably against other existing estimators in multiple settings.

Definition of Mutual Information

The mutual information between two random variables X and Y is defined as I(X; Y ) = D KL (P XY ||P X ⊗ P Y ) = E P XY [log dP XY dP X⊗Y ] where P XY and P X ⊗ P Y are the joint and the marginal distribution, respectively. D KL is the Kullback-Leibler (KL) divergence. Without loss of generality, we consider P XY and P X ⊗ P Y as being distributions on a compact domain Ω ⊂ R d . Variational Mutual Information Estimation Recent works on MI estimation focus on training a neural network to represent a tight variational MI lower bound, where there are several types of representations. Although these methods are known to have statistical limitations (McAllester & Stratos, 2018) , their versatility is widely employed nonetheless (Hjelm et al., 2019; Veličković et al., 2018; Polykovskiy et al., 2018; Ravanelli & Bengio, 2018; Franc ¸ois-Lavet et al., 2019) .  E P XY [T ] -log(E P X ⊗P Y [e T ]). ( ) where both the expectations E P XY [T ] and E P X ⊗P Y [e T ] are finite. However, as the second term in Eq. ( 2) leads to biased gradient estimates with a limited number of samples, MINE uses exponential moving averages of mini-batches to alleviate this problem. To further improve the sampling efficiency of MINE, Lin et al. (2019) proposes DEMINE that partitions the samples into train and test sets. Other representations based on f-measures are also proposed by Nguyen et al. (2010) ; Nowozin et al. (2016) , which produce unbiased estimates and hence eliminating the need for additional techniques. Lemma 2. (Nguyen, Wainwright, and Jordan representation (NWJ)) I(X; Y ) = sup T :Ω→R E P XY [T ] -E P X ⊗P Y [e T -1 ], where the bound is tight when T = log(dP/dQ) + 1. Nevertheless, if MI is too big, estimators exhibit large bias or variation (McAllester & Stratos, 2018; Song & Ermon, 2020) . To balance in between, Poole et al. (2019) design a new estimator I α that interpolates Contrastive Predictive Coding (Oord et al., 2018) and NWJ. Yet, these methods concentrate on various stabilization techniques rather than revealing the dynamics inside the black box. In this paper, we focus on the DV representation and provide intuitive understandings of the inner mechanisms of neural network-based estimators. Based on the analysis, we introduce a new regularization term for MINE, which can effectively remedy its weaknesses theoretically and practically.

3. HOW DOES MINE ESTIMATE?

Before going any further, we first observe the statistics network output in MINE during the optimization process using our novel synthetic dataset, and identify and analyze the following phenomena: • Drifting phenomenon (Fig. 1a ), where estimates of E P XY [T ] and log(E P X ⊗P Y [e T ]) drifts in parallel even after the MI estimate converges. • Exploding network outputs (Fig. 1d ), where smaller batch sizes cause the network outputs to explode, but training with larger batch size reduces the variance of MI estimates (Fig. 2a ). • Bimodal distribution of the outputs (Fig. 2b ), where the network not only classifies input samples but also clusters the network outputs as the MI estimate converges. Based on these observations, we analyze the inner workings of MINE, and understand how batch size affects MI estimation. As there are 100 samples per iteration for (b) and 10 for (d), total number of outputs is 300000 and 30000, respectively. Note that the base-10 logarithmic scale is used on the y-axis for (c) and (d).

3.1. EXPERIMENT SETTINGS

Dataset. We designed a one-hot discrete dataset with uniform distribution U (1, N ) to estimate I(X; X) = log N with MINE, while easily discerning samples of joint distribution X, X from marginal distribution X ⊗ X. Additionally, we use one-hot representation to increase the input dimension, resulting in more network weights to train. In this paper, we used N = 16. Network settings. We designed a simple statistics network T with a concatenated vector of dimension N × 2 = 32 as input. We pass the input through two fully connected layers with ReLU activation by widths: 32 -256 -1. The last layer outputs a single scalar with no bias and activation. We used stochastic gradient descent (SGD) with learning rate 0.1 to optimize the statistics network.

3.2. OBSERVATIONS

We can observe the drifting phenomenon in Fig. 1a , where the statistics of the network output are adrift even after the convergence of MINE loss. The analysis for this phenomenon will be covered in more detail with theoretical results in Section 4. This section will focus extensively on the relationship between batch size and logsumexp, and the classifying nature of MINE loss. Batch size limitation. MINE estimates in a batch-wise manner, i.e., MINE uses samples inside a single batch when estimating E P XY [T ] and log(E P X ⊗P Y [e T ]). Consider the empirical DV Î(X; Y ) = sup T θ :Ω→R E (n) P XY [T θ ] -log E (n) P X ⊗P Y [e T θ ], where E (n) P is an empirical average associated with batch size n. Therefore, the variance of Î(X; Y ) increases as the batch size decreases. The observation in Fig. 2a is consistent with the batch size limitation problem (McAllester & Stratos, 2018; Song & Ermon, 2020) , which shows that MINE must have a batch size proportional to the exponential of true MI to control the variance of the estimation. Exploding network outputs. We can understand the output explosion problem in detail by comparing Fig. 1b and Fig. 1d . During optimization, network outputs of joint samples get increased by the first term of Eq. ( 4), where the inverse of the batch size is multiplied to the gradient of each network output. On the other hand, the output of marginal samples get decreased by the second term of Eq. ( 4) which concentrates the gradient to the maximum output. Note that the second term is dominated by the maximum network output due to logsumexp, which is a smooth approximation of max. As a single batch is sampled from the true underlying distribution, joint case samples may or may not exist. If it exists, then the joint sample output dominates the term, and its output gets de- creased accordingly, while other non-joint sample outputs also get slightly decreased. In summary, the second term acts as an occasional restriction for the increase of joint sample network outputs. 1The second term imposes a problem when the batch size is not large enough. With reduced sample size, joint samples that dominate the second term are getting rare. For the case where joint sample does not exist, marginal sample network outputs decrease much faster compared to the opposite case, and joint sample network outputs are more rarely restricted; thus network outputs diverge in both directions (Fig. 1d ), and the second term vibrates between two extreme values depending on whether the joint case occurred (Fig. 1c ). This obviously leads to numerical instability and estimation failure. Bimodal distribution of the outputs. We furthermore observed network outputs directly, as both averaging terms of DV can inhibit the observation of how the statistics network acts on each sample. From the neural network viewpoint, whether each sample is realized from the joint or the marginal distribution is not distinguishable for the joint cases in marginal samples. Therefore, the statistics network has no means but to return the same output value, as it can be seen in Fig. 1b , indicating that the network can only separate joint and non-joint cases. This approach provides a clue that the network is solving a classification task, isolating joint samples from marginal samples, although the statistics network is only provided with samples from the joint and marginal distribution. We observed the distribution of network outputs in detail, on the case where only the marginal samples are fed to the statistics network in Fig. 2b . It stands to reason that the network outputs follow a particular distribution, as the network output estimates a log-likelihood ratio between joint and marginal distribution with an added constant (Lemma 3). Through this, we can view the estimated MI as a sample average; hence Fig. 2a resembles the Gaussian noise by Central Limit Theorem (CLT). Let us continue by concentrating on each network output. There is no distinction between the loglikelihood ratio of the samples in the same class for the one-hot discrete dataset: j for the joint case and j for the non-joint case. This explains the classifying nature of the statistics network, and there have to be exactly two clusters in Fig. 2b . Also, as j becomes -∞, exp(j) nears 0, and exp(j) is few magnitudes bigger than exp(j) (see Fig. 2b ). As mentioned above, few joint cases dominate the second term, so the second term becomes inherently noisier than the first term. Note that the effectiveness of conventional methods, such as applying exponential moving average to the second term (Belghazi et al., 2018) or clipping the network output values to restrict the magnitude of network outputs (Song & Ermon, 2020) can also be understood with the analysis above. In addition, we cannot interpret the network outputs directly as the log-likelihood ratio due to unregularized outputs, or the drifting problem. We will look into this fundamental limitation of MINE in more detail in the next section.

4. THE PROPOSED METHOD: REMINE

We look into the drifting problem of Fig. 1a in detail to introduce a novel loss with a regularizer that constrains the magnitude of network outputs and enables the use of network outputs from multiple batches. All proofs are provided in Appendix. First, let us concentrate on the optimal function T * , which can be directly drawn from the DV representation. We start with the results of Belghazi et al. ( 2018), Lemma 3. (Optimal Function T * ). Let P and Q be distributions on Ω ⊂ R d . For some constant C * ∈ R, there exists an optimal function T * = log dP dQ + C * such that D KL (P||Q) = E P [T * ] - log(E Q [e T * ]). Note that Mukherjee et al. (2020) directly utilize this fact to model the statistics network. Let us extend the result above, and show that C * can be any real number and T * still be optimal. Lemma 4. (Family of Optimal Functions). For any constant C ∈ R, T = log dP dQ + C satisfies D KL (P||Q) = E P [T ] -log(E Q [e T ]). This explains the drifting phenomenon we encountered in Fig. 1a . C * is drifting as there are no penalties on different C * s. Drifting can be stopped by freezing the network weights (Lin et al., 2019) , but is there a direct way to regularize C so that the neural network can concentrate on finding a single solution, rather than a family of solutions? Theorem 5. (ReMINE Loss Function) Let d be a distance function on R. For any constant C ∈ R and function T : Ω → R, D KL (P||Q) = sup T :Ω→R E P [T ] -log(E Q [e T ]) -d(log(E Q [e T ]), C ), Note that for the optimal T * , E P  i , y (1) i ), • • • (x (J) i , y (J) i ) ∼ P XY Draw M samples from the marginal distribution: (x (1) i , y (1) i ), • • • (x (M ) i , y (M ) i ) ∼ P X ⊗ P Y Evaluate the lower-bound: ÊP XY ← 1 J J j=1 T θ (x (j) i , y (j) i ) , ÊP X ⊗P Y ← log( 1 M M m=1 e T θ (x (m) i ,y (m) i ) ) ν(θ) ← ÊP XY -ÊP X ⊗P Y -d( ÊP X ⊗P Y , C) Update the statistics network parameters: θ ← θ + ∇ θ ν(θ) Estimate MI based on the last window W = [max(0, i -K + 1), min(K, i)] of size K: Î(X; Y ) = 1 J•|W | w∈W J j=1 T θ (x (j) w , y (j) w ) -log( 1 M •|W | w∈W M m=1 e T θ (x (m) w ,y w ) ) Next iteration: i ← i + 1 until convergence; ReMINE differs from other MINE-like methods that rely on single batch estimates, as ReMINE can utilize all the previous network outputs after the convergence of Î(X; Y ) to establish a final estimate. To demonstrate why these MINE-like methods cannot use the network outputs from multiple batches, we first describe two basic averaging strategies and then show that both methods produce a biased estimate when the statistics network T is drifting. Theorem 6. (Estimation Bias caused by Drifting) The two averaging strategies below produce a biased MI estimate when the drifting problem occurs. 1. Macro-averaging (similar to that of Poole et al. (2019) ): Establish a single estimate through the average of estimated MI from each batch. 2. Micro-averaging (our method): Calculate DV representation using the average of the each individual network outputs.  P XY [T θ ], g = log E (n) P X ⊗P Y [e T θ ], and C = 0. Then, ∇ Î(X; Y ) = ∇f -∇g -2λg∇g = ∇f -∇g(1 + 2λg). As previously discussed, f increases each B joint sample outputs with (γ/B)∇T θ . On the other hand, g is the approximation of the maximum marginal sample output, so the gradient becomes close to γ(1 + 2λg)∇T θ for the maximum output. We can break down the dynamics of g into two parts, depending on whether joint samples exist in g. If not, maximum marginal sample output gets reduced with step size γ(1+2λg), which is, unlike MINE, adaptively adjusted by the size of g. Hence, ReMINE regularizes the maximum output to be centered around -1 2λ , preventing the network outputs from diverging to -∞. If any joint sample exists, its network output will be big enough to dominate g. Step size is also γ(1 + 2λg), which regularizes the joint sample network outputs more strongly as it increases. This, too, helps to avoid the output explosion in the +∞ side, as shown in Fig. 3 . Impact of C. Fig. 4a shows the impact of changing C on MI estimation, with the same settings as Fig. 1a . We observed that the newly added regularizer penalizes the log(E P X ⊗P Y [e T ]) term to converge towards C as expected, without losing the ability to estimate MI.

Impact of λ.

As we observed in Section 3.2, the network outputs of joint and non-joint cases converge to j and j, respectively. Using this, we visualize the effect of the proposed method by drawing the loss surface of MINE and ReMINE for the one-hot discrete dataset in Fig. 5 . L M IN E = E P XY [T ] -log(E P X ⊗P Y [e T ]) = j -log(pe j + (1 -p)e j ) (5) L ReM IN E = E P XY [T ] -log(E P X ⊗P Y [e T ]) -d(log(E P X ⊗P Y [e T ]), C) (6) = j -log(pe j + (1 -p)e j ) -λ(log(pe j + (1 -p)e j ) -C) 2 . ( ) We again observe the drifting phenomena, as the loss surface has a plateau of equally palatable solutions. Regularization term successfully warps the loss surface, so that it has a single solution. As λ increases, the loss surface becomes steeper, resulting in sporadic spikes for each gradient step. Network settings. We consider a joint architecture, which concatenates the inputs (x, y), and then passes through three fully connected layers with ReLU activation (excluding the output layer) by widths 40 -256 -256 -1, same as the network used in Poole et al. (2019) . We used Adam optimizer with learning rate 5 × 10 -4 , β 1 = 0.9 and β 2 = 0.999. Comparison to state-of-the-arts. As mentioned in Fig. 4b , we can remove the log(E P X ⊗P Y [e T ]) term by choosing C = 0. As discussed in Section 3.2, the second term is inherently noisy. Hence, using all the terms in ReMINE only in optimization and removing the second term in estimation can effectively reduce noise. We call this trick ReMINE-J. To verify the quality of lower bounds, we compare ReMINE and ReMINE-J to InfoNCE (Oord et al., 2018) , JS (Hjelm et al., 2019; Poole et al., 2019) , MINE (Belghazi et al., 2018) , NWJ (Nguyen et al., 2010) , SMILE (Song & Ermon, 2020) , SMILE+JS (which estimates with SMILE, and optimizes with JS), TUBA (Barber & Agakov, 2004; Poole et al., 2019) and I α (Poole et al., 2019) . To make a fair comparison, ReMINE also uses the macro-averaging strategy, the same as the other methods. Our methods show comparable or better estimation performance with less variance than others, as shown in Fig. 6 . Exact values for bias, variance, and mean square error to the true MI for each estimator are shown in Appendix. Comparison on self-consistency tests. To compare the stability of DV-based estimators, we conducted self-consistency tests on ReMINE, MINE, SMILE, and SMILE+JS. For type 1, every estimator successfully returns values between the theoretical bound with an increasing trend. For type 2, only ReMINE and SMILE+JS estimates are close to the ideal value. For type 3, none of the estimators worked well. However, ReMINE shows smaller variance compared to MINE and has similar stability to SMILE+JS.

6. CONCLUSION

In this paper, we studied how the neural network inside MINE handles the MI estimation problem. We delved into the drifting problem, where two terms of DV continue to fluctuate together even after the MI estimate converges, and the explosion problem, where the network outputs become unstable due to properties of the second term in DV when batch size is small. Based on the analysis, we penalized the objective function for obtaining a unique solution by using L 2 regularization. Despite the simplicity, the proposed loss and the micro-averaging strategy mitigate drifting, exploding, and batch size limitation problems. Further, ReMINE enables us to directly interpret the network output values as the log-likelihood ratio of joint and marginal distribution probability and performs favorably against state-of-the-art methods. However, further investigation needs to be done on the impact of optimizers on the batch size limitation, and why DV-based estimators fail in some of the self consistency tests.

7. APPENDIX: PROOFS

7.1 PROOF OF FAMILY OF OPTIMAL FUNCTIONS Theorem. For any constant C ∈ R, T = log dP dQ + C satisfies D KL (P||Q) = E P [T ] -log(E Q [e T ]). Proof. Suppose that T = log dP dQ + C. We can write the function T = (T * -C * ) + C by Lemma 3 in the manuscript. Therefore, E P [T ] = E P [T * -C * + C] = E P [T * ] -C * + C, log(E Q [e T ]) = log(E Q [e T * -C * +C ]) = log(e C-C * E Q [e T * ]) = (C -C * ) + log(E Q [e T * ]) . Since E P [T ] -log(E Q [e T ]) = E P [T * ] -log(E Q [e T * ] ), the function T also optimal.

7.2. PROOF OF REMINE LOSS FUNCTION

Theorem. Let d be a distance function on R. For any constant C ∈ R and function T : Ω → R, D KL (P||Q) = sup T :Ω→R E P [T ] -log(E Q [e T ]) -d(log(E Q [e T ]), C ), Proof. i) For any T , E P [T ] -log(E Q [e T ]) -d(log(E Q [e T ]), C ) ≤ E P [T ] -log(E Q [e T ]). Therefore, sup T :Ω→R E P [T ] -log(E Q [e T ]) -d(log(E Q [e T ]), C ) ≤ D KL (P||Q). ii) By the theorem above, there exists T * = log dP dQ + C such that D KL (P||Q) = E P [T * ] -log(E Q [e T * ]) and log(E Q [e T * ]) = log(E Q [e C dP dQ ]) = log( e C dP dQ dQ) = C . Therefore, sup T :Ω→R E P [T ] -log(E Q [e T ]) -d(log(E Q [e T ]), C ) ≥ E P [T * ] -log(E Q [e T * ]) -d(log(E Q [e T * ]), C ) = D KL (P||Q) Combining i) and ii) finishes the proof.

8. APPENDIX: ADDITIONAL EXPLANATIONS

Additional explanations for Fig. 4b As the N -dimensional one-hot discrete dataset is uniform, we can easily calculate the likelihood ratio of joint and non-joint case samples. For all the possible samples, P X ⊗ P Y = 1/N 2 , as they are total of N 2 . Also, for joint case samples, P XY = 1/N , and P XY = 0 for non-joint case samples. Hence, the likelihood ratio for the joint cases is N , and non-joint cases is 0. These are consistent with the experimental results, where j converges to log N , and j keeps decreasing. Nonetheless, as exp(j) gets closer to zero, the second term of ReMINE loss has lesser influence; hence the decreasing speed of j gets slowed down to a halt as it reaches -1 2λ = -5. We can explain the same result from the perspective of j and j. As we observed in Section 3.2, the network output values of joint and non-joint cases converge to j and j, respectively. Since the dataset is uniform, the probability p of joint cases appearing from the marginal samples is 1 N . Therefore, we can analyze the value of j and j after convergence as follows: as iteration i → ∞, E P XY [T (i)] = j → I(X; Y ) + C = log N + C (19) log(E P X ⊗P Y [e T (i) ]) = log(pe j + (1 -p)e j ) = log( 1 N e j + N -1 N e j ) → C where T (i) is the statistics network at iteration i. We combine Eq. ( 19) and Eq. ( 20) to 1 N e log N +C + N -1 N e j → e C , and e j → 0. In summary, j will converge to log N + C, and e j to 0, as shown in Fig. 4b . Note that j and j serves as a back-of-the-envelope calculation for us to estimate network outputs easily on discrete settings. What happens if the batch size is small? When the batch size is 1 and C = 0, the loss function of ReMINE changes its characteristics as follows. • Joint case occurs. As the samples are indistinguishable, L = j -log e j -d(j, 0) = -λj 2 , ( ) which is maximized when j = 0. • Non-joint case occurs. L = j -log e j + d(log e j , 0) = j -j -λj 2 . (24) The latter quadratic term of j is maximized when j = -1 2λ . If the statistics network succeeds to converge on both cases for our one-hot discrete dataset, Î(X; Y ) = E (n) P XY [T θ ] -log E (n) P X ⊗P Y [e T θ ] = 0 -log(pe 0 + (1 -p)e -1 2λ ) → -log p when λ → +0. As p = 1 N , Î(X; Y ) → log N . Intuitively, on smaller batch sizes, joint cases cannot occur in marginal samples, as mentioned in Section 3.2. Hence, E P XY [T ] and log(E P X ⊗P Y [e T ]) behave differently compared to the larger batch size. The regularizer term penalizes both terms in different ways. Joint cases in marginal samples can contribute only with Eq. ( 23), so E P XY [T ] → 0. Moreover, as λ gets smaller, log(E P X ⊗P Y [e T ]) gets regularized less so that it can converge to -Î(X; Y ). In contrast, since MINE has no regularization term, namely λ = 0, there is no way for the joint case in marginal samples to influence T , hence failing to estimate MI as shown in Fig. 8b . Impact of λ with batch size. We inspect the relationship between batch size and λ in detail. Fig. 8 shows that imposing regularization reduces noise on a large batch size domain. However, on a small batch size domain, log(E P X ⊗P Y [e T ]) cannot have nonzero value, hence failing to estimate MI value. The effect of the ReMINE loss in two different domains gets mixed in between. Visualizing network outputs on 1-D Gaussian. The dataset forbids us to label joint and non-joint samples explicitly, so we visualized the network outputs on 2-D plane. We used the same experiment settings as Section 3.1, only changing the input dimension to 2. We can see in Fig. 9 that the network outputs of the overlapping region remain near 0, which indicates that the likelihood is equal between joint and marginal distribution. Other regions are separated by the sign of their outputs. Positive network output means that joint distribution is more probable than marginal distribution to sample that data point, and vice versa. and Table 1 . We omitted values which are more than 100 in Table 1 . We additionally show results from KL (Jiao et al., 2018) , Mixed KSG (Gao et al., 2017) , CCMI (Mukherjee et al., 2020 ), TNCE (Oord et al., 2018; Poole et al., 2019) and ReMINE-L1 (our method with L1 regularization). Both ReMINE and ReMINE-J shows comparable or better performance compared to other methods. Note that L1 regularizer also suffers from the explosion problem, as the gradient is not adaptively adjusted by the magnitude of network outputs, as discussed in Section 5.1. 



Loosely speaking, the first term slowly increases a lot of joint samples network outputs, in contrary to the second term which quickly decreases a few joint sample network outputs. We release our code in



Figure 1: Training T for 3000 iterations with batch size 100 [(a), (b)] and 10 [(c), (d)]. For (a) and (c), colored lines represent the estimation for I(X; Y ), E P XY [T ], and log(E P X ⊗P Y [e T ]) of each batch. We also added the ideal MI for (a) and exponential moving average (EMA) of Î(X; Y ) with span 10 (decay rate 2/11) for (c). We can see that MINE fails to generate a meaningful estimation on batch size 10. Outputs of the neural network for each sample are shown in (b) and (d). As there are 100 samples per iteration for (b) and 10 for (d), total number of outputs is 300000 and 30000, respectively. Note that the base-10 logarithmic scale is used on the y-axis for (c) and (d).

Figure 2: (a) Training T with different batch sizes 100, 200, 400, and 800. We show the histogram of Î(X; Y ) of the final 1000 batches. (b) Histogram of network outputs of marginal samples at different iterations. The probability of the joint case to occur is 1N , hence the proportionally small area of joint cases. A black vertical line is drawn at 0 to assist visually.

XY [T * ] = I(X; Y ) + C and log(E P X⊗Y [e T * ]) = C . Based on Theorem 5, we propose a novel loss function by adding a new term d(log(E P X ⊗P Y [e T ]), C ) that regularizes the drifting of C * . The details of the ReMINE algorithm is as follows. Algorithm 1: ReMINE θ ← Initialize network parameters, K ← Moving average window size, i ← 0 repeat Draw J samples from the joint distribution: (x

Figure 3: (a) Training T with the same settings as Fig. 1d with ReMINE. We can see that ReMINE successfully avoids both drifting and network output explosion problem. (b) Same setting as (a), but used the optimizer from Section 5.2. We suspect Adam shows better performance as it accumulates previous gradients. (c) Same setting as (b) except we used micro-averaging strategy with sliding window of size 1000. (d) Training T with the same settings as Fig.1d, but used the same optimizer from (b). Outputs are more stabilized compared to Fig.1d, but still fails to make a stable estimate.

Figure 4: (a) Applying ReMINE for 1500 iterations, with different Cs and λ = 0.1. (b) Histogram of network outputs for the marginal samples at different iterations.We set C = 0 and λ = 0.1. exp(j) and exp(j) converges to N and 0, respectively. These are the likelihood ratios between joint and marginal distribution for joint and non-joint cases. We can now directly interpret the network outputs, and E(j) directly converges to ideal MI log N , thanks to the regularization of C.

Figure 5: Comparing the loss surface for varying lambda. Per each batch, we averaged the network outputs of joint samples to estimate j, and non-joint cases of marginal samples to estimate j.

Figure 6: Estimation performance on 20-D Gaussian. Similar to Poole et al. (2019), we increase ρ every 4000 iterations. The estimated MI (light) and smoothed estimation with exponential moving average (dark) are plotted for each methods, and theoretical bounds are plotted by dotted lines.

Figure 7: Self-consistency tests on CIFAR-10 (Krizhevsky et al., 2009). We report the average result of 10 repeated runs. Dotted lines indicate theoretical bounds for type 1, and ideal ratio for type 2, 3.

Figure 8: (a) Comparing the joint and non-joint case outputs of ReMINE with λ = 0.01 when batch size is 1 for 5000 iterations. We can see j → 0 and j → -1 2λ = -50. (b) Comparing the joint and non-joint case outputs of MINE for 5000 iterations. As the statistics network struggles to diverge into two different values, it becomes numerically unstable, hence failing in the middle of the training. (c) Comparing the estimation performance of different λs on varying batch sizes 1, 2, • • • 100. The dotted line represents the true MI.

Figure 9: Visualizing network outputs on the 1-D correlated Gaussian dataset. Each axis represents a paired value x and y of each sample, and the color represents network outputs (a) marginal and (b) joint samples.

Figure 10: Bias, variance, MSE of estimators on 20-D correlated Gaussian dataset

One of the most commonly used is the Donsker-Varadhan representation, which is first used inBelghazi et al. (2018) to estimate MI through neural networks.

Bias, variance and MSE of estimators on 20-D correlated Gaussian dataset

7.3. PROOF OF ESTIMATION BIAS CAUSED BY DRIFTING

Theorem. (Estimation Bias caused by Drifting) The two averaging strategies below produce a biased MI estimate when the drifting problem occurs.1. Macro-averaging (similar to that of Poole et al. (2019) ): Establish a single estimate through the average of estimated MI from each batch.2. Micro-averaging (our method): Calculate DV representation using the average of the each individual network outputs.Proof. Let the outputs of ith batch, jth sample inside the batch asij , joint and marginal case respectively, and the output without drifting as T * ij , and drifting constant for each batch C i . Then, T ij = T * ij + C i . When the number of batch is B and each batch size is N , 1. Macro averaging:2. Micro averaging:Additional experiments on the self consistency test. We report the performance of other variational bound methods (JS, I α , InfoNCE) in Fig. 12 . I α and JS often result in unstable MI estimates, as shown in the type 2 experiment. On the other hand, InfoNCE estimates MI quite reliably but also fails for type 3. To observe on a different dataset setting, we also used MNIST (LeCun et al., 1998) . As shown in Fig. 13 , tests yielded similar results.Figure 13 : Self-consistency tests on MNIST.To observe on a different statistics network, we used modified ResNet18 (He outputs a single scalar. As shown in Fig. 14 , SMILE has become more unstable, but there are no significant differences in other variational bounds. This experiment shows that the network size has a small impact on the validity of this test on CIFAR-10. Effectiveness on the Conditional Mutual Information Estimation Task We compare the performance of various estimators on the conditional MI (CMI) estimation task. To set the baseline, we chose CCMI (Mukherjee et al., 2020) , MINE (Belghazi et al., 2018) , and KSG estimator (Kraskov et al., 2004) . The Experiment is 1 setting in Mukherjee et We refer to the supplementary of Mukherjee et al. (2020) for hyper-parameter settings such as network structures and optimizer parameters. We only changed the objective function of MINE to test our method. As shown in Fig. 15 , ReMINE can reach comparable performance without changing the form to classification loss. Also, ReMINE produces stable estimates compared to MINE. 

