RAO-BLACKWELLIZING THE STRAIGHT-THROUGH GUMBEL-SOFTMAX GRADIENT ESTIMATOR

Abstract

Gradient estimation in models with discrete latent variables is a challenging problem, because the simplest unbiased estimators tend to have high variance. To counteract this, modern estimators either introduce bias, rely on multiple function evaluations, or use learned, input-dependent baselines. Thus, there is a need for estimators that require minimal tuning, are computationally cheap, and have low mean squared error. In this paper, we show that the variance of the straight-through variant of the popular Gumbel-Softmax estimator can be reduced through Rao-Blackwellization without increasing the number of function evaluations. This provably reduces the mean squared error. We empirically demonstrate that this leads to variance reduction, faster convergence, and generally improved performance in two unsupervised latent variable models.

1. INTRODUCTION

Models with discrete latent variables are common in machine learning. Discrete random variables provide an effective way to parameterize multi-modal distributions, and some domains naturally have latent discrete structure (e.g, parse trees in NLP). Thus, discrete latent variable models can be found across a diverse set of tasks, including conditional density estimation, generative text modelling (Yang et al., 2017) , multi-agent reinforcement learning (Mordatch & Abbeel, 2017; Lowe et al., 2017) or conditional computation (Bengio et al., 2013; Davis & Arel, 2013) . The majority of these models are trained to minimize an expected loss using gradient-based optimization, so the problem of gradient estimation for discrete latent variable models has received considerable attention over recent years. Existing estimation techniques can be broadly categorized into two groups, based on whether they require one loss evaluation (Glynn, 1990; Williams, 1992; Bengio et al., 2013; Mnih & Gregor, 2014; Chung et al., 2017; Maddison et al., 2017; Jang et al., 2017; Grathwohl et al., 2018) or multiple loss evaluations (Gu et al., 2016; Mnih & Rezende, 2016; Tucker et al., 2017) per estimate. These estimators reduce variance by introducing bias or increasing the computational cost with the overall goal being to reduce the total mean squared error. Because loss evaluations are costly in the modern deep learning age, single evaluation estimators are particularly desirable. This family of estimators can be further categorized into those that relax the discrete randomness in the forward pass of the model (Maddison et al., 2017; Jang et al., 2017; Paulus et al., 2020) and those that leave the loss computation unmodified (Glynn, 1990; Williams, 1992; Bengio et al., 2013; Chung et al., 2017; Mnih & Gregor, 2014; Grathwohl et al., 2018) . The ones that do not modify the loss computation are preferred, because they avoid the accumulation of errors in the forward direction and they allow the model to exploit the sparsity of discrete computation. Thus, there is a particular need for single evaluation estimators that do not modify the loss computation. In this paper we introduce such a method. In particular, we propose a Rao-Blackwellization scheme for the straight-through variant of the Gumbel-Softmax estimator (Jang et al., 2017; Maddison et al., 2017) , which comes at a minimal cost, and does not increase the number of function evaluations. The straight-through Gumbel-Softmax estimator (ST-GS, Jang et al., 2017) is a lightweight stateof-the-art single-evaluation estimator based on the Gumbel-Max trick (see Maddison et al., 2014, and references therein) . The ST-GS uses the argmax over Gumbel random variables to generate a discrete random outcome in the forward pass. It computes derivatives via backpropagation through a tempered softmax of the same Gumbel sample. Our Rao-Blackwellization scheme is based on the key insight that there are many configurations of Gumbels corresponding to the same discrete random outcome and that these can be marginalized over with Monte Carlo estimation. By design, there is no need to re-evaluate the loss and the additional cost of our estimator is linear only in the number of Gumbels needed for a single forward pass. As we show, the Rao-Blackwell theorem implies that our estimator has lower mean squared error than the vanilla ST-GS. We demonstrate the effectiveness of our estimator in unsupervised parsing on the ListOps dataset (Nangia & Bowman, 2018) and on a variational autoencoder loss (Kingma & Welling, 2013; Rezende et al., 2014) . We find that in practice our estimator trains faster and achieves better test set performance. The magnitude of the improvement depends on several factors, but is particularly pronounced at small batch sizes and low temperatures.

2. BACKGROUND

For clarity, we consider the following simplified scenario. Let D ∼ p θ be a discrete random variable D ∈ {0, 1} n in a one-hot encoding, D i = 1, with distribution given by p θ (D) ∝ exp(D T θ) where θ ∈ R n . Given a continuously differentiable f : R 2n → R, we wish to minimize, min θ E[f (D, θ)], where the expectation is taken over all of the randomness. In general θ may be computed with some neural network, so our aim is to derive estimators of the total derivative of the expectation with respect to θ for use in stochastic gradient descent. This framework covers most simple discrete latent variable models, including variational autoencoders (Kingma & Welling, 2013; Rezende et al., 2014) . The REINFORCE estimator (Glynn, 1990; Williams, 1992) is unbiased (under certain smoothness assumptions) and given by: ∇ REINF := f (D, θ) ∂ log p θ (D) ∂θ + ∂f (D, θ) ∂θ . Without careful use of control variates (Mnih & Gregor, 2014; Tucker et al., 2017; Grathwohl et al., 2018) , the REINFORCE estimator tends to have prohibitively high variance. To simplify exposition we assume henceforth that f (D, θ) = f (D) does not depend on θ, because the dependence of f (D, θ) on θ is accounted for in the second term of (2), which is shared by most estimators and generally has low variance. One strategy for reducing the variance is to introduce bias through a relaxation (Jang et al., 2017; Maddison et al., 2017) . Define the tempered softmax softmax τ : R n → R n by softmax τ (x) i = exp(x i /τ )/ n j=1 exp(x j /τ ). The relaxations are based on the observation that the sampling of D can be reparameterized using Gumbel random variables and the zero-temperature limit of the tempered softmax under the coupling: D = lim τ →0 S τ ; S τ = softmax τ (θ + G) (3) where G is a vector of i.i.d. G i ∼ Gumbel random variables. At finite temperatures S τ is known as a Gumbel-Softmax (GS) (Jang et al., 2017) or concrete (Maddison et al., 2017) random variable, and the relaxed loss E[f (S τ , θ)] admits the following reparameterization gradient estimator for τ > 0:foot_0  ∇ GS := ∂f (S τ ) ∂S τ d softmax τ (θ + G) dθ .



For a function f (x1, x2), ∂f (z1, z2)/∂x1 is the partial derivative (e.g., a gradient vector) of f in the first variable evaluated at z1, z2. For a function g(θ), dg/dθ is the total derivative of g in θ. For example, d softmaxτ (θ + G)/dθ is the Jacobian of the tempered softmax evaluated at the random variable θ + G.

