TRADE: A SIMPLE SELF-ATTENTION-BASED DENSITY ESTIMATOR

Abstract

We present TraDE, a self-attention-based architecture for auto-regressive density estimation with continuous and discrete valued data. Our model is trained using a penalized maximum likelihood objective, which ensures that samples from the density estimate resemble the training data distribution. The use of self-attention means that the model need not retain conditional sufficient statistics during the autoregressive process beyond what is needed for each covariate. On standard tabular and image data benchmarks, TraDE produces significantly better density estimates than existing approaches such as normalizing flow estimators and recurrent autoregressive models. However log-likelihood on held-out data only partially reflects how useful these estimates are in real-world applications. In order to systematically evaluate density estimators, we present a suite of tasks such as regression using generated samples, out-of-distribution detection, and robustness to noise in the training data and demonstrate that TraDE works well in these scenarios.

1. INTRODUCTION

Density estimation involves estimating a probability density p(x), given independent, identically distributed (iid) samples from it. This is a versatile and important problem as it allows one to generate synthetic data or perform novelty and outlier detection. It is also an important subroutine in applications of graphical models. Deep neural networks are a powerful function class and learning complex distributions with them is promising. This has resulted in a resurgence of interest in the classical problem of density estimation. One of the more popular techniques for density estimation is to sample data from a simple reference distribution and then to learn a (sequence of) invertible transformations that allow us to adapt it to a target distribution. Flow-based methods (Durkan et al., 2019b) employ this with great success. A more classical approach is to decompose p(x) in an iterative manner via conditional probabilities p(x i+1 |x 1...i ) and fit this distribution using the data (Murphy, 2013). One may even employ implicit generative models to sample from p(x) directly, perhaps without the ability to compute density estimates. This is the case with Generative Adversarial Networks (GANs) that reign supreme for image synthesis via sampling (Goodfellow et al., 2014; Karras et al., 2017) . Implementing these above methods however requires special care, e.g., the normalizing transform requires the network to be invertible with an efficiently computable Jacobian. Auto-regressive models using recurrent networks are difficult to scale to high-dimensional data due to the need to store a potentially high-dimensional conditional sufficient statistic (and also due to vanishing gradients). Generative models can be difficult to train and GANs lack a closed density model. Much of the current work is devoted to mitigating these issues. The main contributions of this paper include: 1. We introduce TraDE, a simple but novel auto-regressive density estimator that uses self-attention along with a recurrent neural network 



Figure 1: TraDE is well suited to density estimation of Transformers. Left: Bumblebee (true density), Right: density estimated from data.

