LEARNING CONSISTENT DEEP GENERATIVE MODELS FROM SPARSE DATA VIA PREDICTION CONSTRAINTS

Abstract

We develop a new framework for learning variational autoencoders and other deep generative models that balances generative and discriminative goals. Our framework optimizes model parameters to maximize a variational lower bound on the likelihood of observed data, subject to a task-specific prediction constraint that prevents model misspecification from leading to inaccurate predictions. We further enforce a consistency constraint, derived naturally from the generative model, that requires predictions on reconstructed data to match those on the original data. We show that these two contributions -prediction constraints and consistency constraints -lead to promising image classification performance, especially in the semi-supervised scenario where category labels are sparse but unlabeled data is plentiful. Our approach enables advances in generative modeling to directly boost semi-supervised classification performance, an ability we demonstrate by augmenting deep generative models with latent variables capturing spatial transformations.

1. INTRODUCTION

We develop broadly applicable methods for learning flexible models of high-dimensional data, like images, that are paired with (discrete or continuous) labels. We are particularly interested in semisupervised learning (Zhu, 2005; Oliver et al., 2018) from data that is sparsely labeled, a common situation in practice due to the cost or privacy concerns associated with data annotation. Given a large and sparsely labeled dataset, we seek a single probabilistic model that simultaneously makes good predictions of labels and provides a high-quality generative model of the high-dimensional input data. Strong generative models are valuable because they can allow incorporation of domain knowledge, can address partially missing or corrupted data, and can be visualized to improve interpretability. Prior approaches for the semi-supervised learning of deep generative models include methods based on variational autoencoders (VAEs) (Kingma et al., 2014; Siddharth et al., 2017) , generative adversarial networks (GANs) (Dumoulin et al., 2017; Kumar et al., 2017) , and hybrids of the two (Larsen et al., 2016; de Bem et al., 2018; Zhang et al., 2019) . While these all allow sampling of data, a major shortcoming of these approaches is that they do not adequately use labels to inform the generative model. Furthermore, GAN-based approaches lack the ability to evaluate the learned probability density function, which can be important for tasks such as model selection and anomaly detection. This paper develops a framework for training prediction constrained variational autoencoders (PC-VAEs) that minimize application-motivated loss functions in the prediction of labels, while simultaneously learning high-quality generative models of the raw data. Our approach is inspired by the prediction-constrained framework recently proposed for learning supervised topic models of "bag of words" count data (Hughes et al., 2018) , but differs in four major ways. First, we develop scalable algorithms for learning a much larger and richer family of deep generative models. Second, we capture uncertainty in latent variables rather than simply using point estimates. Third, we allow more flexible specification of loss functions. Finally, we show that the generative model structure leads to a natural consistency constraint vital for semi-supervised learning from very sparse labels. Our experiments demonstrate that consistent prediction-constrained (CPC) VAE training leads to prediction performance competitive with state-of-the-art discriminative methods on fully-labeled datasets, and excels over these baselines when given semi-supervised datasets where labels are rare.

VAE-then-MLP

PC-VAE CPC-VAE M2 M2 ( 14) CPC-VAE ( 14)  4 3 2 1 0 1 2 3 4 2 1 0 1 2 77.9% 4 3 2 1 0 1 2 3 4 2 1 0 1 2 78.1% 4 3 2 1 0 1 2 3 4 2 1 0 1 2 98.4% 4 3 2 1 0 1 2 3 4 2 1 0 1 2 98.1% 4 3 2 1 0 1 2 3 4 2 1 0 1 2 80.6% 4 3 2 1 0 1 2 3 4 2 1 0 1 2 98.5% 4 3 2 1 0 1 2 3 2 1 0 1 83.8% 4 3 2 1 0 1 2 3 2 1 0 1 98.2% 4 3 2 1 0 1 2 3 2 1 0 1 98.4% 4 3 2 1 0 1 2 3 2 1 0 1 98.1% 4 3 2 1 0 1 2 3 2 1 0 1 96.4%

2. BACKGROUND: DEEP GENERATIVE MODELS AND SEMI-SUPERVISION

We now describe VAEs as deep generative models and review previous methods for semi-supervised learning (SSL) of VAEs, highlighting weaknesses that we later improve upon. We assume all SSL tasks provide two training datasets: an unsupervised (or unlabeled) dataset D U of N feature vectors x, and a supervised (or labeled) dataset D S containing M pairs (x, y) of features x and label y ∈ Y. Labels are often sparse (N M ) and can be discrete or continuous.

2.1. UNSUPERVISED GENERATIVE MODELING WITH THE VAE

The variational autoencoder (Kingma & Welling, 2014 ) is an unsupervised model with two components: a generative model and an inference model. The generative model defines for each example a joint distribution p θ (x, z) over "features" (observed vector x ∈ R D ) and "encodings" (hidden vector z ∈ R C ). The "inference model" of the VAE defines an approximate posterior q φ (z | x), which is trained to be close to the true posterior (q φ (z | x) ≈ p θ (z | x)) but much easier to evaluate. As in Kingma & Welling (2014), we assume the following conditional independence structure: p θ (x, z) = N (z | 0, I C ) • F(x | µ θ (z), σ θ (z)), q φ (z | x) = N (z | µ φ (x), σ φ (x)). (1) The likelihood F is often multivariate normal, but other distributions may give robustness to outliers. The (deterministic) functions µ θ and σ θ , with trainable parameters θ, define the mean and covariance of the likelihood. Given any observation x, the posterior of z is approximated as normal with mean µ φ and (diagonal) covariance σ φ parameterized by φ. These functions can be represented as multi-layer perceptrons (MLPs), convolutional neural networks (CNNs), or other (deep) neural networks. We would ideally learn generative parameters θ by maximizing the marginal likelihood of features x, integrating latent variable z. Since this is intractable, we instead maximize a variational lower bound: max θ,φ x∈D L VAE (x; θ, φ), L VAE (x; θ, φ) = E q φ (z|x) log p θ (x,z) q φ (z|x) ≤ log p θ (x). This expectation can be evaluated via Monte Carlo samples from the inference model q φ (z|x). Gradients with respect to θ, φ can be similarly estimated by the reparameterization "trick" of representing q φ (z | x) as a linear transformation of standard normal variables (Kingma & Welling, 2014) . Throughout this paper, we denote variational parameters by φ. Because the factorization of q changes for more complex models, we will write φ z|x to denote the parameters specific to factor q(z|x).

2.2. TWO-STAGE SSL: MAXIMIZE FEATURE LIKELIHOOD THEN TRAIN PREDICTOR

One way to employ the VAE for a semi-supervised task is a two-stage "VAE-then-MLP". First, train a VAE to maximize the unsupervised likelihood (2) of all observed features x (both labeled D S and unlabeled D U ). Second, we define a label-from-code predictor ŷw (z) that maps each learned code representation z to a predicted label y ∈ Y. We use an MLP with weights w, though any predictor could do. Let S (y, ŷ) be a loss function, such as cross-entropy, appropriate for the prediction task. We train the predictor to minimize the loss: min w x,y∈D S E q φ (z|x) [ S (y, ŷw (z))]. Importantly, this second stage uses only the small labeled dataset and relies on fixed parameters φ from stage one.



Predictions from SSL VAE methods on half-moon binary classification task, with accuracy in lower corner. Each dot indicates a 2-dim. feature vector, colored by predicted binary label. Top: 6 labeled examples (diamond markers), 994 unlabeled. Bottom: 100 labeled, 900 unlabeled. First 4 columns use C = 2 encoding dimensions, last 2 use C = 14. M2 (Kingma et al., 2014) classification accuracy deterioriates when increasing model capacity from 2 to 14, especially with only 6 labels (drop from 98.1% to 80.6% accuracy). In contrast, our CPC VAE approach is reliable at any model capacity, as it better aligns generative and discriminative goals.

