SIMPLE: A GRADIENT ESTIMATOR FOR k-SUBSET SAMPLING

Abstract

k-subset sampling is ubiquitous in machine learning, enabling regularization and interpretability through sparsity. The challenge lies in rendering k-subset sampling amenable to end-to-end learning. This has typically involved relaxing the reparameterized samples to allow for backpropagation, with the risk of introducing high bias and high variance. In this work, we fall back to discrete k-subset sampling on the forward pass. This is coupled with using the gradient with respect to the exact marginals, computed efficiently, as a proxy for the true gradient. We show that our gradient estimator, SIMPLE, exhibits lower bias and variance compared to state-of-the-art estimators, including the straight-through Gumbel estimator when k = 1. Empirical results show improved performance on learning to explain and sparse linear regression. We provide an algorithm for computing the exact ELBO for the k-subset distribution, obtaining significantly lower loss compared to SOTA.

1. INTRODUCTION

k-subset sampling, sampling a subset of size k of n variables, is omnipresent in machine learning. It lies at the core of many fundamental problems that rely upon learning sparse features representations of input data, including stochastic high-dimensional data visualization (van der Maaten, 2009) , parametric k-nearest neighbors (Grover et al., 2018) , learning to explain (Chen et al., 2018) , discrete variational auto-encoders (Rolfe, 2017) , and sparse regression, to name a few. All such tasks involve optimizing an expectation of an objective function with respect to a latent discrete distribution parameterized by a neural network, which are often assumed intractable. Score-function estimators offer a cloyingly simple solution: rewrite the gradient of the expectation as an expectation of the gradient, which can subsequently be estimated using a finite number of samples offering an unbiased estimate of the gradient. Simple as it is, score-function estimators suffer from very high variance which can interfere with training. This provided the impetus for other, low-variance, gradient estimators, chief among them are those based on the reparameterization trick, which allows for biased, but low-variance gradient estimates. The reparameterization trick, however, does not allow for a direct application to discrete distributions thereby prompting continuous relaxations, e.g. Gumbelsoftmax (Jang et al., 2017; Maddison et al., 2017) , that allow for reparameterized gradients w.r.t the parameters of a categorical distribution. Reparameterizable subset sampling (Xie & Ermon, 2019) generalizes the Gumbel-softmax trick to k-subsets which while rendering k-subset sampling amenable to backpropagation at the cost of introducing bias in the learning by using relaxed samples. In this paper, we set out with the goal of avoiding all such relaxations. Instead, we fall back to discrete sampling on the forward pass. On the backward pass, we reparameterize the gradient of the loss function with respect to the samples as a function of the exact marginals of the k-subset distribution. Computing the exact conditional marginals is, in general, intractable (Roth, 1996) . We give an efficient algorithm for computing the k-subset probability, and show that the conditional marginals correspond to partial derivatives, and are therefore tractable for the k-subset distribution. We show that our proposed gradient estimator for the k-subset distribution, coined SIMPLE, is reminiscent of the straight-through (ST) Gumbel estimator when k = 1, with the gradients taken with respect to the unperturbed marginals. We empirically demonstrate that SIMPLE exhibits lower bias and variance compared to other known gradient estimators, including the ST Gumbel estimator in the case k = 1. We used the cosine distance, defined as (1-cosine similarity), in place of the euclidean distance as we only care about the direction of the gradient, not magnitude. The bias, variance and error were estimated using a sample of size 10,000. The details of this experiment are provided in Section 5.1. We include an experiment on the task of learning to explain (L2X) using the BEERADVOCATE dataset (McAuley et al., 2012) , where the goal is to select the subset of words that best explains the model's classification of a user's review. We also include an experiment on the task of stochastic sparse linear regression, where the goal is to learn the best sparse model, and show that we are able to recover the Kuramoto-Sivashinsky equation. Finally, we develop an efficient computation for the calculation of the exact variational evidence lower bound (ELBO) for the k-subset distribution, which when used in conjunction with SIMPLE leads to state-of-the-art discrete sparse VAE learning. Contributions. In summary, we propose replacing relaxed sampling on the forward pass with discrete sampling. On the backward pass, we use the gradient with respect to the exact conditional marginals as a proxy for the true gradient, giving an algorithm for computing them efficiently. We empirically demonstrate that discrete samples on the forward pass, coupled with exact conditional marginals on the backward pass leads to a new gradient estimator, SIMPLE, with lower bias and variance compared to other known gradient estimators. We also provide an efficient computation of the exact ELBO for the k-subset distribution, leading to state-of-the-art discrete sparse VAE learning.

2. PROBLEM STATEMENT AND MOTIVATION

We consider models described by the equations θ = h v (x), z ∼ p θ (z | i z i = k), ŷ = f u (z, x), where x ∈ X and ŷ ∈ Y denote feature inputs and target outputs, respectively, h v : X → Θ and f u : Z × X → Y are smooth, parameterized maps and θ are logits inducing a distribution over the latent binary vector z. The induced distribution p θ (z) is defined as p θ (z) = n i=1 p θi (z i ), with p θi (z i = 1) = sigmoid(θ i ) and p θi (z i = 0) = 1 -sigmoid(θ i ). (2) The goal of our stochastic latent layer is not to simply sample from p θ (z), which would yield samples with a Hamming weight between 0 and n (i.e., with an arbitrary number of ones). Instead, we are interested in sampling from the distribution restricted to samples with a Hamming weight of k, for any given k. That is, we are interested in sampling from the conditional distribution p θ (z | i z i = k). Conditioning the distribution p θ (z) on this k-subset constraint introduces intricate dependencies between each of the z i 's. The probability of sampling any given k-subset vector z, therefore, becomes p θ (z | i z i = k) = p θ (z)/p θ ( i z i = k) • i z i = k where • denotes the indicator function. In other words, the probability of sampling each k-subset is re-normalized by p θ ( i z i = k) -the probability of sampling exactly k items from the unconstrained distribution induced by encoder h v . The quantity p θ ( i z i = k) = z p θ (z) • i z i = k appears to be intractable. We show that not to be the case, providing a tractable algorithm for computing it. Given a set of samples D, we are concerned with learning the parameters ω = (v, u) of the architecture in (1) through minimizing the training error L, which is the expected loss: L(x, y; ω) = E z∼p θ (z| i zi=k) [ℓ(f u (z, x), y)] with θ = h v (x),



Figure1: A comparison of the bias and variance of the gradient estimators (left) and the average and standard deviation of the cosine distance of a single-sample gradient estimate to the exact gradient. We used the cosine distance, defined as (1-cosine similarity), in place of the euclidean distance as we only care about the direction of the gradient, not magnitude. The bias, variance and error were estimated using a sample of size 10,000. The details of this experiment are provided in Section 5.1.

