PGASL: PREDICTIVE AND GENERATIVE ADVERSAR-IAL SEMI-SUPERVISED LEARNING FOR IMBALANCED DATA

Abstract

Modern machine learning techniques often suffer from class imbalance where only a small amount of data is available for minority classes. Classifiers trained on an imbalanced dataset, although have high accuracy on majority classes, can perform poorly on minority classes. This is problematic when minority classes are also important. Generative Adversarial Networks (GANs) have been proposed for generating artificial minority examples to balance the training. We propose a class-imbalanced semi-supervised learning algorithm PGASL which can be efficiently trained on unlabeled and class-imbalanced data. In this work, we use a predictive network which is trained adversarially for the discriminator to correct predictions on the unlabeled dataset. Experiments on text datasets show that PGASL outperforms state-of-the-art class-imbalanced learning algorithms by including both predictive network and generator.

1. INTRODUCTION

In many real world applications such as medical data analysis (Cameron et al., 2010) , data is often imbalanced as those patients suffering from a certain disease will typically have a very small portion of the population. It is often challenging to train a machine learning model on imbalanced data since a classifier may be trained biased towards the majority classes and lead to poor performance for the minority classes. Although the classification may have overall good performance, this is not preferable especially for some tasks when the minority classes are also important. Addressing class imbalance has become more and more important in many fields. Another feature of these kinds of data is that usually only a small amount of data is labeled, for example, patients who visit hospitals who are under-diagnosed for a certain disease, can't be labeled as negative, thus remain unlabeled. Recently, many deep neural network based semi-supervised learning (SSL) algorithms have shown their ability of utilizing unlabeled data to improve performance. However, most works only focus on one of the above challenges and class-imbalanced semi-supervised learning is still under explored. In this paper we consider the rare disease detection task in healthcare (Yu et al., 2019) which is a typical imbalanced semi-supervised machine learning task with binary outputs. As GANs have huge success in learning almost real data distributions through adversarial training, we propose a three-player GAN model which can generate artificial minority data as well as extracting minority data from the unlabeled dataset, which helps handling the class imbalance issue. We let the discriminator output probabilities of predictions for each class and we then use a predictive network, which has the same structure as the discriminator but is trained adversarially, to correct predictions of samples on the unlabeled dataset. Finally we use a generator to generate minority samples so that the discriminator will be trained using more balanced data and will be less likely to bias towards the majority class. Our main contributions include: • We introduce a novel semi-supervised three-players GAN model for imbalanced data. • We conduct experiments on binary text classification datasets to benchmark PGASL with previous class imbalance learning methods. Results demonstrate that our proposed method can efficiently utilize unlabeled data and handle imbalanced datasets that outperform prior works. • To the best of our knowledge, this is the first GAN-based algorithm which can handle the class imbalance issue while utilizing the unlabeled dataset at the same time.

2. RELATED WORK

Semi-supervised (SSL) techniques have been proposed to improve performance by utilizing unlabeled data. Pseudo-Labels are used as if they were true labels on unlabeled datasets (Lee et al., 2013) , and the models are trained in a manner of entropy minimization (Grandvalet & Bengio, 2004) . Consistency regularizations aim at learning similar distributions for local perturbations on unlabeled dataset (Park et al., 2018; Miyato et al., 2018) . Several data augmentation techniques in combination with Pseudo-Labels are used to create local perturbations (Berthelot et al., 2019a; Sohn et al., 2020; Berthelot et al., 2019b) . In this paper we want to highlight two GAN-based methods which our paper mainly follows. SSL-GANs use discriminator to classify samples from different classes instead of distinguishing fake samples from real samples. To perform adversarial training, one way is to directly minimize the entropy of the outputs for the discriminator which can help distinguish real samples from fake samples (Springenberg, 2015) . On the other hand, one can use an additional class which indicates fake samples so that the discriminator can make predictions and recognize fake samples at the same time (Salimans et al., 2016; Odena, 2016) . Other works use an additional network to play the role of distinguishing the fake samples (Mullick et al., 2019) . PAN (Hu et al., 2021) replaces the generator with a classifier which mirrors the discriminator in classification on unlabeled dataset to help classify those samples which can be hardly classified by the discriminator. Although PAN was originally proposed for PU learning, it can be extended to semi-supervised learning naturally. We show in this paper, keeping the generator in PAN can have better performance in the class-imbalanced learning setting. The SSL methods discussed above often assume balanced dataset and may have poor performance for an imbalanced dataset especially when the unlabeled dataset is also imbalanced (He & Garcia, 2009; Sun et al., 2009) . Researchers have proposed several techniques for class-imbalanced learning. Re-sampling techniques have been proposed to create a balanced dataset (Chawla et al., 2002; Barandela et al., 2003; Engelmann & Lessmann, 2021b) . However, oversampling may cause overfitting while undersampling may cause information loss (Cao et al., 2019) . To avoid such issues, researchers have also proposed to balance the training by introducing regularization terms instead of creating balanced dataset (Huang et al., 2016; Jamal et al., 2020; Cui et al., 2019) . Recently, researchers have proposed several class-imbalanced semi-supervised learning (CISSL) techniques to improve performance on imbalanced datasets. Refining the pseudo-labels through an iterative algorithm can reduce bias (Kim et al., 2020; Wei et al., 2021) . Introducing an auxiliary balanced classifier(ABC) of a single layer while using a 0/1 mask can balance training (Lee et al., 2021).

3.1. PROBLEM SETTING

Suppose we have a binary labeled dataset L = {(x i , y i ), i = 1, . . . , N l } where x i is a d-dimensional feature vector and y i ∈ (0, 1) is the corresponding binary label. Let N l = N p + N n where N p is the number of positive data and N n is the number of negative data. We also have an unlabeled dataset U = {u i , i = 1, . . . , N u } where u i is also a d-dimensional feature vector but the corresponding label is unknown. We denote the label ratio as ρ = N l N l +Nu . In this paper we consider the case when ρ ≈ 0.1 where most of the labels are expensive to obtain. We also consider the extremely class-imbalanced situation where the imbalanced ratio γ = Np Np+Nn is very small. We assume the labeled dataset and the unlabeled dataset share the same distribution (e.g. have the same imbalanced ratio). Given the labeled dataset and unlabeled dataset in training we aim at learning a classifier f : R d → {0, 1} which can perform well on an imbalanced test set.

