COLD RAO-BLACKWELLIZED STRAIGHT-THROUGH GUMBEL-SOFTMAX GRADIENT ESTIMATOR

Abstract

The problem of estimating the gradient of an expectation in discrete random variables arises in many applications: learning with discrete latent representations, training neural networks with quantized weights, activations, conditional blocks, etc. This work is motivated by the development of the Gumbel-Softmax family of estimators, which are based on approximating argmax with a temperatureparametrized softmax. The state-of-the art in this family, the Gumbel-Rao estimator uses internal MC samples to reduce the variance and appars to improve with lower temperatures. We show that it possesses a zero temperature limit with a surprisingly simple closed form defining our new estimator called ZGR. It has favorable bias and variance properties, is easy to implement, computationally inexpensive and is obviously free of the temperature hyperparameter. Furthermore, it decomposed as the average of the straight through estimator and the DARN estimator -two basic but not very well performing on their own estimators. Unlike them, ZGR is shown to be unbiased for the class of quadratic functions of categorical variables. Experiments thoroughly validate the method.

1. INTRODUCTION

Discrete variables and discrete structures are important in machine learning. For example, the semantic hashing idea (Salakhutdinov & Hinton, 2009) is to learn a binary vector representation such that semantically similar instances (e.g. images or text documents) would have similar bit representations. This allows for a quick retrieval via a lookup table or via nearest neighbor search with respect to the Hamming distance of binary encodings. Recent work employs variational autoencoders (VAEs) with binary latent states (Shen et al., 2018; Dadaneh et al., 2020; Ñanculef et al., 2020) . Another example is neural networks with discrete (binary or quantized) weights and activations. They allow for a low-latency and energy efficient inference, particularly important for edge devices. Recent results indicate that quantized networks can achieve competitive accuracy with a better efficiency in various applications (Nie et al., 2022) . VAEs and quantized networks are two diverse examples that motivate our development and the experimental benchmarks. Other potential applications include conditional computation (Bengio et al., 2013a; Yang et al., 2019; Bulat et al., 2021) reinforcement learning (Yin et al., 2019) , learning task-specific tree structures for agglomerative neural networks (Choi et al., 2018) , neural architecture search (Chang et al., 2019) and more. The learning problem in the presence of stochastic variables is usually formulated as minimization of the expected loss. The gradient-based optimization requires a gradient of the expectation in the probabilities of random variables (or parameters of networks inferring those probabilities). Unbiased gradient estimators have been developed (Williams, 1992; Grathwohl et al., 2018; Tucker et al., 2017; Gu et al., 2016) . These estimators work even for non-differentiable losses, however their high variance is the main limitation. More recent advances (Yin et al., 2019; Kool et al., 2020; Dong et al., 2020; 2021; Dimitriev & Zhou, 2021b; a) reduce the variance by using several cleverly coupled samples. However, the hierarchical (or deep) case is not addressed satisfactory. There is experimental evidence that the variance grows substantially with the depth of the network leading to very poor performance in training (e.g., Shekhovtsov et al. 2020, Fig.C.6, C.7) . Furthermore, existing extensions of coupled sampling methods to networks with L dependent layers (Dong et al., 2020; Yin et al., 2019) apply their base method in every layer, requiring several complete forward  , GS, GS-ST, ZGR X 1 O(K) GR-MC(M ) X 1 O(M K) REINFORCE 1 O(K) RF(M ) (Kool et al., 2019) M ≥ 2 O(K) ARSM (Yin et al., 2019) O(K 2 L 2 ) O(1) DisARM-* (Dong et al., 2020; 2021) O(2L 2 ) O(K) CARMS(M ) Dimitriev & Zhou (2021b;a) O(M L 2 ) * O(M 2 K 2 ) * samples per layer. The computation complexity thus grows quadratically with the number of layers, as summarized in Table 1 . A different family of methods, fitting practical needs for deep models better, exploits continuation arguments. It includes ST variants (Bengio et al., 2013b; Shekhovtsov & Yanush, 2021; Pervez et al., 2020) and Gumbel-Softmax variants (to be discussed below). These methods assume the loss function to be differentiable and try to estimate the derivative with respect to parameters of a discrete distribution from the derivative of the loss function. Such estimators can be easily incorporated into back-propagation by adjusting the forward and backward passes locally for every discrete variable. They are, in general, biased because the derivative of the loss need not be relevant for the discrete expectation. The rationale though is that it may be possible to obtain a low variance estimate at a price of small bias, e.g. for a sufficiently smooth loss function (Shekhovtsov & Yanush, 2021) . Gumbel Softmax (Jang et al., 2017) and the concurrently developed Concrete relaxation (Maddison et al., 2017) enable differentiability through discrete variables by relaxing them to real-valued variables with a distribution approximating the original discrete distribution. The tightness of the relaxation is controlled by the temperature parameter t > 0. The bias can be reduced by decreasing the temperature, but the variance grows as O(1/t) (Shekhovtsov, 2021). Gumbel-Softmax Straight-Through (GS-ST) heuristic (Jang et al., 2017) uses discrete samples on the forward pass, but relaxed ones on the backward, reducing side biases (see below). The Gumbel-Rao (GR) estimator (Paulus et al., 2021 ) is a recent improvement of GS-ST, which can substantially reduce its variance by a local expectation. However the local expectation results in an intractable integration in multiple variables, which is approximated by sampling. The experiments (Paulus et al., 2021) suggest that this estimator performs better at lower temperatures, which requires more MC samples. Therefore the computation cost appears to be a major limitation. Inspired by the performance of GR at low temperatures, we analyze its behavior for temperatures close to and in the limit of the absolute zero. Note that because the underlying relaxation becomes discrete in the cold limit, leading to explosion of variance in both GS and GS-ST, they do not have a zero temperature limit. It is not obvious therefore that GR would have one. We prove that it does and denote this limit estimator as ZGR. In the case of binary variables we give an asymptotic series expansion of GR around t = 0 and show that ZGR has a simple analytic expression, matching the already known DARN( 1 2 ) estimator by Gregor et al. ( 2014) (also re-discovered as importance reweighed ST by Pervez et al. 2020). In the general categorical case, we obtain the analytic expression of ZGR and show that it is a new estimator with a simple formula 1 2 (ST+DARN). We show that ZGR is unbiased for all quadratic functions of categorical variables and experimentally show that it achieves a useful bias-variance tradeoff. We also contribute a refined experimental comparison of GS family with unbiased estimators.

2. BACKGROUND

Let x be a categorical random variable taking values in K = {0, . . . K -1} with probabilities p(x; η) parametrized by η. Let φ(x) ∈ R d be an embedding of the discrete state x in the vector space. Categorical variables are usually represented using 1-hot embedding, in this case d = K. The value of x itself will still be used as an index. For binary and quantized variables we will adopt the embedding φ(x) = x, in this case d = 1.



Computation complexity of estimators in a hierarchical network with L dependency layers of K-way categorical variables. All methods require a backward pass.

