THEORETICAL CHARACTERIZATION OF NEURAL NET-WORK GENERALIZATION WITH GROUP IMBALANCE Anonymous

Abstract

Group imbalance has been a known problem in empirical risk minimization (ERM), where the achieved high average accuracy is accompanied by low accuracy in a minority group. Despite algorithmic efforts to improve the minority group accuracy, a theoretical generalization analysis of ERM on individual groups remains elusive. By formulating the group imbalance problem with the Gaussian Mixture Model, this paper quantifies the impact of individual groups on the sample complexity, the convergence rate, and the average and group-level testing performance. Although our theoretical framework is centered on binary classification using a one-hiddenlayer neural network, to the best of our knowledge, we provide the first theoretical analysis of the group-level generalization of ERM in addition to the commonly studied average generalization performance. Sample insights of our theoretical results include that when all group-level co-variance is in the medium regime and all mean are close to zero, the learning performance is most desirable in the sense of a small sample complexity, a fast training rate, and a high average and group-level testing accuracy. Moreover, we show that increasing the fraction of the minority group in the training data does not necessarily improve the generalization performance of the minority group. Our theoretical results are validated on both synthetic and empirical datasets such as CelebA and CIFAR-10 in image classification.

1. INTRODUCTION

Training neural networks with empirical risk minimization (ERM) is a common practice to reduce the average loss of a machine learning task evaluated on a dataset. However, recent findings (Blodgett et al., 2016; Tatman, 2017; Hashimoto et al., 2018; Buolamwini & Gebru, 2018; McCoy et al., 2019; Sagawa et al., 2020; Sagawa* et al., 2020; Mehrabi et al., 2021) have shown empirical evidence about a critical challenge of ERM, known as group imbalance, where a well-trained model that has high average accuracy may have significant errors on the minority group that infrequently appears in the data. Moreover, the group attributes that determine the majority and minority groups are usually hidden and unknown during training. The training set can be augmented by data augmentation methods (Shorten & Khoshgoftaar, 2019) with varying performance, such as cropping and rotation (Krizhevsky et al., 2012) , noise injection (Moreno-Barea et al., 2018) , and generative adversarial network (GAN)-based methods (Goodfellow et al., 2014; Bowles et al., 2018; Radford et al., 2016) . As ERM is a prominent method and enjoys great empirical success, it is important to characterize the impact of ERM on group imbalance theoretically. However, the technical difficulty of analyzing the nonconvex ERM problem of neural networks results from the concatenation of nonlinear functions across layers, and the existing generalization analyses of ERM often make overly simplistic assumptions and only focus on the average generalization performance. For example, the neural tangent kernel type of analysis (Arora et al., 2019; Allen-Zhu et al., 2019b; a; Cao & Gu, 2019; Chen et al., 2020; Du et al., 2019; Jacot et al., 2018; Zou et al., 2020; Zou & Gu, 2019) linearizes the neural network around the random initialization to remove the nonconvex interactions across layers. The generalization bounds are independent of the feature distribution and cannot be exploited to analyze the impact of individual groups. Li & Liang (2018) provides the sample complexity analysis when the data comes from the mixtures of well-separated distributions but still cannot characterize the learning performance of individual groups. Another line of works (Du et al., 2018a; Ghorbani et al., 2020; Goldt et al., 2020; Li & Liang, 2018; Mei et al., 2018; Mignacco et al., 2020; Yoshida & Okada, 2019 ) considers one-hidden-layer neural networks because the ERM problem is already highly nonconvex, and the analytical complexity increases tremendously when the number of hidden layers increases. In these works, the input features are usually assumed to be i.i.d. samples drawn from the standard Gaussian distribution, and this data model cannot differentiate the majority and minority groups. Contribution: To the best of our knowledge, this paper provides the first theoretical characterization of both the average and group-level generalization of a one-hidden-layer neural network trained by ERM on data generated from a mixture of distributions. This paper considers the binary classification problem with the cross entropy loss function, with training data generated by a ground-truth neural network with known architecture and unknown weights. The optimization problem is challenging due to a high non-convexity from the multi-neuron architecture and the non-linear sigmoid activation. Assuming the features follow a Gaussian Mixture Model (GMM), where samples of each group are generated from a Gaussian distribution with an arbitrary mean vector and co-variance matrix, this paper quantifies the impact of individual groups on the sample complexity, the training convergence rate, and the average and group-level test error. The training algorithm is the gradient descent following a tensor initialization and converges linearly. Our key results include (1) Medium-range group-level co-variance enhances the learning performance. When a grouplevel co-variance deviates from the medium regime, the learning performance degrades in terms of higher sample complexity, slower convergence in training, and worse average and group-level generalization performance. As shown in Figure 1 (a), we introduce Gaussian augmentation to control the co-variance level of the minority group in the CelebA dataset (Liu et al., 2015) . The learned model achieves the highest test accuracy when the co-variance is at the medium level, see Figure 1(b) . Another implication is that the diverse performance of different data augmentation methods might partially result from the different group-level co-variance introduced by these methods. Furthermore, although our setup does not directly model the batch normalization approach (Ioffe & Szegedy, 2015; Bjorck et al., 2018; Chai et al., 2020; Santurkar et al., 2018) that modifies the mean and variance in each layer to achieve fast and stable convergence, our result provides a theoretical insight that co-variance indeed affects the learning performance. (2) Group-level mean shifts from zero hurt the learning performance. When a group-level mean deviates from zero, the sample complexity increases, the algorithm converges slower, and both the average and group-level test error increases. Thus, the learning performance is improved if each distribution is zero-mean. This paper provides a similar theoretical insight to practical tricks such as whitening LeCun et al. (1998 ), subgroup shift (Koch et al., 2022; Ma et al., 2021) , population shift (Biswas & Mukherjee, 2021; Giguere et al., 2022) and the pre-processing of making data zero-mean (Lecun et al., 1998) , that data mean affects the learning performance. (3) Increasing the fraction of the minority group in the training data does not always improve its generalization performance. The generalization performance is also affected by the mean and covariance of individual groups. In fact, increasing the fraction of the minority group in the training data can have a completely opposite impact in different datasets.



Figure 1: Group imbalance experiment. (a) Binary classification on CelebA dataset using Gaussian augmentation to control the minority group co-variance. (b) Test accuracy against the augmented noise level.

