PGASL: PREDICTIVE AND GENERATIVE ADVERSAR-IAL SEMI-SUPERVISED LEARNING FOR IMBALANCED DATA

Abstract

Modern machine learning techniques often suffer from class imbalance where only a small amount of data is available for minority classes. Classifiers trained on an imbalanced dataset, although have high accuracy on majority classes, can perform poorly on minority classes. This is problematic when minority classes are also important. Generative Adversarial Networks (GANs) have been proposed for generating artificial minority examples to balance the training. We propose a class-imbalanced semi-supervised learning algorithm PGASL which can be efficiently trained on unlabeled and class-imbalanced data. In this work, we use a predictive network which is trained adversarially for the discriminator to correct predictions on the unlabeled dataset. Experiments on text datasets show that PGASL outperforms state-of-the-art class-imbalanced learning algorithms by including both predictive network and generator.

1. INTRODUCTION

In many real world applications such as medical data analysis (Cameron et al., 2010) , data is often imbalanced as those patients suffering from a certain disease will typically have a very small portion of the population. It is often challenging to train a machine learning model on imbalanced data since a classifier may be trained biased towards the majority classes and lead to poor performance for the minority classes. Although the classification may have overall good performance, this is not preferable especially for some tasks when the minority classes are also important. Addressing class imbalance has become more and more important in many fields. Another feature of these kinds of data is that usually only a small amount of data is labeled, for example, patients who visit hospitals who are under-diagnosed for a certain disease, can't be labeled as negative, thus remain unlabeled. Recently, many deep neural network based semi-supervised learning (SSL) algorithms have shown their ability of utilizing unlabeled data to improve performance. However, most works only focus on one of the above challenges and class-imbalanced semi-supervised learning is still under explored. In this paper we consider the rare disease detection task in healthcare (Yu et al., 2019) which is a typical imbalanced semi-supervised machine learning task with binary outputs. As GANs have huge success in learning almost real data distributions through adversarial training, we propose a three-player GAN model which can generate artificial minority data as well as extracting minority data from the unlabeled dataset, which helps handling the class imbalance issue. We let the discriminator output probabilities of predictions for each class and we then use a predictive network, which has the same structure as the discriminator but is trained adversarially, to correct predictions of samples on the unlabeled dataset. Finally we use a generator to generate minority samples so that the discriminator will be trained using more balanced data and will be less likely to bias towards the majority class. Our main contributions include: • We introduce a novel semi-supervised three-players GAN model for imbalanced data. • We conduct experiments on binary text classification datasets to benchmark PGASL with previous class imbalance learning methods. Results demonstrate that our proposed method can efficiently utilize unlabeled data and handle imbalanced datasets that outperform prior works. • To the best of our knowledge, this is the first GAN-based algorithm which can handle the class imbalance issue while utilizing the unlabeled dataset at the same time.

2. RELATED WORK

Semi-supervised (SSL) techniques have been proposed to improve performance by utilizing unlabeled data. Pseudo-Labels are used as if they were true labels on unlabeled datasets (Lee et al., 2013) , and the models are trained in a manner of entropy minimization (Grandvalet & Bengio, 2004) . Consistency regularizations aim at learning similar distributions for local perturbations on unlabeled dataset (Park et al., 2018; Miyato et al., 2018) . Several data augmentation techniques in combination with Pseudo-Labels are used to create local perturbations (Berthelot et al., 2019a; Sohn et al., 2020; Berthelot et al., 2019b) . In this paper we want to highlight two GAN-based methods which our paper mainly follows. SSL-GANs use discriminator to classify samples from different classes instead of distinguishing fake samples from real samples. To perform adversarial training, one way is to directly minimize the entropy of the outputs for the discriminator which can help distinguish real samples from fake samples (Springenberg, 2015) . On the other hand, one can use an additional class which indicates fake samples so that the discriminator can make predictions and recognize fake samples at the same time (Salimans et al., 2016; Odena, 2016) . Other works use an additional network to play the role of distinguishing the fake samples (Mullick et al., 2019) . PAN (Hu et al., 2021) replaces the generator with a classifier which mirrors the discriminator in classification on unlabeled dataset to help classify those samples which can be hardly classified by the discriminator. Although PAN was originally proposed for PU learning, it can be extended to semi-supervised learning naturally. We show in this paper, keeping the generator in PAN can have better performance in the class-imbalanced learning setting. The SSL methods discussed above often assume balanced dataset and may have poor performance for an imbalanced dataset especially when the unlabeled dataset is also imbalanced (He & Garcia, 2009; Sun et al., 2009) . Researchers have proposed several techniques for class-imbalanced learning. Re-sampling techniques have been proposed to create a balanced dataset (Chawla et al., 2002; Barandela et al., 2003; Engelmann & Lessmann, 2021b) . However, oversampling may cause overfitting while undersampling may cause information loss (Cao et al., 2019) . To avoid such issues, researchers have also proposed to balance the training by introducing regularization terms instead of creating balanced dataset (Huang et al., 2016; Jamal et al., 2020; Cui et al., 2019) . Recently, researchers have proposed several class-imbalanced semi-supervised learning (CISSL) techniques to improve performance on imbalanced datasets. Refining the pseudo-labels through an iterative algorithm can reduce bias (Kim et al., 2020; Wei et al., 2021) . Introducing an auxiliary balanced classifier(ABC) of a single layer while using a 0/1 mask can balance training (Lee et al., 2021) .

3.1. PROBLEM SETTING

Suppose we have a binary labeled dataset L = {(x i , y i ), i = 1, . . . , N l } where x i is a d-dimensional feature vector and y i ∈ (0, 1) is the corresponding binary label. Let N l = N p + N n where N p is the number of positive data and N n is the number of negative data. We also have an unlabeled dataset U = {u i , i = 1, . . . , N u } where u i is also a d-dimensional feature vector but the corresponding label is unknown. We denote the label ratio as ρ = N l N l +Nu . In this paper we consider the case when ρ ≈ 0.1 where most of the labels are expensive to obtain. We also consider the extremely class-imbalanced situation where the imbalanced ratio γ = Np Np+Nn is very small. We assume the labeled dataset and the unlabeled dataset share the same distribution (e.g. have the same imbalanced ratio). Given the labeled dataset and unlabeled dataset in training we aim at learning a classifier f : R d → {0, 1} which can perform well on an imbalanced test set.

3.2. GENERATIVE ADVERSARIAL NETWORK

First proposed in (Goodfellow et al., 2014) , generative adversarial networks (GANs) have shown its excellent ability of generating almost true samples which follow the real data distribution. The model simultaneously trains a generative network (often called generator) which generates samples that can be hardly distinguish from the real data to fool the discriminator, and a discriminator which tells if a sample is 'real' or 'fake' (generated by generator). GAN can be often formulated as a minmax game: min G max D V (D, G) = E x∼P data (x) [log(D(x))] + E z∼Pz(z) [log(1 -D(G(z)))]. GANs can be trained to learn the distribution of minority classes and can then be used to generate samples to oversample minority classes to create balanced dataset (Engelmann & Lessmann, 2021b) . Although more focus on generative tasks in general, researchers also demonstrate the good performance of GANs on classification tasks by letting discriminators to make predictions for different classes rather than only distinguish fake samples from real samples. Unlike resampling techniques which create a balanced dataset first and then apply classification algorithms afterward, letting discriminators make predictions allows generator to learn more proper distributions as they are trained together. Also it is easier to adjust the regions of samples which we would like the generator to generate and this is important for imbalance learning. It is often not desirable for generator to only generate samples from minority classes as this may cause overfitting and generating samples from positive samples also helps classification. Denote the ratio of positive and negative samples we expect the generator to generate by ρ G . The choice of ρ G is still under investigation. In this paper, we use ρ G as a hyper-parameter and use grid search to choose the best ρ G .

3.3. PROPOSED MODEL ARCHITECTURE

The architecture of our framework is shown in Figure 1 . We aim at balancing the training through minority oversampling, as such, we propose a GAN model structure which can generate artificial minority samples as well as extracting minority samples from the unlabeled dataset so that more minority samples will be fed into the classifier. The model contains three components, a generator G(•) which takes random noise as input and learns to generate artificial samples from the real data distribution, a discriminator D(•) which makes predictions and a predictive network C(•) which mirrors the discriminator but is trained adversarially. We use the same network structure for both D(•) and C(•) except the dropout rate. We found in practice, letting C(•) has larger dropout rate gives better performance. In this paper we consider binary classification problems and therefore both the predictions from D(•) and C(•) have three classes with 0 as negative, 1 as positive and 2 as 'fake'. In particular, we consider the rare disease scenario where the positive class is the minority. The discriminator D(•) first learns to classify samples from the labeled dataset correctly using standard entropy minimization for supervised learning. The supervised loss is sup (x, y) = -E x,y∈L log p D (y|x, y < 2), which minimizes the negative log probability of the label, given the labeled data. Then D(•) learns to recognize the artificial samples generated by the generator G(•) by classifying the generated samples as 'fake' (2). This correspond to minimize the loss where λ is a hyper-parameter for balancing the supervised loss and the unsupervised loss. To perform adversarial training, the predictive network C(•) learns to shrink the distance between the predictions from D(•) and C(•) by maximizing * un so that when reach an equilibrium, C(•) can make predictions with large margin on the unlabeled dataset through adversarial training. Unlike PAN, we keep the existence of the generator for handling the class imbalance issue. Denote the set of positive samples as P and the set of negative samples as N * . To train the generator, we sample n data points from N * and ρ G n data points from P where ρ G is a hyperparameter. In practice, ρ G should be greater than 1 if positive is the minority class. Denote the set of the (ρ G + 1)n data points as L G . We use the last layer of D(•) as features and map the features of labeled samples and the features of generated samples. This is achieved by minimizing the following feature mapping loss: f ake (z) = -E x∈G(z) log p D (2|x). f m = || E x∈L G D h (x) -E t∈G(z1) D h (t)|| 2 2 . In addition, we use the pull away term (Zhao et al., 2016)  pt = 1 N (N -1) N i=1 j =i f (x i ) T f (x j ) ||f (x i )||||f (x j )||

4.1. DATA SOURCE AND PRE-PROCESSING

In this section, we tested the proposed method PGASL and compare it with state-of-the-art algorithms. In this paper we consider the rare disease scenario which has binary outputs with posi-Algorithm 1 PGASL for semi-supervised learning Input: labeled dataset L, unlabeled dataset U, batch size m, hyperparameters λ, ρ G Output: D 1: for k steps do 2: Train D • Sample a mini-batch {(x i , y i ) ∈ L, i = 1, . . . , m} from labeled dataset and Compute supervised loss sup (D(x i ), y i ). • Sample a mini-batch {(x i ) ∈ U, i = 1, . . . , m} from unlabeled dataset and Compute first unsupervised loss D un (D(x i )). • Sample a mini-batch {(x i ) ∈ U, i = 1, . . . , m} from unlabeled dataset and Compute second unsupervised loss * un (D(x i )). • Sample noise z ∈ N (0, 1) and compute fake loss f ake (D(G(z))). • Update D(•) by minimizing D = sup + λ( D un + * un ) + f ake . 3: Train C • Sample a mini-batch {(x i ) ∈ U, i = 1, . . . , m} from unlabeled dataset and update C(•) by maximizing the loss * un (D(x i )).

4:

Train G • Sample a mini-batch {(x * i , 0) ∈ N * , i = 1, . . . , m ρ G }, sample a mini-batch {(x * i , 1) ∈ P, i = 1, . . . , m} and sample noise z ∈ N (0, 1) and compute feature mapping loss f m and pull away term pt . • Update G(•) by minimizing G = f m + pt . 5: end for tive class being the minority class. We first created imbalanced versions of two text classification datasets: Amazon polarity reviews (McAuley & Leskovec, 2013) and Yelp polarity reviews (Zhang et al., 2015) using various imbalance ratio γ. We kept all negative data with N n data points, then we randomly selected N p = γ 1-γ N n positive data points. For the purpose of semi-supervised learning, we randomly selected 10% of the data as labeled data while the rest 90% data as unlabeled data for all datasets. Since the main focus of this paper is semi-supervised learning rather than representation learning, we first used pretrained sentence transformers (Reimers & Gurevych, 2019) to preprocess all the text into sentence embeddings. Then the classification algorithms are applied to the same pretrained sentence embeddings. We summarize the information of each dataset in Table 1 . To evaluate our method, we mainly use three metrics: the area under the receiving operator curve(AUC-ROC), the area under the precision recall curve (PR-AUC), and the brier score. The AUC-ROC score, represents the area under the curve of the true positive rate against the false positive rate for each decision threshold, which evaluates a model on ranking predictions in general while a more proper metric for imbalanced data is the PR-AUC score which can be viewed as the average of precision scores calculated for each recall threshold. We also use the brier score which measures the accuracy of probabilistic predictions. The reason we pick these three metrics is because the evaluations of these scores don't require thresholds and can be directly applied to probabilities. This is important since a threshold of 0.5 usually doesn't work well for imbalanced datasets. We compared the above metrics of the proposed method with the following baselines: • CWGAN+xgboost: This method first use CWGAN (Engelmann & Lessmann, 2021a) to first oversample the imbalanced datasets and then train xgboost (Chen & Guestrin, 2016) on the oversampled balanced training datasets in a fully supervised learning setting. • SGAN (Odena, 2016) : This method learns discriminators using GAN that output class labels and utilize unlabeled data in a semi-supervised setting. We use the same training of generator for PGASL and SGAN. • PAN (Hu et al., 2021) : This method trains classifier adversarially to the discriminator on unlabeled dataset to improve performance. It was originally proposed for PU learning, we extend it semi-supervised learning in this paper. • ABC (Lee et al., 2021) : This method attaches a classifier to the last layer of backbone algorithm and uses a 0/1 mask to balance training. We use the same discriminator in PGASL as the backbone algorithm for ABC.

4.3. TRAINING DETAILS

We use fully-connected neural networks with leaky relu activation (Maas et al., 2013) and dropout (Srivastava et al., 2014) for D(•) and C(•) in PGASL. For fair comparison we also use the same network structure for the networks in PAN, discriminator in SGAN and the backbone in ABC. We use grid search for each method to search for the hyper-parameters of the best PR-AUC scores which we consider as the most important metric. Then to reduce the effect of randomness in training, for each method we run 10 repeated experiments with the best hyper-parameters from grid search and we compare the average scores. We use pytorch (Paszke et al., 2019) for implementation and use adam optimizer (Kingma & Ba, 2014) for training.

4.4. RESULTS AND ANALYSIS

The ROC-AUC scores of each algorithm on test datasets are summarized in Table 2 . Except for the Yelp dataset when γ = 0.05, PGASL has the best scores. Note that all the ROC-AUC scores are large even for the extremely imbalanced dataset. This is because the ROC-AUC score can often give an overly-optimistic performance of a classifier on imbalanced datasets. We believe the PR-AUC score is more proper in the class-imbalance setting and can better measure the performance of the algorithms. The PR-AUC scores are summarized in Table 3 . We also summarize the brier scores in Table 4 . Note that unlike ROC-AUCs and PR-AUCs, the brier score works like a loss function with 0 representing perfect accuracy. PGASL outperforms SGAN since PGASL has an additional predictive network which increase the prediction accuracy on unlabeled dataset during training. PGASL outperforms PAN since PGASL keeps the generator to generate more balanced data which benefit class-imbalanced learning. Moreover PGASL outperforms the oversampling technique CW-GAN+XGB and the state-of-the-art class-imbalanced semi-supervised learning algorithm ABC. 

5. CONCLUSIONS

In this paper, we proposed a novel GAN-based algorithm called PGASL. PGASL is a three-player model which has an additional predictive network. Unlike other GAN-based algorithms where the adversarial trainings are only performed by discriminators against a generator, the predictive network in PGASL is trained adversarial to the discriminator to boost the performance on the unlabeled dataset. We also showed in the paper that PGASL outperforms other GAN-based models as well as state-of-the-art class-imbalanced semi-supervised learning algorithms on imbalanced datasets. Our future work will be further improve the performance of PGASL by developing more robust training for the generator.



D(•) has two unsupervised loss. Minimizing the first unsupervised loss D un (x) = -E x∈U log (1 -p D (2|x)) allows D(•) to recognize the unlabeled data by maximizing the probability of a sample from unlabeled dataset is not 'fake'. Minimizing the second unsupervised loss * un (x) -E x∈U [log p D (x) -log p C (x)]allows D(•) to maximize the distance between the predictions from D(•) and C(•) on the unlabeled dataset. In this paper, we also use KL divergence to measure the distance as in PAN(Hu et al., 2021). Then the total loss for D(•) is: D = sup + f ake + λ D un + * un ,

Figure 1: Framework of proposed architecture on semi-supervised learning. The utilities of D(•), C(•) and G(•) are colored in red, purple and blue respectively.

to increase the diversity of generated samples. Note that the functionality ofC(•) is similar to G(•). The difference is that G(•)generate artificial data that is not from training dataset while C(•) extract data from the unlabeled dataset. Both C(•) and G(•) contribute to creating more positive samples for D(•) to balance training. 3.4 TRAINING OF PGASL Algorithm 1 gives the training procedure of PGASL using stochastic optimization algorithms. The loss functions are defined in Section 3.3. The algorithm trains D(•), C(•) and G(•) alternatively. In practice, one can choose to run more iterations for one component in one epoch to boost performance.

We can see that the discriminator D(•) can successfully learn the correct classification on labeled dataset and the generator G(•) learns to generate diverse samples with similar features to the real data. The predictive network C(•) trained adversarially, although was beat by D(•) at the beginning of training, reach an equilibrium eventually.

Figure 2: Training losses for D(•),C(•) and G(•)

Figure 3: T-sne visualizations where red points and green points represent samples from labeled dataset while blue points and purple points represent artificial samples from generator.

Number of samples of each class in each dataset labeled neg labeled pos unlabel neg unlabel pos test neg test pos

Test ROC-AUCs of PGASL model and benchmark models. The scores are averaged over 10 experiments.

Test PR-AUCs of PGASL model and benchmark models. The scores are averaged over 10 experiments.

Test brier score of PGASL model and benchmark models. The scores are averaged over 10 experiments. Smaller scores indicate better performance.

