THEORETICAL CHARACTERIZATION OF NEURAL NET-WORK GENERALIZATION WITH GROUP IMBALANCE Anonymous

Abstract

Group imbalance has been a known problem in empirical risk minimization (ERM), where the achieved high average accuracy is accompanied by low accuracy in a minority group. Despite algorithmic efforts to improve the minority group accuracy, a theoretical generalization analysis of ERM on individual groups remains elusive. By formulating the group imbalance problem with the Gaussian Mixture Model, this paper quantifies the impact of individual groups on the sample complexity, the convergence rate, and the average and group-level testing performance. Although our theoretical framework is centered on binary classification using a one-hiddenlayer neural network, to the best of our knowledge, we provide the first theoretical analysis of the group-level generalization of ERM in addition to the commonly studied average generalization performance. Sample insights of our theoretical results include that when all group-level co-variance is in the medium regime and all mean are close to zero, the learning performance is most desirable in the sense of a small sample complexity, a fast training rate, and a high average and group-level testing accuracy. Moreover, we show that increasing the fraction of the minority group in the training data does not necessarily improve the generalization performance of the minority group. Our theoretical results are validated on both synthetic and empirical datasets such as CelebA and CIFAR-10 in image classification.

1. INTRODUCTION

Training neural networks with empirical risk minimization (ERM) is a common practice to reduce the average loss of a machine learning task evaluated on a dataset. However, recent findings (Blodgett et al., 2016; Tatman, 2017; Hashimoto et al., 2018; Buolamwini & Gebru, 2018; McCoy et al., 2019; Sagawa et al., 2020; Sagawa* et al., 2020; Mehrabi et al., 2021) have shown empirical evidence about a critical challenge of ERM, known as group imbalance, where a well-trained model that has high average accuracy may have significant errors on the minority group that infrequently appears in the data. Moreover, the group attributes that determine the majority and minority groups are usually hidden and unknown during training. The training set can be augmented by data augmentation methods (Shorten & Khoshgoftaar, 2019) with varying performance, such as cropping and rotation (Krizhevsky et al., 2012) , noise injection (Moreno-Barea et al., 2018) , and generative adversarial network (GAN)-based methods (Goodfellow et al., 2014; Bowles et al., 2018; Radford et al., 2016) . As ERM is a prominent method and enjoys great empirical success, it is important to characterize the impact of ERM on group imbalance theoretically. However, the technical difficulty of analyzing the nonconvex ERM problem of neural networks results from the concatenation of nonlinear functions across layers, and the existing generalization analyses of ERM often make overly simplistic assumptions and only focus on the average generalization performance. For example, the neural tangent kernel type of analysis (Arora et al., 2019; Allen-Zhu et al., 2019b; a; Cao & Gu, 2019; Chen et al., 2020; Du et al., 2019; Jacot et al., 2018; Zou et al., 2020; Zou & Gu, 2019) linearizes the neural network around the random initialization to remove the nonconvex interactions across layers. The generalization bounds are independent of the feature distribution and cannot be exploited to analyze the impact of individual groups. Li & Liang (2018) provides the sample complexity analysis when the data comes from the mixtures of well-separated distributions but still cannot characterize the learning performance of individual groups. Another line of works (Du et al., 2018a; Ghorbani et al., 2020; Goldt et al., 2020; Li & Liang, 2018; Mei et al., 2018; Mignacco et al., 2020; Yoshida & Okada, 2019) considers one-hidden-layer neural networks because the ERM problem is already highly nonconvex, and the analytical complexity increases tremendously when the number of hidden layers increases. In these works, the input features are usually assumed to be i.i.d. samples drawn from the standard Gaussian distribution, and this data model cannot differentiate the majority and minority groups. Contribution: To the best of our knowledge, this paper provides the first theoretical characterization of both the average and group-level generalization of a one-hidden-layer neural network trained by ERM on data generated from a mixture of distributions. This paper considers the binary classification problem with the cross entropy loss function, with training data generated by a ground-truth neural network with known architecture and unknown weights. The optimization problem is challenging due to a high non-convexity from the multi-neuron architecture and the non-linear sigmoid activation. Assuming the features follow a Gaussian Mixture Model (GMM), where samples of each group are generated from a Gaussian distribution with an arbitrary mean vector and co-variance matrix, this paper quantifies the impact of individual groups on the sample complexity, the training convergence rate, and the average and group-level test error. The training algorithm is the gradient descent following a tensor initialization and converges linearly. Our key results include (1) Medium-range group-level co-variance enhances the learning performance. When a grouplevel co-variance deviates from the medium regime, the learning performance degrades in terms of higher sample complexity, slower convergence in training, and worse average and group-level generalization performance. As shown in Figure 1 (a), we introduce Gaussian augmentation to control the co-variance level of the minority group in the CelebA dataset (Liu et al., 2015) . The learned model achieves the highest test accuracy when the co-variance is at the medium level, see Figure 1(b) . Another implication is that the diverse performance of different data augmentation methods might partially result from the different group-level co-variance introduced by these methods. Furthermore, although our setup does not directly model the batch normalization approach (Ioffe & Szegedy, 2015; Bjorck et al., 2018; Chai et al., 2020; Santurkar et al., 2018) that modifies the mean and variance in each layer to achieve fast and stable convergence, our result provides a theoretical insight that co-variance indeed affects the learning performance. (2) Group-level mean shifts from zero hurt the learning performance. When a group-level mean deviates from zero, the sample complexity increases, the algorithm converges slower, and both the average and group-level test error increases. Thus, the learning performance is improved if each distribution is zero-mean. This paper provides a similar theoretical insight to practical tricks such as whitening LeCun et al. (1998) , subgroup shift (Koch et al., 2022; Ma et al., 2021) , population shift (Biswas & Mukherjee, 2021; Giguere et al., 2022) and the pre-processing of making data zero-mean (Lecun et al., 1998) , that data mean affects the learning performance. (3) Increasing the fraction of the minority group in the training data does not always improve its generalization performance. The generalization performance is also affected by the mean and covariance of individual groups. In fact, increasing the fraction of the minority group in the training data can have a completely opposite impact in different datasets. Improving the minority-group performance with known group attributes. With known group attributes, distributionally robust optimization (DRO) (Sagawa* et al., 2020) minimizes the worstgroup training loss instead of solving ERM. DRO is more computationally expensive than ERM and does not always outperform ERM in the minority-group test error. Spurious correlations (Sagawa et al., 2020) can be viewed as one reason of group imbalance, where strong associations between labels and irrelevant features exist in training samples. Different from the approaches that address spurious correlations, such as down-sampling the majority (Japkowicz & Stephen, 2002; Haixiang et al., 2017; Buda et al., 2018) , up-weight the minority group (Shimodaira, 2000; Byrd & Lipton, 2019) , and removing spurious features (Garg et al., 2019; Elhabian et al., 2008; Zemel et al., 2013) , this paper does not require the special model of spurious correlations and any group attribute information. Fairness in machine learning has received a lot of interest recently (Barocas & Selbst, 2016) , and a substantial body of work has been developed to enhance the fairness under various notions (Dwork et al., 2012; Feldman et al., 2015; Hardt et al., 2016; Kleinberg et al., 2017; Kearns et al., 2018; Chen et al., 2018; Makhlouf et al., 2021; Li et al., 2021) . For example, DRO maximizes the welfare of the worst group, satisfying the fairness notion of (Rawls, 2001) . Different from the majority of these works, this paper solves ERM directly without group attribute information. Moreover, this paper focuses on characterizing the generalization performance of ERM as a function of the input distribution but does not attempt to evaluate fairness across groups. Generalization performance with the standard Gaussian input for one-hidden-layer neural networks. (Brutzkus & Globerson, 2017; Du et al., 2018b; Ge et al., 2018; Liang et al., 2018; Li & Yuan, 2017; Shamir, 2018; Safran & Shamir, 2018; Tian, 2017) consider infinite training samples. (Zhong et al., 2017b; a) characterize the sample complexity of fully connected neural networks with smooth activation functions. Zhang et al. (2019; 2020b) extend to the non-smooth ReLU activation for fully-connected and convolutional neural networks, respectively. Fu et al. (2020) analyzes the cross entropy loss function for binary classification problems. Zhang et al. (2020a) analyzes the generalizability of graph neural networks for both regression and binary classification problems. Theoretical characterization of learning performance from other input distributions for onehidden-layer neural networks. Yoshida & Okada (2019) analyzes the training loss with a single Gaussian with an arbitrary co-variance. Mignacco et al. (2020) quantifies the SGD evolution trained on the Gaussian mixture model. When the hidden layer only contains one neuron, Du et al. (2018a) analyzes rotationally invariant distributions. With an infinite number of neurons and an infinite input dimension, Mei et al. (2018) analyzes the generalization error based on the mean-field analysis for distributions like Gaussian Mixture with the same mean. Ghorbani et al. (2020) considers inputs with low-dimensional structures. No sample complexity is provided in all these works. Notations:Z is a matrix, and z is a vector. z i denotes the i-th entry of z, and Z i,j denotes the (i, j)-th entry of Z. [K] denotes the set including integers from 1 to K. I d and e i represent the identity matrix in R d×d and the i-th standard basis vector, respectively. δ i (Z) denotes the i-th largest singular value of Z. The matrix norm ∥Z∥ = δ 1 (Z). f (x) = O(g(x)) (or Ω(g(x)), Θ(g(x))) means that f (x) increases at most, at least, or in the order of g(x), respectively.

3. PROBLEM FORMULATION AND ALGORITHM

We consider the classification problem with an unbalanced dataset using fully connected neural networks over n independent training examples {(x i , y i )} N i=1 from a data distribution. The learning algorithm is to minimize the empirical risk function via gradient descent (GD) . In what follows, we will present the data model and neural network model considered in this paper. Data Model. Let x ∈ R d and y ∈ R denote the input feature and label, respectively. We consider an unbalanced dataset that consists of L (L ≥ 2) groups of data, where the feature x in the group l (l ∈ [L]) is drawn from a multi-variate Gaussian distribution with mean µ l ∈ R d , and covariance Σ l ∈ R d×d . Specifically, x follows the Gaussian mixture model (GMM) (Pearson, 1894; Titterington et al., 1985; Hsu & Kakade, 2013; Vempala & Wang, 2004; Moitra & Valiant, 2010; Regev & Vijayaraghavan, 2017) , denoted as x ∼ L l=1 λ l N (µ l , Σ l ). λ l ∈ (0, 1) is the probability of sampling from distribution-l and represents the expected fraction of group-l data. L l=1 λ l = 1. Group l is defined as a minority group if λ l is less than 1/L. We use Ψ = {λ l , µ l , Σ l , ∀l} to denote all parameters of the mixture model 1 . We consider binary classification with label y generated by a ground-truth neural network with unknown weights W * = [w * 1 , ..., w * K ] ∈ R d×K and sigmoid activation 2 . function ϕ(x) = 1 1+exp(-x) , where P(y = 1|x) = H(W * , x) := 1 K K j=1 ϕ(w * j ⊤ x). (1) Learning model. Learning is performed over a neural network that has the same architecture as in (1), which is a one-hidden-layer fully connected neural network 3 with its weights denoted by W ∈ R d×K . Given n training samples {x i , y i } n i=1 where x i follows the GMM model, and y i is from (1), we aim to find the model weights via minimizing the nonconvex empirical risk f n (W ) as min W ∈R d×K f n (W ) := 1 n n i=1 ℓ(W ; x i , y i ), where ℓ(W ; x i , y i ) is the cross-entropy loss function, i.e., ℓ(W ; x i , y i ) = -y i • log(H(W , x i )) -(1 -y i ) • log(1 -H(W , x i )). Note that for any permutation matrix P , W P corresponds permuting neurons of a network with weights W . Therefore, H(W , x) = H(W P , x), and f n (W P ) = f n (W ). The estimation is considered successful if one finds any column permutation of W * . The average generalization performance of a learned model W is evaluated by the average risk f (W ) = E x∼ L l=1 λ l N (µ l ,Σ l ) ℓ(W ; x i , y i ), and the generalization performance on group l is evaluated by the group-l risk fl (W ) = E x∼N (µ l ,Σ l ) ℓ(W ; x i , y i ). Training Algorithm. Our algorithm starts from an initialization W 0 ∈ R d×K computed based on the tensor initialization method (Subroutine 1 in Section B.1) and then updates the iterates W t using gradient descent with the step size 4 η 0 . The computational complexity of tensor initialization is O(Knd). The per-iteration complexity of the gradient step is O(Knd). We defer the details of Algorithm 1 in Section B of the supplementary material.

4. MAIN THEORETICAL RESULTS

We will formally present our main theory below, and the insights are summarized in Section 4.1. For the convenience of presentation, some quantities are defined here, and all of them can be viewed as constant. Define σ max = max l∈[L] {∥Σ l ∥ 1 2 }, σ min = min l∈[L] {∥Σ -1 l ∥ -1 2 }. Let τ = σmax σmin . We assume τ = Θ(1), indicating that σ max and σ min are in the same order. Let δ i (W * ) denote the i-th largest singular value of W * . Let κ = δ1(W * ) δ K (W * ) , and define η = K i=1 δi(W * ) δ K (W * ) . 1 In practice, Ψ can be estimated by the EM algorithm (Redner & Walker, 1984) and the moment-based method (Hsu & Kakade, 2013) . The EM algorithm returns model parameters within Euclidean distance O(( d n ) 1 2 ) when the number of mixture components L is known. When L is unknown, one usually over-specifies an estimate L > L, then the estimation error by the EM algorithm scales as O(( d n ) 1 4 ). Please refer to (Ho & Nguyen, 2016; Ho et al., 2020; Dwivedi et al., 2020a; b) for details. 2 The results can be generalized to any activation function ϕ with bounded ϕ, ϕ ′ and ϕ ′′ , where ϕ ′ is even. Examples include tanh and erf. 3 All the weights in the second layer are assumed to be fixed to facilitate the analysis. This is a standard assumption in theoretical generalization analysis (Zhang et al., 2019; Fu et al., 2020; Zhang et al., 2020a) . 4 Algorithm 1 employs a constant step size. One can potentially speed up the convergence, i.e., reduce v, by using a variable step size. We leave the corresponding theoretical analysis for future work. Theorem 1. There exist ϵ 0 ∈ (0, 1 4 ) and positive value functions B(Ψ) (sample complexity parameter), q(Ψ) (convergence rate parameter), and E w (Ψ), E(Ψ), E l (Ψ) (generalization parameters) such that as long as the sample size n satisfies n ≥ n sc := poly(ϵ -1 0 , κ, η, τ, K, δ 1 (W * ))B(Ψ)d log 2 d, we have that with probability at least 1 -d -10 , the iterates {W t } T t=1 returned by Algorithm 1 with step size η 0 = O L l=1 λ l (∥µ l ∥ + ∥Σ l ∥ 1 2 ) 2 -1 converge linearly with a statistical error to a critical point W n with the rate of convergence v, i.e., ||W t -W n || F ≤ v(Ψ) t ||W 0 -W n || F + η 0 ξ 1 -v(Ψ) dK log n/n, v(Ψ) = 1 -K -2 q(Ψ), ) where ξ ≥ 0 is the upper bound of the entry-wise additive noise in the gradient computation. Moreover, there exists a permutation matrix P * such that || W n -W * P * || F ≤ E w (Ψ) • poly(κ, η, τ, δ 1 (W * ))Θ K 5 2 (1 + ξ) • d log n/n . ( ) The average population risk f and the group-l risk fl satisfy f ≤ E(Ψ) • poly(κ, η, τ, δ 1 (W * ))Θ K 5 2 (1 + ξ) • d log n/n (10) fl ≤ E l (Ψ) • poly(κ, η, τ, δ 1 (W * ))Θ K 5 2 (1 + ξ) • d log n/n The closed-form expressions of B, q, E w , E, and E l are in Section D of the supplementary material and skipped here. The quantitative impact of the GMM model parameters Ψ on the learning performance varies in different regimes and can be derived from Theorem 1. The following corollary summarizes the impact of Ψ on the learning performance in some sample regimes. Table 1 : Impact of GMM parameters on the learning performance in sample regimes Σ l changes µ l changes λ l changes, constant ∥Σ j ∥'s, equal ∥µ j ∥'s ∥Σ l ∥ = o(1) ∥Σ l ∥ = Ω(1) if ∥Σ l ∥ = σ 2 min if ∥Σ l ∥ = σ 2 max B(Ψ), sample compl. nsc O(|Σ l ∥ -3 ) O∥Σ l ∥ 3 ) O(poly(∥µ l ∥)) 5 O( 1 (1+λ l ) 2 ) O(1)- Θ(1) (1+λ l ) 2 conv. rate v(Ψ) ∝ -q(Ψ) 1 -Θ(∥Σ l ∥ 3 ) 1-Θ( 1 1+∥Σ l ∥ ) 1-Θ( 1 ∥µ l ∥ 2 +1 ) Θ( 1 1+λ l ) 1 -Θ( 1 1+λ l ) Ew (Ψ),∥ Wn -W * P ∥ F O(1)-Θ(∥Σ l ∥ 3 ) O( ∥Σ l ∥) O(1 + ∥µ l ∥) O( 1 1+ √ λ l ) O(1 + λ l ) E(Ψ), average risk f O(1)-Θ(∥Σ l ∥ 3 ) O(∥Σ l ∥) O(1 + ∥µ l ∥ 2 ) O( 1 1+λ l ) O(1) - Θ(1) 1+λ l E l (Ψ), group-l risk fl O(1)-Θ(∥Σ l ∥ 3 ) O(∥Σ l ∥) O(1 + ∥µ l ∥ 2 ) O( 1 1+ √ λ l ) O(1 + λ l ) Corollary 1. When we vary one parameter of group l for any l ∈ [L] of the GMM model Ψ and fix all the others, the learning performance degrades in the sense that the sample complexity n sc , the convergence rate v, ∥ W n -W * P ∥ F , average risk f and group-l risk fl all increase (details summarized in Table 1 ), as long as any of the following conditions happens, (i) ∥Σ l ∥ approaches 0; (ii) ∥Σ l ∥ increases from some constant; (iii) ∥µ l ∥ increases from 0, (iv) λ l decreases, provided that ∥Σ l ∥ = σ 2 min , i.e., group l has the smallest group-level co-variance, where ∥Σ j ∥ are all constants, and ∥µ i ∥ = ∥µ j ∥ for all i, j ∈ [L]. (v) λ l increases, provided that ∥Σ l ∥ = σ 2 max , i.e., group l has the largest group-level co-variance, where ∥Σ j ∥ are all constants, and ∥µ i ∥ = ∥µ j ∥ for all i, j ∈ [L]. To the best of our knowledge, Theorem 1 provides the first characterization of the sample complexity, learning rate, and generalization performance under the Gaussian mixture model. It also firstly characterizes the per-group generalization performance in addition to the average generalization.

4.1. THEORETICAL INSIGHTS

We summarize the crucial implications of Theorem 1 and Corollary 1 as follows. (P1). Training convergence and generalization guarantee. The iterates W t converge to a critical point W n linearly, and the distance between W n and W * P * is O( d log n/n) for a certain permutation matrix P * . When the computed gradients contain noise, there is an additional error term of O(ξ d log n/n), where ξ is the noise level (ξ = 0 for noiseless case). Moreover, the average risk of all groups and the risk of each individual group are both O((1 + ξ) d log n/n). (P2). Sample complexity. For a given GMM, the sample complexity is Θ(d log 2 d), where d is the feature dimension. This result is in the same order as the sample complexity for the standard Gaussian input in (Fu et al., 2020) and (Zhong et al., 2017b) . Our bound is almost order-wise optimal with respect to d because the degree of freedom is dK. The additional multiplier of log 2 d results from the concentration bound in the proof technique. We focus on the dependence on the feature dimension d and treat the network width K as constant. The sample complexity in (Fu et al., 2020) and (Zhong et al., 2017b) is also d • poly(K, log d). (P3). Learning performance is improved at a medium regime of group-level co-variance. On the one hand, when ∥Σ l ∥ is Ω(1), the learning performance degrades as ∥Σ l ∥ increases in the sense that the sample complexity n sc , the convergence rate v, the estimation error of W * , the average risk f , and the group-l risk fl all increase. This is due to the saturation of the loss and gradient when the samples have a large magnitude. On the other hand, when ∥Σ l ∥ is o(1), the learning performance also degrades when ∥Σ l ∥ approaches zero. The intuition is that in this regime, the input data are concentrated on a few vectors, and the optimization problem does not have a benign landscape. (P4). Increasing the fraction of the minority group data does not always improve the generalization, while the performance also depends on the mean and co-variance of individual groups. Take ∥Σ j ∥ = Θ(1) for all group j, and ∥µ j ∥ is the same for all j as an example (columns 5 and 6 of Table 1 ). When ∥Σ l ∥ is the smallest among all groups, increasing λ l improves the learning performance. When ∥Σ l ∥ is the largest among all groups, increasing λ l actually degrades the performance. The intuition is that from (P3), the learning performance is enhanced at a medium regime of group-level co-variance. Thus, increasing the fraction of a group with a medium level of co-variance improves the performance, while increasing the fraction of a group with large co-variance degrades the learning performance. Similarly, when augmenting the training data, an argumentation method that introduces medium variance could improve the learning performance, while an argumentation method that introduces a significant level of variance could hurt the learning performance. (P5). Group-level mean shifts from zero degrade the learning performance. The learning performance degrades as ∥µ l ∥ increases. An intuitive explanation of the degradation is that some training samples have a significant large magnitude such that the sigmoid function saturates.

4.2. PROOF IDEA AND TECHNICAL NOVELTY

Different from the analysis based on generalized linear models, our paper deals with more technical challenges of nonconvex optimization due to the multi-neuron architecture, the GMM model, and a more complicated activation and loss. The main idea of proof is to show that the nonconvex empirical risk f n (W ) in a small neighborhood around W * (or any permutation W * P ) is almost convex with a sufficiently large n. Then if W 0 can be initialized in any of these local regions, gradient-based iterates can be proved to converge to W * (or W * P ). The idea of tensor initialization is to first find quantities (see Q j in (14) in the supplementary material) which are proven to be functions of tensors of w * i . Then the method approximates these quantities numerically using training samples and then applies the tensor decomposition method on the estimated quantities to obtain W 0 , which is an estimation of W * . With a large number of training samples n, the estimation W 0 can be proved to be in the local convex region. The full proof is in Section D of the supplementary material. Our algorithmic and analytical framework is built upon some recent works on the generalization analysis of one-hidden-layer neural networks, see, e.g., (Zhong et al., 2017b; Zhang et al., 2019; Fu et al., 2020; Zhang et al., 2020a; 2021b) , which assume that x i follows the standard Gaussian distribution and cannot be directly extended to GMM. This paper makes new technical contributions from the following aspects. First, we characterize the local convex region near W * for the GMM model, while existing results only hold for standard Gaussian data. Second, new tools including matrix concentration bounds are develped to explicitly quantify the impact of Ψ on the sample comeplxity. Third, we design and analyze new tensors for the mixture model to initialize properly, while the previous tensor methods in (Zhong et al., 2017b; Zhang et al., 2019; Fu et al., 2020; Zhang et al., 2020a) utilize the rotation invariant property that only holds for zero mean Gaussian.

5.1. EXPERIMENTS ON SYNTHETIC DATASETS

We first verify the theoretical bounds in Theorem 1 on synthetic data. Each entry of W * ∈ R d×K is generated from N (0, 1). The training data {x i , y i } n i=1 is generated using the GMM model and (1). If not otherwise specified, L = 2, d = 5, and K = 3 6 . To reduce the computational time, we randomly initialize near W * instead of computing the tensor initialization 7 . Sample complexity. We first study the impact of d on the sample complexity. Let µ 1 = 1 in R d and let µ 2 = 0. Let Σ 1 = Σ 2 = I. λ 1 = λ 2 = 0.5. We randomly initialize M times and let W Convergence analysis. We next study the convergence rate of Algorithm 1. Figure 3 (a) shows the impact of ∥µ l ∥. λ 1 = λ 2 = 0.5, µ 1 = -µ 2 = C • 1 for a positive C, and Σ 1 = Σ 2 = Λ ⊤ DΛ. Here Λ is generated by computing the left-singular vectors of a d × d random matrix from the Gaussian distribution. D = diag(1, 1.1, 1.2, 1.3, 1.4). n = 1 × 10 4 . Algorithm 1 always converges linearly when ∥µ 1 ∥ changes. Moreover, as ∥µ 1 ∥ increases, Algorithm 1 converges slower. Figure 3 (b) shows the impact of the variance of the Gaussian mixture model. λ 1 = λ 2 = 0.5, µ 1 = 1, µ 2 = -1, Σ 1 = Σ 2 = Σ = σ 2 • Λ ⊤ DΛ. n = 5 × 10 4 . We change ∥Σ∥ by changing σ. Among 6 Like Zhong et al. (2017b) ; Zhang et al. (2019) ; Fu et al. (2020) , we consider a small-sized network in synthetic experiments to reduce the computational time, especially for computing the sample complexity in Figure 2 . Our results hold for large networks too. 7 The existing methods based on tensor initialization all use random initialization in synthetic experiments to reduce the computational time. See Fu et al. (2020) ; Zhang et al. (2019; 2020a; 2021b; a) as examples. We compare tensor initialization and local random initialization numerically in Section B.1 of the supplementary material and show that they have the same performance. the values we test, Algorithm 1 converges fastest when ∥Σ∥ = 1. The convergence rate slows down when ∥Σ∥ increases or decreases from 1. All results are consistent with the predictions in Corollary 1. We then study the impact of K on the convergence rate. We evaluate the impact of one mean/co-variance of the minority group on the generalization. n = 2 × 10 4 . Let λ 1 = 0.8, λ 2 = 0.2, µ 1 = -1, Σ 1 = I. First, we let µ 2 = µ 2 • 1 and Σ 2 = I. Figure 4 (c) shows that both the average risk and the group-2 risk increase as µ 2 increases, consistent with (P5). Then we set λ 1 = λ 2 = 0.5, µ 1 = 1, µ 2 = -1, Σ 1 = Σ 2 = I. µ 2 = 2 • 1, Σ 2 = σ 2 2 • I. Figure 4 (b) indicates that both the average and the group-2 risk will first decrease and then increase as the σ 2 increases, consistent with (P3). Next, we study the impact of increasing the fraction of the minority group. µ 1 = µ 2 = 0. Let group 2 be the minority group. In Figure 5 (a) , Σ 1 = 10 • I and Σ 2 = I, the minority group has a smaller level of co-variance. Then when λ 2 increases from 0 to 0.5, both the average and group-2 risk decease. In Figure 5 (b) , Σ 1 = I and Σ 2 = 10 • I, and the minority group has a higher-level of co-variance. Then when λ 2 increases from 0 to 0.3, both the average and group-2 risk increase. As predicted by insight (P4), increasing λ 2 does not necessarily improve the generalization of group 2.

5.2. IMAGE CLASSIFICATION ON DATASET CELEBA

We choose the attribute "blonde hair" as the binary classification label. ResNet 9 He et al. (2016) is selected to be the learning model here because it was applied in many simple computer vision tasks Wu et al. (2018) ; Dutta et al. (2020) . To study the impact of co-variance, we pick 4000 female (majority) and 1000 male (minority) images and implement Gaussian data augmentation to create additional 300 images for the male group. Specifically, we select 300 out of 1000 male images and add i.i.d. noise drawn from N (0, δ 2 ) to every entry. The test set includes 500 male and 500 female images. Figure 1 shows that when δ 2 increases, i.e., when the co-variance of the minority group increases, both the minority-group and average test accuracy increase first and then decrease, coinciding with our insight (P3). Then we fix the total number of training data to be 5000 and vary the fractions of the two groups. From Figure 6 (a)foot_1 and (b), we observe opposite trends if we increase the fraction of the minority group in the training data with the male being the minority and the female being the minority. This is consistent with Insight (P4). Due to space limit, our results on the CIFAR10 dataset are deferred to Section A in the supplementary material. This paper provides a novel theoretical framework for characterizing neural network generalization with group imbalance. The group imbalance is formulated using the Gaussian mixture model, and this paper explicitly quantifies the impact of each group on the sample complexity, convergence rate, and the average and the group-level generalization. The learning performance is enhanced when the group-level covariance is at a medium regime, and the group-level mean is close to zero. Moreover, increasing the fraction of minority group does not guarantee improved group-level generalization. Our results are limited to one-hidden-layer neural networks for binary classification problems. One future direction is to extend the analysis to multiple-hidden-layer neural networks and multi-class classification. Because of the concatenation of nonlinear activation functions, the analysis of the landscape of the empirical risk and the design of a proper initialization is more challenging and requires the development of new tools. Like many existing works, our sample complexity analysis is also based on the sufficient condition for training success, although it is already almost order-wise optimal. Another future direction is to formally characterize the information-theoretic lower bound of the sample complexity. We see no ethical or immediate negative societal consequence of our work. We begin our Appendix here. Section A provides more experiment results as a supplement of Section 5. Section B introduces the algorithm, especially the tensor initialization in detail. Section C includes some definitions and properties as a preliminary to our proof. Section D shows the proof of Theorem 1 and Corollary 1, followed by Section E, F, and G as the proof of three key Lemmas about local convexity, linear convergence and tensor initialization, respectively.

A MORE EXPERIMENT RESULTS

We present our experiment resultson empirical datasets CelebA (Liu et al., 2015) and CIFAR-10foot_2 in this section. To be more specific, we evaluate the impact of the variance levels introduced by different data augmentation methods on the learning performance. We also evaluate the impact of the minority group fraction in the training data on the learning performance. All the experiments are reported in a format of "mean±2×standard deviation" with a random seed equal to 10. We implement our experiments on an NVIDIA GeForce RTX 2070 super GPU and a work station with 8 cores of 3.40GHz Intel i7 CPU. A.1 TESTS ON CELEBA In addition to the Gaussian augmentation method in Figure 7 (a) and Figure 1 (b). We also evaluate the performance of data augmentation by cropping. The setup is exactly the same as that for Gaussian augmentation, expect that we augment the data by cropping instead of adding Gaussian noise. Specifically, to generate an augmented image, we randomly crop an image with a size w × w × 3 and then resize back to 224 × 224 × 3. One can observe that the minority-group and average test accuracy first increase and then decrease as w increases, which is in accordance with Insight (P3). (a) Figure 7 : The test accuracy of CelebA dataset with the data augmentation method of cropping.

A.2 TESTS ON CIFAR-10

Group 1 contains images with attributes "bird", "cat", "deer", "dog", "frog" and "horse." Group 2 contains "airplane" images. In this setting, Group 1 has a larger variance. Because each image in CIFAR-10 only has one attribute, we consider the binary classification setting where all images in Group 1 are labeled as "animal" and all images are labeled as "airplane." This is a special scenario that the group label is also the classification label. Note that our results hold for general setups where group labels and classification labels are irrelevant, like our previous results on CelebA. LeNet 5 Lecun et al. ( 1998) is selected to be the learning model. We first pick 8000 animal images (majority) and 2000 airplane images (minority). We select 1000 out of 2000 airplane images to implement data augmentation, including both Gaussian augmentation and random cropping. For Gaussian augmentation, we add i.i.d. Gaussian noise drawn from N (0, δ 2 ) to each entry 10 . For random cropping, we randomly crop the image with a certain size w×w×3 and then resizing back to 32 × 32 × 3. Figure 8 shows that when δ or w increase, i.e., the variance introduced by either augmentation method increases, both the minority-group and average test accuracy increase first and then decrease, which is consistent with our Insight (P3). Then we fix the total number of training data to be 5000 and vary the fractions of the two groups. One can see opposite trends in Figure 9 if we increase the fraction of the minority group with the airplane being the minority and the animal being the minority, which reflects our Insight (P4). 

B ALGORITHM

We first introduce new notations to be used in this part and summarize key notions in Table 2 . We write f (x) ≲ (≳)g(x) if f (x) ≤ (≥)Θ(g(x) . The gradient and the Hessian of a function f (W ) are denoted by ∇f (W ) and ∇ 2 f (W ), respectively. A ⪰ 0 means A is a positive semi-definite (PSD) matrix. A 1 2 means that A = (A 1 2 ) 2 . The outer product of vectors z i ∈ R ni , i ∈ [l], is defined as T = z 1 ⊗ • • • ⊗ z l ∈ R n1×•••×n l with T j1•••j l = (z 1 ) j1 • • • (z l ) j l . Given a tensor T ∈ R n1×n2×n3 and matrices A ∈ R n1×d1 , B ∈ R n2×d2 , C ∈ R n3×d3 , the (i 1 , i 2 , i 3 )-th entry of the tensor T (A, B, C) is given by n1 i ′ 1 n2 i ′ 2 n3 i ′ 3 T i ′ 1 ,i ′ 2 ,i ′ 3 A i ′ 1 ,i1 B i ′ 2 ,i2 C i ′ 3 ,i3 . ( ) 10 In this experiment, the noise is added to the raw image where the pixel value ranges from 0 to 255, while in the experiment of CelebA (Figure 1 (b) ), the noise is added to the image after normalization where the pixel value ranges from 0 to 1.  λ l , µ l , Σ l , l ∈ [L] The fraction, mean, and covariance of the l-th component in the Gaussian mixture distribution, respectively.

d, n, K

The feature dimension, the number of training samples, and the number of neurons, respectively. W * , Wt W * is the ground truth weight. W t is the updated weight in the t-th iteration. fn, f , ℓ f n is the empirical risk function. f is the average risk or the population risk function. ℓ is the cross-entropy loss function. Ψ, σmax, σmin, τ Ψ denotes our Gaussian mixture model (λ l , µ l , Σ l , ∀l). σ max = max l∈[L] {∥Σ l ∥ 1 2 }. σ min = min l∈[L] {∥Σ -1 l ∥ -1 2 }. τ = σ max /σ min . δi(W * ), η, κ, i ∈ [K] δ i (W * ) is the i-th largest singular value of W * . η and κ are two functions of W * . ρ(u, σ), Γ(Ψ), Dm(Ψ) These items are functions of the Gaussian mixture distribution Ψ used to develop our Theorem 1. νi, ξ ν i is the gradient noise. ξ is the upper bound of the noise level. Qj, j = 1, 2, 3 Q j 's are tensors used in the initialization.

B(Ψ)

A parameter appeared in the sample complexity bound (6). v(Ψ), q(Ψ) v(Ψ) is the convergence rate (7). q(Ψ) is a parameter in the definition of v(Ψ) (8). Ew(Ψ), E, E l Generalization parameters. E w (Ψ) appears in the error bound of the model ( 9). E(Ψ) and E l (Ψ) are to characterize the average risk ( 10) and the group-l risk (11), respectively. The method starts from an initialization W 0 ∈ R d×K computed based on the tensor initialization method (Subroutine 1) and then updates the iterates W t using gradient descent with the step size η 0 . To model the inaccuracy in computing the gradient, an i.i.d. zero-mean noise {ν i } n i=1 ∈ R d×K with bounded magnitude |(ν i ) jk | ≤ ξ (j ∈ [d], k ∈ [K] ) for some ξ ≥ 0 are added in (13) when computing the gradient of the loss in (3). Algorithm 1 Our proposed learning algorithm 1: Input: Training data {(x i , y i )} n i=1 , the step size η 0 = O L l=1 λ l (∥ μl ∥ ∞ + ∥Σ 1 2 l ∥) 2 -1 , the total number of iterations T 2: Initialization: W 0 ← Tensor initialization method via Subroutine 1 3: Gradient Descent: for t = 0, 1, • • • , T -1 W t+1 = W t -η 0 • 1 n n i=1 (∇l(W , x i , y i ) + ν i ) = W t -η 0 ∇f n (W ) + 1 n n i=1 ν i 4: Output: W T Our tensor initialization method in Subroutine 1 is extended from Janzamin et al. ( 2014) and Zhong et al. (2017b) . The idea is to compute quantities (Q j in ( 14)) that are tensors of w * i and then apply the tensor decomposition method to estimate w * i . Because Q j can only be estimated from training samples, tensor decomposition does not return w * i exactly but provides a close approximation, and this approximation is used as the initialization for Algorithm 1. Because the existing method on tensor construction only applies to the standard Gaussian distribution, we exploit the relationship between probability density functions and tensor expressions developed in Janzamin et al. (2014) to design tensors suitable for the Gaussian mixture model. Formally, Definition 1. For j = 1, 2, 3, we define Q j := E x∼ L l=1 λ l N (µ l ,Σ l ) [y • (-1) j p -1 (x)∇ (j) p(x)], where p(x), the probability density function of GMM is defined as p(x) = L l=1 λ l (2π) -d 2 |Σ l | -1 2 exp - 1 2 (x -µ l )Σ -1 l (x -µ l ) (15) If the Gaussian mixture model is symmetric, the symmetric distribution can be written as x ∼          L 2 l=1 λ l N (µ l , Σ l ) + N (-µ l , Σ l ) L is even λ 1 N (0, Σ 1 ) + L-1 2 l=2 λ l N (µ l , Σ l ) + N (-µ l , Σ l ) L is odd (16) Q j is a jth-order tensor of w * i , e.g., Q 3 = 1 K K i=1 E x∼ L l=1 λ l N (µ l ,Σ l ) [ϕ ′′′ (w * i ⊤ x)]w * i ⊗3 . These quantifies cannot be directly computed from ( 14) but can be estimated by sample means, denoted by Q j (j = 1, 2, 3), from samples {x i , y i } n i=1 . The following assumption guarantees that these tensors are nonzero and can thus be leveraged to estimate W * . Assumption 1. The Gaussian Mixture Model in ( 16) satisfies the following conditions: 1. Q 1 and Q 3 are nonzero. 2. If the distribution is not symmetric, then Q 2 is nonzero. Assumption 1 is a very mild assumptionfoot_3 . Moreover, as indicated in Janzamin et al. ( 2014), in the rare case that some quantities Q i (i = 1, 2, 3) are zero, one can construct higher-order tensors in a similar way as in Definition 1 and then estimate W * from higher-order tensors. Subroutine 1 describes the tensor initialization method, which estimates the direction and magnitude of w * j , j ∈ [K], separately. The direction vectors are denoted as w * j = w * j /∥w * j ∥ and the magnitude ∥w * j ∥ is denoted as z j . Lines 2-6 estimate the subspace U spanned by {w * 1 , • • • , w * K } using Q 2 or, in the case that Q 2 = 0, a second-order tensor projected by Q 3 . Lines 7-8 estimate w * j by employing the KCL algorithm Kuleshov et al. (2015) . Lines 9-10 estimate the magnitude z j . Finally, the returned estimation of W * is used as an initialization W 0 for Algorithm 1. The computational complexity of Subroutine 1 is O(Knd) based on similar calculations as those in Zhong et al. (2017b) .

B.1 NUMERICAL EVALUATION OF TENSOR INITIALIZATION

Figure 10 shows the accuracy of the returned model by Algorithm 1. Here n = 2 × 10 5 , d = 50, K = 2, λ 1 = λ 2 = 0.5, µ 1 = -0.3 • 1 and µ 2 = 0. We compare the tensor initialization with a random initialization in a local region {W ∈ R d×K : ||W -W * || F ≤ ϵ}. Each entry of W * is selected from [-0.1, 0.1] uniformly. Tensor initialization in Subroutine 1 returns an initial point close to one permutation of W * , with a relative error of 0.65. If the random initialization is also close to W * , e.g., ϵ = 0.1, then the gradient descent algorithm converges to a critical point from both initializations, and the linear convergence rate is the same. We also test a random initialization with each entry drawn from N (0, 25). The initialization is sufficiently far from W * , and the algorithm does not converge. On a MacBook Pro with Intel(R) Core(TM) i5-7360U CPU at 2.30GHz and MATLAB 2017a, it takes 5.52 seconds to compute the tensor initialization. Thus, to reduce the computational time, we consider a random initialization with ϵ = 0.1 in the experiments instead of computing tensor initialization.

C PRELIMINARIES OF THE MAIN PROOF

In this section, we introduce some definitions and properties that will be used to prove the main results. First, we define the sub-Gaussian random variable and sub-Gaussian norm. Definition 2. We say X is a sub-Gaussian random variable with sub-Gaussian norm K > 0, if z = arg min α∈R K 1 2 ∥ Q 1 - K j=1 α j w * j ∥ 2 (E|X| p ) 1 p ≤ K √ p for all p ≥ 1. In addition, the sub-Gaussian norm of X, denoted ∥X∥ ψ2 , is defined as ∥X∥ ψ2 = sup p≥1 p -1 2 (E|X| p ) 1 p . Then we define the following three quantities. ρ(µ, σ) is motivated by the ρ parameter for the standard Gaussian distribution in Zhong et al. (2017b) , and we generalize it to a Gaussian with an arbitrary mean and variance. We define the new quantities Γ(Ψ) and D m (Ψ) for the Gaussian mixture model. Definition 3. (ρ-function). Let z ∼ N (u, I d ) ∈ R d . Define α q (i, u, σ) = E zi∼N (ui,1) [ϕ ′ (σ • z i )z q i ] and β q (i, u, σ) = E zi∼N (ui,1) [ϕ ′2 (σ • z i )z q i ], ∀ q ∈ {0, 1, 2} , where z i and u i is the i-th entry of z and u, respectively. Define ρ(u, σ) as ρ(u, σ) = min i,j∈[d],j̸ =i {(u 2 j + 1)(β 0 (i, u, σ) -α 0 (i, u, σ) 2 ), β 2 (i, u, σ) - α 2 (i, u, σ) 2 u 2 i + 1 } (18) Definition 4. (Γ-function). With (18) and κ, η defined in Section 3, we define  Γ(Ψ) = L l=1 λ l τ K κ 2 η ∥Σ -1 l ∥ -1 σ 2 max ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) (19) D m (Ψ) = L l=1 λ l ( ∥µ l ∥ ∥Σ -1 l ∥ -1 2 + 1) m , ρ-function is defined to compute the lower bound of the Hessian of the population risk with Gaussian input. Γ function is the weighted sum of ρ-function under mixture Gaussian distribution. This function is positive and upper bounded by a small value. Γ goes to zero if all ∥µ l ∥ or all σ l goes to infinity. D-function is a normalized parameter for the means and variances. It is lower bounded by 1. D-function is an increasing function of ∥µ l ∥ and a decreasing function of σ l . Property 1. Given W * = U V ∈ R d×k , where U ∈ R d×K is the orthogonal basis of W * . For any µ ∈ R d , we can find an orthogonal decomposition of µ based on the colomn space of W * , i.e. µ = µ U + µ U ⊥ . If we consider the recovery problem of FCN with a dataset of Gaussian Mixture Model, in which x i ∼ N (µ h , Σ h ) for some h ∈ [L], the problem is equivalent to the problem of FCN with x i ∼ N (µ U h , Σ h ). Hence, we can assume without loss of generality that µ l belongs to the column space of W * for all l ∈ [L].

Proof:

From ( 1) and ( 3), the recovery problem can be formulated as min W * g(W * ⊤ x i , y i ) For any x i ∼ N (µ h , Σ h ), x i can be written as x i = z + µ h where z ∼ N (0, Σ h ). Therefore, W * ⊤ x i = W * ⊤ (z + µ h ) = W * ⊤ (z + µ U h + µ U ⊥ h ) = W * ⊤ (z + µ U h ) The final step is because W * ⊤ µ U ⊥ = 0. So the problem is equivalent to the recovery problem of FCN with x i ∼ N (µ U h , Σ h ). Recall that the gradient noise ν i ∈ R d×K is zero-mean, and each of its entry is upper bounded by ξ > 0. Property 2. We have that ∥ν i ∥ F is a sub-Gaussian random variable with its sub-Gaussian norm bounded bu ξ √ dK. Proof: (E∥ν i ∥ p F ) 1 p ≤ (E| √ dKξ| p ) 1 p ≤ ξ √ dK We state some general properties of the ρ function defined in Definition 3 in the following. Property 3. ρ(u, σ) in Definition 3 satisfies the following properties, 1. (Positive) ρ(u, σ) > 0 for any u ∈ R d and σ ̸ = 0. 2. (Finite limit point for zero mean) ρ(u, σ) converges to a positive value function of σ as u i goes to 0, i.e. lim ui→0 ρ(u, σ) := C m (σ).

3.. (Finite limit point for zero variance) When all

u i ̸ = 0 (i ∈ [d]), ρ( u σ , σ ) converges to a strictly positive real function of u as σ goes to 0, i.e. lim σ→0 ρ( u σ , σ) := C s (u). When u i = 0 for some i ∈ [d], lim σ→0 ρ( u σ , σ) = 0. 4. (Lower bound function of the mean) When everything else except |u i | is fixed, ρ( W * ⊤ u σδ K (W * ) , σδ K (W * )) is lower bounded by a strictly positive real function, L m ( (ΛW * ) ⊤ Λu σδ K (W * ) , σδ K (W * )) , which is monotonically decreasing as |u i | increases.

5.. (Lower bound function of the variance)

When everything else except σ is fixed, ρ( W * ⊤ u σδ K (W * ) , σδ K (W * )) is lower bounded by a strictly positive real function, L s ( W * ⊤ u σδ K (W * ) , σδ K (W * )), which satisfies the following conditions: (a) there exists ζ s ′ > 0, such that σ -1 L s ( W * ⊤ u σδ K (W * ) , σδ K (W * )) is an increasing function of σ when σ ∈ (0, ζ s ′ ); (b) there exists ζ s > 0 such that L s ( W * ⊤ u σδ K (W * ) , σδ K (W * )) is a decreasing function of σ when σ ∈ (ζ s , +∞).

Proof:

(1) From Cauchy Schwarz's inequality, we have E zi∼N (ui,1) [ϕ ′ (σ • z i )] ≤ E zi∼N (ui,1) [ϕ ′2 (σ • z i )] (22) E zi∼N (ui,1) [ϕ ′ (σ • z i )z i • z i ] ≤ E zi∼N (ui,1) [ϕ ′2 (σ • z i )z 2 i ] • E zi∼N (ui,1) [z 2 i ] = E zi∼N (ui,1) [ϕ ′2 (σ • z i )z 2 i ] • u 2 i + 1 (23) The equalities of the ( 22) and ( 23) hold if and only if ϕ ′ is a constant function. Since that ϕ is the sigmoid function, the equalities of ( 22) and ( 23) cannot hold. By the definition of ρ(u, σ) in Definition 3, we have β 0 (i, u, σ) -α 2 0 (i, u, σ) > 0, β 2 (i, u, σ) - α 2 2 (i, u, σ) u 2 i + 1 > 0. (25) Therefore, ρ(u, σ) > 0 (26) (2) We can derive that lim ui→0 ( u 2 j σ 2 + 1) β 0 (i, u, σ) -α 2 0 (i, u, σ) = lim ui→0 ( u 2 j σ 2 + 1) ∞ -∞ ϕ ′2 (σ • z i )(2π) -1 2 exp(- ∥z i -u i ∥ 2 2 )dz i -( ∞ -∞ ϕ ′ (σ • z i )(2π) -1 2 exp(- ∥z i -u i ∥ 2 2 )dz i ) 2 =( u 2 j σ 2 + 1) ∞ -∞ ϕ ′2 (σ • z i )(2π) -1 2 exp(- ∥z i ∥ 2 2 )dz i -( ∞ -∞ ϕ ′ (σ • z i )(2π) -1 2 exp(- ∥z i ∥ 2 2 )dz i ) 2 , where the first step is by Definition 3, and the second step comes from the limit laws. Similarly, we also have lim ui→0 β 2 (i, u, σ) - 1 u 2 i + 1 α 2 2 (i, u, σ) = lim ui→0 ∞ -∞ ϕ ′2 (σ • z i )z 2 i (2π) -1 2 exp(- ∥z i -u i ∥ 2 2 )dz i -( 1 u 2 i + 1 ∞ -∞ ϕ ′ (σ • z i )z 2 i (2π) -1 2 exp(- ∥z i -u i ∥ 2 2 )dz i ) 2 = ∞ -∞ ϕ ′2 (σ • z i )z 2 i (2π) -1 2 exp(- ∥z i ∥ 2 2 )dz i -( ∞ -∞ ϕ ′ (σ • z i )z 2 i (2π) -1 2 exp(- ∥z i ∥ 2 2 )dz i ) 2 Since that ( 27) and ( 28) are positive due to Jensen's inequality, we can derive that ρ(u, σ) converges to a positive value function of σ as u i goes to 0, i.e. lim u→0 ρ(u, σ) := C m (σ) (3) When all u i ̸ = 0 (i ∈ [d]), lim σ→0 β 2 (i, u σ , σ) - 1 u 2 i σ 2 + 1 α 2 2 (i, u σ , σ) = lim σ→0 ∞ -∞ ϕ ′2 (σ • z i )z 2 i (2π) -1 2 exp(- ∥z i -ui σ ∥ 2 2 )dz i - 1 u 2 i σ 2 + 1 ∞ -∞ ϕ ′ (σ • z i )z 2 i (2π) -1 2 exp(- ∥z i -ui σ ∥ 2 2 )dz i 2 = lim σ→0 ∞ -∞ ϕ ′2 (u i • x i ) u 2 i σ 2 x 2 i (2π σ 2 u 2 i ) -1 2 exp(- ∥x i -1∥ 2 2 σ 2 u 2 i )dx i - 1 u 2 i σ 2 + 1 ∞ -∞ ϕ ′ (u i • x i ) u 2 i σ 2 x 2 i (2π σ 2 u 2 i ) -1 2 exp(- ∥x i -1∥ 2 2 σ 2 u 2 i )dx i 2 z i = u i σ x i = lim σ→0 ϕ ′2 (u i ) u 2 i σ 2 - 1 u 2 i σ 2 + 1 (ϕ ′ (u i ) u 2 i σ 2 ) 2 = lim σ→0 ϕ ′2 (u i ) u 2 i σ 2 1 - u 2 i σ 2 1 + u 2 i σ 2 = lim σ→0 ϕ ′2 (u i ) 1 1 + σ 2 u 2 i =ϕ ′2 (u i ) (30) The first step of (30) comes from Definition 3. The second step and the last three steps are derived from some basic mathematical computation and the limit laws. The third step of (30) is by the fact that the Gaussian distribution goes to a Dirac delta function when σ goes to 0. Then the integral will take the value when x i = 1. Similarly, we can obtain the following lim σ→0 β 0 (i, u σ , σ) -α 2 0 (i, u σ , σ) = lim σ→0 ∞ -∞ ϕ ′2 (σ • z i )(2π) -1 2 exp(- ∥z i -ui σ ∥ 2 2 )dz i - ∞ -∞ ϕ ′ (σ • z i )(2π) -1 2 exp(- ∥z i -ui σ ∥ 2 2 )dz i 2 =ϕ ′2 (u i ) -ϕ ′2 (u i ) = 0 (31) lim σ→0 ∂ ∂σ β 0 (i, u σ , σ) -α 2 0 (i, u σ , σ) = lim σ→0 ∂ ∂σ ∞ -∞ ϕ ′2 (x i )(2πσ 2 ) -1 2 exp(- ∥x i -u i ∥ 2 2σ 2 )dx i - ∞ -∞ ϕ ′ (x i )(2πσ 2 ) -1 2 exp(- ∥x i -u i ∥ 2 2σ 2 )dx i 2 x i = σ • z i = lim σ→0 ∞ -∞ ϕ ′2 (x i )(2πσ 2 ) -1 2 exp(- ∥x i -u i ∥ 2 2σ 2 )(-σ -1 + ∥x i -u i ∥ 2 σ -2 )dx i -2 ∞ -∞ ϕ ′ (x i )(2πσ 2 ) -1 2 exp(- ∥x i -u i ∥ 2 2σ 2 )dx i • ∞ -∞ ϕ ′ (x i )(2πσ 2 ) -1 2 exp(- ∥x i -u i ∥ 2 2σ 2 )(-σ -1 + ∥x i -u i ∥ 2 σ -2 )dx i = lim σ→0 ϕ ′2 (u i ) -σ -2ϕ ′ (u i ) ϕ ′ (u i ) -σ = lim σ→0 ϕ ′2 (u i ) σ = +∞ Therefore, by L'Hopital's rule and ( 31), (32), we have lim σ→0 ( u 2 j σ 2 + 1)(β 0 (i, u σ , σ) -α 0 (i, u σ , σ)) = lim σ→0 u 2 i 2σ ∂ ∂σ (β 0 (i, u σ , σ) -α 0 (i, u σ , σ)) = + ∞ Combining ( 33) and (30), we can derive that ρ( u σ , σ) converges to a positive value function of u as σ goes to 0, i.e. lim σ→0 ρ( u σ , σ) := C s (u). ( ) When u i = 0 for some i ∈ [d], lim σ→0 ( u 2 i σ 2 + 1)(β 0 (j, u σ , σ) -α 2 (j, u σ , σ )) = 0 by (31). Then from the Definition 3, we have lim σ→0 ρ( u σ , σ) = 0 (35) (4) We can define L m ( (ΛW * ) ⊤ Λu σδ K (W * ) , σδ K (W * )) as L m ( (ΛW * ) ⊤ Λu σδ K (W * ) , σδ K (W * )) = min vi∈[0,ui] ρ( (ΛW * ) ⊤ Λv σ l δ K (W * ) , σδ K (W * )) : v j = u j for all j ̸ = i (36) Then by this definition, we have 0 < L m ( (ΛW * ) ⊤ Λu σδ K (W * ) , σδ K (W * )) ≤ ρ( (ΛW * ) ⊤ Λu σ l δ K (W * ) , σδ K (W * )) Meanwhile, for any 0 ≤ u ′ i ≤ u * i , since that [0, u ′ i ] ⊂ [0, u * i ], we can obtain L m ( (ΛW * ) ⊤ Λu σδ K (W * ) , σδ K (W * ))| ui=u ′ i ≥ L m ( (ΛW * ) ⊤ Λu σδ K (W * ) , σδ K (W * ))| ui=u * i (38) Hence, L m ( (ΛW * ) ⊤ Λu σδ K (W * ) , σδ K (W * ) ) is a strictly positive real function which is monotonically decreasing. (5) Therefore, we only need to show the condition (a). When (W * ⊤ u) i ̸ = 0 for all i ∈ [K], lim σ→0 ρ( W * ⊤ u σδ K (W * ) , σδ K (W * )) = C s (u) > 0. Therefore, there exists ζ s > 0, such that when 0 < σ < ζ s , ρ( W * ⊤ u σδ K (W * ) , σδ K (W * )) > C s (W * ⊤ u) 2 . ( ) Then we can define L s ( W * ⊤ u σδ K (W * ) , σδ K (W * )) := C s (W * ⊤ u) 2ζ s σ 2 (41) such that σ -1 L s ( W * ⊤ u σδ K (W * ) , σδ K (W * )) is an increasing function of σ below ρ( W * ⊤ u σδ K (W * ) , σδ K (W * )). When (W * ⊤ u) i = 0 for some i ∈ [K], then lim σ→0 ρ( W * ⊤ u σδ K (W * ) , σδ K (W * )) = 0. ( ) We can define L s ( W * ⊤ u σδ K (W * ) , σδ K (W * )) = σ • min vi∈[ui,ζ s ′ ] ρ( W * ⊤ v σδ K (W * ) , σδ K (W * )) : v j ̸ = u j for all j ̸ = i (43) Then, σ -1 L s ( W * ⊤ u σδ K (W * ) , σδ K (W * )) = min vi∈[ui,ζ s ′ ] ρ( W * ⊤ v σδ K (W * ) , σδ K (W * )) : v j = u j for all j ̸ = i (44) For any 0 ≤ u ′ i ≤ u * i < ζ s ′ , since that [u * i , ζ s ′ ] ⊂ [u ′ i , ζ s ′ ], we can obtain σ -1 L s ( W * ⊤ u σδ K (W * ) , σδ K (W * ))| ui=u ′ i ≤ σ -1 L s ( W * ⊤ u σδ K (W * ) , σδ K (W * ))| ui=u * i (45) Therefore, we can derive that σ -1 L s ( W * ⊤ u σδ K (W * ) , σδ K (W * )) is monotonically increasing. Following the steps in (4), we can have that σ -1 L s ( W * ⊤ u σδ K (W * ) , σδ K (W * )) is a strictly positive real function which is upper bounded by ρ( W * ⊤ u σδ K (W * ) , σδ K (W * )). In conclusion, condition (a) is proved. For condition (b), since that ζ s > 0, ρ( W * ⊤ u σ l δ K (W * ) , σδ K (W * ) ) is continuous and positive, we can obtain ρ( W * ⊤ v σδ K (W * ) , σδ K (W * )) σ=ζs > 0 Then condition (b) can be easily proved as in (4). We then characterize the order of the ρ function in different cases as follows. Property 4. To specify the order with regard to the distribution parameters, ρ(u, σ) in Definition 3 satisfies the following properties, 1. (Small variance) lim σ→0 + ρ(u, σ) = Θ(σ 4 ).

2.. (Large variance)

For any ϵ > 0, lim σ→∞ ρ(u, σ) ≥ Θ( 1 σ 3+ϵ ).

3.. (Large mean)

For any ϵ > 0, lim µ→∞ ρ(u, σ) ≥ Θ(e -∥u∥ 2 2 ) 1 ∥u∥ 3+ϵ . Proof: (1) β 0 (i, u, σ) -α 0 (i, u, σ) 2 =E z∼N (µ,1) [ϕ ′ 2 (σ • z)] -(E z∼N (µ,1) [ϕ ′ (σ • z)]) 2 = ∞ -∞ ϕ ′ 2 (σ • z) 1 √ 2π e -(z-µ) 2 2 dz -( ∞ -∞ ϕ ′ (σ • z) 1 √ 2π e -(z-µ) 2 2 dz) 2 = ∞ -∞ ( 1 4 - t 2 16 + t 4 96 • • • ) 2 1 √ 2πσ e -(t-µσ) 2 2σ 2 dt -( ∞ -∞ ( 1 4 - t 2 16 + t 4 96 + • • • ) 1 √ 2πσ e -(t-µσ) 2 2σ 2 dt) 2 =( 1 16 - 1 32 (µ 2 σ 2 + σ 2 ) + 7 768 (3σ 4 + 6µ 2 σ 4 + µ 4 σ 4 ) + • • • ) -( 1 4 - µ 2 σ 2 + σ 2 16 + 3σ 4 + 6µ 2 σ 4 + µ 4 σ 4 192 + • • • ) 2 = 1 128 σ 4 + µ 2 σ 4 64 + o(σ 4 ), as σ → 0 + . ( ) The first step of ( 47) is by Definition 3. The second step and the last steps come from some basic mathematical computation. The third step is from Taylor expansion. Hence, lim σ→0 + (β 0 (i, u, σ) -α 0 (i, u, σ) 2 ) = 1 128 σ 4 + µ 2 σ 4 64 + o(σ 4 ) Similarly, we can obtain β 2 (i, u, σ) - α 2 (i, u, σ) 2 µ 2 + 1 =E z∼N (0,1) [ϕ ′ 2 (σ • z)z 2 ] - (E z∼N (0,1) [ϕ ′ (σ • z)z 2 ]) 2 µ 2 + 1 = ∞ -∞ ϕ ′ 2 (σ • z)z 2 1 √ 2π e -(z-µ) 2 2 dz - 1 µ 2 + 1 ( ∞ -∞ ϕ ′ (σ • z)z 2 1 √ 2π e -(z-µ) 2 2 dz) 2 = ∞ -∞ ( t 4σ - t 3 16σ + t 5 96σ • • • ) 2 1 √ 2πσ e -(t-µσ) 2 2σ 2 dt - 1 µ 2 + 1 ( ∞ -∞ ( t 2 4σ 2 - t 4 16σ 2 + t 6 96σ 2 + • • • ) 1 √ 2πσ e -(t-µσ) 2 2σ 2 dt) 2 =( 1 + µ 2 16 - 3σ 2 + 6µ 2 σ 2 + µ 4 σ 2 32 + • • • ) - 1 µ 2 + 1 ( 1 + µ 2 4 - 15σ 2 + 45µ 2 σ 2 + 15µ 4 σ 2 + µ 6 σ 2 32 + • • • ) 2 = 9 64 σ 2 + 33 64 µ 2 σ 2 + 13 64 µ 4 σ 2 + 1 64 µ 6 σ 2 + o(σ 2 ), as σ → 0 + Hence, lim σ→0 + (β 2 (i, u, σ) - α 2 (i, u, σ) 2 µ 2 + 1 ) = 9 64 σ 2 + o(σ 2 ) Therefore, lim σ→0 + ρ(u, σ) = min j∈[d],uj ̸ =µ {(u 2 j + 1)} 1 128 σ 4 (2) Note that by some basic mathematical derivation, ∞ -∞ ϕ ′ 2 (σ • z) 1 √ 2π e -(z-µ) 2 2 dz = ∞ -∞ 1 (e σ•z + e -σ•z + 2) 2 1 √ 2π e -(z-µ) 2 2 dz ≥ 2 ∞ 0 1 16e 2σ•z 1 √ 2π e -(z+|µ|) 2 2 dz = 1 8 e 2|µ|σ+2σ 2 ∞ 0 1 √ 2π e -(z+2σ) 2 2 dz = 1 8 √ 2π e 2|µ|σ+2σ 2 ∞ |µ|+2σ e -t 2 2 dt (52) We then provide the following Claim with its proof to give a lower bound for (52). Claim: ∞ |µ|+2σ e -t 2 2 dt > e -2|µ|σ-2σ 2 -k1 log σ for k 1 > 1. Proof: Let f (σ) = ∞ |µ|+2σ e -t 2 2 dt -e -2|µ|σ-2σ 2 -k1 log σ . ( ) Then, f ′ (σ) = e -2σ 2 ((2|µ| + 4σ + k 1 σ )σ -k1 -2e -1 2 µ 2 ). It can be easily verified that for a given |µ| ≥ 0, f ′ (σ) < 0 when σ is large enough if k 1 > 1. Combining that lim σ→∞ f (σ) = 0, we have f (σ) > 0 when σ is large enough by showing the contradiction in the following: Suppose there is a strictly increasing function f (x) > 0 with lim x→∞ f (x) = 0 when x is large enough. Then there exists x 0 > 0 such that for any ϵ > 0, f (x) < ϵ for x > x 0 . Pick ϵ = f (x 0 ) > 0, then for x 1 > x 0 , f (x 1 ) > f (x 0 ) = ϵ. Contradiction! Similarly, we also have ∞ -∞ ϕ ′ (σ • z) 1 √ 2π e -z 2 2 dz = ∞ -∞ 1 e σ•z + e -σ•z + 2 1 √ 2π e -(z-µ) 2 2 dz ≤ 2 ∞ 0 1 e σ•z 1 √ 2π e -(z-µ) 2 2 dz = e |µ|σ+ 1 2 σ 2 ∞ 0 2 √ 2π e -(z+|µ|+σ) 2 2 dz = 2 √ 2π e |µ|σ+ 1 2 σ 2 ∞ |µ|+σ e -t 2 2 dt, and the Claim: ∞ |µ|+σ e -t 2 2 dt < e -|µ|σ-1 2 σ 2 -k2 log σ for k 2 ≤ 1 to give an upper bound for (55). Therefore, combining (52, 55) and two claims, we have that for any ϵ > 0, β 0 (i, u, σ) -α 0 (i, u, σ) 2 ≥ 1 8 √ 2π 1 σ k1 - 1 2π 1 σ 2k2 ≳ 1 σ 1+ϵ (The above inequality holds for any 2k 2 > k 1 where k 1 > 1 and k 2 ≤ 1.) Similarly, ∞ -∞ ϕ ′ 2 (σ • z)z 2 1 √ 2π e -z 2 2 dz = ∞ -∞ z 2 (e σ•z + e -σ•z + 2) 2 1 √ 2π e -(z-µ) 2 2 dz ≥ 2 ∞ 0 z 2 16e 2σ•z 1 √ 2π e -(z+|µ|) 2 2 dz = 1 8 √ 2π e |µ|σ+2σ 2 ∞ 2|µ|+2σ (t -2σ) 2 e -t 2 2 dt (57) Claim: ∞ |µ|+2σ (t -2σ) 2 e -t 2 2 dt ≥ e -2|µ|σ-2σ 2 -k1 log σ if k 1 > 3. Proof: Let f (σ) = ∞ |µ|+2σ (t -2σ) 2 e -t 2 2 dt -e -2|µ|σ-2σ 2 -k1 log σ . ( ) f ′ (σ) = 8σ ∞ |µ|+2σ e -t 2 2 dt + e -2|µ|σ-2σ 2 (4σ 1-k1 + k 1 σ -1-k1 + 2|µ|σ -k1 -4e -1 2 µ 2 ). (59) We need f ′ (σ) < 0 when σ is large enough. Since that f ′ (σ) → 0, f ′′ (σ) → 0 when σ is large, we need f ′′ (σ) > 0 and f ′′′ (σ) < 0 recursively. Hence, f ′′′ (σ) =e -2|µ|σ-2σ 2 (64σ 3-k1 + 96µσ 2-k1 + 16(3k 1 -3 + µ 2 )σ 1-k1 + 8µ(-µ 2 -3 + 6k 1 )σ -k1 + 4k 1 (3k 1 + µ 2 )σ -1-k1 + 2k 1 (1 + k 1 )(µ + 2)σ -2-k1 + k 1 (1 + k 1 )(2 + k 1 )σ -3-k1 -16e -1 2 µ 2 ) < 0 (60) requires k 1 > 3. Similarly, we have ∞ -∞ ϕ ′ (σ • z)z 2 1 √ 2π e -z 2 2 dz ≤ 2 ∞ 0 1 e σ•z 1 √ 2π z 2 e -z 2 2 dz = 2 √ 2π e 1 2 σ 2 ∞ σ (t -σ) 2 e -t 2 2 dt (61) and the Claim: ∞ σ (t -σ) 2 e -t 2 2 dt < e -σ 2 2 -k2 log σ . Hence, β 2 (i, u, σ) - α 2 (i, u, σ) 2 µ 2 + 1 ≥ 1 8 √ 2π 1 σ k1 - 2 π(µ 2 + 1) 1 σ 2k2 ≳ 1 σ 3.1 (The above inequality holds for any 2k 2 > k 1 where k 1 > 3 and k 2 < 3.) Therefore, by combining ( 56) and ( 62), for any ϵ > 0 lim σ→∞ ρ(u, σ) ≥ Θ( 1 σ 3+ϵ ). (3) Let σ be fixed. For any ϵ > 0, following the steps in (2), we can obtain ∞ -∞ ϕ ′ 2 (σ • z) 1 √ 2π e -(z-µ) 2 2 dz = ∞ -∞ 1 (e σ•z + e -σ•z + 2) 2 1 √ 2π e -(z-µ) 2 2 dz ≥ 2 ∞ 0 1 16e 2σ•z 1 √ 2π e -(z+|µ|) 2 2 dz = 1 8 √ 2π e 2|µ|σ+2σ 2 ∞ |µ|+2σ e -t 2 2 dt ≥ 1 8 √ 2π e -µ 2 2 1 µ 1+ϵ (64) ∞ -∞ ϕ ′ (σ • z) 1 √ 2π e -(z-µ) 2 2 dz = ∞ -∞ 1 e σ•z + e -σ•z + 2 1 √ 2π e -(z-µ) 2 2 dz ≤ 2 ∞ 0 1 e σ•z 1 √ 2π e -(z-µ) 2 2 dz = 2 √ 2π e -µ 2 2 1 µ 1-ϵ (65) Similarly, ∞ -∞ ϕ ′ 2 (σ • z)z 2 1 √ 2π e -(z-µ) 2 2 dz ≥ 1 8 √ 2π e -µ 2 2 1 µ 3+ϵ (66) ∞ -∞ ϕ ′ (σ • z)z 2 1 √ 2π e -(z-µ) 2 2 dz ≤ 2 √ 2π e -µ 2 2 1 µ 3-ϵ (67) We can conclude that lim µ→∞ ρ(u, σ) ≥ Θ(e -∥u∥ 2 2 ) 1 ∥u∥ 3+ϵ . Property 5. If a function f (x) is an even function, then E x∼N (µ,Σ) [f (x)] = E x∼ 1 2 N (µ,Σ)+ 1 2 N (-µ,Σ) [f (x)] (68) Proof: Denote g(x) = f (x)(2π|Σ| 2 ) -d 2 exp(- 1 2 (x -µ)Σ -1 (x -µ)) By some basic mathematical computation, E x∼N (µ,Σ) [f (x)] = x∈R d g(x)dx = ∞ -∞ • • • ∞ -∞ g(x 1 , • • • , x d )dx 1 • • • dx d = ∞ -∞ • • • ∞ -∞ -∞ ∞ g(x 1 , x 2 , • • • , x d )d(-x 1 )dx 2 • • • dx d = ∞ -∞ • • • ∞ -∞ g(-x 1 , x 2 • • • , x d )dx 1 dx 2 • • • dx d = x∈R d g(-x)dx = x∈R d f (x)(2π|Σ| 2 ) -d 2 exp(- 1 2 (x + µ)Σ -1 (x + µ)) = E x∼N (-µ,Σ) [f (x)] Therefore, we have E x∼N (µ,Σ) [f (x)] = E x∼ 1 2 N (µ,Σ)+ 1 2 N (-µ,Σ) [f (x)] (71) Property 6. Under Gaussian Mixture Model x ∼ L l=1 λ l N (µ l , Σ l ) where Σ l = diag(σ 2 l1 , • • • , σ 2 ld ) , we have the following upper bound. E x∼ L l=1 λ l N (µ l ,Σ l ) [(u ⊤ x) 2t ] ≤ (2t -1)!!||u|| 2t L l=1 λ l (||µ l || + ∥Σ 1 2 l ∥) 2t Proof: Note that E x∼ L l=1 λ l N (µ l ,Σ l ) [(u ⊤ x) 2t ] = L l=1 λ l E x∼N (µ l ,Σ l ) [(u ⊤ x) 2t ] = L l=1 λ l E y∼N (u ⊤ µ l ,u ⊤ Σ l u) [y 2t ], (73) where the last step is by that u ⊤ x ∼ N (u ⊤ µ, u ⊤ Σ l u) for x ∼ N (µ l , Σ l ). By some basic mathematical computation, we know E y∼N (u ⊤ µ l ,u ⊤ Σ l u) [y 2t ] = ∞ -∞ (y -u ⊤ µ l + u ⊤ µ l ) 2t 1 2πu ⊤ Σ l u e - (y-u ⊤ µ l ) 2 2u ⊤ Σ l u dy = ∞ -∞ 2t p=0 2t p (u ⊤ µ l ) 2t-p (y -u ⊤ µ l ) p 1 2πu ⊤ Σ l u e - (y-u ⊤ µ l ) 2 2u ⊤ Σ l u dy = 2t p=0 2t p (u ⊤ µ l ) 2t-p • 0 , p is odd (p -1)!!(u ⊤ Σ l u) p 2 , p is even ≤ 2t p=0 2t p |u ⊤ µ l | 2t-p (p -1)!!|u ⊤ Σ l u| p 2 ≤(2t -1)!!(|u ⊤ µ l | + |u ⊤ Σ l u| 1 2 ) 2t ≤(2t -1)!!∥u∥ 2t (∥µ l ∥ + ∥Σ∥ 1 2 ) 2t , where the second step is by the Binomial theorem. Hence, E x∼ L l=1 λ l N (µ l ,Σ l ) [(u ⊤ x) 2t ] ≤ (2t -1)!!||u|| 2t L l=1 λ l (||µ l || + ∥Σ 1 2 l ∥) 2t Property 7. With the Gaussian Mixture Model, we have E x∼ L l=1 λ l N (µ l ,Σ l ) [||x|| 2t ] ≤ d t (2t -1)!! L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2t Proof: E x∼ L l=1 λ l N (µ l ,Σ l ) [||x|| 2t 2 ] =E x∼ L l=1 λ l N (µ l ,Σ l ) [( d i=1 x 2 i ) t ] =E x∼ L l=1 λ l N (µ l ,Σ l ) [d t ( d i=1 x 2 i d ) t ] ≤E x∼ L l=1 λ l N (µ l ,Σ l ) [d t d i=1 x 2t i d ] =d t-1 d i=1 L j=1 ∞ -∞ (x i -µ ji + µ ji ) 2t λ j 1 √ 2πσ ji exp(- (x i -µ ji ) 2 2σ 2 ji )dx i =d t-1 d i=1 L j=1 2t k=1 2t k λ j |µ ji | 2t-k • 0 , k is odd (k -1)!!σ k ji , k is even ≤d t-1 d i=1 L j=1 2t k=1 2t k λ j |µ ji | 2t-k σ k j • (2t -1)!! =d t-1 d i=1 L j=1 λ j (|µ ji | + σ ji ) 2t (2t -1)!! ≤d t (2t -1)!! L l=1 λ l (∥µ∥ + ∥Σ 1 2 l ∥) 2t In the 3rd step, we apply Jensen inequality because f (x) = x t is convex when x ≥ 0 and t ≥ 1. In the 4th step we apply the Binomial theorem and the result of k-order central moment of Gaussian variable. Property 8. Under the Gaussian Mixture Model x ∼ L l=1 λ l N (µ l , Σ l ) where Σ l = Λ ⊤ l D l Λ l , we have the following upper bound. E x∼ L l=1 λ l N (µ l ,Σ l ) [(u ⊤ x) 2t ] ≤ (2t -1)!!||u|| 2t L l=1 λ l (||µ l || + ∥Σ 1 2 l ∥) 2t Proof: If x ∼ N (µ l , Σ l ), then u ⊤ x ∼ N (u ⊤ µ l , u ⊤ Σ l u) = N ((Λ l u) ⊤ Λ l µ l , (Λ l u) ⊤ D l (Λ l u)) . By Property 6, we have E x∼N (µ l ,Σ l ) [(u ⊤ x) 2t ] ≤ (2t -1)!!∥u∥ 2t (∥µ l ∥ + ∥Σ 1 2 l ∥) 2t Then we can derive the final result. Property 9. The population risk function f (W ) is defined as f (W ) = E x∼ L l=1 λ l N (µ l ,Σ l ) [f n (W )] =E x∼ L l=1 λ l N (µ l ,Σ l ) 1 n n i=1 ℓ(W ; x i , y i ) =E x∼ L l=1 λ l N (µ l ,Σ l ) [ℓ(W ; x i , y i )] For any permutation matrix P , where {π(j)} K j=1 is the indices permuted by P , we have H(W P , x) = 1 K π * (j) ϕ(w π(j) ⊤ x) = 1 K K j=1 ϕ(w j ⊤ x) = H(W , x) Therefore, f (W ) = f (W P ) (82) Based on (1) and (3), we can derive its gradient and Hessian as follows. ∂ℓ(W ; x, y) ∂w j = - 1 K y -H(W ) H(W )(1 -H(W )) ϕ ′ (w ⊤ j x)x = ζ(W ) • x (83) ∂ 2 ℓ(W ; x, y) ∂w j ∂w l = ξ j,l • xx ⊤ (84) ξ j,l (W ) = 1 K 2 ϕ ′ (w ⊤ j x)ϕ ′ (w ⊤ l x) H(W ) 2 +y-2y•H(W ) H 2 (W )(1-H(W )) 2 , j ̸ = l 1 K 2 ϕ ′ (w ⊤ j x)ϕ ′ (w ⊤ l x) H(W ) 2 +y-2y•H(W ) H 2 (W )(1-H(W )) 2 -1 K ϕ ′′ (w ⊤ j x) y-H(W ) H(W )(1-H(W )) , j = l (85) Property 10. With D m (Ψ defined in definition 5, we have (i) D m (Ψ)D 2m (Ψ) ≤ D 3m (Ψ) (ii) D m (Ψ) 2 ≤ D 2m (Ψ) Proof: To prove (86), we can first compare the terms L i=1 λ i a i L i=1 λ i a 2 i and L i=1 λ i a 3 i , where a i ≥ 1, i ∈ [L] and L i=1 λ i = 1. L i=1 λ i a 3 i - L i=1 λ i a i L i=1 λ i a 2 i = L i=1 λ i a i • a 2 i - L j=1 λ j a 2 j = L i=1 λ i a i • (1 -λ i )a 2 i - 1≤j≤L,j̸ =i λ j a 2 j = L i=1 λ i a i • 1≤j≤L,j̸ =i λ j a 2 i - 1≤j≤L,j̸ =i λ j a 2 j = L i=1 λ i a i • 1≤j≤L,j̸ =i λ j (a 2 i -a 2 j ) = 1≤i,j≤L,i̸ =j λ i λ j a i (a 2 i -a 2 j ) + λ i λ j a j (a 2 j -a 2 i ) = 1≤i,j≤L,i̸ =j λ i λ j (a i -a j ) 2 (a i + a j ) ≥ 0 The second to last step is because we can find the pairwise terms λ i a i • λ j (a 2 i -a 2 j ) and λ j a j • λ i (a 2 j -a 2 i ) in the summation that can be putted together. From (88), we can obtain L i=1 λ i a i L i=1 λ i a 2 i ≤ L i=1 λ i a 3 i (89) Combining ( 89) and the definition of D m (Ψ) in ( 5), we can derive (86). Similarly, to prove (87), we can first compare the terms ( L i=1 λ i a i ) 2 and L i=1 λ i a 2 i , where a i ≥ 1, i ∈ [L] and L i=1 λ i = 1. L i=1 λ i a 2 i -( L i=1 λ i a i ) 2 = L i=1 λ i a i • a i - L j=1 λ j a j = L i=1 λ i a i • (1 -λ i )a i - 1≤j≤L,j̸ =i λ j a j = L i=1 λ i a i • 1≤j≤L,j̸ =i λ j a i - 1≤j≤L,j̸ =i λ j a j = L i=1 λ i a i • 1≤j≤L,j̸ =i λ j (a i -a j ) = 1≤i,j≤L,i̸ =j λ i λ j a i (a i -a j ) + λ i λ j a j (a j -a i ) = 1≤i,j≤L,i̸ =j λ i λ j (a i -a j ) 2 ≥ 0 (90) The derivation of ( 90) is close to (88). By (90) we have ( L i=1 λ i a i ) 2 ≤ L i=1 λ i a 2 i (91) Combining ( 91) and the definition of D m (Ψ) in ( 5), we can derive (87).

D PROOF OF THEOREM 1 AND COROLLARY 1

Theorem 1 is built upon three lemmas. Lemma 1 shows that with O(dK 5 log 2 d) samples, the empirical risk function is strongly convex in the neighborhood of W * . Lemma 2 shows that if initialized in the convex region, the gradient descent algorithm converges linearly to a critical point W n , which is close to W * . Lemma 3 shows that the Tensor Initialization Method in Subroutine 1 initializes W 0 ∈ R d×K in the local convex region. Theorem 1 follows naturally by combining these three lemmas. This proving approach is built upon those in Fu et al. (2020) . One of our major technical contribution is extending Lemmas 1 and 2 to the Gaussian mixture model, while the results in Fu et al. (2020) only apply to Standard Gaussian models. The second major contribution is a new tensor initialization method for Gaussian mixture model such that the initial point is in the convex region (see Lemma 3). Both contributions require the development of new tools, and our analyses are much more involved than those for the standard Gaussian due to the complexity introduced by the Gaussian mixture model. To present these lemmas, the Euclidean ball B(W * P * , r) is used to denote the neighborhood of W * P * , where r is the radius of the ball.

B(W

* P * , r) = {W ∈ R d×K : ||W -W * P * || F ≤ r} (92) The radius of the convex region is r := Θ C 3 ϵ 0 • L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) K 7 2 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 4 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 8 1 4 with some constant C 3 > 0. Lemma 1. (Strongly local convexity) Consider the classification model with FCN (1) and the sigmoid activation function. There exists a constant C such that as long as the sample size n ≥C 1 ϵ -2 0 • L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 2 • L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) -2 dK 5 log 2 d (94) for some constant C 1 > 0, ϵ 0 ∈ (0, 1 4 ), and any fixed permutation matrix P ∈ R K×K we have for all W ∈ B(W * P , r), Ω 1 -2ϵ 0 K 2 L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) • I dK ⪯ ∇ 2 f n (W ) ⪯ C 2 L l=1 λ l (|| μl || ∞ + ∥Σ 1 2 l ∥) 2 • I dK (95) with probability at least 1 -d -10 for some constant C 2 > 0. Lemma 2. (Linear convergence of gradient descent) Assume the conditions in Lemma 1 hold. Given any fixed permutation matrix P ∈ R K×K , if the local convexity of B(W * P , r) holds, there exists a critical point in B(W * P , r) for some constant C 3 > 0, and ϵ 0 ∈ (0, 1 2 ), such that || W n -W * P || F ≤ O( K 5 2 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 (1 + ξ) L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) d log n n ) (96) If the initial point W 0 ∈ B(W * P , r), the gradient descent linearly converges to W n , i.e.,  ||W t -W n || F ≤ 1 -Ω L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) K 2 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 t ||W 0 -W n || F n ≥ κ 8 K 4 τ 12 D 6 (Ψ) • d log 2 d, (98) then the output W 0 ∈ R d×K satisfies ||W 0 -W * P * || ≲ κ 6 K 3 • τ 6 D 6 (Ψ) d log n n ||W * || (99) with probability at least 1 -n -Ω(δ 4 1 ) for a specific permutation matrix P * ∈ R K×K .

Proof of Theorem 1

From Lemma 2 and Lemma 3, we know that if n is sufficiently large such that the initialization W 0 by the tensor method is in the region B(W * P , r), then the gradient descent method converges to a critical point W n that is sufficiently close to W * . To achieve that, one sufficient condition is ||W 0 -W * P * || F ≤ √ K||W 0 -W * P * || ≤ κ 6 K 7 2 • τ 6 D 6 (Ψ) d log n n ||W * P || ≤ C 3 ϵ 0 Γ(Ψ)σ 2 max K 7 2 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 4 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 8 1 4 (100) where the first inequality follows from ||W || F ≤ √ K||W || for W ∈ R d×K , the second inequality comes from Lemma 3, and the third inequality comes from the requirement to be in the region B(W * P , r). That is equivalent to the following condition n ≥C 0 ϵ -2 0 • τ 12 κ 12 K 14 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 4 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 8 1 2 • (δ 1 (W * )) 2 D 6 (Ψ)Γ(Ψ) -2 σ -4 max • d log 2 d ( ) where C 0 = max{C 4 , C -2 3 }. By Definition 5, we can obtain L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 4 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 8 1 2 ≤ D 4 (Ψ)D 8 (Ψ)σ 6 max ( ) From Property 10, we have that D 4 (Ψ)D 8 (Ψ)D 6 (Ψ) ≤ D 12 (Ψ) D 12 (Ψ) = D 12 (Ψ) Plugging ( 102), ( 103) into ( 101), we have n ≥ C 0 ϵ -2 0 • κ 12 K 14 (σ max δ 1 (W * )) 2 τ 12 Γ(Ψ) -2 D 12 (Ψ) • d log 2 d (104) Considering the requirements on the sample complexity in ( 94), (98), and ( 104), ( 104) shows a sufficient number of samples. Taking the union bound of all the failure probabilities in Lemma 1, and 3, (104) holds with probability 1 -d -10 . By Property 3.4, ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) can be lower bounded by positive and monotonically decreasing functions L m ( (Λ l W * ) ⊤ μl δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) when everything else except | μl(i) | is fixed, or L s ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) when everything else except ∥Σ 1 2 l ∥ is fixed. Then, by replacing the lower bound of ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) with these two functions in Γ(Ψ), we can have an upper bound of (σ max δ 1 (W * )) 2 τ 12 Γ(Ψ) -2 D 12 (Ψ), denoted as B(Ψ). To be more specific, when everything else except | μl(i) | is fixed, L m ( (Λ l W * ) ⊤ μl δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) is plugged in B(Ψ). Then since that D 12 (Ψ) and L m ( (Λ l W * ) ⊤ μl δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) are both increasing function of | μl(i) |, B(Ψ) is an increasing function of | μl(i) |. When everything else except ∥Σ l ∥ < ζ s ′ and go to 0, two decreasing functions of ∥Σ 1 2 l ∥ is fixed, if ∥Σ 1 2 l ∥ = σ max > ζ s , then σ 2 max τ 12 D 12 (Ψ) is an increasing function of ∥Σ 1 2 l ∥. Since that L s ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) is a decreasing function, L s ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) -2 is an increasing function of ∥Σ 1 2 l ∥, σ 2 max L s ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) -2 and D 12 (Ψ) will be the dominant term of B(Ψ). Therefore, B(Ψ) increases to infinity as all ∥Σ 1 2 l ∥'s go to 0. In sum, we can define a universe B(Ψ) as: B(Ψ) =                            (σ max δ 1 (W * )) 2 τ 12 L l=1 λ l ∥Σ -1 l ∥ -1 ησ 2 max L m ( (Λ l W * ) ⊤ μl δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) -2 •D 12 (Ψ), if S is fixed (σ max δ 1 (W * )) 2 τ 12 L l=1 λ l ∥Σ -1 l ∥ -1 ησ 2 max L s ( (Λ l W * ) ⊤ μl δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) -2 •D 12 (Ψ), if M is fixed (σ max δ 1 (W * )) 2 τ 12 L l=1 λ l ∥Σ -1 l ∥ -1 ησ 2 max ρ( (Λ l W * ) ⊤ μl δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) -2 •D 12 (Ψ), otherwise (105) where L m , L s and D 12 are defined in ( 38), ( 43) and Definition 5, respectively. Hence, we have n ≥ poly(ϵ -1 0 , κ, η, τ K)B(Ψ) • d log 2 d (106) Similarly, by replacing ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) with L m ( (Λ l W * ) ⊤ μl δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) when everything else except | μl(i) | is fixed, or L s ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) (or ∥Σ -1 l ∥L s ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) for ∥Σ -1 l ∥ -1 ≥ 1) when everything else except ∥Σ 1 2 l ∥ is fixed, (97) can also be transferred to another feasible upper bound. We denote the modified version of the convergence rate as v = 1 -K -2 q(Ψ). Since that q(Ψ) is a ratio between the smallest and the largest singular value of ∇ 2 f (W * ), we have q(Ψ) ∈ (0, 1). Hence, we can obtain 1 -K -2 q(Ψ) ∈ (0, 1) by K ≥ 1.

When everything else except

| μl(i) | is fixed, since that L m ( (Λ l W * ) ⊤ μl δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) is monotonically decreasing and L l=1 λ(∥µ l ∥ + ∥Σ 1 2 l ∥) 2 is increasing as | μl(i) | increases, v is an increasing function of | μl(i) | to 1. Similarly, when everything else except ∥Σ 1 2 l ∥ is fixed where ∥Σ 1 2 l ∥ ≥ max{1, ζ s }, 1 L l=1 λ l (∥µ l ∥+∥Σ 1 2 l ∥) 2 decreases to 0 as ∥Σ l ∥ increases. We replace ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) by ∥Σ -1 l ∥L s ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) and then ∥Σ -1 l ∥ -1 • ∥Σ -1 l ∥L s ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) =L s ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) (107) is an decreasing function less than ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ). Therefore, v is an increasing function of ∥Σ 1 2 l ∥ to 1 when ∥Σ 1 2 l ∥ ≥ max{1, ζ s }. When everything else except all ∥Σ 1 2 l ∥ ≤ ζ s ′ 's go to 0, all L s ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 's will decrease and all ∥Σ -1 l ∥ -1 L l=1 λ l (∥µ l ∥∞+∥Σ 1 2 l ∥) 2 's will decrease to 0. Therefore, v increases to 1. q(Ψ) can then be defined as q(Ψ) =                                                        Ω L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 Lm( (Λ l W * ) ⊤ μl δ K (W * )∥Σ -1 l ∥ -1 2 ,δ K (W * )∥Σ -1 l ∥ -1 2 ) L l=1 λ l (∥µ l ∥+∥Σ 1 2 l ∥) 2 ), if S is fixed Ω L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 Ls( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 ,δ K (W * )∥Σ -1 l ∥ -1 2 ) L l=1 λ l (∥µ l ∥+∥Σ 1 2 l ∥) 2 , if M is fixed and all ∥Σ 1 2 l ∥ ≤ ζ s ′ Ω λ l 1 ητ K κ 2 Ls( W * ⊤ µ i δ K (W * )∥Σ -1 i ∥ -1 2 ,δ K (W * )∥Σ -1 i ∥ -1 2 )+ l̸ =i r(λ l ,µ l ,Σ l ,W * ) L l=1 λ l (∥µ l ∥+∥Σ 1 2 l ∥) 2 , if M is fixed and one ∥Σ 1 2 i ∥ ≥ max{1, ζ s } Ω L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 ,δ K (W * )∥Σ -1 l ∥ -1 2 ) L l=1 λ l (∥µ l ∥+∥Σ 1 2 l ∥) 2 , otherwise . ( ) where r(λ l , µ l , Σ l , W * ) = λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ). Note that here the ρ(•) function is defined in Definition 3. L m (•) and L s (•) are defined in ( 38) and ( 43), respectively. The bound of ∥ W n -W * P ∥ F is directly from (96). We can derive that E w (Ψ) = O( L j=1 λ l (∥µ j ∥ + ∥Σ 1 2 j ∥) 2 L j=1 λ l ∥Σ -1 j ∥ -1 ρ( W * ⊤ µj δ K (W * )∥Σ -1 j ∥ -1 2 , δ K (W * )∥Σ -1 j ∥ -1 2 ) ) E(Ψ) = O( L j=1 λ l (∥µ j ∥ + ∥Σ 1 2 j ∥) 2 L j=1 λ l ∥Σ -1 j ∥ -1 ρ( W * ⊤ µj δ K (W * )∥Σ -1 j ∥ -1 2 , δ K (W * )∥Σ -1 j ∥ -1 2 ) ) (110) E l (Ψ) = O( L j=1 λ l (∥µ j ∥ + ∥Σ 1 2 j ∥) 2 (∥µ l ∥ + ∥Σ l ∥ 1 2 ) L j=1 λ l ∥Σ -1 j ∥ -1 ρ( W * ⊤ µj δ K (W * )∥Σ -1 j ∥ -1 2 , δ K (W * )∥Σ -1 j ∥ -1 2 ) ) (111) The discussion of the monotonicity of E w (Ψ), E(Ψ) and E l (Ψ) can follow the analysis of q(Ψ). We finish our proof of Theorem 1 here. The parameters B(Ψ), q(Ψ), E w (Ψ), E(Ψ), and E l (Ψ) can be found in 105, 108, 109, 110, and 111, respectively.

Proof of Corollary 1:

The monotonicity analysis has been included in the proof of Theorem 1. In this part, we specify our proof for the results in Table 1 . For simplicity, we denote ρ l = ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ). When everything else except ∥Σ l ∥ 1 2 is fixed, if ∥Σ l ∥ = o(1) , by some basic mathematical computation, then we have n sc =C 0 ϵ -2 0 • η 2 τ 12 κ 16 K 14 L l=1 λ l (∥ μl ∥ + ∥Σ 1 2 l ∥) 4 L l=1 λ l (∥ μl ∥ + ∥Σ 1 2 l ∥) 8 1 2 (δ 1 (W * )) 2 D 6 (Ψ) • ( 1 L l=1 λ l ∥Σ -1 l ∥ -1 ρ l ) 2 • d log 2 d ≲poly(ϵ -1 0 , η, τ, κ, K, δ 1 (W * )) • d log 2 d • O(λ L 1 ∥Σ 1 2 L ∥ 6 ) (112) v(Ψ) = 1 - L l=1 λ l ∥Σ -1 l ∥ -1 ηκ 2 ρ l K 2 ( L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 ) ≤ 1 - λ l K 2 ηκ 2 τ K Θ(∥Σ l ∥ 3 ) (113) ∥ W n -W * P * ∥ ≤ O( K 5 2 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 (1 + ξ) L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) d log n n ) ≲ poly(η, κ, τ, δ K (W * )) d log n n K 2 (1 + ξ) • O(1 -∥Σ l ∥ 3 ) (114) fl (W t ) = fl (W t ) -fl (W * ) ≤ E K k=1 ∂( fl (W t ) ∂ wk ) ⊤ (w t(k) -w * k ) ≤ ∥W t -W * P * ∥(∥µ l ∥ + ∥Σ l ∥ 1 2 ) ≲ O L j=1 √ λ l (∥µ j ∥ + ∥Σ j ∥ 1 2 ) L j=1 λ j ∥Σ -1 j ∥ -1 ρ j (∥µ j ∥ + ∥Σ j ∥ 1 2 ) • d log n n ηκ 2 K 2 (1 + ξ) ≲ poly(η, κ, τ, δ K (W * )) d log n n K 2 (1 + ξ) • O( 1 1 + ∥Σ l ∥ 3 ) ≲ poly(η, κ, τ, δ K (W * )) d log n n K 2 (1 + ξ) • O(1) -Θ(∥Σ l ∥ 3 ), The first inequality of ( 115) is by the Mean Value Theorem. The second inequality of ( 115) is from Property 8, and the third inequality is derived from (96, 97). The last inequality is obtained by the condition that ∥Σ l ∥ = o(1). We can similarly have f (W t ) ≤ E K k=1 ∂( f (W t ) ∂ wk ) ⊤ (w t(k) -w * k ) ≲ poly(η, κ, τ, δ K (W * )) d log n n K 2 (1 + ξ) • O( 1 1 + ∥Σ l ∥ 3 ) ≲ poly(η, κ, τ, δ K (W * )) d log n n K 2 (1 + ξ) • O(1) -Θ(∥Σ l ∥ 3 ) If ∥Σ l ∥ 1 2 = Ω(1), we have n sc ≲ poly(ϵ -1 0 , η, τ, κ, K, δ 1 (W * )) • d log 2 d • O(∥Σ l ∥ 3 ) (117) v(Ψ) ≤ 1 - 1 K 2 τ K ηκ 2 Θ( 1 1 + ∥Σ l ∥ ) (118) ∥ W n -W * P * ∥ F ≲ poly(η, τ, κ, δ K (W * )) d log n n K 5 2 (1 + ξ) • ∥Σ l ∥ (119) fl (W t ) ≲ poly(η, τ, κ, δ K (W * )) d log n n K 2 (1 + ξ) • ∥Σ l ∥ (120) f (W t ) ≲ poly(η, τ, κ, δ K (W * )) d log n n K 2 (1 + ξ) • ∥Σ l ∥ When everything is fixed except ∥µ l ∥, by combining ( 94) and ( 101), we have n sc ≲ poly(ϵ -1 0 , η, τ, κ, K, δ 1 (W * )) • d log 2 d • O(∥µ l ∥ 4 ), if ∥µ l ∥ ≤ 1 O(∥µ l ∥ 12 ), if ∥µ l ∥ ≥ 1 (122) v(Ψ) ≤ 1 - 1 K 2 τ K ηκ 2 Θ( 1 1 + ∥µ l ∥ 2 ) (123) ∥ W n -W * P * ∥ F ≲ poly(η, τ, κ, δ K (W * )) d log n n K 5 2 (1 + ξ) • (1 + ∥µ l ∥) (124) fl (W t ) ≲ poly(η, τ, κ, δ K (W * )) d log n n K 2 (1 + ξ) • (1 + ∥µ l ∥ 2 ) (125) f (W t ) ≲ poly(η, τ, κ, δ K (W * )) d log n n K 2 (1 + ξ) • (1 + ∥µ l ∥ 2 ) (126) When everything else is fixed except λ 1 , λ 2 , • • • , λ L , where ∥Σ j ∥ = Ω(1), j ∈ [L] and ∥µ j ∥ = ∥µ i ∥, i, j ∈ [L], if ∥Σ l ∥ ≤ ∥Σ j ∥, j ∈ [L], we have n sc ≲poly(ϵ -1 0 , η, κ, K, δ 1 (W * )) • d log 2 d • (a 1 λ 2 l + a 2 λ 3 2 l + a 3 λ l + a 4 λ 1 2 l + a 5 ) ( L j=1 λ j ρ j ) 2 ≤poly(ϵ -1 0 , η, κ, K, δ 1 (W * )) • d log 2 d • a 5 ( L j=1 λ j ρ j ) 2 ≲poly(ϵ -1 0 , η, κ, K, δ 1 (W * )) • d log 2 d • O((1 + λ l ) -2 ) where a 1 = (∥µ l ∥+∥Σ l ∥ 1 2 ) 12 /∥Σ l ∥ 3 , a 2 = (∥µ l ∥+∥Σ 1 2 l ∥) 8 ( j̸ =l λ j (∥µ j ∥+∥Σ j ∥ 1 2 ) 8 ) 1 2 /∥Σ l ∥ 3 , a 3 = (∥µ l ∥/∥Σ l ∥ 1 2 + 1) 6 ( j̸ =l λ j (∥µ j ∥ + ∥Σ j ∥ 1 2 ) 4 j̸ =l λ j (∥µ j ∥ + ∥Σ j ∥ 1 2 ) 8 ) 1 2 + (∥µ l ∥ + ∥Σ l ∥ 1 2 ) 6 j̸ =l λ j (∥µ j ∥/∥Σ j ∥ 1 2 + 1) 6 , a 4 = j̸ =l λ j (∥µ j ∥/∥Σ j ∥ 1 2 + 1) 6 (∥µ l ∥ + ∥Σ l ∥ 1 2 ) 2 ( j̸ =l λ j (∥µ j ∥ + ∥Σ j ∥ 1 2 ) 8 ) 1 2 , a 5 = ( j̸ =l λ j (∥µ j ∥ + ∥Σ j ∥ 1 2 ) 4 j̸ =l λ j (∥µ j ∥ + ∥Σ j ∥ 1 2 ) 8 ) 1 2 • j̸ =l λ j (∥µ j ∥/∥Σ j ∥ 1 2 + 1) 6 . The second step of (127) is by a i = O(a 5 ), i = 1, 2, 3, 4. v ≤ 1 K 2 ητ K κ 2 Θ( 1 1 + λ l ) (128) ∥ W n -W * P ∥ F ≤ poly(η, κ, , τ, δ 1 (W * )) • d log n n K 5 2 (1 + ξ) • O( 1 1 + √ λ l ) (129) fl (W t ) ≤ poly(η, κ, , τ, δ 1 (W * )) • d log n n K 2 (1 + ξ) • O( 1 1 + √ λ l ) (130) f (W t ) ≤ poly(η, κ, , τ, δ 1 (W * )) • d log n n K 2 (1 + ξ) • O( 1 1 + λ l ) (131) If ∥Σ l ∥ ≥ ∥Σ j ∥, j ∈ [L], we can similarly derive that n sc ≲poly(ϵ -1 0 , η, κ, K, δ 1 (W * )) • d log 2 d • (a 1 λ 2 l + a 2 λ 3 2 l + a 3 λ l + a 4 λ 1 2 l + a 5 ) ( L j=1 λ j ρ j ) 2 ≲poly(ϵ -1 0 , η, κ, K, δ 1 (W * )) • d log 2 d • (O(1) -Θ((1 + λ l ) -2 )) (132) v ≤ 1 - 1 K 2 ητ K κ 2 Θ( 1 1 + λ l ) (133) ∥ W n -W * P ∥ F ≤ poly(η, κ, , τ, δ 1 (W * )) • d log n n K 5 2 (1 + ξ) • O(1 + λ l ) (134) fl (W t ) ≤ poly(η, κ, , τ, δ 1 (W * )) • d log n n K 2 (1 + ξ) • O(1 + λ l ) (135) f (W t ) ≤ poly(η, κ, , τ, δ 1 (W * )) • d log n n K 2 (1 + ξ) • (O(1) - Θ(1) 1 + λ l ) E PROOF OF LEMMA 1 We first state some lemmas used in proof in Section E.1 and describe the proof in Section E.2. The proofs of these lemmas are provided in Section E.3 to E.7 in sequence. The proof idea mainly follows from Fu et al. (2020) . Lemma 6 shows the Hessian ∇ 2 f (W ) of the population risk function is smooth. Lemma 7 illustrates that ∇ 2 f (W ) is strongly convex in the neighborhood around µ * . Lemma 8 shows the Hessian of the empirical risk function ∇ 2 f n (W * ) is close to its population risk ∇ 2 f (W * ) in the local convex region. Summing up these three lemmas, we can derive the proof of Lemma 1. Lemma 4 is used in the proof of Lemma 7. Lemma 5 is used in the proof of Lemma 8. The analysis of the Hessian matrix of the population loss in Fu et al. (2020) and Zhong et al. (2017b) can not be extended to the Gaussian mixture model. To solve this problem, we develop new tools using some good properties of symmetric distribution and even function. Our approach can also be applied to other activations like tanh or erf. Moreover, if we directly apply the existing matrix concentration inequalities in these works in bounding the error between the empirical loss and the population loss, the resulting sample complexity bound is loose and cannot reflect the influence of each component of the Gaussian mixture distribution. We develop a new version of Bernstein's inequality (see ( 208)) so that the final bound is O(d log 2 d). Mei et al. (2016) showed that the landscape of the empirical risk is close to that of the population risk when the number of samples is sufficiently large for the special case that K = 1. Focusing on Gaussian mixture models, our result explicitly shows how the parameters of the input distribution, including the proportion, mean and, variance of each component will affect the error bound between the empirical loss and the population loss in Lemma 8.

E.1 USEFUL LEMMAS IN THE PROOF OF LEMMA 1

Lemma 4. E x∼ 1 2 N (µ,I d )+ 1 2 N (-µ,I d ) ( k i=1 r ⊤ i x • ϕ ′ (σ • x i )) 2 ≥ ρ(µ, σ)||R|| 2 F , where ρ(µ, σ) is defined in Definition 3 and R = (r 1 , • • • , r k ) ∈ R d×k is an arbitrary matrix. Lemma 5. With the FCN model (1) and the Gaussian Mixture Model, for any permutation matrix P , for some constant C 12 > 0, we have we have E x∼ L l=1 λ l N (µ l ,Σ l ) sup W ̸ =W ′ ∈B(W * P ,r) ||∇ 2 ℓ(W , x) -∇ 2 ℓ(W ′ , x)|| ||W -W ′ || F ≤C 12 • d 3 2 K 5 2 L l=1 λ l (∥µ l ∥ ∞ + ∥Σ l ∥) 2 L l=1 λ l (∥µ l ∥ ∞ + ∥Σ l ∥) 4 Lemma 6. (Hessian smoothness of population loss) In the FCN model (1), for some constant C 5 > 0, for any permutation matrix P , we have ||∇ 2 f (W ) -∇ 2 f (W * P )|| ≤C 5 • K 3 2 • L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 4 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 8 1 4 • ||W -W * P || F Lemma 7. (Local strong convexity of population loss) In the FCN model (1), for any permutation matrix P , if ||W -W * P || F ≤ r for an ϵ 0 ∈ (0, 1 4 ), then for some constant C 4 > 0, 4(1 -ϵ 0 ) K 2 L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) • I dK ⪯∇ 2 f (W ) ⪯ C 4 • L l=1 λ l (∥µ l ∥ + Σ 1 2 l ) 2 • I dK Lemma 8. In the FCN model (1), for any permutation P , as long as n ≥ C ′ • dK log dK for some constant C ′ > 0, we have sup W ∈B(W * P ,r) ||∇ 2 f n (W ) -∇ 2 f (W )|| ≤ C 6 • L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 dK log n n ) with probability at least 1 -d -10 for some constant C 6 > 0.

E.2 PROOF OF LEMMA 1

From Lemma 7 and 8, with probability at least 1 -d -10 , ∇ 2 f n (W ) ⪰ ∇ 2 f (W ) -||∇ 2 f (W ) -∇ 2 f n (W )|| • I ⪰ Ω (1 -ϵ 0 ) K 2 L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) • I -O C 6 • L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 dK log n n • I (142) As long as the sample complexity is set to satisfy C 6 • L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 dK log n n ≤ ϵ 0 K 2 L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) • I i.e., n ≥C 1 ϵ -2 0 • L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 2 • L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) • I -2 dK 5 log 2 d for some constant C 1 > 0, then we have the lower bound of the Hessian with probability at least 1 -d -10 . ∇ 2 f n (W ) ⪰ Ω 1 -2ϵ 0 K 2 L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) • I (145) By ( 140) and ( 141), we can also derive the upper bound as follows, ||∇ 2 f n (W )|| ≤ ||∇ 2 f (W )|| + ||∇ 2 f n (W ) -∇ 2 f (W )|| ≤ C 4 • L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 + C 6 • 1=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 dK log n n ≤ C 2 • L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 for some constant C 2 > 0. Combining ( 145) and ( 146), we have Ω 1 -2ϵ 0 K 2 L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) • I ⪯∇ 2 f n (W ) ⪯ C 2 L l=1 λ l (|| μl || ∞ + ∥Σ 1 2 l ∥) 2 • I with probability at least 1 -d -10 .

E.3 PROOF OF LEMMA 4

Following the proof idea in Lemma D.4 of Zhong et al. (2017b) , we have E x∼ 1 2 N (µ,I d 1 2 N (-µ,I d ) ( k i=1 r ⊤ i x • ϕ ′ (σ • x i )) 2 = A 0 + B 0 A 0 = E x∼ 1 2 N (µ,I d )+ 1 2 N (-µ,I d ) k i=1 r ⊤ i x • ϕ ′2 (σ • x i ) • xx ⊤ r i B 0 = E x∼ 1 2 N (µ,I d )+ 1 2 N (-µ,I d ) i̸ =l r ⊤ i ϕ ′ (σ • x i )ϕ ′ (σ • x l ) • xx ⊤ r l In A 0 , we know that E x∼ 1 2 N (µ,I d )+ 1 2 N (-µ,I d ) x j = 0. Therefore, by some basic mathematical computation, A 0 = k i=1 E x∼ 1 2 N (µ,I d )+ 1 2 N (-µ,I d ) r ⊤ i ϕ ′2 (σ • x i ) x 2 i e i e ⊤ i + j̸ =i x i x j (e i e ⊤ j + e j e ⊤ i ) + j̸ =i l̸ =i x j x l e j e ⊤ l r i = k i=1 E x∼ 1 2 N (µ,I d )+ 1 2 N (-µ,I d ) r ⊤ i ϕ ′2 (σ • x i ) x 2 i e i e ⊤ i + j̸ =i x 2 j e j e ⊤ j r i = k i=1 E x∼ 1 2 N (µ,I d )+ 1 2 N (-µ,I d ) [ϕ ′2 (σ • x i )x 2 i ]r ⊤ i e i e ⊤ i r i + j̸ =i E x∼ 1 2 N (µ,I d )+ 1 2 N (-µ,I d ) [x 2 j ]E x∼ 1 2 N (µ,I)+ 1 2 N (-µ,I) [ϕ ′2 (σ • x i )]r ⊤ i e j e ⊤ j r i = k i=1 r 2 ii β 2 (i, µ, σ) + k i=1 j̸ =i r 2 ij β 0 (i, µ, σ)(1 + µ 2 j ) In B 0 , α 1 (i, µ, σ) = E x∼ 1 2 N (µ,I d )+ 1 2 N (-µ,I d ) (x i ϕ ′ (x i )) = 0. By the equation in Page 30 of Zhong et al. (2017b) , we have B 0 = k i̸ =l E x∼ 1 2 N (µ,I d )+ 1 2 N (-µ,I d ) r ⊤ i ϕ ′ (σ • x i )ϕ ′ (σ • x l ) x 2 i e i e ⊤ i + x 2 l e l e ⊤ l + x i x l (e i e ⊤ l + e l e ⊤ i ) + j̸ =i x j x l e j e ⊤ l + j̸ =l x j x i e j e ⊤ i + j̸ =i,l j ′ ̸ =i,l x j x j ′ e j e ⊤ j ′ r l = i̸ =l r ii r li α 2 (i, µ, σ)α 0 (l, µ, σ) + i̸ =l r ij r lj α 0 (i, µ, σ)α 0 (l, µ, σ)(1 + µ 2 j ) Therefore, A 0 + B 0 = k i=1 r ii α 2 (i, µ, σ) 1 + µ 2 i + l̸ =i r li α 0 (l, µ, σ) 1 + µ 2 i 2 - k i=1 r 2 ii α 2 (i, µ, σ) 1 + µ 2 i - k i=1 l̸ =i r 2 li α 0 (l, µ, σ) 2 (1 + µ 2 i ) + k i=1 r 2 ii β 2 (i, µ, σ) + k i=1 j̸ =i r 2 ij β 0 (i, µ, σ)(1 + µ 2 j ) ≥ k i=1 r 2 ii β 2 (i, µ, σ) - α 2 2 (i, µ, σ) 1 + µ 2 i + k i=1 j̸ =i r 2 ij β 0 (i, µ, σ) -α 2 0 (i, µ, σ) (1 + µ 2 j ) ≥ ρ(µ, σ)||R|| 2 F (153) E.4 PROOF OF LEMMA 5 Following the equation ( 92) in Lemma 8 of Fu et al. (2020) and by ( 85) ||∇ 2 ℓ(W ) -∇ 2 ℓ(W ′ )|| ≤ K j=1 K l=1 |ξ j,l (W ) -ξ j,l (W ′ )| • ||xx ⊤ || (154) By Lagrange's inequality, we have |ξ j,l (W ) -ξ j,l (W ′ )| ≤ (max k |T j,k,l |) • ||x|| • √ K||W -W ′ || F From Lemma 6, we know max k |T j,k,l | ≤ C 7 By Property 7, we have E x∼ L l=1 λ l N (µ l ,Σ l ) [||x|| 2t ||] ≤ d t (2t -1)!! L l=1 λ l (∥µ l ∥ ∞ + ∥Σ l ∥) 2t Therefore, for some constant C 12 > 0 E x∼ L l=1 λ l N (µ l ,Σ l ) [ sup W ̸ =W ′ ||∇ 2 ℓ(W ) -∇ 2 ℓ(W ′ )|| ||W -W ′ || F ] ≤ K 5 2 E[||x|| 3 2 ] ≤K 5 2 d L l=1 λ l (∥µ∥ ∞ + ∥Σ l ∥) 2 3d 2 L l=1 λ l (∥µ l ∥ ∞ + ∥Σ l ∥) 4 =C 12 • d 3 2 K 5 2 L l=1 λ l (∥µ l ∥ ∞ + ∥Σ l ∥) 2 L l=1 λ l (∥µ l ∥ ∞ + ∥Σ l ∥) 4 (158) E.5 PROOF OF LEMMA 6 Let a = (a ⊤ 1 , • • • , a ⊤ K ) ⊤ ∈ R dK . Let ∆ j,l ∈ R d×d be the (j, l)-th block of ∇ 2 f (W ) - ∇ 2 f (W * P ) ∈ R dK×dK . By definition, ||∇ 2 f (W ) -∇ 2 f (W * P )|| = max ||a||=1 K j=1 K l=1 a ⊤ j ∆ j,l a l ( ) Denote P = (p 1 , • • • , p K ) ∈ R K×K . By the mean value theorem and (85), ∆ j,l = ∂ 2 f (W ) ∂w j ∂w l - ∂ 2 (W * P ) ∂w * j ∂w * l = E x∼ L l=1 λ l N (µ l ,σ 2 l I d ) [(ξ j,l (W ) -ξ j,l (W * P )) • xx ⊤ ] = E x∼ L l=1 λ l N (µ l ,Σ l ) [ K k=1 ∂ξ j,l (W ′ ) ∂w ′ k , w k -W * p k • xx ⊤ ] = E x∼ L l=1 λ l N (µ l ,Σ l ) [ K k=1 ⟨T j,l,k • x, w k -W * p k ⟩ • xx ⊤ ] (160) where W ′ = γW + (1 -γ)W * P for some γ ∈ (0, 1) and T j,l,k is defined such that ∂ξ j,l (W ′ ) ∂w ′ k = T j,l,k • x ∈ R d . Then we provide an upper bound for ξ j,l . Since that y = 1 or 0, we first compute the case in which y = 1. From ( 85) we can obtain ξ j,l (W ) = 1 K 2 ϕ ′ (w ⊤ j x)ϕ ′ (w ⊤ l x) • 1 H 2 (W ) , j ̸ = l 1 K 2 ϕ ′ (w ⊤ j x)ϕ ′ (w ⊤ l x) • 1 H 2 (W ) -1 K ϕ ′′ (w ⊤ j x) • 1 H(W ) , j = l We can bound ξ j,l (W ) by bounding each component of ( 161). Note that we have 1 K 2 ϕ ′ (w ⊤ j x)ϕ ′ (w ⊤ l x) • 1 H 2 (W ) ≤ 1 K 2 ϕ(w ⊤ j x)ϕ(w ⊤ l x)(1 -ϕ(w ⊤ j x))(1 -ϕ(w ⊤ l x)) 1 K 2 ϕ(w ⊤ j x)ϕ(w ⊤ l x) ≤ 1 (162) 1 K ϕ ′′ (w ⊤ j x) • 1 H(W ) ≤ 1 K ϕ(w ⊤ j x)(1 -ϕ(w ⊤ j x))(1 -2ϕ(w ⊤ j x)) 1 K ϕ(w ⊤ j x) ≤ 1 (163) where ( 162) holds for any j, l ∈ [K]. The case y = 0 can be computed with the same upper bound by substituting 161), ( 162) and ( 163). Therefore, there exists a constant C 9 > 0, such that (1 -H(W )) = 1 K K j=1 (1 -ϕ(w ⊤ j x)) for H(W ) in ( |ξ j,l (W )| ≤ C 9 We then need to calculate T j,l,k . Following the analysis of ξ j,l (W ), we only consider the case of y = 1 here for simplicity. T j,l,k = -2 K 3 H 3 (W ′ ) ϕ ′ (w ′ ⊤ j x)ϕ ′ (w ′ ⊤ l x)ϕ ′ (w ′ ⊤ k x) , where j, l, k are not equal to each other (165) T j,j,k = -2 K 3 H 3 (W ′ ) ϕ ′ (w ′ ⊤ j x)ϕ ′ (w ′ ⊤ j x)ϕ ′ (w ′ ⊤ k x) + 1 K 2 H 2 (W ′ ) ϕ ′′ (w ′ ⊤ j x)ϕ ′ (w ′ ⊤ k x), j ̸ = k -2 K 3 H 3 (W ′ ) (ϕ ′ (w ′ ⊤ j x)) 3 + 3 K 2 H 2 (W ′ ) ϕ ′′ (w ′ ⊤ j x)ϕ ′ (w ′ ⊤ j x) - ϕ ′′′ (w ′ ⊤ j x) KH(W ′ ) , j = k (166) a ⊤ j ∆ j,l a l = E x∼ L l=1 N (µ l ,Σ l ) [( K k=1 T j,l,k ⟨x, w k -W * p k ⟩) • (a ⊤ j x)(a ⊤ l x)] ≤ E x∼ L N (µ l ,Σ l ) [ K k=1 T 2 j,k,l ] • E[ K k=1 (⟨x, w k -W * p k ⟩ (a ⊤ j x)(a ⊤ l x)) 2 ] ≤ E x∼ L l=1 N (µ l ,Σ l ) [ K k=1 T 2 j,k,l ] K k=1 E((w k -W * p k ) ⊤ x) 4 • E[(a ⊤ j x) 4 (a ⊤ l x) 4 ] ≤ C 8 E x∼ L l=1 N (µ l ,Σ l ) [ K k=1 T 2 j,k,l ] K k=1 ||w k -W * p k || 2 2 • ||a j || 2 2 • ||a l || 2 2 • L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 4 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 8 1 4 (167) for some constant C 8 > 0. All the three inequalities of ( 167) are derived from Cauchy-Schwarz inequality. Note that we have -2 K 3 H 3 (W ) (ϕ ′ (w ⊤ j x)) 2 ϕ ′ (w ⊤ k x) ≤ 2ϕ 2 (w ⊤ j x)(1 -ϕ(w ⊤ j x)) 2 ϕ(w ⊤ k x)(1 -ϕ(w ⊤ k x)) K 3 1 K 3 ϕ 2 (w ⊤ j x)ϕ(w ⊤ k x) = 2(1 -ϕ(w ⊤ j x)) 2 (1 -ϕ(w ⊤ k x)) ≤ 2 (168) -2 K 3 H 3 (W ) ϕ ′ (w ⊤ j x)ϕ ′ (w ⊤ l x)ϕ ′ (w ⊤ k x) ≤ 2 (169) 3 K 2 H 2 (W ) ϕ ′′ (w ⊤ j x)ϕ ′ (w ⊤ k x) ≤ 3ϕ(w ⊤ j x)(1 -ϕ(w ⊤ j x))(1 -2ϕ(w ⊤ j x))ϕ(w ⊤ k x)(1 -ϕ(w ⊤ k x)) K 2 1 K 2 ϕ(w ⊤ j x)ϕ(w ⊤ k x) = 3(1 -ϕ(w ⊤ j x))(1 -2ϕ(w ⊤ j x))(1 -ϕ(w ⊤ k x)) ≤ 3 (170) ϕ ′′′ (w ⊤ j x) KH(W ) ≤ ϕ(w ⊤ j x)(1 -ϕ(w ⊤ j x))(1 -6ϕ(w ⊤ j x) + 6ϕ 2 (w ⊤ j x)) K 1 K ϕ(w ⊤ j x) ≤ 1 Therefore, by combining ( 165), ( 166) and ( 168) to (171), we have |T j,l,k | ≤ C 7 ⇒ T 2 j,l,k ≤ C 2 7 , ∀j, l, k ∈ [K], for some constants C 7 > 0. By ( 159), ( 160), ( 167), ( 172) and the Cauchy-Schwarz's Inequality, we have ∥∇ 2 f (W ) -∇ 2 f (W * P )∥ ≤C 8 C 2 7 K||W -W * P || F L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 4 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 8 1 4 • max ||a||=1 K j=1 K l=1 ||a j || 2 ||a l || 2 ≤C 8 C 2 7 K • ||W -W * P || F • L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 4 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 8 1 4 • K j=1 ||a j || 2 ≤C 8 C 2 7 K 3 • ||W -W * P || F • L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 4 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 8 1 4 Hence, we have ||∇ 2 f (W ) -∇ 2 f (W * P )|| ≤C 5 K 3 2 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 4 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 8 1 4 ||W -W * P || F (174) for some constant C 5 > 0. E.6 PROOF OF LEMMA 7 From Fu et al. (2020) , we know ∇ 2 f (W * P ) ⪰ min ||a||=1 4 K 2 E x∼ L l=1 λ l N (µ l ,Σ l ) K j=1 ϕ ′ (w * π * (j) ⊤ x)(a ⊤ π * (j) x) 2 • I dK = min ||a||=1 4 K 2 E x∼ L l=1 λ l N (µ l ,Σ l ) K j=1 ϕ ′ (w * j ⊤ x)(a ⊤ j x) 2 • I dK (175) with a = (a ⊤ 1 , • • • , a ⊤ K ) ⊤ ∈ R dK , where P is a specific permutation matrix and {π * (j)} K j=1 is the indices permuted by P . Similarly, ∇ 2 f (W * P ) ⪯ max ||a||=1 a ⊤ ∇ 2 f (W * )a • I dK ⪯ C 4 • max ||a||=1 E x∼ L l=1 λ l N (µ l ,Σ l ) K j=1 (a ⊤ π * (j) x) 2 • I dK = C 4 • max ||a||=1 E x∼ L l=1 λ l N (µ l ,Σ l ) K j=1 (a ⊤ j x) 2 • I dK for some constant C 4 > 0. By applying Property 8, we can derive the upper bound in (176) as C 4 • E x∼ L l=1 λ l N (µ l ,Σ l ) K j=1 (a ⊤ j x) 2 • I dK ⪯ C 4 • L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 • I dK To find a lower bound for (175), we can first transfer the expectation of the Gaussian Mixture Model to the weight sum of the expectations over general Gaussian distributions. min ||a||=1 E x∼ L l=1 λ l N (µ l ,Σ l ) K j=1 ϕ ′ (w * j ⊤ x)(a ⊤ j x) 2 = min ||a||=1 L l=1 λ l E x∼N (µ l ,Σ l ) K j=1 ϕ ′ (w * j ⊤ x)(a ⊤ j x) 2 Denote U ∈ R d×k as the orthogonal basis of W * . For any vector a i ∈ R d , there exists two vectors b i ∈ R K and c i ∈ R d-K such that a i = U b i + U ⊥ c i ( ) where U ⊥ ∈ R d×(d-K) denotes the complement of U . We also have U ⊤ ⊥ µ l = 0 by Property 1. Plugging (179) into RHS of (178), and then we have E x∼N (µ l ,Σ l ) K i=1 a ⊤ i x • ϕ ′ (w * i ⊤ x) 2 =E x∼N (µ l ,Σ l ) K i=1 (U b i + U ⊥ c i ) ⊤ x • ϕ ′ (w * i ⊤ x) 2 = A + B + C (180) A = E x∼N (µ l ,Σ l ) K i=1 b ⊤ i U ⊤ x • ϕ ′ (w * i ⊤ x) 2 (181) C = E x∼N (µ l ,Σ l ) 2 K i=1 c ⊤ i U ⊤ ⊥ x • ϕ ′ (w * i ⊤ x) • K i=1 b ⊤ i U ⊤ x • ϕ ′ (w * i ⊤ x) = K i=1 K j=1 E x∼N (µ l ,Σ l ) 2c ⊤ i U ⊤ ⊥ x E x∼N (µ l ,Σ l ) b ⊤ i U ⊤ x • ϕ ′ (w * i ⊤ x)ϕ ′ (w * j ⊤ x) = K i=1 K j=1 2c ⊤ i U ⊤ ⊥ µ l E x∼N (µ l ,Σ l ) b ⊤ i U ⊤ x • ϕ ′ (w * i ⊤ x)ϕ ′ (w * j ⊤ x) = 0 (182) where the last step is by U ⊤ ⊥ µ l = 0 by Property 1. B =E x∼N (µ l ,Σ l ) ( K i=1 c ⊤ i U ⊤ ⊥ x • ϕ ′ (w * i ⊤ x)) 2 =E x∼N (µ l ,Σ l ) [(t ⊤ s) 2 ] by defining t = k i=1 ϕ ′ (w * i ⊤ x)c i ∈ R d-K and s = U ⊤ ⊥ x = K i=1 E[t 2 i s 2 i ] + i̸ =j E[t i t j s i s j ] = K i=1 E[t 2 i ] d k=1 (U ⊥ ) 2 ik σ 2 lk + K i=1 E[t 2 i ](U ⊤ ⊥ µ l ) 2 i + i̸ =j E[t i t j ](U ⊤ ⊥ µ l ) i • (U ⊤ ⊥ µ l ) j =E[ d-K i=1 t 2 i • d k=1 (U ⊥ ) 2 ik σ 2 lk ] + E[(t ⊤ U ⊤ ⊥ µ l ) 2 ] = E[ d-K i=1 t 2 i • d k=1 (U ⊥ ) 2 ik σ 2 lk ] (183) The last step is by U ⊤ ⊥ µ l = 0. The 4th step is because that s i is independent of t i , thus E[t i t j s i s j ] = E[t i t j ]E[s i s j ] E[s i s j ] = (U ⊤ ⊥ µ l ) i • (U ⊤ ⊥ µ l ) j , if i ̸ = j (U ⊤ ⊥ µ l ) 2 i + d k=1 (U ⊥ ) 2 ik σ 2 lk , if i = j (184) Since k i=1 r ⊤ i x • ϕ ′ (σ • x i ) 2 is an even function for any r i ∈ R d , i ∈ [k], so from Property 5 we have E x∼N (µ l ,Σ l ) ( k i=1 r ⊤ i x•ϕ ′ (σ•x i )) 2 = E x∼ 1 2 N (µ l ,Σ l )+ 1 2 N (-µ l ,Σ l ) ( k i=1 r ⊤ i x•ϕ ′ (σ•x i )) 2 (185) Combining Lemma 4 and Property 5, we next follow the derivation for the standard Gaussian distribution in Page 36 of Zhong et al. (2017b) and generalize the result to a Gaussian distribution with an arbitrary mean and variance as follows. A = E x∼N (µ l ,Σ l ) K i=1 b ⊤ i U ⊤ x • ϕ ′ (w * i ⊤ x) 2 ≥ (2π) -K 2 |U ⊤ Σ l U | -1 2 K i=1 b ⊤ i z • ϕ ′ (v i ⊤ z) 2 exp - 1 2 ∥Σ -1 l ∥∥z -U ⊤ µ l ∥ 2 dz = (2π) -K 2 |U ⊤ Σ l U | -1 2 K i=1 b ⊤ i V † ⊤ s • ϕ ′ (s i ) 2 exp - 1 2 ∥Σ -1 l ∥∥V † ⊤ s -U ⊤ µ l ∥ 2 det(V † ) ds ≥ (2π) -K 2 |U ⊤ Σ l U | -1 2 k i=1 b ⊤ i V † ⊤ s • ϕ ′ (s i ) 2 exp - ∥Σ -1 l ∥∥s -V ⊤ U ⊤ µ l ∥ 2 2δ 2 K (W * ) det(V † ) ds ≥ (2π) -K 2 |U ⊤ Σ l U | -1 2 k i=1 b ⊤ i V † ⊤ (δ K (W * )∥Σ -1 l ∥ -1 2 )g • ϕ ′ (δ K (W * )∥Σ -1 l ∥ -1 2 • g i ) 2 • exp - ||g - √ ∥Σ -1 l ∥W * ⊤ µ l δ K (W * ) || 2 2 det(V † ) ∥Σ -1 l ∥ -K 2 δ K K (W * )dg = ∥Σ -1 l ∥ -1 τ K η E g ( K i=1 (b ⊤ i V † ⊤ δ K (W * ))g • ϕ ′ (∥Σ -1 l ∥ -1 2 δ K (W * ) • g i )) 2 ≥ ∥Σ -1 l ∥ -1 τ K κ 2 η ρ( W * ⊤ µ l ∥Σ -1 l ∥ -1 2 δ K (W * ) , ∥Σ -1 l ∥ -1 2 δ K (W * ))||b|| 2 . (186) The second step is by letting z = U ⊤ x ∼ N (U ⊤ µ l , U ⊤ ΣU ), y ⊤ U ⊤ Σ -1 l U y ≤ ∥Σ -1 l ∥∥y∥ 2 for any y ∈ R K . The third step is by letting s = V ⊤ z. The last to second step follows from g = s ∥Σ -1 l ∥ -1 2 δ K (W * ) , where g ∼ N ( W * ⊤ µ l ∥Σ -1 l ∥ -1 2 δ K (W * ) , I K ) and the last inequality is by Lemma 4. Similarly, we extend the derivation in Page 37 of Zhong et al. (2017b) for the standard Gaussian distribution to a general Gaussian distribution as follows. B = d k=1 (U ⊥ ) 2 ik σ 2 lk E x∼N (µ l ,Σ l ) [||t|| 2 ] ≥ ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l ∥Σ -1 l ∥ -1 2 |δ K (W * ) , ∥Σ -1 l ∥ -1 2 δ K (W * ))||c|| 2 (187) Combining ( 180) -( 183), ( 186) and (187), we have min ||a||=1 E x∼N (µ l ,Σ l ) ( k i=1 a ⊤ i x•ϕ ′ (w * i ⊤ x)) 2 ≥ ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ). For the Gaussian Mixture Model x ∼ L l=1 N (µ l , Σ), we have min ||a||=1 E x∼ L l=1 λ l N (µ l ,Σ l ) ( k i=1 a ⊤ i x • ϕ ′ (w * i ⊤ x)) 2 ≥ L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) (189) Therefore, 4 K 2 L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) • I dK ⪯∇ 2 f (W * P ) ⪯ C 4 • L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 • I dK From ( 139) in Lemma 6, since that we have the condition ∥W -W * P ∥ F ≤ r and (93), we can obtain ||∇ 2 f (W ) -∇ 2 f (W * P )|| ≤C 5 K 3 2 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 4 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 8 1 4 ||W -W * P || F ≤ 4ϵ 0 K 2 L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ), where ϵ 0 ∈ (0, 1 4 ). Then we have ||∇ 2 f (W )|| ≥ ||∇ 2 f (W * P )|| -||∇ 2 f (W ) -∇ 2 f (W * P )|| ≥ 4(1 -ϵ 0 ) K 2 L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) (192) ||∇ 2 f (W )|| ≤ ||∇ 2 f (W * )|| + ||∇ 2 f (W ) -∇ 2 f (W * P )|| ≤ C 4 • L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 ∥) 2 + 4 K 2 L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 2 l ∥ , δ K (W * )∥Σ -1 2 l ∥) ≲ C 4 • L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 (193) The last inequality of (193) holds since C 4 • l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 = Ω(max l {∥Σ l ∥}), 4 K 2 L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) = O( max l {∥Σ l ∥} K 2 ) and Ω(max l {∥Σ l ∥}) ≥ O( max l {∥Σ l ∥} K 2 ). Combining ( 192) and (193), we have 4(1 -ϵ 0 ) K 2 L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) • I ⪯∇ 2 f (W ) ⪯ C 4 • L l=1 λ l (∥µ l ∥ + σ l ) 2 • I E.7 PROOF OF LEMMA 8 Let N ϵ be the ϵ-covering number of the Euclidean ball B(W * P , r). It is known that log N ϵ ≤ dK log( 3r ϵ ) from Vershynin (2010). Let W ϵ = {W 1 , ..., W Nϵ } be the ϵ-cover set with N ϵ elements. For any W ∈ B(W * P , r), let j(W ) = arg min j∈[Nϵ] ||W -W j(W ) || F ≤ ϵ for all W ∈ B(W * P , r). Then for any W ∈ B(W * P , r), we have ∥∇ 2 f n (W ) -∇ 2 f (W )∥ ≤ 1 n || n i=1 [∇ 2 ℓ(W ; x i ) -∇ 2 ℓ(W j(W ) ; x i )]|| + || 1 n n i=1 ∇ 2 ℓ(W j(W ) ; x i ) -E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ 2 ℓ(W j(W ) ; x i )]|| + ||E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ 2 ℓ(W j(W ) ; x i )] -E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ 2 ℓ(W ; x i )]|| Hence, we have P sup W ∈B(W * P ,r) ||∇ 2 f n (W ) -∇ 2 f (W )|| ≥ t ≤ P(A t ) + P(B t ) + P(C t ) where A t , B t and C t are defined as A t = { sup W ∈B(W * P ,r) 1 n || n i=1 [∇ 2 ℓ(W ; x i ) -∇ 2 ℓ(W j(W ) ; x i )]|| ≥ t 3 } (197) B t = { sup W ∈B(W * P ,r) || 1 n n i=1 ∇ 2 ℓ(W j(W ) ; x i ) -E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ 2 ℓ(W j(W ) ; x i )]|| ≥ t 3 } (198) C t ={ sup W ∈B(W * P ,r) ||E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ 2 ℓ(W j(W ) ; x i )] -E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ 2 ℓ(W ; x i )]|| ≥ t 3 } Then we bound P(A t ), P(B t ), and P(C t ) separately. 1) Upper bound on P(B t ). By Lemma 6 in Fu et al. (2020) , we obtain 1 n n i=1 ∇ 2 ℓ(W ; x i ) -E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ 2 ℓ(W ; x i )] ≤2 sup v∈V 1 4 v, ( n n i=1 ∇ 2 ℓ(W ; x i ) -E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ 2 ℓ(W ; x i )])v where V 1 4 is a 1 4 -cover of the unit-Euclidean-norm ball B(0, 1) with log |V 1 4 | ≤ dK log 12. Taking the union bound over W ϵ and V 1 4 , we have P(B t ) ≤P sup W ∈Wϵ,v∈V 1 4 1 n n i=1 G i ≥ t 6 ≤ exp(dK(log 3r ϵ + log 12)) sup W ∈Wϵ,v∈V 1 4 P(| 1 n n i=1 G i | ≥ t 6 ) where G i = v, (∇ 2 ℓ(W , x i ) -E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ 2 ℓ(W , x i )]v) and E[G i ] = 0. Here v = (u ⊤ 1 , • • • , u ⊤ K ) ⊤ ∈ R dK . |G i | = K j=1 K l=1 ξ j,l u ⊤ j xx ⊤ u l -E x∼ L l=1 λ l N (µ l ,Σ l ) (ξ j,l u ⊤ j xx ⊤ u l ) ≤ C 9 • K j=1 (u ⊤ j x) 2 + K j=1 E x∼ L l=1 λ l N (µ l ,Σ l ) (u ⊤ j x) 2 for some C 9 > 0. The first step of ( 202) is by (84). The last step is by ( 164) and the Cauchy-Schwarz's Inequality.

E[|G

i | p ] ≤ p l=1 p l C 9 • E x∼ L l=1 λ l N (µ l ,Σ l ) ( K j=1 (u ⊤ j x) 2 ) l • K j=1 E x∼ L l=1 λ l N (µ l ,Σ l ) (u ⊤ j x) 2 p-l = p l=1 p l C 9 • E x∼ L l=1 λ l N (µ l ,Σ l ) l1+•••+l K =l l! K j=1 l j ! K j=1 (u ⊤ j x) 2lj • K j=1 E x∼ L l=1 λ l N (µ l ,Σ l ) (u ⊤ j x) 2 p-l = p l=1 p l C 9 • l1+•••+l K =l l! K j=1 l j ! K j=1 E x∼ L l=1 λ l N (µ l ,Σ l ) (u ⊤ j x) 2lj • K j=1 E x∼ L l=1 λ l N (µ l ,Σ l ) (u ⊤ j x) 2 p-l = C 9 • p l=1 p l K j=1 E x∼ L l=1 λ l N (µ l ,Σ l ) (u ⊤ j x) 2 l • K j=1 E x∼ L l=1 λ l N (µ l ,Σ l ) (u ⊤ j x) 2 p-l = C 9 • K j=1 E x∼ L l=1 λ l N (µ l ,Σ l ) (u ⊤ j x) 2 p ≤ C 9 • K j=1 1!!||u j || 2 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 p ≤ C 9 • L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 p (203) where the first step is by the triangle inequality and the Binomial theorem, and the second step comes from the Multinomial theorem. The second to last inequality in (203) results from Property 8. The last inequality is because v ∈ V 1 4 , K j=1 ||u j || 2 = ||v|| 2 ≤ 1. E[exp(θG i )] = 1 + θE[G i ] + ∞ p=2 θ p E[|G i | p ] p! ≤ 1 + ∞ p=2 |eθ| p p p C 9 • l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 p ≤ 1 + C 9 • |eθ| 2 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 2 (204) where the first inequality holds from p! ≥ ( p e ) p and ( 203), and the third line holds provided that max p≥2 { |eθ| (p+1) (p+1) (p+1) • L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 p+1 |eθ| p p p • L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 p } ≤ 1 2 Note that the quantity inside the maximization in (205) achieves its maximum when p = 2, because it is monotonously decreasing. Therefore, (205 ) holds if θ ≤ 27 4e L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 . Then P 1 n n i=1 G i ≥ t 6 = P exp(θ n i=1 G i ) ≥ exp( nθt 6 ) ≤ e -nθt 6 n i=1 E[exp(θG i )] ≤ exp(C 10 θ 2 n L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 2 - nθt 6 ) for some constant C 10 > 0. The first inequality follows from Markov's Inequality. When θ = min{ t 12C10 L l=1 λ l (∥µ l ∥+∥Σ 1 2 l ∥) 2 2 , 27 4e L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 }, we have a modified Bernstein's Inequality for the Gaussian Mixture Model as follows P( 1 n n i=1 G i ≥ t 6 ) ≤ exp max{ - C 10 nt 2 144 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 2 , -C 11 n L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 • t} for some constant C 11 > 0. We can obtain the same bound for P(- 1 n n i=1 G i ≥ t 6 ) by replacing G i as -G i . Therefore, we have P(| 1 n n i=1 G i | ≥ t 6 ) ≤ 2 exp max{ - C 10 nt 2 144 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 2 , -C 11 n L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 • t} Thus, as long as t ≥ C 6 • max{ L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 dK log 36r ϵ + log 4 δ n , dK log 36r ϵ + log 4 δ L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 n } (209) for some large constant C 6 > 0, we have P(B t ) ≤ δ 2 . 2) Upper bound on P(A t ) and P(C t ). From Lemma 5, we can obtain sup W ∈B(W * P ,r) ||E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ 2 ℓ(W j(W ) ; x)] -E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ 2 ℓ(W ; x)]|| ≤ sup W ∈B(W * P ,r) ||E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ 2 ℓ(W j(W ) ; x)] -E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ 2 ℓ(W ; x)]|| ||W -W j(W ) || F • sup W ∈B(W * P ,r) ||W -W j(W ) || F ≤ C 12 • d 3 2 K 5 2 L l=1 λ l (∥µ l ∥ ∞ + ∥Σ 1 2 l ∥) 2 L l=1 λ l (∥µ l ∥ ∞ + ∥Σ 1 2 l ∥) 4 • ϵ (210) Therefore, C t holds if t ≥ C 12 • d 3 2 K 5 2 L l=1 λ l (∥µ l ∥ ∞ + ∥Σ 1 2 l ∥) 2 L l=1 λ l (∥µ l ∥ ∞ + ∥Σ 1 2 l ∥) 4 • ϵ (211) We can bound the A t as below.

P sup

W ∈B(W * P ,r) 1 n || n i=1 [∇ 2 ℓ(W j(W ) ; x i ) -∇ 2 ℓ(W ; x i )]|| ≥ t 3 ≤ 3 t E x∼ L l=1 λ l N (µ l ,Σ l ) sup W ∈B(W * P ,r) 1 n || n i=1 [∇ 2 ℓ(W j(W ) ; x i ) -∇ 2 ℓ(W ; x i )]|| = 3 t E x∼ L l=1 λ l N (µ l ,Σ l ) sup W ∈B(W * P ,r) ||∇ 2 ℓ(W j(W ) ; x i ) -∇ 2 ℓ(W ; x i )|| ≤ 3 t E sup W ∈B(W * P ,r) ||∇ 2 ℓ(W j(W ) ; x i ) -∇ 2 ℓ(W ; x i )|| ||W -W j(W ) || F • sup W ∈B(W * P ,r) ||W -W j(W ) || F ≤ C 12 • d 3 2 K 5 2 L l=1 λ l (∥µ l ∥ ∞ + ∥Σ 1 2 l ∥) 2 L l=1 λ l (∥µ l ∥ ∞ + ∥Σ 1 2 l ∥) 4 • ϵ t , ) where the first inequality is by Markov's inequality, and the last inequality comes from Lemma 5. Thus, taking t ≥ C 12 • d 3 2 K 5 2 L l=1 λ l (∥µ l ∥ ∞ + ∥Σ 1 2 l ∥) 2 L l=1 λ l (∥µ l ∥ ∞ + ∥Σ 1 2 l ∥) 4 • ϵ δ (213) ensures that P(A t ) ≤ δ 2 . 3) Final step Let ϵ = for some constant C 13 > 0, where P is a permutation matrix. Proof: Then, similar to the idea of the proof of Lemma 8, we adopt an ϵ-covering net of the ball B(W * , r) to build a relationship between any arbitrary point in the ball and the points in the covering set. We can then divide the distance between ∇f n (W ) and ∇ f (W ) into three parts, similar to (195) . ( 218) to ( 220) can be derived in a similar way as (197) to (199), with "∇ 2 " replaced by "∇". Then we need to bound P(A ′ t ), P(B ′ t ) and P(C ′ t ) respectively, where A ′ t , B ′ t and C ′ t are defined below.  Note that ∇ fn (W ) = ∇f n (W ) + 1 n n i=1 ν i , ∇ f (W ) = ∇ f (W ) + E[ν i ] = ∇ f (W ). (W ) =    -1 K 1 H(W ) ϕ ′ (w ⊤ j x) ≤ ϕ(w ⊤ j x)(1-ϕ(w ⊤ j x)) K• 1 K ϕ(w ⊤ j x) ≤ 1, y = 1 1 K 1 1-H(W ) ϕ ′ (w ⊤ j x) ≤ ϕ(w ⊤ j x)(1-ϕ(w ⊤ j x)) K• 1 K (1-ϕ(w ⊤ j x)) ≤ 1, y = 0 Then we have an upper bound of G ′ i . Following the idea of ( 203) and ( 204), and by v ∈ V 1 2 , we have |G ′ i | = ζ j,l v ⊤ x -E x∼ L l=1 λ l N (µ l ,Σ l ) [ζv ⊤ x] ≤ |v ⊤ x| + E x∼ L l=1 λ l N (µ l ,Σ l ) [|v ⊤ x|] E[|G ′ i | p ] ≤ O L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 p 2 (224) E[exp(θG ′ i )] ≤ 1 + O |eθ 2 | L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 where ( 225) holds if θ ≤ 27 4e L l=1 λ l (∥µ l ∥ + ∥Σ l ∥) 2 . Following the derivation of ( 201) and ( 206) to ( 209  C 9 • ||x|| 2 √ K||W -W ′ || F ||W -W ′ || F ≤C 9 • 3 √ Kd • L l=1 λ l (∥µ l ∥ ∞ + ∥Σ 1 2 l ∥) 2 (228) The first inequality is by (83). The second inequality is by the Mean Value Theorem. The third step is by (164). The last inequality is by Property 7. Therefore, following the steps in part (2) of Lemma 8, we can conclude that C ′ t holds if t ≥ 3C 9 • √ Kd • L l=1 λ l (∥µ l ∥ ∞ + ∥Σ 1 2 l ∥) 2 • ϵ (229) Moreover, from (213) in Lemma 8 we have that From Lemma 7 and Lemma 9, we have The second to last step of (237) comes from the triangle inequality, and the last step follows from the fact ∇ f (W * P ) = 0. Combining ( 235), ( 236) and ( 237), we have t ≥ 18C 9 • √ Kd • L l=1 λ l (∥µ l ∥ ∞ + ∥Σ l ∥) 2 • ϵ δ 4 K 2 L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 )|| W n -W * P || 2 F ≤ 1 2 vec( W n -W * P )∇ 2 f n (W ′ )vec( W n -W * P ) || W n -W * P || F ≤ O K 5 2 L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 (1 + ξ) L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) d log n n (238) Therefore, we have concluded that there indeed exists a critical point W in B(W * P , r). Then we show the linear convergence of Algorithm 1 as below. By the update rule, we have W t+1 -W n = W t -η 0 (∇f n (W t ) + 1 n n i=1 ν i ) -( W n -η 0 ∇f n ( W n )) = I -η 0 1 0 ∇ f n (W (γ)) (W t -W n ) - η 0 n n i=1 ν i where W (γ) = γ W n + (1 -γ)W t for γ ∈ (0, 1). Since W (γ) ∈ B(W * P , r), by Lemma 1, we have H min • I ⪯ ∇ 2 f n (W (γ)) ≤ H max • I ( ) where H min = Ω 1 K 2 L l=1 λ l ∥Σ -1 l ∥ -1 ητ K κ 2 ρ( W * ⊤ µ l δ K (W * )∥Σ -1 l ∥ -1 2 , δ K (W * )∥Σ -1 l ∥ -1 2 ) , H max = C 4 • L l=1 λ l (∥µ l ∥ + ∥Σ l ∥) 2 . Therefore, ||W t+1 -W n || F = ||I -η 0 1 0 ∇ 2 f n (W (γ))|| • ||W t -W n || F + ∥ η 0 n n i=1 ν i ∥ F ≤ (1 -η 0 H min )||W t -W n || F + ∥ η 0 n n i=1 ν i ∥ F By setting η 0 = G PROOF OF LEMMA 3 We need Lemma 10 to Lemma 14, which are stated in Section G.1, for the proof of Lemma 3. Section G.2 summarizes the proof of Lemma 3. The proofs of Lemma 10 to Lemma 12 are provided in Section G.3 to Section G.5. Lemma 13 and Lemma 14 are cited from Zhong et al. (2017b) . Although Zhong et al. (2017b) considers the standard Gaussian distribution, the proofs of Lemma 13 and 14 hold for any data distribution. Therefore, these two lemmas can be applied here directly. The tensor initialization in Zhong et al. (2017b) only holds for the standard Gaussian distribution. We exploit a more general definition of tensors from Janzamin et al. ( 2014) for the tensor initialization in our algorithm. We also develop new error bounds for the initialization.



poly(∥µ l ∥) is ∥µ l ∥ 4 for ∥µ l ∥ ≤ 1, and ∥µ l ∥ 12 for ∥µ l ∥ > 1. In Figure6(a), when the minority fraction is less than 0.01, the minority group distribution is almost removed from the Gaussian mixture model in the analysis. Then the O(1) constants in the last column of Table1have some minor changes, and the order-wise analyses do not reflect the minor fluctuations in this regime. Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. The CIFAR-10 dataset. www.cs.toronto.edu/ ~kriz/cifar.html By mild, we mean given L, if Assumption 1 is not met for some Ψ0, there exists an infinite number of Ψ ′ in any neighborhood of Ψ0 such that Assumption 1 holds for Ψ ′ ,



Figure 1: Group imbalance experiment. (a) Binary classification on CelebA dataset using Gaussian augmentation to control the minority group co-variance. (b) Test accuracy against the augmented noise level.

denote the output of Algorithm 1 in the mth trail. Let Wn denote the mean values of all W (m) n , and let V W = M m=1 || w m n -Wn || 2 /M denote the variance. An experiment is successful if V W ≤ 10 -3 and fails otherwise. M is set to 20. For each pair of d and n, 20 independent sets of W * and the corresponding training samples are generated. Figure 2 (a) shows the success rate of these independent experiments. A black block means that all the experiments fail. A white block means that they all succeed. The sample complexity is indeed almost linear in d, as predicted by (6).

Figure 2: The sample complexity (a) when the feature dimension changes, (b) when one mean changes, (c) when one co-variance changes. We next study the impact on the sample complexity of the GMM model. In Figure 2 (b), Σ 1 = Σ 2 = I, and let µ 1 = µ • 1, µ 2 = -1. ∥µ 1 ∥ varies from 0 to 5. Figure 2(b) shows that when the mean increases, the sample complexity increases. In Figure 2 (c), we fix µ 1 = 1, µ 2 = -1, and let Σ 1 = σ 2 I and Σ 2 = I. σ varies from 10 -1 to 10 1 . The sample complexity increases both when ∥Σ 1 ∥ increases and when ∥Σ 1 ∥ approaches zero. All results match predictions in Corollary 1.

Figure 3: (a) The convergence rate with different µ 1 . (b) The convergence rate with different Σ. (c) Convergence rate when the number of neurons K changes.

Figure 5: The test loss (cross entropy loss) of synthetic data with different λ 2 values. (a) Group 2 has a smaller level of co-variance. (b) Group 2 has a larger level of co-variance.

Figure 6: The test accuracy on CelebA dataset has opposite trends when the minority group fraction increases. (a) Male group is the minority (b) Female group is the minority

Figure 8: The test accuracy of CIFAR-10 dataset with different data augmentation methods (a) Gaussian noise (b) cropping.

Tensor Initialization Method 1: Input: Partition n pairs of data {(x i , y i )} n i=1 into three disjoint subsets D 1 , D 2 , D 3 2: if the Gaussian Mixture distribution is not symmetric then 3: Compute Q 2 using D 1 . Estimate the subspace U by orthogonalizing the eigenvectors with respect to the K largest eigenvalues of Q 2 4: else 5: Pick an arbitrary vector α ∈ R d , and use D 1 to compute Q 3 (I d , I d , α). Estimate U by orthogonalizing the eigenvectors with respect to the K largest eigenvalues of Q 3 (I d , I d , α). 6: end if 7: Compute R 3 = Q 3 ( U , U , U ) from data set D 2 8: Employ the KCL algorithm to compute vectors {v i } i∈[K] , which are the estimates of { U ⊤ w * i } K i=1 . Then the direction vectors { w * i } K i=1 can be approximated by { U vi } K i=1 . 9: Compute Q 1 from data set D 3 . 10: Estimate the magnitude z by solving the optimization problem

17) 11: Return: Use ẑj U vj as the jth column of W 0 , j ∈ [K].

Figure 10: Comparison between tensor initialization, a random initialization near W * , and an arbitrary random initialization

(D-function). Given the Gaussian Mixture Model and any positive integer m, define D m (Ψ) as

97) with probability at least 1 -d -10 . Lemma 3. (Tensor initialization) For classification model, with D 6 (Ψ) defined in Definition 5, we have that if the sample size

∥. Hence, B(Ψ) is an increasing function of ∥Σ 1 2 l ∥. Moreover, when all ∥Σ 1 2

l (∥µ l ∥ ∞ + ∥Σ 1 2 l ∥) 2 L l=1 λ l (∥µ l ∥ ∞ + ∥Σ l (∥µ l ∥ ∞ + ∥Σ 1 2 l ∥) 2 L l=1 λ l (∥µ l ∥ ∞ + ∥Σ as n ≥ C ′ • dK log dK, we have P( sup W ∈B(W * P ,r) ||∇ 2 f n (W ) -∇ 2 f (W )|| ≥ C 6 • PROOF OF LEMMA 2We first present a lemma used in proving Lemma 2 in Section F.1 and then prove Lemma 2 in Section F.2. F.1 A USEFUL LEMMA USED IN THE PROOF Lemma 9. If r is defined in (93) for ϵ 0 ∈ (0, 1 4 ), then with probability at least 1 -d -10 , we have 12 sup W ∈B(W * P ,r) ||∇ fn (W ) -∇ f (W )|| ≤ C 13 • K L l=1 λ l (∥µ l ∥ + ∥Σ l ∥) 2 d log n n (1 + ξ) (216)

∈B(W * P ,r) ||∇ fn (W )-∇ f (W )|| ≤ sup W ∈B(W * P ,r)

x i ) -∇ℓ(W j(W ) ; x i )]|| ≥ j(W ) ; x i ) -E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ℓ(W j(W ) ; x i )]|| ≥ ∈B(W * P ,r) ||E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ℓ(W j(W ) ; x i )] -E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ℓ(W ; x i )]|| ≥ Upper bound of P(B ′ t ). Applying Lemma 3 inMei et al. (2016), we have|| 1 n n i=1 ∇ℓ(W j(W ) ; x i ) -E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ℓ(W j(W ) ; x i )]|| j(W ) ; x i ) -E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ℓ(W j(W ) ; x i )], v(221)DefineG ′ i = v, (∇ℓ(W , x i ) -E x∼ L l=1 λ l N (µ l ,Σ l ) [∇ℓ(W , x i )]) . Here v ∈ R d . To compute ∇ℓ(W , x i ), we require the derivation in Property 9. Then we can have an upper bound of ζ(W ) in (83).

ζ

)12 ∇ fn(W ) is defined as1 n n i=1 (∇l(W , xi, yi) + νi) in algorithm 1

constant C 14 > 0 and C 15 > 0. Moreover, we can obtain P(B ′ t ) ≤ δ 2 as long ast ≥ C 13 • max{ L l=1λ l (∥µ l ∥ + ∥Σ For the upper bound of P(A ′ t ) and P(C ′ t ), we can first deriveE x∼ L l=1 λ l N (µ l ,Σ l ) sup W ̸ =W ′ ∈B(W * P ,r) ||∇ℓ(W , x) -∇ℓ(W ′ , x)|| ||W -W ′ || F ≤E x∼ L l=1 l N (µ l ,Σ l ) sup W ̸ =W ′ ∈B(W * P ,r) |ζ(W ) -ζ(W ′ )| • ||x|| ||W -W ′ || F ≤E x∼ L l=1 λ l N (µ l ,Σ l ) sup W ̸ =W ′ ∈B(W * P ,r) max 1≤j,l≤K {|ξ j,l (W ′′ )|} • ||x|| 2 √ K||W -W ′ || F ||W -W ′ || F ≤E x∼ L l=1 λ l N (µ l ,Σ l ) sup W ̸ =W ′ ∈B(W * P ,r)

λ l (∥µ l ∥∞+∥Σ l ∥) 2 •ϵ•ndK , δ = d -10 and t = C 13 K L l=1 λ l (∥µ l ∥ + ∥Σ l ∥) 2 d log n n , if n ≥ C ′′ • dK log dK for some constant C ′′ > 0, we have P( sup W ∈B(W * P ,r) ||∇f n (W ) -∇ f (W )||) ≥ C 13 • K L l=1 λ l (∥µ l ∥ + ∥Σ l ∥) 2 d log n n ≤ d -10 of Theorem 2 inFu et al. (2020), first, we have Taylor's expansion off n ( W n ) f n ( W n ) =f n (W * P ) + ∇ fn (W * P ), vec( W n -W * P ) + 1 2 vec( W n -W * P )∇ 2 f n (W ′ )vec( W n -W * P )(234)Here W ′ is on the straight line connecting W * P and W n . By the fact thatf n ( W n ) ≤ f n (W * P ), we have 1 2 vec( W n -W * P )∇ 2 f n (W ′ )vec( W n -W * P ) ≤ ∇f n (W * P ) ⊤ vec( W n -W * P ) (235)

) and∇ fn (W * P ) ⊤ vec( W n -W * P ) ≤∥∇ fn (W * P )∥ • ∥ W n -W * P ∥ F ≤(∥∇ fn (W * P ) -∇ f (W * P )∥ + ∥∇ f (W * P )∥) • ∥ W n -W * P ∥ F ≤O K L l=1 λ l (∥µ l ∥ + ∥Σ 1 2 l ∥) 2 d log n n (1 + ξ) || W n -W * P || F(237)

Summary of notations

(∥µ l ∥+∥Σ l ∥) 2 , we obtain|| W t+1 -W n || F ≤ (1 -H min H max )||W t -W n || F + η 0 nTherefore, Algorithm 1 converges linearly to the local minimizer with an extra statistical error. By Hoeffding's inequality inVershynin (2010) and Property 2, we have Therefore, with probability 1 -d -10 we can derive|| W t -W n || F ≤ (1 -H min H max ) t ||W 0 -W n || F + H max η 0 H min

G.1 USEFUL LEMMAS IN THE PROOF

Lemma 10. Let Q 2 and Q 3 follow Definition 1. Let S be a set of i.i.d. samples generated from the mixed Gaussian distribution L l=1 λ l N (µ l , Σ l ). Let Q 2 , Q 3 be the empirical version of Q 2 , Q 3 using data set S, respectively. Then with a probability at least 1 -2n -Ω(δ1(W * ) 4 d) , we haveif the mixed Gaussian distribution is not symmetric. We also havefor any arbitrary vector α ∈ R d , if the mixed Gaussian distribution is symmetric. Lemma 11. Let U ∈ E d×K be the orthogonal column span of W * . Let α be a fixed unit vector and U ∈ R d×K denote an orthogonal matrix satisfying ||U U ⊤ -U U ⊤ || ≤ 1 4 . Define R 3 = Q 3 ( U , U , U ), where Q 3 is defined in Definition 1. Let R 3 be the empirical version of R 3 using data set S, where each sample of S is i.i.d. sampled from the mixed Gaussian distribution L l=1 λ l N (µ l , Σ l ). Then with a probability at least 1 -n -Ω(δ 4 (W * )) , we haveLemma 12. Let Q 1 be the empirical version of Q 1 using dataset S. Then with a probability at least 1 -2n -Ω(d) , we haveLemma 13. (Zhong et al. (2017b) , Lemma E.6) Let Q 2 , Q 3 be defined in Definition 1 and Q 2 , Q 3 be their empirical version, respectively. Let U ∈ R d×K be the column span of W * . Assumefor non-symmetric distribution cases andfor symmetric distribution cases and any arbitrary vector α ∈ R d . Then after T = O(log( 1 ϵ )) iterations, the output of the Tensor Initialization Method 1, U will satisfywhich impliesif the mixed Gaussian distribution is not symmetric. Similarly, we havewhich impliesif the mixed Gaussian distribution is symmetric. Lemma 14. (Zhong et al. (2017b) , Lemma E.13) Let U ∈ R d×K be the orthogonal column span of W * . Let U ∈ R d×K be an orthogonal matrix such thatG.2 PROOF OF LEMMA 3By the triangle inequality, we haveif the mixed Gaussian distribution is not symmetric, and(256) if the mixed Gaussian distribution is symmetric. Moreover, we havein which the first step is by Theorem 3 in Kuleshov et al. (2015) , and the second step is by Lemma 11. By Lemma 14, we haveTherefore, taking the union bound of failure probabilities in Lemmas 10, 11, and 12 and bywith probability at least 1 -n -Ω(δ 4 1 (W * ))

G.3 PROOF OF LEMMA 10

From Assumption 1, if the Gaussian Mixture Model is a symmetric probability distribution defined in ( 16), then by Definition 1, we haveFollowing Zhong et al. (2017b) , ⊗ is defined such that for any v ∈ R d1 andwhere z i is the i-th column of Z. By Definition 1, we haveis the dominant term of the entire expression, and y ≤ 1. The second step is because the expression can be considered as a normalized weighted summation of ((x -µ l )Σ l ) ⊗2 (α ⊤ Σ -1 l (x -µ l )) and (x ⊤ α)xx ⊤ is its dominant term. Define S m (x) = (-1) m ∇ m x p(x) p(x) , where p(x) is the probability density function of the random variable x. From Definition 1, we can verify thatThen define, where ||v|| = 1, then E[Gp i ] = 0. Similar to the proof of ( 202), (203), and (204) in Lemma 8, we haveHence, similar to the derivation of ( 206), we havewith probability at least 1 -2n -Ω(δ 4 1 (W * )d) . If the Gaussian Mixture Model is not a symmetric distribution which is defined in ( 16), we would have a similar result as follows.Then defineSimilar to the proof of ( 202), ( 203) and ( 204) in Lemma 8, we haveHence, similar to the derivation of (206), we havefor some constantwith probability at least 1 -2n -Ω(δ 4 1 (W * )d) .

G.4 PROOF OF LEMMA 11

We consider each component of yDefine T i (x) : R d → R K×K×K such thatWe flatten T i (x) : R d → R K×K×K along the first dimension to obtain the function B i (x) : R d → R K×K 2 . Similar to the derivation of the last step of Lemma E.8 in Zhong et al. (2017b) , we can obtain ∥T i (x)∥ ≤ ∥B i (x)∥. By (260), we have, where ||v|| = 1, so E[Gr i ] = 0. Similar to the proof of ( 202), ( 203) and ( 204) in Lemma 8, we haveHence, similar to the derivation of ( 206), we havewith probability at least 1 -2n -Ω(δ 4 1 (W * )) .

G.5 PROOF OF LEMMA 12

From Definition 1, we haveBased on Definition 1,, where ||v|| = 1, so E[Gq i ] = 0. Similar to the proof of ( 202), (203), and (204) in Lemma 8, we haveHence, similar to the derivation of ( 206 2 and t = τ 2 D 2 (Ψ) • d log n n , then we havewith probability at least 1 -2n -Ω(d) .

