HYBRID DISCRIMINATIVE-GENERATIVE TRAINING VIA CONTRASTIVE LEARNING

Abstract

Contrastive learning and supervised learning have both seen significant progress and success. However, thus far they have largely been treated as two separate objectives, brought together only by having a shared neural network. In this paper we show that through the perspective of hybrid discriminative-generative training of energy-based models we can make a direct connection between contrastive learning and supervised learning. Beyond presenting this unified view, we show our specific choice of approximation of the energy-based loss significantly improves energybased models and contrastive learning based methods in confidence-calibration, out-of-distribution detection, adversarial robustness, generative modeling, and image classification tasks. In addition to significantly improved performance, our method also gets rid of SGLD training and does not suffer from training instability. Our evaluations also demonstrate that our method performs better than or on par with state-of-the-art hand-tailored methods in each task.

1. INTRODUCTION

In the past few years, the field of deep learning has seen significant progress. Example successes include large-scale image classification (He et al., 2016; Simonyan & Zisserman, 2014; Srivastava et al., 2015; Szegedy et al., 2016) on the challenging ImageNet benchmark (Deng et al., 2009) . The common objective for solving supervised machine learning problems is to minimize the crossentropy loss, which is defined as the cross entropy between a target distribution and a categorical distribution called Softmax which is parameterized by the model's real-valued outputs known as logits. The target distribution usually consists of one-hot labels. There has been a continuing effort on improving upon the cross-entropy loss, various methods have been proposed, motivated by different considerations (Hinton et al., 2015; Müller et al., 2019; Szegedy et al., 2016) . Recently, contrastive learning has achieved remarkable success in representation learning. Contrastive learning allows learning good representations and enables efficient training on downstream tasks, an incomplete list includes image classification (Chen et al., 2020a; b; Grill et al., 2020; He et al., 2019; Tian et al., 2019; Oord et al., 2018 ), video understanding (Han et al., 2019) , and knowledge distillation (Tian et al., 2019) . Many different training approaches have been proposed to learn such representations, usually relying on visual pretext tasks. Among them, state-of-the-art contrastive methods (He et al., 2019; Chen et al., 2020a; c) are trained by reducing the distance between representations of different augmented views of the same image ('positive pairs'), and increasing the distance between representations of augment views from different images ('negative pairs'). Despite the success of the two objectives, they have been treated as two separate objectives, brought together only by having a shared neural network. In this paper, to show a direct connection between contrastive learning and supervised learning, we consider the energy-based interpretation of models trained with cross-entropy loss, building on Grathwohl et al. (2019) . We propose a novel objective that consists of a term for the conditional of the label given the input (the classifier) and a term for the conditional of the input given the label. We optimize the classifier term the normal way. Different from Grathwohl et al. (2019) , we approximately optimize the second conditional over the input with a contrastive learning objective instead of a Monte-Carlo sampling-based approximation. In doing so, we provide a unified view on existing practice. Our work takes inspiration from the work by Ng & Jordan (2002) . In their 2002 paper, Ng & Jordan (2002) showed that classifiers trained with a generative loss (i.e., optimizing p(x|y), with x the input and y the classification label) can outperform classifiers with the same expressiveness trained with a discriminative loss (i.e., optimizing p(y|x)). Later it was shown that hybrid discriminative generative model training can get the best of both worlds (Raina et al., 2004) . The work by Ng & Jordan (2002) was done in the (simpler) context of Naive Bayes and Logistic Regression. Our work can be seen as lifting this work into today's context of training deep neural net classifiers. Our empirical evaluation shows our method improves both the confidence-calibration and the classification accuracy of the learned classifiers, beating state-of-the-art methods. Despite its simplicity, our method outperforms competitive baselines in out-of-distribution (OOD) detection for all tested datasets. On hybrid generative-discriminative modeling tasks (Grathwohl et al., 2019) , our method obtains superior performance without needing to run computational expensive SGLD steps. Our method learns significantly more robust classifiers than supervised training and achieves highly competitive results with hand-tailored adversarial robustness algorithms. The contributions of this paper can be summarized as: (i) To the best of our knowledge, we are the first to reveal the connection between contrastive learning and supervised learning. We connect the two objectives through energy-based model. (ii) Built upon the insight, we present a novel framework for hybrid generative discriminative modeling via contrastive learning. (iii) Our method gets rid of SGLD therefore does not suffer from training instability of energy-based model. We empirically show that our method improves confidence-calibration, OOD detection, adversarial robustness, generative modeling, and classification accuracy, performing on par with or better than state-of-the-art energy-based models and contrastive learning algorithms for each task. 2017) compare and study the connections and differences between discriminative model and generative model, and shows hybrid generative discriminative models can outperform purely discriminative models and purely generative models. Our work differs in that we propose an effective training approach in the context of deep neural network. By using contrastive learning to optimize the generative models, our method achieves state-of-the-art performance on a wide range of tasks. Energy-based models (EBMs) have been shown can be derived from classifiers in supervised learning in the work of Xie et al. (2016); Du & Mordatch (2019) , they reinterpret the logits to define a classconditional EBM p(x|y). Our work builds heavily on JEM (Grathwohl et al., 2019) which reveals that one can re-interpret the logits obtained from classifiers to define EBM p(x) and p(x, y), and shows this leads to significant improvement in OOD detection, calibration, and robustness while retain compelling classification accuracy. Our method differs in that we optimize our generative term via contrastive learning, buying the performance of state-of-the-art canonical EBMs algorithms (Grathwohl et al., 2019) without suffering from running computational expensive and slow SGLD (Welling & Teh, 2011) at every iteration. Concurrent to our work, Winkens et al. (2020) proposes to pretrain using contrastive loss and then finetune with a joint supervised and contrastive loss, and shows the SimCLR loss improves likelihood-based OOD detection. Tack et al. ( 2020) also demonstrate contrastive learning improves OOD detection and calibration. Our work differs in that instead of a contrastive representation pre-train followed by supervised loss fine-tune, we use the contrastive loss to approximate a hybrid discriminative-generative model. We also empirically demonstrate our method enjoys broader usage by applying it to generative modeling, calibration, and adversarial robustness.

3.1. SUPERVISED LEARNING

In supervised learning, given a data distribution p(x) and a label distribution p(y|x) with C categories, a classification problem is typically addressed using a parametric function, f θ : R D → R C , which



into the category of hybrid generative discriminative models. Ng & Jordan (2002); Raina et al. (2004); Lasserre et al. (2006); Larochelle & Bengio (2008); Tu (2007); Lazarow et al. (

