COLD RAO-BLACKWELLIZED STRAIGHT-THROUGH GUMBEL-SOFTMAX GRADIENT ESTIMATOR

Abstract

The problem of estimating the gradient of an expectation in discrete random variables arises in many applications: learning with discrete latent representations, training neural networks with quantized weights, activations, conditional blocks, etc. This work is motivated by the development of the Gumbel-Softmax family of estimators, which are based on approximating argmax with a temperatureparametrized softmax. The state-of-the art in this family, the Gumbel-Rao estimator uses internal MC samples to reduce the variance and appars to improve with lower temperatures. We show that it possesses a zero temperature limit with a surprisingly simple closed form defining our new estimator called ZGR. It has favorable bias and variance properties, is easy to implement, computationally inexpensive and is obviously free of the temperature hyperparameter. Furthermore, it decomposed as the average of the straight through estimator and the DARN estimator -two basic but not very well performing on their own estimators. Unlike them, ZGR is shown to be unbiased for the class of quadratic functions of categorical variables. Experiments thoroughly validate the method.

1. INTRODUCTION

Discrete variables and discrete structures are important in machine learning. For example, the semantic hashing idea (Salakhutdinov & Hinton, 2009) is to learn a binary vector representation such that semantically similar instances (e.g. images or text documents) would have similar bit representations. This allows for a quick retrieval via a lookup table or via nearest neighbor search with respect to the Hamming distance of binary encodings. Recent work employs variational autoencoders (VAEs) with binary latent states (Shen et al., 2018; Dadaneh et al., 2020; Ñanculef et al., 2020) . Another example is neural networks with discrete (binary or quantized) weights and activations. They allow for a low-latency and energy efficient inference, particularly important for edge devices. Recent results indicate that quantized networks can achieve competitive accuracy with a better efficiency in various applications (Nie et al., 2022) . VAEs and quantized networks are two diverse examples that motivate our development and the experimental benchmarks. Other potential applications include conditional computation (Bengio et al., 2013a; Yang et al., 2019; Bulat et al., 2021) reinforcement learning (Yin et al., 2019) , learning task-specific tree structures for agglomerative neural networks (Choi et al., 2018) , neural architecture search (Chang et al., 2019) and more. The learning problem in the presence of stochastic variables is usually formulated as minimization of the expected loss. The gradient-based optimization requires a gradient of the expectation in the probabilities of random variables (or parameters of networks inferring those probabilities). Unbiased gradient estimators have been developed (Williams, 1992; Grathwohl et al., 2018; Tucker et al., 2017; Gu et al., 2016) . These estimators work even for non-differentiable losses, however their high variance is the main limitation. More recent advances (Yin et al., 2019; Kool et al., 2020; Dong et al., 2020; 2021; Dimitriev & Zhou, 2021b; a) reduce the variance by using several cleverly coupled samples. However, the hierarchical (or deep) case is not addressed satisfactory. There is experimental evidence that the variance grows substantially with the depth of the network leading to very poor performance in training (e.g., Shekhovtsov et al. 2020, Fig.C.6, C.7) . Furthermore, existing extensions of coupled sampling methods to networks with L dependent layers (Dong et al., 2020; Yin et al., 2019) apply their base method in every layer, requiring several complete forward Table 1 : Computation complexity of estimators in a hierarchical network with L dependency layers of K-way categorical variables. All methods require a backward pass. * -educated guess.

Method

Unbiased Forward passes Cost per cat. variable per pass ST, GS, GS-ST, ZGR (Kool et al., 2019) M ≥ 2 O(K) ARSM (Yin et al., 2019) O(K 2 L 2 ) O(1) DisARM-* (Dong et al., 2020; 2021) O(2L 2 ) O(K) CARMS(M ) Dimitriev & Zhou (2021b; a) O(M L 2 ) * O(M 2 K 2 ) * samples per layer. The computation complexity thus grows quadratically with the number of layers, as summarized in Table 1 . X 1 O(K) GR-MC(M ) X 1 O(M K) REINFORCE 1 O(K) RF(M ) A different family of methods, fitting practical needs for deep models better, exploits continuation arguments. It includes ST variants (Bengio et al., 2013b; Shekhovtsov & Yanush, 2021; Pervez et al., 2020) and Gumbel-Softmax variants (to be discussed below). These methods assume the loss function to be differentiable and try to estimate the derivative with respect to parameters of a discrete distribution from the derivative of the loss function. Such estimators can be easily incorporated into back-propagation by adjusting the forward and backward passes locally for every discrete variable. They are, in general, biased because the derivative of the loss need not be relevant for the discrete expectation. The rationale though is that it may be possible to obtain a low variance estimate at a price of small bias, e.g. for a sufficiently smooth loss function (Shekhovtsov & Yanush, 2021) . Gumbel Softmax (Jang et al., 2017) and the concurrently developed Concrete relaxation (Maddison et al., 2017) enable differentiability through discrete variables by relaxing them to real-valued variables with a distribution approximating the original discrete distribution. The tightness of the relaxation is controlled by the temperature parameter t > 0. The bias can be reduced by decreasing the temperature, but the variance grows as O(1/t) (Shekhovtsov, 2021) . Gumbel-Softmax Straight-Through (GS-ST) heuristic (Jang et al., 2017) uses discrete samples on the forward pass, but relaxed ones on the backward, reducing side biases (see below). The Gumbel-Rao (GR) estimator (Paulus et al., 2021) is a recent improvement of GS-ST, which can substantially reduce its variance by a local expectation. However the local expectation results in an intractable integration in multiple variables, which is approximated by sampling. The experiments (Paulus et al., 2021) suggest that this estimator performs better at lower temperatures, which requires more MC samples. Therefore the computation cost appears to be a major limitation. Inspired by the performance of GR at low temperatures, we analyze its behavior for temperatures close to and in the limit of the absolute zero. Note that because the underlying relaxation becomes discrete in the cold limit, leading to explosion of variance in both GS and GS-ST, they do not have a zero temperature limit. It is not obvious therefore that GR would have one. We prove that it does and denote this limit estimator as ZGR. In the case of binary variables we give an asymptotic series expansion of GR around t = 0 and show that ZGR has a simple analytic expression, matching the already known DARN( 1 2 ) estimator by Gregor et al. (2014) (also re-discovered as importance reweighed ST by Pervez et al. 2020 ). In the general categorical case, we obtain the analytic expression of ZGR and show that it is a new estimator with a simple formula 1 2 (ST+DARN). We show that ZGR is unbiased for all quadratic functions of categorical variables and experimentally show that it achieves a useful bias-variance tradeoff. We also contribute a refined experimental comparison of GS family with unbiased estimators.

2. BACKGROUND

Let x be a categorical random variable taking values in K = {0, . . . K -1} with probabilities p(x; η) parametrized by η. Let φ(x) ∈ R d be an embedding of the discrete state x in the vector space. Categorical variables are usually represented using 1-hot embedding, in this case d = K. The value of x itself will still be used as an index. For binary and quantized variables we will adopt the embedding φ(x) = x, in this case d = 1. Let L : R d → R be a differentiable loss function. We will denote the Jacobian of L with respect to an input φ as J φ = dL dφ . It is the transposed gradient of L with respect to φ, nevertheless, we will informally refer to such Jacobians as gradients. For brevity, let us also use a shorthand L(x) = L(φ(x)). The goal is to estimate the gradient of the expected loss J η = d dη E[L(x)] = d dη x L(x)p(x; η). In a network with many (dependent) categorical variables it has proven efficient (e.g., Bengio et al. 2013a; Jang et al. 2017; Paulus et al. 2021; Shekhovtsov & Yanush 2021; Pervez et al. 2020 ) to consider estimators that make an estimate of J η based on the gradient J φ = dL(φ(x)) dφ at a sample x from p(x; η) in (1). Such estimators can be easily extended to losses in multiple discrete variables (e.g. defined by a stochastic computation graph) by simply applying the elementary estimator whenever the respective Jacobian is needed in backpropagation. In this case J η is the gradient in a specific variable or intermediate activation η at a joint sample. Let us review the basic estimators with which we will work theoretically. REINFORCE We can use the log-derivative trick to rewrite d dη x L(x)p(x; η) = x L(x) dp(x;η) dη = x L(x)p(x; η) d log p(x;η) dη = E L(x) d log p(x;η) dη (2) and define the REINFORCE estimate (Williams, 1992)  to be J RF η = L(x) d log p(x;η) dη , where x ∼ p(x; η). This estimator is clearly unbiased as E[J RF η ] = J η , but may have a high variance. ST Let φ(η) = E[φ(x)] = x φ(x)p(x; η) -the mean embedding in R d under the current distri- bution of x. The Straight-Through estimator is J ST η = J φ d φ(η) dη . Note, ST estimators of different empirical forms exist. The present definition for categorical variables is the same as, e.g., by Gu et al. (2016) and is consistent with Hinton (2012) and Shekhovtsov & Yanush (2021) in the binary case. DARN Gregor et al. (2014) use LR with a baseline b(x) = L(x) + J φ ( φ -φ(x)), which is a first order Taylor approximation of L about φ(x) evaluated at some point φ. This results in the estimator J DARN( φ) η = J φ (φ(x) -φ) d log p(x;η) dη . ( ) When x is binary, the choice φ = 1 2 x φ(x) -the mean embedding under the uniform distribution, ensures that the estimator is unbiased for any quadratic function. However, for categorical variables no φ with such property exist. Gu et al. (2016) have experimentally tested several heuristic choices for φ, including φ = φ(η), and found that none performed well in the categorical case. GS The Gumbel-Softmax (GS) estimator (Jang et al., 2017) is a relaxation of the Gumbel-argmax sampling scheme. Let θ k = log p(x=k; η). Let G k ∼ Gumbel(0, 1), k ∈ K, where Gumbel(0, 1) is Gumbel distribution with cdf F (u) = e -e -u . Then x = arg max k (θ k + G k ) (5) is a sample from p(x; η). The relaxation is obtained by using a softmax instead of arg max. This construction assumes one-hot embedding φ and creates relaxed (continuous) samples in the simplex ∆ K . Formally, introducing temperature hyperparameter t, it can be written as φ = softmax((θ + G)/t) =: softmax t (θ + G); J GS θ = dL( φ) dθ = dL( φ) d φ d φ dη . There are two practical concerns. First, the loss function is evaluated at a relaxed sample, which in a large computation graph can offset estimates of all other gradients even not related to the discrete variable we relax. This effect can be mitigated by using a smaller temperature, causing relaxed samples φ to concentrate in a corner of the simplex. However, and this is the second concern, the variance of the estimator grows as O( 1 t ) if t is decreased towards zero (Shekhovtsov, 2021) .

GS-ST

The Straight-Through Gumbel-Softmax estimator (Jang et al., 2017) is an empirical modification of GS, addressing the first concern above. It uses discrete samples in the forward pass but swaps in the Jacobian of the continuous relaxation in the backward pass: G k ∼ Gumbel(0, 1), k ∈ K; (7a) x = arg max k (θ k + G k ); (7b) φ = softmax t (θ + G); (7c) J GS-ST θ = dL(φ(x)) dφ d φ dθ . Notice that the hard sample φ(x) and the relaxed sample φ are entangled through G. Although, x has the law of p(x; η) as desired, not biasing other variables, there is still bias in estimating the gradient in θ, which is typically larger than that of GS (see Fig. 1 ). To make the bias smaller the temperature t should be decreased, however, the variance still grows as O(1/t) (Shekhovtsov, 2021) . Values of t between 0.1 and 1 are used in practice (Jang et al., 2017) . GR Notice that the forward pass in GS-ST is fully determined by x alone and the value of G that generated that x is needed only in the backward pass. Paulus et al. (2021) proposed that the variance of GS-St can be reduced by computing the conditional expectation in G|x, leading to the Gumbel-Rao estimator: J GR θ = E G|x J ST-GS θ (G) = dL(φ(x)) dφ E G|x d φ dθ . Because the value of the loss L(x) and its gradient do not depend on the specific realization of G|x, enabling the equality above, the expectation is localized and can be computed in the backward pass. However, this expectation is in multiple variables and is not analytically tractable. Paulus et al. (2021) use Monte Carlo integration with M samples from G|x. In their experiments they report improvement of the mean squared error of the estimator when the temperature was decreasing from 1 down to 0.1. The trend suggests that it would improve even further below t = 0.1 provided that the conditional expectation is approximated accurately enough, indicating that the variance does not grow with the decrease of temperature in contrast to O(1/t) asymptote for GS and GS-ST and so there might be a meaningful cold limit.

3. METHOD

Given the experimental evidence about the GR estimator, we took the challenge to study its cold asymptotic behavior, i.e. for t → 0. The temperatured softmax in (7c) approaches a nondifferentiable arg max indicator in this limit and we have to handle the limit of the GR estimator with care to obtain correct results. We first analyze the binary case, where derivations are substantially simpler. Proofs of all formal claims can be found in Appendix A.

3.1. BINARY CASE

In the case with two categories we can simplify the initial ST-GS estimator as follows. We assume x ∈ {0, 1} and φ(x) = x. The argmax trick can be expressed as x = [[θ 1 + G 1 ≥ θ 0 + G 0 ]], where [[•]] is the Iverson bracket. It is convenient to assume that the distribution of x is parametrized so that p(x=1; η) = σ(η), the logistic sigmoid function. This is without loss of generality, because any other parametrization will result in just an extra deterministic Jacobian. Recalling that θ k = log p(x= k; η), we have θ 1 -θ 0 = η. Next, denoting Z = G 1 -G 0 , we can write the argmax trick compactly as x = [[η + Z ≥ 0]]. Being the difference of two Gumbel(0,1) variables Z follows the standard logistic distribution (with cdf σ(z)). The GR estimator of gradient in η simplifies as Z ∼ Logistic(0,1); x = [[η + Z ≥ 0]]; x = σ t (η + Z); J GR η = dL(x) dx E Z|x dx dη , where σ t (u) = σ(u/t) is the temperatured logistic sigmoid function. Although there is no closed form, we can compute, with a careful limit-integral exchange, the series expansion around t = 0. Proposition 1. J GR η = L (x) p(x) p Z (η) 1 2 + (2x -1)c 1 log(2)t + O(t 2 ), where p Z is the logistic density: p Z (η) = σ(η)σ(-η) and c 1 = 2p -1. Corollary 1. In the limit t → 0 the GR estimator becomes the DARN( 1 2 ) estimator. This establishes an interesting theoretical link, but discovers no new method for practical applications, as the DARN estimator is already known. Using the same expansion, we can study the asymptotic bias and variance of GR around t = 0. Corollary 2. The mean and variance of the GR estimator (9) in the asymptote t → 0 are: E[J GR η ] = p(1 -p) 1 2 (L (1) + L (0)) + (L (1) -L (0))c 1 t + O(t 2 ), V[J GR η ] = (p(1 -p)) 3 1 4 (a -b) 2 + 1 2 (a 2 -b 2 )c 1 t + O(t 2 ), where p = σ(η), a = L (1) p , b = L (0) 1-p and c1 = (2p -1) log(2). This allows to make some predictions, in particular in the case of a linear objective L. In this case the bias is O(t 2 ) and the squared bias is O(t 4 ). Therefore the MSE is determined by the variance alone up to O(t 4 ). The dependence of variance on t for a linear objective is negative in the first order term. Therefore the temperature corresponding to the minimum MSE will be non-zero.

3.2. GENERAL CATEGORICAL CASE

In the general categorical case, the analysis is more complicated (exchange of the limit and multivariate integral over G|x) but gives a novel result: Theorem 1 (ZGR). The Gumbel-Rao estimator for one-hot embedding φ in the limit of zero temperature is given by J ZGR θi = 1 2 (J φi -J φx )p(x=i; η) if i = x; -1 2 j =x (J φj -J φx )p(x=j; η) if i = x. Proposition 2. ZGR estimator decomposes as J ZGR = 1 2 (J ST + J DARN( φ(η)) ), i.e., with the choice φ = φ(η) in DARN. This is consistent with the expression for the binary case given by Corollary 1 by noting that in this case (ST + DARN( φ(η)))/2 is exactly DARN(1/2). As we have defined ST and DARN estimators for a general embedding, the form (2) is a valid expression of ZGR for any embedding (a change of the embedding is just a linear transform). This form shows a surprising connection. Neither ST nor DARN estimators perform particularly well in categorical VAEs on their own (Gu et al., 2016; Paulus et al., 2021) . However ZGR, being effectively their average, appears superior. It has the following property, truly extending the design principle of binary DARN( 1 2 ) to the categorical case: Theorem 2. ZGR is unbiased for quadratic loss functions L with arbitrary coefficients. Both ST and DARN( φ(η)) are unbiased for linear functions and the theorem shows their biases for quadratic functions are exactly opposite and cancel in the average. Because the bias of ST-GS, GR and GR-MC is the same and we have shown that ZGR is the limit of GR at t → 0, the following is straightforward. Corollary 3. ST-GS and GR are asymptotically unbiased for quadratic functions with t → 0. These results extend to multiple independent discrete variables as follows. Corollary 4. Let x 1 , . . . , x n be independent categorical variables and L(x 1 , . . . , x n ) be such that for all i and all configurations x the restriction L(x i ) is a quadratic function. Then ZGR is unbiased. The unbiased property for quadratic function gives us some intuition about applicability limits of ZGR. Namely, if the loss function is reasonably smooth, such that it can be approximated well by a quadratic function, we expect gradient estimates to be accurate. Compared to ST, which is unbiased for multilinear functions only, we hypothesize that ZGR can capture interactions more accurately.

4. EXPERIMENTS

We will compare ZGR with Gumbel-Softmax (GS), Straight-through Gumbel-Softmax (GS-ST) (Jang et al., 2017) , Gumbel-Rao with MC samples (GR-MC) (Paulus et al., 2021) and the ST estimator (3). We also compare to the REINFORCE with the leave-one-out baseline (Kool et al., 2019) using M ≥ 2 inner samples, denoted RF(M ), which is a strong baseline amongst unbiased estimators. In some tests we include ARSM (Yin et al., 2019) , which requires more computation than RF(4) but performs worse. See Appendix B.2 for details of implementations.

4.1. DISCRETE VARIATIONAL AUTOENCODERS

We follow established benchmarks for evaluating gradient estimators in discrete VAEs. We use a MNIST data with a fixed binarization (Yin et al., 2019) and Omniglot data with dynamic binarization (Burda et al., 2016; Dong et al., 2021) . We use the encoder-decoder networks same as (Yin et al., 2019; Dong et al., 2021) up to the following difference. We embed categorical variables with 2 b states as {-1, 1} b vectors. This allows to vary the number of categorical variables and categories while keeping the decoder size the same. Please see full details in Appendix B.

4.1.1. ZERO TEMPERATURE LIMIT AND ESTIMATION ACCURACY IN VAE

First, we measure the gradient estimation accuracy at a particular point of VAE training, comparing GS family of estimators at different temperatures as in (Paulus et al., 2021, Fig 2b.) . This referenced plot shows a steady decrease of MSE with the decrease of temperature down to 0.1 and we were expecting ZGR to achieve the lowest MSE. As an evaluation point we take a VAE model after 100 epochs of training with RF(4). We then measure the bias and variance of all gradient estimators as follows. The loss function is the average ELBO of a fixed random mini-batch of size 200. Let X ∈ R d be the reference unbiased estimator and Y ∈ R d be the tested estimator. We want to measure the average over parameters (= dimensions of the gradient) squared bias, which can be written as b 2 = 1 d E[X] -E[Y ] 2 . We obtain n 1 = 10 4 independent samples X i from RF(4) and n 2 = 10 4 independent samples Y i from the tested estimator and compute an unbiased estimate of b 2 : b 2 = 1 d μ1 -μ2 2 -V1 n1 -V2 n2 , ( ) where μ1 is the sample mean of X and V 1 is the average (over dimensions) sample variance of X and μ2 and V 2 are likewise for Y . The variance of the evaluated estimator is just V 2 . The results are shown in Fig. 1 . All of ST-GS and GR estimators share the same bias according to the theory but differ in variance. GS estimator is asymptotically unbiased but the variance grows as O(1/t). We observe that the variance is by several orders larger than the squared bias. Respectively, the mean squared error, which is the sum of the variance and the squared bias, is dominated by the variance alone for all methods. This is in high contradiction with MSE analysis in (Paulus et al., 2021, Fig 2b.) , which we deem incorrect. Note however, that the effect of the bias-variance trade-off in learning is not straightforward: common optimization methods use momentum as an effective way of variance reduction while the bias can be correlated across iterations and may potentially accumulate. Therefore MSE is not necessarily indicative of performance and ZGR still fulfills our expectations of the zero limit estimator: it has the limiting bias, which is the lowest in the GR/GS-ST family, and the limiting variance which is moderate. In particular the variance is better than that of GR for temperatures below about 0.1 and is comparable to the variance of RF(4). More tests at different stages of training are provided in Appendix B.4.

4.1.2. TRAINING PERFORMANCE IN VAE

Next we compare training performance of several methods, in the same setting as prior work. In particular we use the same Adam optimizer, batch size, learning rate and training duration as (Dong et al., 2021; Dimitriev & Zhou, 2021b) . Full details are given in Appendix B.3. Table 2 presents results for two datasets and different splitting of latent bits into discrete variables (from binary to 64-way categorical). We observe the following: 1) ST performs the worst and we GR-MC at the optimal temperature is insignificant. The plot shows the mean from 5 random initializations with confidence intervals ±(max -min)/2 of the 5 runs. blame its high bias (c.f . Fig. 1 ) 2) ZGR performs no worse than GR-MC variants. In Fig. 2 we additionally verify that at no other temperature GR-MC can achieve significantly better results; 3) ZGR outperforms RF(2) and RF(4), significantly so with more categories. Finally, we measure the (orientational) computation time in Fig. 2 and observe that ZGR is faster than both GR-MC and RF(2), consistently with theoretical expectations in Table 1 . Finally, according to the published results in a similar setup, the recent unbiased methods (Dong et al., 2021; Dimitriev & Zhou, 2021b) appear to improve only marginally over RF with an equal number of samples, i.e. the difference is much smaller than between ZGR and RF(2). We leave a peer-to-peer comparison to future work.

4.2. QUANTIZED NEURAL NETWORKS

The mainstream progress in training quantized and binary neural networks, following Hubara et al. (2017) , has been achieved so far using empirical variants of ST estimator (with different clamping rules, etc.) applied to deterministically quantized models, where there is no gradient to be estimated in the first place. A sound training approach is to consider a stochastic relaxation, replacing all discrete weights and activations by discrete random variables, leading in the binary case to stochastic binary networks (Peters & Welling, 2018; Roth et al., 2019; Shekhovtsov & Yanush, 2021) . We will consider a parameter-efficient stochastic relaxation for quantization of Louizos et al. (2019) . In this model the distribution of a quantized weight or activation x is defined by a single real-valued input η via: x = η + z K , where • rounds to the nearest integer in K and z is an injected noise, such as logistic noise. Therefore x is a discrete integer variable with a distribution determined by η. In Appendix B.5 we give an comprehensive evaluation of bias and variance of estimators for a single such quantization unit. In a deep network, the pre-activation input η depends on the weights of the current layer as well as on the preceding activations (both stochastically quantized), causing a hierarchical dependence. We train a convolutional network, closely following Louizos et al. (2019) and test on MNIST and FashionMNIST. We do not quantize the input (it has 8 bit resolution in the dataset), the first and last weight matrices are quantized to 4 bits. All inner layer weights and activations are quantized to 2 bits or below. We form two real-valued baselines applying ReLU or Clamp (x → min(max(x, 0), K -1)) as activations instead of quantization. We expect that with enough bits, quantized training should achieve performance of at least the Clamp variant. We evaluate training with logistic injected noise (as in Louizos et al. (2019) ) and triangular noise with the density p(z) = max(0, 1 -|z|). For GS-ST variants we enable high temperatures (0.5, 1, 2) as recommended by Louizos et al. (2019) . See Appendix B.5 for details of the experimental setup. The results are presented in Table 3 . We see that the best results for are obtained with estimators having the least variance, prominently ST and GS-ST(t=2). It suggests that the bias is less detrimental in this application. In Fig. 3 we measured bias and variance along the training trajectory of ZGR following the same methodology as in Section 4.1.1 with 10 4 samples from the reference RF(4) estimator and 10 3 samples from candidate estimators. The two methods with the lowest variance are exactly ST and GS-ST(t=2) while the bias was hard to measure accurately to draw any conclusions. More generally, the ranking of results in Table 3 is quite similar to the ranking of variance in Fig. 3 . In particular, variance of RF( 4) is several orders larger than that of biased estimators and its test accuracy is completely out of the competition. Regarding performance of ZGR we observe the following: 1) it outperforms GR-MC with temperature 0.1, more clearly so on Fashion-MNIST (Table B .2), while being cheaper and simpler; 2) It is close in performance to the best results obtained with more biased estimators.

5. DISCUSSION / CONCLUSION

On the theoretical side, we showed that GR estimator has a zero temperature limit, computed this limit, studied its properties and connected to the existing estimators. Despite we derived ZGR from the Gumbel-Softmax family, we do not consider it to be a proper member of this family. The straightthrough heuristic in GS-ST disposes of relaxed samples on the forward pass. The zero temperature limit disposes of them also on the backward pass, leaving essentially nothing form the Gumbel- softmax relaxation design. On the other side, we showed that it is unbiased for quadratic functions, generalizing the key property of DARN( 12 ) to the categorical case. We believe that such rationale can be put forward for obtaining improved biased estimators. On the practical side, ZGR is extremely simple, versatile and computationally inexpensive. In VAE it can replace GR-MC family completely, reducing the computational burden and hyperparameter tuning. It outperforms state-of-the-art unbiased estimators of comparable computation complexity and ST by a large margin. In quantized training it performs close to ST and can fairly replace low-temperature GR-MC variants, while unbiased estimators are completely out of the competition. Thus, across the two corner applications, ZGR is the only estimator which is computationally cheap and well-performing. While in VAE unbiased estimators performs well (and we do not need to worry about the bias) and in quantization simple ST performs well, the above results suggest that there should be cases where ZGR would perform significantly better than both RF and ST. This may be the case for example when considering a different learning formulation such as Bayesian learning of quantized weights or learning with the multi-sample objective Raiko et al. (2015) .

A PROOFS

A.1 BINARY CASE Proposition 1. J GR η = L (x) p(x) p Z (η) 1 2 + (2x -1)c 1 log(2)t + O(t 2 ), where p Z is the logistic density: p Z (η) = σ(η)σ(-η) and c 1 = 2p -1. Proof. The conditional density p(z|x) is p(z|x) = p Z (z)[[η + z ≥ 0]]/p(x=1), if x = 1; p Z (z)[[η + z < 0]]/p(x=0), if x = 0. ( ) Let us denote p(x=1) as just p. The GR estimator expands as J GR η = L (1) p η -∞ σ t (η + z)p(z)dz, if x = 1; L (0) 1-p ∞ η σ t (η + z)p(z)dz, if x = 0. Using the change of variables v = σ t (η + z), with the inverse z = tlogit(v) -η, we have dv = σ t (η + z)dz and can write the estimator as J GR η = L (x) p(x) x 1 1 2 p Z (η -tlogitv)dv + (1 -x) 1 2 0 p Z (η -tlogitv)dv . Note that p Z (η -tlogitv) is bounded above by a constant sup z p Z (z) = 1 4 . A constant is integrable on [0, 1]. By dominated convergence theorem we can take the limits t → 0 under the integral. In particular we can use lim t→0 p Z (η -tlogit(v)) = 1 2 p Z (η) under the integral. In order to get a more detailed view, we make the Taylor series expansion of p Z (η -tlogit(v)) and substitute it under the integral. With the help of Mathematica (Wolfram Research, 2021) we obtain: J GR η = L (x) p(x) p Z (η) 1 2 + (2x -1)c 1 log(2)t + c 2 π 2 6 t 2 + O(t 3 ), where c 1 = tanh(η/2) = 2p -1 and c 2 = 1 2 (1 -3/(cosh(η) + 1)). Corollary 1. In the limit t → 0 the GR estimator becomes the DARN( 12 ) estimator. Proof. From the series expansion, the limit t-→ 0 is J GR η = 1 2 L (x) p(x) p Z (η) = 1 2 L (x) p(x) p(1 -p), where p = σ(η). It remains to show that it matches DARN as defined in (4). Note that d log p(x=1;η) dη = σ(η)(1 -σ(η)) = p(1 -p) and d log p(x=0;η) dη = -p(1 -p) . By expanding the cases for x = 1 and x = 0 we verify that (x -x) d log p(x=1;η) dη = 1 2 p(1 -p), where x = 1 2 . Corollary 2. The mean and variance of the GR estimator (9) in the asymptote t → 0 are: E[J GR η ] = p(1 -p) 1 2 (L (1) + L (0)) + (L (1) -L (0))c 1 t + O(t 2 ), V[J GR η ] = (p(1 -p)) 3 1 4 (a -b) 2 + 1 2 (a 2 -b 2 )c 1 t + O(t 2 ), where p = σ(η), a = L (1) p , b = L (0) 1-p and c1 = (2p -1) log(2). Proof. The mean of the estimator is computed from the series expansion up to the first order as p L (1) p p Z (η) 1 2 + c1 t + O(t 2 ) +(1 -p) L (0) 1-p p Z (η) 1 2 + -1c 1 t + O(t 2 ) =p Z (η) 1 2 (L (1) + L (1)) + c1 (L (1) -L (0))t + O(t 2 ). ( ) Since the GR estimator J GR η (x) is a Bernoulli variable with values J GR η (0) and J GR η (1) with probabilities p and 1 -p, respectively, we can compute is variance simply as J GR η (1) -J GR η (0) 2 p(1 -p). ( ) Using that p Z (η) = p(1 -p), the asymptotic expansion of variance up to first order in t is (p(1 -p)) 3 a( 1 2 + c1 t) -b( 1 2 -c1 t) 2 + O(t 2 ) (25a) =(p(1 -p)) 3 1 2 (a -b) + (a + b)c 1 t 2 + O(t 2 ), where a = f (1) p , b = f (0) 1-p , c1 = log(2)c 1 . The first order term is 1 2 (p(1 -p)) 3 (a 2 -b 2 )c 1 t. ( ) It could be positive or negative depending on the values of the derivatives and of p. Let us expand a,b and c 1 = tanh(η/2) = 2p -1. We obtain, up to positive constants, p(1 -p)(f (1) 2 (1 -p) 2 -f (0) 2 p 2 )(2p -1)t. We see that for the corner points, where p approaches either 0 or 1, this linear term is negative. In particular for a linear objective we have f (1) = f (0) and the linear term becomes -p(1 -p)f (1) 2 (2p -1) 2 t, which is non-positive for any p and is zero for p = 1 2 .

A.2 GENERAL CATEGORICAL CASE

This case is significantly more difficult, as we are dealing with multivariate integration in K Gumbel variables. We will make use of the following statistical relationship. Lemma A.1. Let G 1 , . . . G K be independent standard Gumbel random variables. Then Z with components Z i = G i -G K for i = 1 . . . K -1 has the multivariate logistic distribution (Malik & Abraham, 1973) with cdf F Z (z) = 1 1+ K-1 i=1 e -z i . (29) Proof. The cdf and density of Gumbel(0, 1) distribution are given respectively by F G (x) = e -e -x ; p G (x) = e -(x+e -x ) . ( ) The conditional distribution of Z i given G K is F Zi|G K (z i |y) = e -e -(z i +y) . The conditional joint distribution of Z given G K is respectively (31) where S = K-1 i=1 e -zi . The cdf of Z is obtained by computing the expectation of F Z|G K in G K : F Z|G K (z|y) = K-1 i=1 e -e -(z i +y) = exp( K-1 i=1 -e -(zi+y) ) = exp(-e -y K-1 i=1 e -zi ) = e -e -y S , F Z (z) = ∞ -∞ e -e -y S e -(y+e -y ) dy = ∞ -∞ e -y-e -y (1+S) dy = 1 1+S ∞ -∞ (1 + S)e -y-(1+S)e -y dy = 1 1+S ∞ -∞ e -v-e -v dv = 1 1+S , where v = y -log(S + 1) and the last equality is by recognizing the Gumbel density under the integral. Theorem 1 (ZGR). The Gumbel-Rao estimator for one-hot embedding φ in the limit of zero temperature is given by J ZGR θi = 1 2 (J φi -J φx )p(x=i; η) if i = x; -1 2 j =x (J φj -J φx )p(x=j; η) if i = x. (12) Proof. We can take J φ = dL(φ(x)) dφ out of the conditional expectation since it does not depend on G for fixed x. Furthermore, X is distributed as p(x) by the sampling procedure and therefore P G (X=x) = p(x). We can thus rewrite the conditional expectation (8) as J φ [[X(u) = x]] d φ(u) dθ dF G (u)/P G (X=x) = 1 p(x) J φ arg max(θ+u)=x d φ(u) dθ dF G (u), ( ) where F G is the joint cdf of G. The condition arg max k (θ k + G k ) = x can be expressed as θ j + G j -(θ x + G x ) ≤ 0, ∀j = x. ( ) Let us define β j = θ j -θ x and Z j = G j -G x for j = 1 . . . K. Note that β x = Z x = 0 by this definition. Then the constraint can be written as Z ≤ -β. ( ) The integrand φ expresses in variables β, Z as φ = softmax t (θ + G) = softmax t (β + Z). Let us denote Z ¬x = (Z j |j = x). The joint distribution of Z ¬x is the (K -1)-variate multivariate logistic distribution (Malik & Abraham, 1973) , as detailed in Lemma A.1, with cdf: F Z¬x (z ¬x ) = 1 1+ i =x e -z i . To simplify notation, we let Z x have the discrete law with mass 1 at a single point z x = 0 = -β x and extend F Z¬x to the full joint F Z accordingly. We then can rewrite the integral as z≤-β ∂ ∂β softmax t (β + z)dF Z (z) ∂β ∂θ . ( ) The Jacobian ∂ ∂β softmax t (β + z) is a K × K matrix with indices (k, j) where the column j = x is zero by definition. Let us consider one component of the above integral for j = x: I k,j = z≤-β ∂ ∂βj softmax t (β + z) k dF Z (z). ( ) We want to evaluating its limit for t → 0. We cannot push the limit under the integral in this form, we need to transform it first. To shorten the notation, let us denote a i = e (zi+βi)/t . We change the variable z j by the mapping T : z j → v j = (2 + S j ) aj 1+aj +Sj . ( ) where S j = i =x,j a j . This mapping is monotone increasing and one-to one from (-∞, -β) to (0, 1), therefore the constraint z j ≤ -β j will trivialize. Let A k = softmax t (β + z) k . We can rewrite the integrand ∂ ∂βj A k dz j as follows: dA k dβj dz j = dA k dzj dz j (A k depends on β j in the same way as on z j ) = dA k dzj dzj dv dv = dA k dvj dv. ( ) We will thus need to evaluate C j := dA k dvj = dA k daj daj dvj = dA k daj ( dvj daj ) -1 . For j = k we simply have A k = aj 1+aj +Sj = 1 2+Sj v j ; C j = 1 2+Sj . For j = k we have dA k daj = d daj a k 1+aj +Sj = -a k (1+aj +Sj ) 2 (43a) dvj daj = (2 + S j ) 1 1+aj +Sj - aj (1+aj +Sj ) 2 = (2+Sj )(1+Sj ) (1+aj +Sj ) 2 (43b) C j = -a k (2+Sj )(1+Sj ) . The integral I k,j with the change of variable z j → v j expresses as zi≤-βi ∀i =j 1 0 C j f t (v j |z ¬j )dv j dF Z¬j (z ¬j ), where z ¬j = (z i |i = j) and f t (v j |z ¬j ) = p Zj |Z¬j (T -1 (v j )|z ¬j ). The dependance of f t on t is through T , while C j depends on t and z ¬j . Note that f is a squashed density and is itself not a density. Next we show dominated convergence of h t (v j , z ¬j ) = C j f t (v j |z ¬j ) in t → 0. If h t (v j , z ¬j ) converges point-wise and bounded above by an integrable function, then the limit t → 0 can be taken under the integral. We show a constant bound on h t (v j , z ¬j ) as follows. Note that |C j | ≤ 1. We then have |h t (v j , z ¬j )| ≤ sup v∈(0,1),z¬j f (v j |z ¬j ) = sup z p Zj |Z¬j (z j |z ¬j ), which is the supremum of the conditional density of the standard multivariate distribution and is equal to some constant c independent of t. The integral of a constant function c over (0, 1) × R K-1 with respect to the measure dv j dF Z¬j (z ¬j ) is c. The point-wise limit is as follows. For z satisfying the constraints z i + β i ≤ 0 strictly for i = j, x, we have lim t→0 a i = lim t→0 e (βi+zi)/t = 0 and lim t→0 S j = 0. Therefore we have lim t→0 C j =    1 2 if j = k; 0 if j = k ∧ k = x; -1 2 if j = k ∧ k = x. The inverse of mapping T is given by the relations a j = vj (1+Sj ) 2+Sj -vj ; z j = -β j + t log(a j ). It is seen that the limit of log(a j ) is finite and therefore lim t→0 T -1 (v j ) = -β j . and lim t→0 f t (v j |z ¬j ) = lim t→0 p Zj |Z¬j (T -1 (v j )|z ¬j ) = p Zj |Z¬j (lim t→0 T -1 (v j )|z ¬j ) = p Zj |Z¬j (-β j |z ¬j ). By dominated convergence theorem, we can now claim lim t→0 I k,j = 0 if j = k and k = x. And elsewise, if j = k or k = x, lim t→0 I k,j = zi≤-βi ∀i =j ± 1 2 p Zj |Z¬j (-β j |z ¬j )dF Z¬j (z ¬j ) = ± 1 2 ∂ ∂zj F Z (z) z=-β (53) = ∓ 1 2 ∂ ∂βj 1 1+ i =x e β i = ∓ 1 2 ∂ ∂βj e βx i e β i = ∓ 1 2 ∂ ∂βj softmax(β) x (54) = ± 1 2 p(x)p(j), where the upper sign corresponds to the case j = k and the lower to k = x. Let us denote Î = lim t→0 I. Multiplying it with the incoming derivative J φ on the left, we obtain: (J φ Î) j = 1 2 (J φj -J φx )p(x)p(j). And finally, multiplying (56) with the Jacobian ∂β ∂θ on the right per (38) and with the factor 1 p(x) per (33), we obtain J ZGR θi = 1 2 (J φi -J φx )p(i) if i = x; -1 2 j =x (J φj -J φx )p(j) if i = x. (57) Proposition 2. ZGR estimator decomposes as J ZGR = 1 2 (J ST + J DARN( φ(η)) ), i.e., with the choice φ = φ(η) in DARN. Proof. Let p denote the vector of probabilities (p(x=k; η)|k = 0, . . . , K -1). Recall that we have derived ZGR under the assumption of on-hot embedding φ, inherited from GS. In this case J φ φ(i) = J φi and φk = i φ(i) k p i = p k . Note that ZGR (57) defines the gradient in the parametrization θ used in Gumbel Rao and initially in Gumbel-Softmax, while ST and DARN estimators are given by us with respect to η. We need to bring these two to a common basis. We chose to reconstruct J ZGR p because both J ST p and J DARN p are particularly simple: J ST pi = J φ φ(i) = J φi , J DARN pi = J φ (φ i -φ)[[x=i]]/p(x) = (J φi -J φ p)[[x=i]]/p(x). Note, because p lies in the simplex, gradients in p are defined up to an additive constant to all coordinates. In other words any such additive constant is irrelevant and will not affect the gradient in η. In order to reconstruct J ZGR p we represent J ZGR θ = J ZGR p P , where P is the Jacobian of softmax, given by P = diag(p) -pp T = diag(p)(I -1p T ). We first note that J ZGR θ satisfies i J ZGR θi = 0 (as any gradient should, but not necessarily a stochastic estimator) and therefore J ZGR θ = J ZGR θ (I -1p T ) = J ZGR θ diag(p) -1 P. (61) We obtained: J ZGR pi = 1 2 (J φi -J φx ) if i = x; -1 2 j =x (J φj -J φx )p(j)/p(x) if i = x, up to a constant, i.e. adding the same number c to all components. We further add the constant 1 2 J φx and obtain J ZGR pi = 1 2 J φi if i = x; 1 2 J φx -1 2 j =x (J φj -J φx )p(j)/p(x) if i = x, Subtracting 1 2 J ST p , the reminder is 1 2 J RE p with J RE pi = [[i=x]] 1 p(x) j =x (J φx -J φj )p(j). Simplifying j =x (J φx -J φj )p(j) = j (J φx -J φj )p(j) = J φx -j J φj p(j) we obtain J RE pi = [[i=x]] 1 p(x) J φx -j J φj p(j) . and we see that J RE p = J DARN p with φ = p = φ(η). Theorem 2. ZGR is unbiased for quadratic loss functions L with arbitrary coefficients. Proof. Since ZGR estimator is linear in L (estimate for a linear combination of two loss functions is the linear combination of estimates), it is sufficient to prove the claim for one-hot embedding φ and some elementary functions forming a basis for all quadratic functions. With one-hot embedding we have φ(η) i = p(x=i; η) = p i . Let us start with a linear monomial L(x) = φ(x) i . The expected loss is E[L(x)] = p(x=i; η). The true gradient is J η = d dη p(x=i; η). Substituting J φ k = [[k=i] ] in ST we have J ST η = dL(φ(x)) dφ d φ(η) dη = d φ(η)i dη = J η . ( ) This may come as a surprise for someone, but ST for a single categorical variable is exact (zero bias and zero variance). The expectation of J DARN simplifies as follows for any φ and a linear loss function, ensuring that J φ is constant in x: E[J DARN η ] = x p(x)J φ (φ(x) -φ) 1 p(x) dp(x;η) dη = J φ x (φ(x) -φ) dp(x;η) dη = J φ x φ(x) dp(x;η) dη -J φ φ d dη x p(x; η) = J φ d φ(η) dη . (69) Substituting J φ k = [[k=i]] and φ(η) = p we obtain E[J DARN η ] = dp(x=i;η) dη = J η , reconfirming that DARN is unbiased for linear function of categorical variables as expected. It follows that 1 2 (J ST η + J DARN η ) is also unbiased. Let us now consider the elementary quadratic function L(φ(x)) = φ(x) 2 i -φ(x) i . For all discrete assignments it is zero, therefore the true gradient of its expected value is zero. We have J φ k (x) = 2φ(x) i -1 k = i 0 k = i. Therefore J ST p k = 0 for k = i and E[J ST pi ] = E[2φ(x) i -1] = 2p i -1. For J DARN pi we have J DARN pi = (J φi (x) -J φ (x)p) 1 p(x) [[x=i]] (73) = (1 -p i )(2φ(x) i -1) 1 p(x) [[x=i]]. ( ) Its expectation is (1 -p i )(2φ i (i) -1) = 1 -p i . ( ) For J DARN p k = 0 for k = i we have J DARN p k = (J φ k (x) -J φ (x)p) 1 p(x) [[x=k]] (76a) = -J φ (x)p 1 p(x) [[x=k]] (76b) = -p i (2φ(x) i -1) 1 p(x) [[x=k]]. ( ) Its expectation is Next we consider a bilinear monomial in φ: L(φ(x)) = φ(x) 1 φ(x) 2 , where we have taken indices 1 and 2, without loss of generality. Its is zero for all discrete assignments and therefore the gradient of its expectation is zero. We have -p i (2φ i (k) -1) = p i . J phi1 = φ 2 (x) = [[x=2]] (78a) J phi2 = φ 1 (x) = [[x=1]]. For ST we have J ST p = J φ and E[J ST p1 ] = p 2 , E[J ST p2 ] = p 1 , E[J ST p k ] = 0, k = 1, 2. ( ) For DARN part we have: J DARN p1 = [[x=1]] 1 p1 (J φ1 -J φ1 p 1 -J φ2 p 2 ), J DARN p2 = [[x=2]] 1 p2 (J φ2 -J φ1 p 1 -J φ2 p 2 ), J DARN p k = [[x=k]] 1 p k (-J φ1 p 1 -J φ2 p 2 ), k = 1, 2. In the expectation, substituting J φ : E[J DARN p1 ] = [[x=1]] 1 p1 (φ 2 (1) -φ 2 (1)p 1 -φ 1 (1)p 2 ) = -p 2 , E[J DARN p2 ] = [[x=2]] 1 p2 (φ 1 (2) -φ 2 (2)p 1 -φ 1 (2)p 2 ) = -p 1 , E[J DARN p k ] = [[x=k]] 1 p k (-φ 2 (k)p 1 -φ 1 (k)p 2 ) = 0, k = 1, 2. This exactly cancels with ST. The elementary functions we have considered form a basis in the space of all quadratic functions. By linearity argument, J ZGR =foot_0 2 (J ST + J DARN ) is unbiased for all quadratic functions.

B DETAILS OF EXPERIMENTS

Here we give detailed specifications of our experiments. The implementation of all experiments will be made publicly available upon publication. During the review period, we will be happy to answer questions and share the code with reviewers confidentially through the OpenReview platform.

B.1 DATASET

In quantized training we use MNIST 1 and FashionMNISTfoot_1 datasets. Each contains 60000 training and 10000 test images. We used 54000 images for training and 6000 for validation. In VAE training, following the prior work, we use a decoder with Bernoulli output layer, which requires binary datasets. MNIST-B is a binarized MNIST with a fixed threshold of 0.5, same as in Yin et al. (2019) . The original Omniglot dataset is of the size 105 × 105 and contains binary images. However the established benchmarks use its down-sampled version (to size 28 × 28), which is then dynamically sampled: binary pixel values are generated with probabilities proportional to the original pixel values (Burda et al., 2016; Dong et al., 2021) , which we denote as Omniglot-28-D. The down-scaled dataset published by Burda et al. (2016) foot_2 was used, same as in the public implementation of Dong et al. (2021) . It contains about 24000 training images, which were split into training (90%) and validation (10%) parts and currently we are not using the validation part. def ZGR(p:Tensor)->Tensor: """Returns a categorical sample from p [ * ,C] (over axis=-1) as one-hot vector, with ZGR gradient. """ index = Categorical(probs=p).sample() x = F.one_hot(index, num_classes=p.shape[-1]).to(p) logpx = p.log().gather(-1, index.unsqueeze(-1))# log p(x) dx_ST = p dx_RE = (x -p.detach()) * logpx dx = (dx_ST + dx_RE) / 2 return x + (dx -dx.detach()) # value of x with backprop through dx Gumbel-Softmax (GS) and Straight-through Gumbel-Softmax (GS-ST) (Jang et al., 2017) are shipped with pytorchfoot_3 . For Gumbel-Rao with MC samples (GR) we adopted the public reimplementation by nshepperdfoot_4 . RF(M ) we implemented according to (Kool et al., 2019, Eq. 8) . The part of the computation relevant to the encoder is propagated forward and backward only once. In the decoder we perform as many backward passes as forward, as this reduces variance of the gradient in decoder parameters. In quantization our implementation performs a backward pass for each forward pass, and is not well organized. For ARSM (Yin et al., 2019) we made own reimplementation, cross-checked with the authors tensorflow implementationfoot_5 . As with RF(M ), we also performed a backward pass for each forward pass.

B.3 VAE

Model In our model each categorical variable is encoded as a vector of ±1, corresponding to the bit representation of x, similar to Paulus et al. (2021) . There is a fixed number of total hidden bits (192) , which are split into several categorical variables. For example 192 1b variables or 32 6-bit variables. This way the number of weights in the network stays constant. The network architecture is adopted from Yin et al. (2019) : Linear(784,512) → LReLU → Linear(512,256) → LReLU → Linear(256, D*K), where in the last layer we have D of K-way categorical units and LReLU has a leaky coefficient of 0.2 (same as in Dong et al. (2021) , default in tensorflow). The output of the encoder defines logits of the encoder Bernoulli model q(z i =1|x), where x is the input binary image and z is the latent discrete state. The decoder has exactly the reverse Linear-LReLU architecture and outputs logits of conditionally independent Bernoulli generative model p(x i =1|z). We optimize the standard evidence lower bound (ELBO) Kingma & Welling (2013) with prior distribution p(z) uniform and not learned. We do not perform any special data-based initializations like subtracting data mean in the encoder in Dong et al. (2021) . Optimization In the forward pass all methods produce a sample, from which a stochastic estimate of the gradient with respect to the decoder parameters is readily computed by backpropagation through decoder. We compute the KL term in ELBO analytically for a mini-batch and use its exact gradient. The estimation problem (1) occurs for the gradient of the data term with respect to the encoder parameters, where the estimators through discrete variables are applied. All methods, including GS that optimizes ELBO with relaxed samples, are evaluated by the correct ELBO with discrete samples. In the VAE experiments we measure the gradient accuracy and the training performance and do not make use of validation or test sets. First, this is reasonable when comparing quality of gradient estimators, regardless generalization. Second, the prior work Dong et al. (2021) has verified that improvement in the training ELBO translates into improvement of the test ELBO and IWAE bounds. Following Dong et al. (2021) we train with Adam with learning rate 10 -4 using batch size 50. Furthermore we tried to match the training time that of Dong et al. (2021) . For MNIST we perform 500 epochs, and for Omniglot-28-D we perform 1000 epochs, roughly equivalent in booth cases to their 500K iterations with batch size 50.

B.4 ADDITIONAL VAE EXPERIMENTS

We include some extended results.

Bias-variance analysis

We conducted the bias-variance analysis for VAE at different training stages. Namely, we trained the model using RF(4) for 1, 10, 100, and 200 epochs and at each stage evaluated bias and variance of all gradient estimators. The model used 16 categorical variables of 16 categories (64 total latent bits). The results are displayed in Fig. B .2. At the very beginning of training, the picture looks substantially different in that there is some bias reversal in GS-ST and derived estimators. However from epoch 10 and on the trends and relative ordering of methods stabilizes, with only RF(4) slightly overtaking ZGR in variance. The different picture after 1 epoch suggests that it would be beneficial in practice to warm-up the training with a few epochs of GS-ST at t = 1 or ST. We left such tuning to future work.

B.5 QUANTIZED NEURAL NETWORKS

Experimental Setup In this experiment we train a convolutional network 32C5-MP2-64C5-MP2-512FC-10FC, closely replicating the model evaluated by Louizos et al. (2019) for MNIST. Each activation quantization is preceded by batch normalization (Ioffe & Szegedy, 2015) . All gradient estimators are working with the same network, parametrization and initialization. In the case of logistic noise the noise standard deviation is learnable and is initialized to 1/3. All methods are applied with Adam optimizer for 200 epochs. For every method we select the bast validation performance with the grid search for the learning rate from {10 -3 , 3.3 * 10 -4 , 10 -4 }. We used the step-wise learning rate schedule decreasing the learning rate 10 times at epochs 100 and 150. The whole procedure is repeated for 3 different initialization seeds and we report the mean test error over seeds and ±(max -min)/2 over seeds. For validation and testing, we evaluate the network in the 'deterministic' mode, turning off all injected noises. This corresponds to a simple deterministic quantized model to be deployed. Single Unit Quantization We include the following toy experiment that well illustrates properties of different estimators. We evaluate bias and variance of all estimators on a simple function of a single quantized variable. Let η be a real-valued parameter. Let p(x; η) be given by the stochastic quantization model with K = 4 states and a particular noise type. Given a test function L(x) we can compute the true gradient of E[L(x)]. For each estimator we draw 10 4 samples to compute its mean and standard deviation for each value of η. The results are presented in Fig. B.3 . In this plot we show several combinations of loss functions and noises. The test functions are: linear L(x) = x; quadratic L(x) = 1 2 (x -c) 2 and sigmoid L(x) = σ(2(x -c)), where c = (K -1)/2 is chosen for centering. The noises shown refer to the logistic noise with std = 1/3 as used at initialization by Louizos et al. 2019 and the triangular noise with the density p(z) = max(0, 1 -|z|). The bias of the GS family quickly decreases with the temperature. ZRG estimator achieves the same expected value as GS-ST in the limit of small temperature illustrated by GS-ST(t=0.1) and the variance comparable to that of GR(t=0.1, M=100). We also verify that ZGR has zero bias for quadratic objectives as we have shown theoretically. Additional Comparisons In Table 3 we left out GS method after preliminary testing. These results (without confidence bounds) are available in Table B .1. A we also performed a full comparison on FMNIST dataset in Table B .2, which qualitatively agrees with the results for MNIST Table 3 and confirms that ZGR improves over GR-MC(t=0.1, M=10). The variance of ARSM is substantially higher. Plain ST estimator has a yet smaller variance but a larger bias, which may accumulate during taring.



http://yann.lecun.com/exdb/mnist/ https://github.com/zalandoresearch/fashion-mnist https://github.com/yburda/iwae/raw/master/datasets/OMNIGLOT/chardata. mat Pytorch function torch.functional.gumbel_softmax https://github.com/nshepperd/gumbel-rao-pytorch https://github.com/ARM-gradient/ARSM



Figure 3: MNIST classification gradient estimation accuracy and computation cost. Left, middle: average (per parameter) squared bias (rsp. variance) in the first layer of the network of different estimators along the training trajectory of ZGR (ternary weights/activations, triangular noise). Right: Time of gradient estimate per batch (of size 128) in our implementation. * Our RF implementation is apparently not efficient.

By subtracting p i from all ordinates of J DARN p , we obtain an equivalent (having identical derivative in η) form where E[J DARN p k ] = 0 for all k = i and E[J DARN pi ] = 1-2p i , which cancels with E[J ST pi ].

Figure B.1: ZGR implementation in Pytorch for a general categorical variable.

Figure B.2: Gradient estimation accuracy in VAE on MNIST-B. Average (per parameter) squared bias (left) and variance (right) of gradient estimators versus temperature for a model snapshot at a particular iteration of training with RF(4). VAE network with 16 categorical variables with 16 categories.

Figure B.3:Performance of estimators in stochastic quantization with selected combinations of test functions and injected noise for varied input η of the stochastic quantizer. The dashed blue line is the exact gradient. Red line is the mean of the estimator and the red shaded area shows ±1 std. GR with 100 is able to reduce the variance of GS-ST substantially. ZGR reduces the variance by an edge further while keeping the bias equal to the theoretical bias of GR with zero temperature. The variance of ARSM is substantially higher. Plain ST estimator has a yet smaller variance but a larger bias, which may accumulate during taring.

VAE training negative ELBO for binary MNIST > 0.5 and the down-sampled and dynamically binarized Omniglot. Each value is the mean over 3 random initializations and confidence intervals are ±(max -min)/2 of the 3 runs. Bold results are the three best ones per configuration. Gradient estimation accuracy in VAE on MNIST-B. Average (per parameter) squared bias (left) and variance (right) of gradient estimators versus temperature at a model snapshot after 100 epochs of training with RF(4). Confidence intervals are 95% empirical intervals of 100 bootstrap samples of the estimates (very narrow for variances).

MNIST classification test error[%]  in deterministic mode (no injected noises at test time) for different bit-width per weigh and activation (T denotes ternary). Hyperparameters are selected on the validation set. Reference test errors: ReLU 0.69% , Clamp 0.64%.

REPRODUCIBILITY STATEMENT

Proofs of all formal claims are presented in Appendix A. Details of the experiments are described in Appendix B. The source code of our implementation will be made publicly available upon publication. During the review period, we will be happy to answer questions and share the code with reviewers confidentially through the OpenReview platform. 

