ON REPRESENTATION LEARNING UNDER CLASS IMBALANCE

Abstract

Unlike carefully curated academic benchmarks, real-world datasets are often highly class-imbalanced, involving training and test sets which contain few examples from certain minority classes. While there is a common understanding that neural network generalization is negatively impacted by imbalance, the source of this problem and its resolution are unclear. Through extensive empirical investigation, we study foundational learning behaviors for various models such as neural networks, gradient-boosted decision trees, and SVMs across a range of domains and find that (1) contrary to conventional wisdom, re-balancing the training set to include a higher proportion of minority samples degrades performance on imbalanced test sets; (2) minority samples are hard to fit, yet algorithms which fit them, such as oversampling, do not improve generalization. Motivated by the observation that re-balancing class-imbalanced training data is ineffective, we show that several existing techniques for improving representation learning are effective in this setting: (3) self-supervised pre-training is insensitive to imbalance and can be used for feature learning before fine-tuning on labels; (4) Bayesian inference is effective because neural networks are especially underspecified under class imbalance; (5) flatness-seeking regularization pulls decision boundaries away from minority samples, especially when we seek minima that are particularly flat on the minority samples' loss.

1. INTRODUCTION

In real-world data collection scenarios, some events are common while others are exceedingly rare. For example, only a miniscule proportion of credit card transactions are fraudulent, and most cancer screenings come back negative. As a result of this property, machine learning systems are routinely trained and deployed on class-imbalanced data where relatively few samples are associated with certain minority classes, while majority classes dominate the datasets. Nonetheless, the vast majority of works exclusively consider class balanced benchmarks (LeCun, 1998; Krizhevsky, 2009; Deng et al., 2009) , including both foundational literature which seeks to understand how and why machine learning algorithms operate as well as applied methodological literature. In this work, we conduct an exploration, on various machine learning approaches including neural networks, gradient-boosted decision trees, and SVMs, of what makes learning under class imbalance so difficult and the associated implications for best practices in such scenarios. Many of the widely referenced methods for remedying class-imbalance problems rely on modifying how the training data is sampled, such as oversampling or SMOTE (Chawla et al., 2002) and have been shown to be ineffective for neural networks (Buda et al., 2018) . A common assumption underpinning these sampling methods is that learning under class imbalance is pathological, perhaps even defaulting to predicting only the majority class on all inputs when imbalance is sufficiently severe, so we must intervene by simulating balanced training. An effect of sampling from minority classes disproportionately more often is that more signal from those classes is injected into model updates, helping to fit the otherwise rarely seen minority samples. To tease out exactly why oversampling is ineffective, we begin by studying the relationship between imbalances seen at train and test time, and we investigate whether poor generalization under class-imbalance can really be explained by failures of optimization. We find that while minority samples are hard to fit, this optimization phenomenon has little explanatory power regarding generalization as fitting them does not affect performance. Additionally, we find that on the one hand, re-balancing training data to include more minority samples can negatively impact generalization under imbalanced testing, but on the other hand, gathering more majority samples to increase the size of the dataset degrades generalization as well. Following our investigation into the role of dataset imbalances in generalization, we show why self-supervised learning (SSL), Bayesian inference, and flatness-seeking regularizers are particularly well-suited for deep learning in class-imbalanced settings. Self-supervised learning algorithms are less sensitive to the proportion of samples in various classes as they do not make use of label information, so we can learn better feature representations via SSL before fine-tuning even on the same data but with labels. Previous works have found that the number of high singular values of the Hessian is related to the number of classes (Sagun et al., 2017; Papyan, 2020) , but these works were conducted strictly on balanced data. In examining such properties on imbalanced datasets, we observe that neural networks trained in such settings are significantly more underdetermined by the data. In cases where many parameter settings, and induced functions, are compatible with the data, Bayesian Neural Networks (BNNs) can represent our uncertainty for improved accuracy (Wilson & Izmailov, 2020; Shwartz-Ziv et al., 2022) , and we find that this capability of BNNs is especially advantageous in the class-imbalanced regime. Finally, whereas neural network decision boundaries tend to hug minority samples in order to expand the margins from majority data points which occur more frequently in training data, we can counteract this behavior with Sharpness-Aware Minimization (SAM) (Foret et al., 2020) , and we can further improve margins with respect to minority samples by increasing flatness on their corresponding loss functions. In summary, our work questions the motivation of orthodox sampling methods and proposes new directions by which improved representation learning can benefit classifiers in class-imbalanced settings.

2. RELATED WORKS

A long line of research has been conducted on imbalanced classification. There are several general approaches to address this problem: (1) Re-sampling the data -In early ensemble learning studies, boosting and bagging algorithms were adjusted to take account of imbalanced data by re-sampling. Traditionally, re-sampling involves oversampling minority class samples by simply copying them (Guo & Viktor, 2004; Chawla et al., 2002; Han et al., 2005) , or undersampling majority classes by removing samples (Drummond et al., 2003; Hu et al., 2020; Ando & Huang, 2017; Buda et al., 2018) , so that minority and majority class samples appear equally frequently in the training process. (2) Loss re-weighting: Loss re-weighting assigns different weights to majority and minority classes, thus reducing optimization difficulty under class imbalance (Cui et al., 2019; Huang et al., 2019a) . For instance, one may scale the loss by inverse class frequency He & Garcia (2009) or re-weight it using the effective number of samples Cui et al. (2019) . As an alternative approach, one may focus on hard examples by down-weighing the loss of well-classified examples (Lin et al., 2017) or dynamically re-scaling the cross-entropy loss based on the difficulty of classifying a sample (Ryou et al., 2019) . Bertsimas et al. (2018) propose to encourage larger margins for rare classes, while Goh & Sim (2010) learn robust features for classify minority classes using class-uncertainty information which approximates Bayesian methods. 3) Two-stage fine-tuning and meta-learning approaches: Two-stage methods separate the training process into representation learning and classifier learning (Liu et al., 2019; Ouyang et al., 2016; Kang et al., 2019; Bansal et al., 2021) . In the first stage, the data is unmodified, and no re-sampling or re-weighting is used to train good representations. In the second stage, the classifier is balanced by freezing the backbone and fine-tuning the last layers with re-sampling techniques or by learning to debias the confidence. These methods assume that the bias towards majority classes exists only in the classifier layer or that tweaking the classifier layer can correct the underlying biases. 



Several works have also inspected representations learned under class imbalance.Kang et al. (2019)   find that representations learned on class-imbalanced training data via supervised learning perform better when the linear head is fine-tuned on balanced samples. Yang & Xu (2020) instead examine the effect of self-and semi-supervised training on imbalanced data and conclude that imbalanced labels are significantly more useful when accompanied by auxiliary data for semi-supervised learning.Kotar et al. (2021); Yang & Xu (2020); Liu et al. (2021) make the observation that self-supervised pre-training is insensitive to imbalance in the upstream training data. These works study SSL pretraining for the purpose of transfer learning, sometimes using linear probes to evaluate the quality of

