SIMPLE: A GRADIENT ESTIMATOR FOR k-SUBSET SAMPLING

Abstract

k-subset sampling is ubiquitous in machine learning, enabling regularization and interpretability through sparsity. The challenge lies in rendering k-subset sampling amenable to end-to-end learning. This has typically involved relaxing the reparameterized samples to allow for backpropagation, with the risk of introducing high bias and high variance. In this work, we fall back to discrete k-subset sampling on the forward pass. This is coupled with using the gradient with respect to the exact marginals, computed efficiently, as a proxy for the true gradient. We show that our gradient estimator, SIMPLE, exhibits lower bias and variance compared to state-of-the-art estimators, including the straight-through Gumbel estimator when k = 1. Empirical results show improved performance on learning to explain and sparse linear regression. We provide an algorithm for computing the exact ELBO for the k-subset distribution, obtaining significantly lower loss compared to SOTA.

1. INTRODUCTION

k-subset sampling, sampling a subset of size k of n variables, is omnipresent in machine learning. It lies at the core of many fundamental problems that rely upon learning sparse features representations of input data, including stochastic high-dimensional data visualization (van der Maaten, 2009), parametric k-nearest neighbors (Grover et al., 2018) , learning to explain (Chen et al., 2018) , discrete variational auto-encoders (Rolfe, 2017), and sparse regression, to name a few. All such tasks involve optimizing an expectation of an objective function with respect to a latent discrete distribution parameterized by a neural network, which are often assumed intractable. Score-function estimators offer a cloyingly simple solution: rewrite the gradient of the expectation as an expectation of the gradient, which can subsequently be estimated using a finite number of samples offering an unbiased estimate of the gradient. Simple as it is, score-function estimators suffer from very high variance which can interfere with training. This provided the impetus for other, low-variance, gradient estimators, chief among them are those based on the reparameterization trick, which allows for biased, but low-variance gradient estimates. The reparameterization trick, however, does not allow for a direct application to discrete distributions thereby prompting continuous relaxations, e.g. Gumbelsoftmax (Jang et al., 2017; Maddison et al., 2017) , that allow for reparameterized gradients w.r.t the parameters of a categorical distribution. Reparameterizable subset sampling (Xie & Ermon, 2019) generalizes the Gumbel-softmax trick to k-subsets which while rendering k-subset sampling amenable to backpropagation at the cost of introducing bias in the learning by using relaxed samples. In this paper, we set out with the goal of avoiding all such relaxations. Instead, we fall back to discrete sampling on the forward pass. On the backward pass, we reparameterize the gradient of the loss function with respect to the samples as a function of the exact marginals of the k-subset distribution. Computing the exact conditional marginals is, in general, intractable (Roth, 1996) . We give an efficient algorithm for computing the k-subset probability, and show that the conditional marginals correspond to partial derivatives, and are therefore tractable for the k-subset distribution. We show that our proposed gradient estimator for the k-subset distribution, coined SIMPLE, is reminiscent of the straight-through (ST) Gumbel estimator when k = 1, with the gradients taken with respect to the unperturbed marginals. We empirically demonstrate that SIMPLE exhibits lower bias and variance compared to other known gradient estimators, including the ST Gumbel estimator in the case k = 1.

