UNDERSTANDING WHY GENERALIZED REWEIGHTING DOES NOT IMPROVE OVER ERM

Abstract

Empirical risk minimization (ERM) is known to be non-robust in practice to distributional shift where the training and the test distributions are different. A suite of approaches, such as importance weighting, and variants of distributionally robust optimization (DRO), have been proposed to solve this problem. But a line of recent work has empirically shown that these approaches do not significantly improve over ERM in real applications with distribution shift. The goal of this work is to obtain a comprehensive theoretical understanding of this intriguing phenomenon. We first posit the class of Generalized Reweighting (GRW) algorithms, as a broad category of approaches that iteratively update model parameters based on iterative reweighting of the training samples. We show that when overparameterized models are trained under GRW, the resulting models are close to that obtained by ERM. We also show that adding small regularization which does not greatly affect the empirical training accuracy does not help. Together, our results show that a broad category of what we term GRW approaches are not able to achieve distributionally robust generalization. Our work thus has the following sobering takeaway: to make progress towards distributionally robust generalization, we either have to develop non-GRW approaches, or perhaps devise novel classification/regression loss functions that are adapted to GRW approaches.

1. INTRODUCTION

It has now been well established that empirical risk minimization (ERM) can empirically achieve high test performance on a variety of tasks, particularly with modern overparameterized models where the number of parameters is much larger than the number of training samples. This strong performance of ERM however has been shown to degrade under distributional shift, where the training and test distributions are different (Hovy & Søgaard, 2015; Blodgett et al., 2016; Tatman, 2017) . There are two broad categories of distribution shift: domain generalization, defined as the scenario where the test distribution contains samples from new domains that did not appear during training; and subpopulation shift, defined as the scenario where the training set contains several subgroups and the testing distribution weighs these subgroups differently, like in fair machine learning. People have proposed various approaches to learn models robust to distributional shift. The most classical one is importance weighting (IW) (Shimodaira, 2000; Fang et al., 2020) , which reweights training samples; for subpopulation shift these weights are typically set so that each subpopulation has the same overall weight in the training objective. The approach most widely used today is Distributional Robust Optimization (DRO) (Duchi & Namkoong, 2018; Hashimoto et al., 2018) , which assumes that the test distribution belongs to a certain uncertainty set of distributions that are close to the training distribution, and train on the worst distribution in that set. Many variants of DRO have been proposed and are used in practice (Sagawa et al., 2020a; Zhai et al., 2021a; b) . While these approaches have been developed for the express purpose of improving ERM for distribution shift, a line of recent work has empirically shown the negative result that when used to train overparameterized models, these methods do not improve over ERM. For IW, Byrd & Lipton (2019) observed that its effect under stochastic gradient descent (SGD) diminishes over training epochs, and finally does not improve over ERM. For variants of DRO, Sagawa et al. (2020a) found that these methods overfit very easily, i.e. their test performances will drop to the same low level as ERM after sufficiently many epochs if no regularization is applied. Gulrajani & Lopez-Paz (2021) ; Koh et al. (2021) compared these methods with ERM on a number of real-world applications, and found that in most cases none of these methods improves over ERM. This line of empirical results has also been bolstered by some recent theoretical results. Sagawa et al. (2020b) constructed a synthetic dataset where a linear model trained with IW is provably not robust to subpopulation shift. Xu et al. (2021) further proved that under gradient descent (GD) with a sufficiently small learning rate, a linear classifier trained with either IW or ERM converges to the same max-margin classifier, and thus upon convergence, are no different. These previous theoretical results are limited to linear models and specific approaches such as IW where sample weights are fixed during training. They are not applicable to more complex models, and more general approaches where the sample weights could iteratively change, including most DRO variants. Towards placing the empirical results on a stronger theoretical footing, we define the class of generalized reweighting (GRW), which dynamically assigns weights to the training samples, and iteratively minimizes the weighted average of the sample losses. By allowing the weights to vary with iterations, we cover not just static importance weighting, but also DRO approaches outlined earlier; though of course, the GRW class is much broader than just these instances. Main contributions. We prove that GRW and ERM have (almost) equivalent implicit biases, in the sense that the points they converge to are very close to each other, under a much more general setting than those used in previous work. Thus, GRW cannot improve over ERM because it does not yield a significantly different model. We are the first to extend this line of theoretical results (i) to wide neural networks, (ii) to reweighting methods with dynamic weights, (iii) to regression tasks, and (iv) to methods with L 2 regularization. We note that these extensions are non-trivial technically as they require the result that wide neural networks can be approximated by their linearized counterparts to hold uniformly throughout the iterative process of GRW algorithms. Moreover, we fix the proof in a previous paper (Lee et al., 2019 ) (see Appendix E) which is also a great contribution. Overall, the important takeaway is that distributionally robust generalization (DRG) cannot be directly achieved by the broad class of GRW algorithms (which includes popular approaches such as importance weighting and most DRO variants). Progress towards this important goal thus requires either going beyond GRW algorithms, or devising novel loss functions that are adapted to GRW approaches. In Section 6 we will discuss some promising future directions and the case with nonoverparameterized models and early stopping. Finally, we want to emphasize that while the models we use in our results (linear models and wide neural networks) are different from practical models, they are general models most widely used in existing theory papers, and our results based on these models provide explanations to the baffling observations made in previous empirical work, as well as valuable insights into how to improve distributionally robust generalization.

2. PRELIMINARIES

Let the input space be X ⊆ R d and the output space be Y ⊆ R. 1 We assume that X is a subset of the unit L 2 ball of R d , so that any x ∈ X satisfies x 2 ≤ 1. We have a training set {z i = (x i , y i )} n i=1 i.i.d. sampled from an underlying distribution P over X × Y. Denote X = (x 1 , • • • , x n ) ∈ R d×n , and Y = (y 1 , • • • , y n ) ∈ R n . For any function g : X → R m , we overload notation and use g(X) = (g(x 1 ), • • • , g(x n )) ∈ R m×n (except when m = 1, g(X) is defined as a column vector). Let the loss function be : Y × Y → [0, 1]. ERM trains a model by minimizing its expected risk R(f ; P ) = E z∼P [ (f (x), y)] via minimizing the empirical risk R(f ) = 1 n n i=1 (f (x i ), y i ). In distributional shift, the model is evaluated not on the training distribution P , but a different test distribution P test , so that we care about the expected risk R(f ; P test ). A large family of methods designed for such distributional shift is distributionally robust optimization (DRO), which minimizes the expected risk over the worst-case distribution Q Pfoot_1 in a ball w.r.t. divergence D around the training distribution P . Specifically, DRO minimizes the expected DRO risk defined as: R D,ρ (f ; P ) = sup Q P {E Q [ (f (x), y)] : D(Q P ) ≤ ρ} for ρ > 0. Examples include CVaR, χ 2 -DRO (Hashimoto et al., 2018) , and DORO (Zhai et al., 2021a) , among others. A common category of distribution shift is known as subpopulation shift. Let the data domain contain K groups D 1 , • • • , D K . The training distribution P is the distribution over all groups, and the test distribution P test is the distribution over one of the groups. Let P k (z) = P (z | z ∈ D k ) be the conditional distribution over group k, then P test can be any one of P 1 , • • • , P k . The goal is to train a model f that performs well over every group. There are two common ways to achieve this goal: one is minimizing the balanced empirical risk which is an unweighted average of the empirical risk over each group, and the other is minimizing the worst-group risk defined as R max (f ; P ) = max k=1,••• ,K R(f ; P k ) = max k=1,••• ,K E z∼P [ (f (x), y)|z ∈ D k ] 3 GENERALIZED REWEIGHTING (GRW) Various methods have been proposed towards learning models that are robust to distributional shift. In contrast to analyzing each of these individually, we instead consider a large class of what we call Generalized Reweighting (GRW) algorithms that includes the ones mentioned earlier, but potentially many others more. Loosely, GRW algorithms iteratively assign each sample a weight during training (that could vary with the iteration) and iteratively minimize the weighted average risk. Specifically, at iteration t, GRW assigns a weight q (t) i to sample z i , and minimizes the weighted empirical risk: Rq (t) (f ) = n i=1 q (t) i (f (x i ), y i ) where q (t) = (q (t) 1 , • • • , q (t) n ) and q (t) 1 + • • • + q (t) n = 1. Static GRW assigns to each z i = (x i , y i ) a fixed weight q i that does not change during training, i.e. q (t) i ≡ q i . A classical method is importance weighting (IW) (Shimodaira, 2000) , where if z i ∈ D k and the size of D k is n k , then q i = (Kn k ) -1 . Under IW, (3) becomes the balanced empirical risk in which each group has the same weight. Note that ERM is also a special case of static GRW. On the other hand, in dynamic GRW, q (t) changes with t. For instance, any approach that iteratively upweights samples with high losses in order to help the model learn "hard" samples, such as DRO, is an instance of GRW. When estimating the population DRO risk R D,ρ (f ; P ) in Eqn. (1), if P is set to the empirical distribution over the training samples, then Q P implies that Q is also a distribution over the training samples. Thus, DRO methods belong to the broad class of GRW algorithms. There are two common ways to implement DRO. One uses Danskin's theorem and chooses Q as the maximizer of E Q [ (f (x), y)] in each epoch. The other one formulates DRO as a bi-level optimization problem, where the lower level updates the model to minimize the expected risk over Q, and the upper level updates Q to maximize it. Both can be seen as instances of GRW. As one popular instance of the latter, Group DRO was proposed by Sagawa et al. (2020a) to minimize (2). Denote the empirical risk over group k by Rk (f ), and the model at time t by f (t) . Group DRO iteratively sets q (t) i = g (t) k /n k for all z i ∈ D k where g (t) k is the group weight that is updated as g (t) k ∝ g (t-1) k exp ν Rk (f (t-1) ) (∀k = 1, • • • , K) for some ν > 0, and then normalized so that q Sagawa et al. (2020a) then showed (in their Proposition 2) that for convex settings, the Group DRO risk of iterates converges to the global minimum with the rate O(t -1/2 ) if ν is sufficiently small. (t) 1 + • • • + q (t) n = 1.

4. THEORETICAL RESULTS FOR REGRESSION

In this section, we will study GRW for regression tasks that use the squared loss (ŷ, y) = 1 2 (ŷ -y) 2 . (5) We will prove that for both linear models and sufficiently wide fully-connected neural networks, the implicit bias of GRW is equivalent to ERM, which means that starting from the same initial point, GRW and ERM will converge to the same point when trained for an infinitely long time. Thus, GRW cannot improve over ERM as it produces the exact same model as ERM. We will further show that while regularization can affect this implicit bias, it must be large enough to significantly lower the training performance, or the final model will still be similar to the unregularized ERM model. 

4.1. LINEAR MODELS

We first demonstrate our result on simple linear models to provide our readers with a key intuition which we will later apply to neural networks. This key intuition draws from results of Gunasekar et al. (2018) . Let the linear model be denoted by f (x) = θ, x , where θ ∈ R d . We consider the overparameterized setting where d > n. The weight update rule of GRW under GD is the following: θ (t+1) = θ (t) -η n i=1 q (t) i ∇ θ (f (t) (x i ), y i ) (6) where η > 0 is the learning rate. For a linear model with the squared loss, the update rule is θ (t+1) = θ (t) -η n i=1 q (t) i x i (f (t) (x i ) -y i ) For this training scheme, we can prove that if the training error converges to zero, then the model converges to an interpolator θ * (s.t. ∀i, θ * , x i = y i ) independent of q (t) i (proofs in Appendix D): Theorem 1. If x 1 , • • • , x n are linearly independent, then under the squared loss, for any GRW such that the empirical training risk R(f (t) ) → 0 as t → ∞, it holds that θ (t) converges to an interpolator θ * that only depends on θ (0) and x 1 , • • • , x n , but does not depend on q (t) i . The proof is based on the following key intuition regarding the update rule (7): θ (t+1) -θ (t) is a linear combination of x 1 , • • • , x n for all t, so θ (t) -θ (0) always lies in the linear subspace span{x 1 , • • • , x n }, which is an n-dimensional linear subspace if x 1 , • • • , x n are linearly independent. By Cramer's rule, there is exactly one θ in this subspace such that we get interpolation of all the data θ + θ (0) , x i = y i for all i ∈ {1, . . . , n}. In other words, the parameter θ * = θ + θ (0) in this subspace that interpolates all the data is unique. Thus the proof would follow if we were to show that θ (t) -θ (0) , which lies in the subspace, also converges to interpolating the data. Moreover, this proof works for any first-order optimization method such that the training risk converges to 0. We have essentially proved the following sobering result: any GRW algorithm that achieves zero training error exactly produces the ERM model, so it does not improve over ERM. While the various distributional shift methods discussed in the introduction have been shown to satisfy the precondition of convergence to zero training error with overparameterized models and linearly independent inputs (Sagawa et al., 2020a) , we provide the following theorem that shows this for the broad class of GRW methods. Specifically, we show this result for any GRW method that satisfies the following assumption with a sufficiently small learning rate: Assumption 1. There are constants q 1 , • • • , q n s.t. ∀i, q (t) i → q i as t → ∞. And min i q i = q * > 0. Theorem 2. If x 1 , • • • , x n are linearly independent, then there exists η 0 > 0 such that for any GRW satisfying Assumption 1 with the squared loss, and any η ≤ η 0 , the empirical training risk R(f (t) ) → 0 as t → ∞. Finally, we use a simple experiment to demonstrate the correctness of this result. The experiment is conducted on a training set of six MNIST images, five of which are digit 0 and one is digit 1. We use a 784-dimensional linear model and run ERM, importance weighting and group DRO. The results are presented in Figure 1 , and they show that the training loss of each method converges to 0, and the gap between the model weights of importance weighting, Group DRO and ERM converges to 0, meaning that all three model weights converge to the same point, whose L 2 norm is about 0.63. Figure 1d also shows that the group weights in Group DRO empirically satisfy Assumption 1.

4.2. WIDE NEURAL NETWORKS (WIDE NNS)

Now we study sufficiently wide fully-connected neural networks. We extend the analysis in Lee et al. (2019) in the neural tangent kernel (NTK) regime (Jacot et al., 2018) . In particular we study the following network: h l+1 = W l √ d l x l + βb l and x l+1 = σ(h l+1 ) (l = 0, • • • , L) (8) where σ is a non-linear activation function, W l ∈ R d l+1 ×d l and W L ∈ R 1×d L . Here d 0 = d. The parameter vector θ consists of W 0 , • • • , W L and b 0 , • • • , b L ( θ is the concatenation of all flattened weights and biases). The final output is f (x) = h L+1 . And let the neural network be initialized as W l(0) i,j ∼ N (0, 1) b l(0) j ∼ N (0, 1) (l = 0, • • • , L -1) and W L(0) i,j = 0 b L(0) j ∼ N (0, 1) We also need the following assumption on the wide neural network: Assumption 2. σ is differentiable everywhere. Both σ and its first-order derivative σ are Lipschitz.foot_2  Difference from Jacot et al. (2018) . Our initialization (9) differs from the original one in Jacot et al. (2018) in the last (output) layer, where we use the zero initialization W L(0) i,j = 0 instead of the Gaussian initialization W L(0) i,j ∼ N (0, 1). This modification permits us to accurately approximate the NN with its linearized counterpart (11), as we notice that the proofs in Lee et al. (2019) (particularly the proofs of their Theorem 2.1 and their Lemma 1 in Appendix G) are flawed. In Appendix E we will explain what went wrong in their proofs and how we fix it with this modification. Denote the neural network at time t by f (t) (x) = f (x; θ (t) ) which is parameterized by θ (t) ∈ R p where p is the number of parameters. We use the shorthand ∇ θ f (0) (x) := ∇ θ f (x; θ) θ=θ0 . The neural tangent kernel (NTK) of this model is Θ (0) (x, x ) = ∇ θ f (0) (x) ∇ θ f (0) (x ) , and the Gram matrix is Θ (0) = Θ (0) (X, X) ∈ R n×n . For this wide NN, we still have the following NTK theorem: Lemma 3. If σ is Lipschitz and d l → ∞ for l = 1, • • • , L sequentially, then Θ (0) (x, x ) converges in probability to a non-degeneratefoot_3 deterministic limiting kernel Θ(x, x ). The kernel Gram matrix Θ = Θ(X, X) ∈ R n×n is a positive semi-definite symmetric matrix. Denote its largest and smallest eigenvalues by λ max and λ min . Note that Θ is non-degenerate, so we can assume that λ min > 0 (which is almost surely true when d L n). Then we have: Theorem 4. Let f (t) be a wide fully-connected neural network that satisfies Assumption 2 and is trained by any GRW satisfying Assumption 1 with the squared loss. Let f (t) ERM be the same model trained by ERM from the same initial point. If d 1 = • • • = d L = d, ∇ θ f (0) (x 1 ), • • • , ∇ θ f (0) (x n ) are linearly independent, and λ min > 0, then there exists a constant η 1 > 0 such that: if η ≤ η 1foot_4 , then for any δ > 0, there exists D > 0 such that as long as d ≥ D, with probability at least (1 -δ) over random initialization we have: for any test point x ∈ R d such that x 2 ≤ 1, as d → ∞, lim sup t→∞ f (t) (x) -f (t) ERM (x) = O( d-1/4 ) → 0 Essentially this theorem implies that on any test point x in the unit ball, the GRW model and the ERM model produce almost the same output, so they have almost the same performance. Note that for simplicity, we only prove for d 1 = • • • = d L = d → ∞ , but the result can be very easily extended to the case where d l /d 1 → α l for l = 2, • • • , L for some constants α 2 , • • • , α L , and d 1 → ∞. The key to proving this theorem is to consider the linearized neural network of f (t) (x): f (t) lin (x) = f (0) (x) + θ (t) -θ (0) , ∇ θ f (0) (x) which is a linear model w.r.t. ∇ θ f (0) (x). If ∇ θ f (0) (x 1 ), • • • , ∇ θ f (0) (x n ) are linearly independent (which is almost surely true when the model is overparameterized so that θ has a very high dimension), then our key intuition tells us that the linearized NN will converge to the unique interpolator. Then we show that the wide NN can be approximated by its linearized counterpart uniformly throughout training, which is considerably more subtle in our case due to the GRW dynamics. Here we prove the upper bound O( d-1/4 ), but in fact the upper bound can be O( d-1/2+ ) for any > 0: Lemma 5 (Approximation Theorem). For a wide fully-connected neural network f (t) satisfying Assumption 2 and is trained by any GRW satisfying Assumption 1 with the squared loss, let f (t) lin be its linearized neural network trained by the same GRW (i.e. q (t) i are the same for both networks for any i and t). Under the conditions of Theorem 4, with a sufficiently small learning rate, for any δ > 0, there exist constants D > 0 and C > 0 such that as long as d ≥ D, with probability at least (1 -δ) over random initialization we have: for any test point x ∈ R d such that x 2 ≤ 1, sup t≥0 f (t) lin (x) -f (t) (x) ≤ C d-1/4 (12) This lemma essentially says that throughout the GRW training process, on any test point x in the unit ball, the linear NN and the wide NN produce almost the same output. So far, we have shown that in a regression task, for both linear and wide NNs, GRW does not improve over ERM. 4.3 WIDE NEURAL NETWORKS, WITH L 2 REGULARIZATION Previous work such as Sagawa et al. (2020a) proposed to improve GRW by adding L 2 penalty to the objective function. In this section, we thus study adding L 2 regularization to GRW algorithms: Rµ q (t) (f ) = n i=1 q (t) i (f (x i ), y i ) + µ 2 θ -θ (0) 2 2 (13) At first sight, adding regularization seems to be a natural approach and should make a difference. Indeed, from the outset, we can easily show that with L 2 regularization, the GRW model and the ERM model are different unlike the case without regularization. As an concrete example, when f is a linear model, is convex and smooth, then Rµ q (t) (f ) with static GRW is a convex smooth objective function, so under GD with a sufficiently small learning rate, the model will converge to the global minimizer (see Appendix D.1). Moreover, the global optimum θ * satisfies ∇ θ Rµ q (t) (f (x; θ * )) = 0, solving which yields θ * = θ (0) + (XQX + µI) -1 XQ(Y -f (0) (X)), which depends on Q = diag(q 1 , • • • , q n ) , so adding L 2 regularization at least seems to yield different results from ERM (albeit whether it improves over ERM might depend on q 1 , • • • , q n ). However, the following result shows that this regularization must be large enough to significantly lower the training performance, or the final model would still be close to the unregularized ERM model. We still denote the largest and smallest eigenvalues of the kernel Gram matrix Θ by λ max and λ min . We use the subscript "reg" to refer to a regularized model (trained by minimizing ( 13)). Theorem 6. Suppose there exists M 0 > 0 s.t. ∇ θ f (0) (x) 2 ≤ M 0 for all x 2 ≤ 1. If λ min > 0 and µ > 0, then for a wide NN satisfying Assumption 2, and any GRW minimizing the squared loss with a sufficiently small learning rate η, if d 1 = d 2 = • • • = d L = d, ∇ θ f (0) (x 1 ), • • • , ∇ θ f (0) (x n ) are linearly independent, and the empirical training risk of f (t) reg satisfies lim sup t→∞ R(f (t) reg ) < (14) for some > 0, then with a sufficiently small learning rate, as d → ∞, with probability close to 1 over random initialization, for any x such that x 2 ≤ 1 we have lim sup t→∞ f (t) reg (x) -f (t) ERM (x) = O( d-1/4 + √ ) → O( √ ) where f (t) reg is trained by regularized GRW and f ERM by unregularized ERM from same initial points. This theorem essentially says that if the regularization is too small and the training error is close to zero, then the regularized GRW model still produces an output very close to that of the ERM model on any test point x in the unit ball. Thus, a small regularization makes little difference. The proof again starts from analyzing linearized NNs, and showing that regularization does not help there (Appendix D.4.2). Then, we prove a new approximation theorem for L 2 regularized GRW connecting wide NNs to linearized NNs uniformly throughout training (Appendix D.4.1). With regularization, we no longer need Assumption 1 to prove the new approximation theorem which was used to prove the convergence of GRW, because with regularization GRW naturally converges. To empirically demonstrate this result, we run the same experiment as in Section 4.1 but with L 2 regularization. The results are presented in Figure 2 . We can see that when the regularization is small, the training losses still converge to 0, and the three model weights still converge to the same point. On the contrary, with a large regularization, the training loss does not converge to 0, and the three model weights converge to different points. This shows that the regularization must be large enough to lower the training performance to make a significant difference to the implicit bias.

5. THEORETICAL RESULTS FOR CLASSIFICATION

Now we consider classification where Y = {+1, -1}. The big difference is that classification losses don't have finite minimizers. A classification loss converging to zero means that the model weight "explodes" to infinity instead of converging to a finite point. We focus on the canonical logistic loss: (ŷ, y) = log(1 + exp(-ŷy)) (16)

5.1. LINEAR MODELS

We first consider training the linear model f (x) = θ, x with GRW under gradient descent with the logistic loss. As noted earlier, in this setting, Byrd & Lipton (2019) made the empirical observation that importance weighting does not improve over ERM. Then, Xu et al. (2021) proved that for importance weighting algorithms, as t → ∞, θ (t) 2 → ∞ and θ (t) / θ (t) 2 converges to a unit vector that does not depend on the sample weights, so it does not improve over ERM. To extend this theoretical result to the broad class of GRW algorithms, we will prove two results. First, in Theorem 7 we will show that for the logistic loss and any GRW algorithm satisfying the weaker assumption: Assumption 3. For all i, lim inf t→∞ q (t) i > 0, if the training error converges to 0, and the direction of the model weight converges to a fixed unit vector, then this unit vector must be the max-margin classifier defined as θMM = arg max θ: θ 2 =1 min i=1,••• ,n y i • θ, x i Second, Theorem 8 shows that for any GRW satisfying Assumption 1, the training error converges to 0 and the direction of the model weight converges, so it does not improve over ERM. Theorem 7. If x 1 , • • • , x n are linearly independent, then for the logistic loss, we have: for any GRW satisfying Assumption 3, if as t → ∞ the empirical training risk R(f (t) ) converges to 0 and θ (t) / θ (t) 2 → u for some unit vector u, then u = θMM . This result is an extension of Soudry et al. (2018) , and says that all GRW methods (including ERM) make the model converge to the same point θMM that does not depend on q (t) i . In other words, the samples weights do not affect the implicit bias. Thus, for any GRW method that only satisfies the weak Assumption 3, as long as the training error converges to 0 and the model weight direction converges, GRW does not improve over ERM. We next show that any GRW satisfying Assumption 1 does have its model weight direction converge, and its training error converge to 0. Theorem 8. For any loss that is convex, L-smooth in ŷ and strictly monotonically decreasing to zero as y ŷ → +∞, and GRW satisfying Assumption 1, denote F (θ) = n i=1 q i ( θ, x i , y i ). If x 1 , • • • , x n are linearly independent, then with a sufficiently small learning rate η, we have: F (θ (t) ) → 0 as t → ∞. (i) θ (t) 2 → ∞ as t → ∞. (ii) Let θ R = arg min θ {F (θ) : θ 2 ≤ R}. θ R is unique for any R such that min θ 2 ≤R F (θ) < min i q i (0, y i ). And if lim R→∞ θ R R exists, then lim t→∞ θ (t) θ (t) 2 also exists and they are equal.

(iii)

This result is an extension of Theorem 1 of Ji et al. (2020) . For the logistic loss, it is easy to show that it satisfies the conditions of the above theorem and lim R→∞ θ R R = θMM . Thus, Theorems 8 and 7 together imply that all GRW satisfying Assumption 1 (including ERM) have the same implicit bias (see Appendix D.5.3). We also have empirical verification for these results (see Appendix C). Remark. It is impossible to extend these results to wide NNs like Theorem 4 because for a neural network, if θ (t) 2 goes to infinity, then ∇ θ f 2 will also go to infinity. However, for a linear model, the gradient is a constant. Consequently, the gap between the neural networks and its linearized counterpart will "explode" under gradient descent, so there can be no approximation theorem like Lemma 5 that can connect wide NNs to their linearized counterparts. Thus, we consider regularized GRW, for which θ (t) converges to a finite point and there is an approximation theorem.

5.2. WIDE NEURAL NETWORKS, WITH L 2 REGULARIZATION

Consider minimizing the regularized weighted empirical risk (13) with being the logistic loss. As in the regression case, with L 2 regularization, GRW methods have different implicit biases than ERM for the same reasons as in Section 4.3. And similarly, we can show that in order for GRW methods to be sufficiently different from ERM, the regularization needs to be large enough to significantly lower the training performance. Specifically, in the following theorem we show that if the regularization is too small to lower the training performance, then a wide neural network trained with regularized GRW and the logistic loss will still be very close to the max-margin linearized neural network: f MM (x) = θMM , ∇ θ f (0) (x) where θMM = arg max θ 2 =1 min i=1,••• ,n y i • θ, ∇ θ f (0) (x i ) Note that f MM does not depend on q (t) i . Moreover, using the result in the previous section we can show that a linearized neural network trained with unregularized ERM will converge to f MM : Theorem 9. Suppose there exists M 0 > 0 such that ∇ θ f (0) (x) 2 ≤ M 0 for all test point x. For a wide NN satisfying Assumption 2, and for any GRW satisfying Assumption 1 with the logistic loss, if d 1 = d 2 = • • • = d L = d and ∇ θ f (0) (x 1 ), • • • , ∇ θ f (0) (x n ) are linearly independent and the learning rate is sufficiently small, then for any δ > 0 there exists a constant C > 0 such that: with probability at least (1 -δ) over random initialization, as d → ∞ we have: for any ∈ (0, 1 4 ), if the empirical training error satisfies lim sup t→∞ R(f (t) reg ) < , then for any test point x such that |f MM (x)| > C • (-log 2 ) -1/2 , f (t) reg (x) has the same sign as f MM (x) when t is sufficiently large. This result says that at any test point x on which the max-margin linear classifier classifies with a margin of Ω((-log 2 ) -1/2 ), the neural network has the same prediction. And as decreases, the confidence threshold also becomes lower. Similar to Theorem 6, this theorem provides the scaling of the gap between the regularized GRW model and the unregularized ERM model w.r.t. . This result justifies the empirical observation in Sagawa et al. (2020a) that with large regularization, some GRW algorithms can maintain a high worst-group test performance, with the cost of suffering a significant drop in training accuracy. On the other hand, if the regularization is small and the model can achieve nearly perfect training accuracy, then its worst-group test performance will still significantly drop.

6.1. DISTRIBUTIONALLY ROBUST GENERALIZATION AND FUTURE DIRECTIONS

A large body of prior work focused on distributionally robust optimization, but we show that these methods have (almost) equivalent implicit biases as ERM. In other words, distributionally robust optimization (DRO) does not necessarily achieve better distributionally robust generalization (DRG). Our results pinpoint a critical bottleneck in the current distribution shift research, and we argue that a deeper understanding in DRG is crucial for developing better distributionally robust training algorithms. Here we discuss three promising future directions to improving DRG. The first approach is data augmentation and pretraining on large datasets. Our theoretical findings suggest that the implicit bias of GRW is determined by the training samples and the initial point, but not the sample weights. Thus, to improve DRG, we can either obtain more training samples, or start from a better initial point, as proposed in two recent papers (Wiles et al., 2022; Sagawa et al., 2022) . The second approach (for classification) is to go beyond the class of (iterative) sample reweighting based GRW algorithms, for instance via logit adjustment (Menon et al., 2021) , which makes a classifier have larger margins on smaller groups to improve its generalization on smaller groups. An early approach by Cao et al. (2019) proposed to add an O(n   -1/4  k ) additive adjustment term to the logits output by the classifier. Following this spirit, Menon et al. ( 2021) proposed the LA-loss which also adds an additive adjustment term to the logits. Ye et al. (2020) proposed the CDT-loss which adds a multiplicative adjustment term to the logits by dividing the logits of different classes with different temperatures. Kini et al. (2021) proposed the VS-loss which includes both additive and multiplicative adjustment terms, and they showed that only the multiplicative adjustment term affects the implicit bias, while the additive term only affects optimization, a fact that can be easily derived from our Theorem 8. Finally, Li et al. (2021a) proposed AutoBalance which optimizes the adjustment terms with a bi-level optimization framework. The third approach is to stay within the class of GRW algorithms, but to change the classification/regression loss function to be suited to GRW. A recent paper (Wang et al., 2022) showed that for linear classifiers, one can make the implicit bias of GRW dependent on the sample weights by replacing the exponentially-tailed logistic loss with the following polynomially-tailed loss: α,β (ŷ, y) =    left (ŷy) , if ŷy < β 1 [ŷy -(β -1)] α , if ŷy ≥ β (19) And this result can be extended to GRW satisfying Assumption 1 using our Theorem 8. The reason why loss ( 19) works is that it changes lim R→∞ θ R R , and the new limit depends on the sample weights.

6.2. LIMITATIONS

Like most theory papers, our work makes some strong assumptions. The two main assumptions are: (i) The model is a linear model or a sufficiently wide fully-connected neural network. (ii) The model is trained for sufficiently long time, i.e. without early stopping. Regarding (i), Chizat et al. (2019) argued that NTK neural networks fall in the "lazy training" regime and results might not be transferable to general neural networks. However, this class of neural networks has been widely studied in recent years and has provided considerable insights into the behavior of general neural networks, which is hard to analyze otherwise. Regarding (ii), in some easy tasks, when early stopping is applied, existing algorithms for distributional shift can do better than ERM (Sagawa et al., 2020a) . However, as demonstrated in Gulrajani & Lopez-Paz (2021) ; Koh et al. (2021) , in real applications these methods still cannot significantly improve over ERM even with early stopping, so early stopping is not the ultimate universal solution. Thus, though inevitably our results rely on some strong assumptions, we believe that they provide important insights into the problems of existing methods and directions for future work, which are significant contributions to the study of distributional shift problems.

7. CONCLUSION

In this work, we posit a broad class of what we call Generalized Reweighting (GRW) algorithms that include popular approaches such as importance weighting, and Distributionally Robust Optimization (DRO) variants, that were designed towards the task of learning models that are robust to distributional shift. We show that when used to train overparameterized linear models or wide NN models, even this very broad class of GRW algorithms does not improve over ERM, because they have the same implicit biases. We also showed that regularization does not help if it is not large enough to significantly lower the average training performance. Our results thus suggest to make progress towards learning models that are robust to distributional shift, we have to either go beyond this broad class of GRW algorithms, or design new losses specifically targeted to this class.

A RELATED WORK A.1 SUBPOPULATION SHIFT

In this work, we mainly focus on the subpopulation shift problem, which has two main applications: group fairness and long-tailed learning (learning with class imbalance). In both applications, the dataset can be divided into several subgroups, and this work considers minimizing the worst-group risk, defined as the maximum risk over any group. Group Fairness. Group fairness refers to the scenario where the dataset contains several "groups" (such as demographic groups), and a model is considered fair if its per-group performances meet certain criteria (a "fairness" notion). Group fairness in machine learning was first studied in Hardt et al. (2016) and Zafar et al. (2017) , where they required the model to perform equally well over all groups. Many previous papers proposed a number of fairness notions, such as equal opportunity, statistical parity, etc. Among them, Hashimoto et al. (2018) Long-tailed Learning. Long-tailed learning refers to the scenario where different classes have different sizes, and usually there are some "minority classes" with extremely few samples that are much more difficult to learn than the other classes. Using GRW such as importance weighting for long-tailed learning is a very old idea which dates back to Xie & Manski (1989) . However, recently Byrd & Lipton (2019) found that the effect of importance weighting for long-tailed learning diminishes as training proceeds, which leads to a line of recent work on how to improve the generalization in long-tailed learning (Cao et al., 2019; Menon et al., 2021; Ye et al., 2020; Kim & Kim, 2020; Kini et al., 2021) . Most of these papers share a common idea: Forcing the model to have larger margins on smaller groups, so that its generalization on smaller groups can be better. Self-supervised learning is also used in long-tailed learning. For instance, Liu et al. (2022) found that self-supervised learning can achieve good performances in long-tailed learning, Wang et al. (2021) used contrastive learning for long-tailed learning, and Li et al. (2021b) used self-distillation.

A.2 DOMAIN GENERALIZATION

Domain generalization is the second common type of distribution shift. In domain generalization and the related domain adaptation, a model is tested on a different domain than what it is trained on. The most common idea in domain generalization is invariant learning, which learns a feature extractor that is invariant across domains, usually by matching the feature distribution of different domains. Since we have no access to the target domain, in invariant learning we assume that we have access to multiple domains in the training set, and we learn a feature extractor with a small variance across these domains. Examples include CORAL (Sun & Saenko, 2016) , DANN (Ganin et al., 2016) , MMD (Li et al., 2018) and IRM (Arjovsky et al., 2019) . However, Gulrajani & Lopez-Paz (2021) ; Koh et al. (2021) empirically showed that most of these methods cannot do better than standard ERM, and Rosenfeld et al. (2021) theoretically proved that IRM cannot do better than ERM unless the number of training domains is greater than the number of independent features. One problem of invariant learning methods is that they do not necessarily align the classes. For a source domain P and a target domain Q, even if we have successfully learned a feature extractor Φ such that Φ(P ) ≈ Φ(Q), there is no guarantee that Φ can map the samples in P and Q from the same class to the same location in the feature space. In the worst case, Φ can map the positive samples in P and the negative samples in Q to the same location and vice versa, in which case 100% accuracy over P means 0% accuracy over Q. The goal of class alignment is to make sure that samples from the same class are mapped together, and far away from the other classes. For example, Tzeng et al. (2015) used soft labels to align the classes, Long et al. (2016) 2018), and specifically we use linearized neural networks to approximate such wide neural networks following Lee et al. (2019) . There is some criticism of this line of work, e.g. Chizat et al. (2019) argued that infinitely wide neural networks fall in the "lazy training" regime and results might not be transferable to general neural networks. Nonetheless such wide neural networks are being widely studied in recent years, since they provide considerable insights into the behavior of more general neural networks, which are typically intractable to analyze otherwise.

A.4 COMPARISON WITH HU ET AL. (2018)

A prior work Hu et al. (2018) also proved that GRW is equivalent to ERM under certain conditions. However, we would like to point out that this work is substantially different from Hu et al. (2018) . Hu et al. (2018) proved that in classification that uses the zero-one loss, GRW methods such as DRSL are equivalent to ERM, in the sense that the minimizer of the DRSL risk is also the minimizer of the average risk. However, this does not mean that DRSL and ERM will always converge to the same point, as there could be multiple minimizers. Their result relies on the zero-one loss, which leads to a monotonic linear relationship between the DRSL risk and the average risk. Moreover, their result is only about the relationship between two minimizers, and they did not prove that DRSL and ERM can actually converge to these global minima. On the other hand, in our results, we first show that without regularization, GRW and ERM will converge to the exact same point, so that they have equivalent implicit biases, which is a much stronger result. Then we show that even with regularization, if the regularization is not large enough, GRW will still converge to a point that is very close to the point ERM converges to. Our results do not depend on the loss function, and work for both the squared loss for regression and the logistic loss for classification (and can be extended to other losses). Instead, our results depend on the optimization method (must be first-order or gradient-based) as well as the model architecture (linear or wide NN), since we need to explicitly prove that both GRW and ERM can reach the global minima if trained under a small learning rate for sufficiently long. In a word, Hu et al. (2018) proves the equivalence between GRW and ERM under the zero-one loss with a monotonic relationship between the two risk functions, while our results focus on the optimization and training dynamics, and prove that GRW and ERM have almost equivalent implicit biases.

B EXTENSION TO MULTI-DIMENSIONAL REGRESSION / MULTI-CLASS CLASSIFICATION

In our results, we assume that f : R d → R for simplicity, but our results can be very easily extended to the case where f : R d → R k . For most of our results, the proof consists of two major components: (i) The linearized neural network will converge to some point (interpolator, max-margin classifier, etc.); (ii) The wide fully-connected neural network can be approximated by its linearized counterpart. For both components, the extension is very simple and straightforward. For (i), the proof only relies on the smoothness of the objective function and the upper quadratic bound it entails, and the function is still smooth when its output becomes multi-dimensional; For (ii), we can prove that sup t f (x) -f lin (x) 2 = O( d-1/4 ) in exactly the same way. Thus, all of our results hold for multi-dimensional regression and multi-class classification. Particularly, for the multi-class cross-entropy loss, using Theorem 8 we can show that under any GRW satisfying Assumption 1, the direction of the weight of a linear classifier will converge to the following max-margin classifier: θMM = arg min θ min i=1,••• ,n f (x i ) yi -max y =yi f (x i ) y : θ 2 = 1 which is still independent of q i .

C MORE EXPERIMENTS

We run ERM, importance weighting and Group DRO on the training set with 6 MNIST images which we used in Section 4.1 with the logistic loss and the polynomially-tailed loss (Eqn. ( 19), with α = 1, β = 0 and left being the logistic loss shifted to make the overall loss function continuous) on this dataset for 10 million epochs (note that we run for much more epochs because the convergence is very slow). The results are shown in Figure 3 . From the plots we can see that: • For either loss function, the training loss of each method converges to 0. • In contrast to the theory that the norm of the ERM model will go to infinity and all models will converge to the max-margin classifier, the weight of the ERM model gets stuck at some point, and the norms of the gaps between the normalized model weights also get stuck. The reason is that the training loss has got so small that it becomes zero in the floating number representation, so the gradient also becomes zero and the training halts due to limited computational precision. • However, we can still observe a fundamental difference between the logistic loss and the polynomially-tailed loss. For the logistic loss, the norm of the gap between importance weighting (or Group DRO) and ERM will converge to around 0.06 when the training stops, while for the polynomially-tailed loss, the norm will be larger than 0.22 and will keep growing, which shows that for the polynomially-tailed loss the normalized model weights do not converge to the same point. • For either loss, the group weights of Group DRO still empirically satisfy Assumption 1.

D PROOFS

In this paper, for any matrix A, we will use A 2 to denote its spectral norm and A F to denote its Frobenius norm.

D.1 BACKGROUND

ON SMOOTHNESS A first-order differentiable function f over D is called L-smooth for L > 0 if f (y) ≤ f (x) + ∇f (x), y -x + L 2 y -x 2 2 ∀x, y ∈ D (21) which is also called the upper quadratic bound. If f is second-order differentiable and D is a convex set, then f is L-smooth is equivalent to v ∇ 2 f (x)v ≤ L ∀ v 2 = 1, ∀x ∈ D A classical result in convex optimization is the following: Theorem 10. If f (x) is convex and L-smooth with a unique finite minimizer x * , and is minimized by gradient descent x t+1 = x t -η∇f (x t ) starting from x 0 where the learning rate η ≤ 1 L , then we have f (x T ) ≤ f (x * ) + 1 ηT x 0 -x * 2 2 ( ) which also implies that x T converges to x * as T → ∞. D.2 PROOFS FOR SUBSECTION 4.1 D.2.1 PROOF OF THEOREM 1 Using the key intuition, the weight update rule (7) implies that θ (t+1) -θ (t) ∈ span{x 1 , • • • , x n } for all t, which further implies that θ (t) -θ (0) ∈ span{x 1 , • • • , x n } for all t. By Cramer's rule, in this n-dimensional subspace there exists one and only one θ * such that θ * -θ (0) ∈ span{x 1 , • • • , x n } and θ * , x i for all i. Then we have X (θ (t) -θ * ) 2 = (X θ (t) -Y ) -(X θ * -Y ) 2 ≤ X θ (t) -Y 2 + X θ * -Y 2 → 0 (24) because X θ -Y 2 2 = 2n R(f (x; θ)). On the other hand, let s min be the smallest singular value of X. Since X is full-rank, s min > 0, and X (θ (t) -θ * ) 2 ≥ s min θ (t) -θ * 2 . This shows that θ (t) -θ * 2 → 0. Thus, θ (t) converges to this unique θ * .

D.2.2 PROOF OF THEOREM 2

To help our readers understand the proof more easily, we will first prove the result for static GRW where q (t) i = q i for all t, and then we will prove the result for dynamic GRW that satisfy q (t) i → q i as t → ∞. Static GRW. We first prove the result for all static GRW such that min i q i = q * > 0.

We will use smoothness introduce in Appendix

D.1. Denote A = n i=1 x i 2 2 . The empirical risk of the linear model f (x) = θ, x is F (θ) = n i=1 q i (x i θ -y i ) 2 (25) whose Hessian is ∇ 2 θ F (θ) = 2 n i=1 q i x i x i So for any unit vector v ∈ R d , we have (since q i ∈ [0, 1]) v ∇ 2 θ F (θ)v = 2 n i=1 q i (x i v) 2 ≤ 2 n i=1 q i x i 2 2 ≤ 2A which implies that F (θ) is 2A-smooth. Thus, we have the following upper quadratic bound: for any θ 1 , θ 2 ∈ R d , F (θ 2 ) ≤ F (θ 1 ) + ∇ θ F (θ 1 ), θ 2 -θ 1 + A θ 2 -θ 1 2 2 (28) Denote g(θ (t) ) = (X θ (t) -Y ) ∈ R n . We can see that √ Qg(θ (t) ) 2 2 = F (θ (t) ) , where where √ Q = diag( √ q 1 , • • • , √ q n ). Thus, ∇F (θ (t) ) = 2XQg(θ (t) ). The update rule of a static GRW with gradient descent and the squared loss is: θ (t+1) = θ (t) -η n i=1 q i x i (f (t) (x i ) -y i ) = θ (t) -ηXQg(θ (t) ) Substituting θ 1 and θ 2 in (28) with θ (t) and θ (t+1) yields F (θ (t+1) ) ≤ F (θ (t) ) -2ηg(θ (t) ) Q X XQg(θ (t) ) + A ηXQg(θ (t) ) 2 2 (30) Since x 1 , • • • , x n are linearly independent, X X is a positive definite matrix. Denote the smallest eigenvalue of X X by λ min > 0. And Qg(θ (t) ) 2 ≥ √ q * g(θ (t) ) 2 = q * F (θ (t) ), so we have g(θ (t) ) Q X XQg(θ (t) ) ≥ q * λ min F (θ (t) ). Thus, F (θ (t+1) ) ≤ F (θ (t) ) -2ηq * λ min F (θ (t) ) + Aη 2 X Q 2 2 Qg(θ (t) ) 2 2 ≤ F (θ (t) ) -2ηq * λ min F (θ (t) ) + Aη 2 X Q 2 F F (θ (t) ) ≤ F (θ (t) ) -2ηq * λ min F (θ (t) ) + Aη 2 X 2 F F (θ (t) ) = (1 -2ηq * λ min + A 2 η 2 )F (θ (t) ) Let η 0 = q * λ min A 2 . For any η ≤ η 0 , we have F (θ (t+1) ) ≤ (1 -ηq * λ min )F (θ (t) ) for all t, which implies that lim t→∞ F (θ (t) ) = 0. This implies that the empirical training risk must converge to 0. Dynamic GRW. Now we prove the result for all dynamic GRW satisfying Assumption 1. By Assumption 1, for any > 0, there exists t such that for all t ≥ t and all i, q (t) i ∈ (q i -, q i + ) This is because for all i, there exists t i such that for all t ≥ t i , q i ∈ (q i -, q i + ). Then, we can define t = max{t 1 , • • • , t n }. Denote the largest and smallest eigenvalues of X X by λ max and λ min , and because X is full-rank, we have λ min > 0. Define = min{ q * 3 , (q * λ min ) 2 12λ max 2 }, and then t is also fixed. We still denote Q = diag(q 1 , • • • , q n ). When t ≥ t , the update rule of a dynamic GRW with gradient descent and the squared loss is: θ (t+1) = θ (t) -ηXQ (t) (X θ (t) -Y ) where Q (t) = Q (t) , and we use the subscript to indicate that Q (t) -Q 2 < . Then, note that we can rewrite Q (t) as Q (t) = Q (t) 3 • √ Q as long as ≤ q * / 3. This is because q i + ≤ (q i + 3 )q i and q i -≥ (q i -3 )q i for all ≤ q i /3, and q i ≥ q * . Thus, we have θ (t+1) = θ (t) -ηX Q (t) 3 Qg(θ (t) ) where Q (t) = Q (t) 3 • Q (34) Again, substituting θ 1 and θ 2 in (28) with θ (t) and θ (t+1) yields F (θ (t+1) ) ≤ F (θ (t) ) -2ηg(θ (t) ) Q X X Q (t) 3 Qg(θ (t) ) + A ηX Q (t) 3 Qg(θ (t) ) 2 2 (35) Then, note that g(θ (t) ) Q X X Q (t) 3 -Q Qg(θ (t) ) ≤ Q X X Q (t) 3 -Q 2 Qg(θ (t) ) 2 2 ≤ Q 2 X X 2 Q (t) 3 -Q 2 Qg(θ (t) ) 2 2 ≤λ max √ 3 F (θ (t) ) where the last step comes from the following fact: for all < q i /3, q i + 3 - √ q i ≤ √ 3 and √ q i -q i -3 ≤ √ 3 And as proved before, we also have g(θ (t) ) Q X XQg(θ (t) ) ≥ q * λ min F (θ (t) ) Since ≤ (q * λ min ) 2 12λ max 2 , we have g(θ (t) ) Q X X Q (t) 3 Qg(θ (t) ) ≥ q * λ min -λ max √ 3 F (θ (t) ) ≥ 1 2 q * λ min F (θ (t) ) (39) Thus, F (θ (t+1) ) ≤ F (θ (t) ) -ηq * λ min F (θ (t) ) + Aη 2 X Q (t) 3 2 2 Qg(θ (t) ) 2 2 ≤ (1 -ηq * λ min + A 2 η 2 (1 + 3 ))F (θ (t) ) ≤ (1 -ηq * λ min + 2A 2 η 2 )F (θ (t) ) for all ≤ 1/3. Let η 0 = q * λ min 4A 2 . For any η ≤ η 0 , we have F (θ (t+1) ) ≤ (1 -ηq * λ min /2)F (θ (t) ) for all t ≥ t , which implies that lim t→∞ F (θ (t) ) = 0. Thus, the empirical training risk converges to 0.

D.3 PROOFS FOR SUBSECTION 4.2 D.3.1 PROOF OF LEMMA 3

Note that the first l layers (except the output layer) of the original NTK formulation and our new formulation are the same, so we still have the following proposition: Proposition 11 (Proposition 1 in Jacot et al. (2018) ). If σ is Lipschitz and d l → ∞ for l = 1, • • • , L sequentially, then for all l = 1, • • • , L, the distribution of a single element of h l converges in probability to a zero-mean Gaussian process of covariance Σ l that is defined recursively by: Σ 1 (x, x ) = 1 d 0 x x + β 2 Σ l (x, x ) = E f [σ(f (x))σ(f (x ))] + β 2 (41) where f is sampled from a zero-mean Gaussian process of covariance Σ (l-1) . Now we show that for an infinitely wide neural network with L ≥ 1 hidden layers, Θ (0) converges in probability to the following non-degenerated deterministic limiting kernel Θ = E f ∼Σ L [σ(f (x))σ(f (x ))] + β 2 (42) Consider the output layer h L+1 = W L √ d σ(h L ) + βb L . We can see that for any parameter θ i before the output layer, ∇ θi h L+1 = diag( σ(h L )) W L √ d L ∇ θi h L = 0 And for W L and b L , we have ∇ W L h L+1 = 1 √ d L σ(h L ) and ∇ b L h L+1 = β (44) Then we can achieve (42) by the law of large numbers.

D.3.2 PROOF OF LEMMA 5

We will use the following short-hand in the proof:      g(θ (t) ) = f (t) (X) -Y J(θ (t) ) = ∇ θ f (X; θ (t) ) ∈ R p×n Θ (t) = J(θ (t) ) J(θ (t) ) For any > 0, there exists t such that for all t ≥ t and all i, q (t) i ∈ (q i -, q i + ). Like what we have done in (34), we can rewrite Q (t) = Q (t) = Q (t) 3 • √ Q, where Q = diag(q 1 , • • • , q n ). The update rule of a GRW with gradient descent and the squared loss for the wide neural network is: θ (t+1) = θ (t) -ηJ(θ (t) )Q (t) g(θ (t) ) and for t ≥ t , it can be rewritten as θ (t+1) = θ (t) -ηJ(θ (t) ) Q (t) 3 Qg(θ (t) ) First, we will prove the following theorem: Theorem 12. There exist constants M > 0 and 0 > 0 such that for all ∈ (0, 0 ], η ≤ η * and any δ > 0, there exist R 0 > 0, D > 0 and B > 1 such that for any d ≥ D, the following (i) and (ii) hold with probability at least (1 -δ) over random initialization when applying gradient descent with learning rate η: (i) For all t ≤ t , there is g(θ (t) ) 2 ≤ B t R 0 (48) t j=1 θ (j) -θ (j-1) 2 ≤ ηM R 0 t j=1 B j-1 < M B t R 0 B -1 (ii) For all t ≥ t , we have Qg(θ (t) ) 2 ≤ 1 - ηq * λ min 3 t-t B t R 0 (50) t j=t +1 θ (j) -θ (j-1) 2 ≤ η √ 1 + 3 M B t R 0 t j=t +1 1 - ηq * λ min 3 j-t < 3 √ 1 + 3 M B t R 0 q * λ min (51) Proof. The proof is based on the following lemma: Lemma 13 (Local Lipschitzness of the Jacobian). Under Assumption 2, there is a constant M > 0 such that for any C 0 > 0 and any δ > 0, there exists a D such that: If d ≥ D, then with probability at least (1 -δ) over random initialization, for any x such that x 2 ≤ 1,                    ∇ θ f (x; θ) -∇ θ f (x; θ) 2 ≤ M 4 d θ -θ 2 ∇ θ f (x; θ) 2 ≤ M J(θ) -J( θ) F ≤ M 4 d θ -θ 2 J(θ) F ≤ M , ∀θ, θ ∈ B(θ (0) , C 0 ) (52) where B(θ (0) , R) = {θ : θ -θ (0) 2 < R}. The proof of this lemma can be found in Appendix D.3.3. Note that for any x, f (0) (x) = βb L where b L is sampled from the standard Gaussian distribution. Thus, for any δ > 0, there exists a constant R 0 such that with probability at least (1 -δ/3) over random initialization, g(θ (0) ) 2 < R 0 And by Proposition 3, there exists D 2 ≥ 0 such that for any d ≥ D 2 , with probability at least (1 -δ/3), Θ -Θ (0) F ≤ q * λ min 3 (54) Let M be the constant in Lemma 13. Let 0 = (q * λ min ) 2 108M 4 . Let B = 1 + η * M 2 , and C 0 = M B t R0 B-1 + 3 √ 1+3 M B t R0 q * λ min . By Lemma 13, there exists D 1 > 0 such that with probability at least (1 -δ/3), 52) is true for all θ, θ ∈ B(θ (0) , C 0 ). for any d ≥ D 1 , ( By union bound, with probability at least (1 -δ), (52), ( 53) and ( 54) are all true. Now we assume that all of them are true, and prove ( 48) and ( 49) by induction. ( 48) is true for t = 0 due to (53), and ( 49) is always true for t = 0. Suppose ( 48) and ( 49) are true for t, then for t + 1 we have θ (t+1) -θ (t) 2 ≤ η J(θ (t) )Q (t) 2 g(θ (t) ) 2 ≤ η J(θ (t) )Q (t) F g(θ (t) ) 2 ≤ η J(θ (t) ) F g(θ (t) ) 2 ≤ M ηB t R 0 (55) So ( 49) is also true for t + 1. And we also have g(θ (t+1) ) 2 = g(θ (t+1) ) -g(θ (t) ) + g(θ (t) ) 2 = J( θ(t) ) (θ (t+1) -θ (t) ) + g(θ (t) ) 2 = -ηJ( θ(t) ) J(θ (t) )Q (t) g(θ (t) ) + g(θ (t) ) 2 ≤ I -ηJ( θ(t) ) J(θ (t) )Q (t) 2 g(θ (t) ) 2 ≤ 1 + ηJ( θ(t) ) J(θ (t) )Q (t) 2 g(θ (t) ) 2 ≤ 1 + η J( θ(t) ) F J(θ (t) ) F g(θ (t) ) 2 ≤ (1 + η * M 2 ) g(θ (t) ) 2 ≤ B t+1 R 0 (56) Therefore, ( 48) and ( 49) are true for all t ≤ t , which implies that √ Qg(θ (t ) ) 2 ≤ g(θ (t ) ) 2 ≤ B t R 0 , so ( 50) is true for t = t . And ( 51) is obviously true for t = t . Now, let us prove (ii) by induction. Note that when t ≥ t , we have the alternative update rule (47). If ( 50) and ( 51) are true for t, then for t + 1, there is θ (t+1) -θ (t) 2 ≤ η J(θ (t) ) Q (t) 3 2 Qg(θ (t) ) 2 ≤ η J(θ (t) ) Q (t) 3 F Qg(θ (t) ) 2 ≤ η √ 1 + 3 J(θ (t) ) F Qg(θ (t) ) 2 ≤ M η √ 1 + 3 1 - ηq * λ min 3 t-t B t R 0 (57) So ( 51) is true for t + 1. And we also have Qg(θ (t+1) ) 2 = Qg(θ (t+1) ) -Qg(θ (t) ) + Qg(θ (t) ) 2 = QJ( θ(t) ) (θ (t+1) -θ (t) ) + Qg(θ (t) ) 2 = -η QJ( θ(t) ) J(θ (t) )Q (t) g(θ (t) ) + Qg(θ (t) ) 2 ≤ I -η QJ( θ(t) ) J(θ (t) ) Q (t) 3 2 Qg(θ (t) ) 2 ≤ I -η QJ( θ(t) ) J(θ (t) ) Q (t) 3 2 1 - ηq * λ min 3 t R 0 ( ) where θ(t) is some linear interpolation between θ (t) and θ (t+1) . Now we prove that I -η QJ( θ(t) ) J(θ (t) ) Q (t) 3 2 ≤ 1 - ηq * λ min 3 (59) For any unit vector v ∈ R n , we have v (I -η QΘ Q)v = 1 -ηv QΘ Qv (60) √ Qv 2 ∈ [ √ q * , 1], so for any η ≤ η * , v (I -η √ QΘ √ Q)v ∈ [0, 1 -ηλ min q * ], which implies that I -η √ QΘ √ Q 2 ≤ 1 -ηλ min q * . Thus, I -η QJ( θ(t) ) J(θ (t) ) Q 2 ≤ I -η QΘ Q 2 + η Q(Θ -Θ (0) ) Q 2 + η Q(J(θ (0) ) J(θ (0) ) -J( θ(t) ) J(θ (t) )) Q 2 ≤1 -ηλ min q * + η Q(Θ -Θ (0) ) Q F + η Q(J(θ (0) ) J(θ (0) ) -J( θ(t) ) J(θ (t) )) Q F ≤1 -ηλ min q * + η Θ -Θ (0) F + η J(θ (0) ) J(θ (0) ) -J( θ(t) ) J(θ (t) ) F ≤1 -ηλ min q * + ηq * λ min 3 + ηM 2 4 d θ (t) -θ (0) 2 + θ(t) -θ (0) 2 ≤ 1 - ηq * λ min 2 (61) for all d ≥ max D 1 , D 2 , 12M 2 C0 q * λ min 4 , which implies that I -η QJ( θ(t) ) J(θ (t) ) Q (t) 3 2 ≤1 - ηq * λ min 2 + η QJ( θ(t) ) J(θ (t) ) Q (t) 3 -Q 2 ≤1 - ηq * λ min 2 + ηM 2 √ 3 ≤ 1 - ηq * λ min 3 (due to (37)) (62) for all ≤ 0 . Thus, ( 50) is also true for t + 1. In conclusion, ( 50) and ( 51) are true with probability at least (1 -δ) for all d ≥ D = max D 1 , D 2 , 12M 2 C0 q * λ min 4 . Returning back to the proof of Lemma 5. Choose and fix an such that < min{ 0 , 1 3 q * λ min 3λ max +q * λ min 2 }, where 0 is defined by Theorem 12. Then, t is also fixed. There exists D ≥ 0 such that for any d ≥ D, with probability at least (1 -δ), Theorem 12 and Lemma 13 are true and Θ -Θ (0) F ≤ q * λ min 3 (63) which immediately implies that Θ (0) 2 ≤ Θ 2 + Θ -Θ (0) F ≤ λ max + q * λ min 3 (64) We still denote B = 1 + η * M 2 and C 0 = M B t R0 B-1 + 3 √ 1+3 M B t R0 q * λ min . Theorem 12 ensures that for all t, θ (t) ∈ B(θ (0) , C 0 ). Then we have I -η QΘ (0) Q 2 ≤ I -η QΘ Q 2 + η Q(Θ -Θ (0) ) Q 2 ≤ 1 -ηλ min q * + ηq * λ min 3 = 1 - 2ηq * λ min 3 (65) so it follows that I -η QΘ (0) Q (t) 3 2 ≤ I -η QΘ (0) Q 2 + η QΘ (0) Q (t) 3 -Q 2 ≤ 1 - 2ηq * λ min 3 + η(λ max + q * λ min 3 ) √ 3 Thus, for all < 1 3 q * λ min 3λ max +q * λ min 2 , there is I -η QΘ (0) Q (t) 3 2 ≤ 1 - ηq * λ min 3 (67) The update rule of the GRW for the linearized neural network is: θ (t+1) lin = θ (t) lin -ηJ(θ (0) )Q (t) g lin (θ (t) ) ( ) where we use the subscript "lin" to denote the linearized neural network, and with a slight abuse of notion denote g lin (θ (t) ) = g(θ (t) lin ). First, let us consider the training data X. Denote ∆ t = g lin (θ (t) ) -g(θ (t) ). We have g lin (θ (t+1) ) -g lin (θ (t) ) = -ηJ(θ (0) ) J(θ (0) )Q (t) g lin (θ (t) ) g(θ (t+1) ) -g(θ (t) ) = -ηJ( θ(t) ) J(θ (t) )Q (t) g(θ (t) ) ( ) where θ(t) is some linear interpolation between θ (t) and θ (t+1) . Thus, ∆ t+1 -∆ t =η J( θ(t) ) J(θ (t) ) -J(θ (0) ) J(θ (0) ) Q (t) g(θ (t) ) -ηJ(θ (0) ) J(θ (0) )Q (t) ∆ t (70) By Lemma 13, we have J( θ(t) ) J(θ (t) ) -J(θ (0) ) J(θ (0) ) F ≤ J( θ(t) ) -J(θ (0) ) J(θ (t) ) F + J(θ (0) ) J(θ (t) ) -J(θ (0) ) F ≤2M 2 C 0 d-1/4 (71) which implies that for all t < t , ∆ t+1 2 ≤ I -ηJ(θ (0) ) J(θ (0) )Q (t) ∆ t 2 + η J( θ(t) ) J(θ (t) ) -J(θ (0) ) J(θ (0) ) Q (t) g(θ (t) ) 2 ≤ I -ηJ(θ (0) ) J(θ (0) )Q (t) F ∆ t 2 + η J( θ(t) ) J(θ (t) ) -J(θ (0) ) J(θ (0) ) F g(θ (t) ) 2 ≤ (1 + ηM 2 ) ∆ t 2 + 2ηM 2 C 0 B t R 0 d-1/4 ≤ B ∆ t 2 + 2ηM 2 C 0 B t R 0 d-1/4 (72) Therefore, we have B -(t+1) ∆ t+1 2 ≤ B -t ∆ t 2 + 2ηM 2 C 0 B -1 R 0 d-1/4 ( ) Since ∆ 0 = 0, it follows that for all t ≤ t , ∆ t 2 ≤ 2tηM 2 C 0 B t-1 R 0 d-1/4 (74) and particularly we have Q∆ t 2 ≤ ∆ t 2 ≤ 2t ηM 2 C 0 B t -1 R 0 d-1/4 (75) For t ≥ t , we have the alternative update rule (47). Thus, Q∆ t+1 -Q∆ t =η Q J( θ(t) ) J(θ (t) ) -J(θ (0) ) J(θ (0) ) Q (t) 3 Qg(θ (t) ) -η (0) ) J(θ (0) ) Q (t) 3 Q∆ t (76) Let A = I -η √ QJ(θ (0) ) J(θ (0) ) Q (t) 3 = I -η √ QΘ (0) Q (t) 3 . Then, we have Q∆ t+1 = A Q∆ t +η Q J( θ(t) ) J(θ (t) ) -J(θ (0) ) J(θ (0) ) Q (t) 3 Qg(θ (t) ) (77) Let γ = 1 -ηq * λ min 3 < 1. Combining with Theorem 12 and (67), the above leads to Q∆ t+1 2 ≤ A 2 Q∆ t 2 + η Q J( θ(t) ) J(θ (t) ) -J(θ (0) ) J(θ (0) ) Q (t) 3 2 Qg(θ (t) ) 2 ≤ γ Q∆ t 2 + η J( θ(t) ) J(θ (t) ) -J(θ (0) ) J(θ (0) ) F √ 1 + 3 γ t-t B t R 0 ≤ γ Q∆ t 2 + 2ηM 2 C 0 √ 1 + 3 γ t-t B t R 0 d-1/4 (78) This implies that γ -(t+1) Q∆ t+1 2 ≤ γ -t Q∆ t 2 + 2ηM 2 C 0 √ 1 + 3 γ -1-t B t R 0 d-1/4 (79) Combining with (75), it implies that for all t ≥ t , Q∆ t 2 ≤ 2γ t-t ηM 2 C 0 B t R 0 t B -1 + √ 1 + 3 γ -1 (t -t ) d-1/4 (80) Next, we consider an arbitrary test point x such that x 2 ≤ 1. Denote δ t = f (t) lin (x) -f (t) (x). Then we have f (t+1) lin (x) -f (t) lin (x) = -η∇ θ f (x; θ (0) ) J(θ (0) )Q (t) g lin (θ (t) ) f (t+1) (x) -f (t) (x) = -η∇ θ f (x; θ(t) ) J(θ (t) )Q (t) g(θ (t) ) (81) which yields δ t+1 -δ t =η ∇ θ f (x; θ(t) ) J(θ (t) ) -∇ θ f (x; θ (0) ) J(θ (0) ) Q (t) g(θ (t) ) -η∇ θ f (x; θ (0) ) J(θ (0) )Q (t) ∆ t (82) For t ≤ t , we have δ t 2 ≤η t-1 s=0 ∇ θ f (x; θ(s) ) J(θ (s) ) -∇ θ f (x; θ (0) ) J(θ (0) ) Q (s) 2 g(θ (s) ) 2 + η t-1 s=0 ∇ θ f (x; θ (0) ) J(θ (0) )Q (s) 2 ∆ s 2 ≤η t-1 s=0 ∇ θ f (x; θ(s) ) J(θ (s) ) -∇ θ f (x; θ (0) ) J(θ (0) ) F g(θ (s) ) 2 + η t-1 s=0 ∇ θ f (x; θ (0) ) 2 J(θ (0) ) F ∆ s 2 ≤2ηM 2 C 0 d-1/4 t-1 s=0 B s R 0 + ηM 2 t-1 s=0 (2sηM 2 C 0 B s-1 R 0 d-1/4 ) So we can see that there exists a constant C 1 such that δ t 2 ≤ C 1 d-1/4 . Then, for t > t , we have δ t 2 -δ t 2 ≤η t-1 s=t ∇ θ f (x; θ(s) ) J(θ (s) ) -∇ θ f (x; θ (0) ) J(θ (0) ) Q (s) 3 2 Qg(θ (s) ) 2 + η t-1 s=t ∇ θ f (x; θ (0) ) J(θ (0) ) Q (s) 3 2 Q∆ s 2 ≤2ηM 2 C 0 d-1/4 √ 1 + 3 t-1 s=t γ s-t B t R 0 + ηM 2 √ 1 + 3 t-1 s=t 2γ s-t ηM 2 C 0 B t R 0 t B -1 + √ 1 + 3 γ -1 (s -t ) d-1/4 (84) Note that ∞ t=0 tγ t is finite as long as γ ∈ (0, 1). Therefore, there is a constant C such that for any t, δ t 2 ≤ C d-1/4 with probability at least (1 -δ) for any d ≥ D.

D.3.3 PROOF OF LEMMA 13

We will use the following theorem regarding the eigenvalues of random Gaussian matrices: Theorem 14 (Corollary 5.35 in Vershynin (2010) ). If A ∈ R p×q is a random matrix whose entries are independent standard normal random variables, then for every t ≥ 0, with probability at least 1 -2 exp(-t 2 /2), √ p - √ q -t ≤ λ min (A) ≤ λ max (A) ≤ √ p + √ q + t By this theorem, and also note that W L is a vector, we can see that for any δ, there exist D > 0 and M 1 > 0 such that if d ≥ D, then with probability at least (1 -δ), for all θ ∈ B(θ (0) , C 0 ), we have W l 2 ≤ 3 d (∀0 ≤ l ≤ L -1) and W L 2 ≤ C 0 ≤ 3 4 d as well as βb l 2 ≤ M 1 d (∀l = 0, • • • , L) Now we assume that ( 86) and ( 87) are true. Then, for any x such that x 2 ≤ 1, h 1 2 = 1 √ d 0 W 0 x + βb 0 2 ≤ 1 √ d 0 W 0 2 x 2 + βb 0 2 ≤ ( 3 √ d 0 + M 1 ) d h l+1 2 = 1 d W l x l + βb l 2 ≤ 1 d W l 2 x l 2 + βb l 2 (∀l ≥ 1) x l 2 = σ(h l ) -σ(0 l ) + σ(0 l ) 2 ≤ L 0 h l 2 + σ(0) d (∀l ≥ 1) where L 0 is the Lipschitz constant of σ and σ(0 l ) = (σ(0), • • • , σ(0)) ∈ R d l . By induction, there exists an M 2 > 0 such that x l 2 ≤ M 2 d and h l 2 ≤ M 2 d for all l = 1, • • • , L. Denote α l = ∇ h l f (x) = ∇ h l h L+1 . For all = 1, • • • , L, we have α l = diag( σ(h l )) W l √ d α l+1 where σ(x) ≤ L 0 for all x ∈ R since σ is L 0 -Lipschitz, α L+1 = 1 and α L 2 = diag( σ(h L )) W L √ d 2 ≤ 3 4 √ d L 0 . Then, we can easily prove by induction that there exists an 86) and (87) are true, then there exists an M 4 > 0, such that ∇ θ f (x) 2 ≤ M 4 / √ n. And since x i 2 ≤ 1 for all i, so J(θ) F ≤ M 4 . M 3 > 1 such that α l 2 ≤ M 3 / 4 d for all l = 1, • • • , L (note that this is not true for L + 1 because α L+1 = 1). For l = 0, ∇ W 0 f (x) = 1 √ d0 x 0 α 1 , so ∇ W l f (x) 2 ≤ 1 √ d0 x 0 2 α 1 2 ≤ 1 √ d0 M 3 / 4 d. And for any l = 1, • • • , L, ∇ W l f (x) = 1 √ d x l α l+1 , so ∇ W l f (x) 2 ≤ 1 √ d x l 2 α l+1 2 ≤ M 2 M 3 . (Note that if M 3 > 1, then α L+1 2 ≤ M 3 ; and since d ≥ 1, there is α l 2 ≤ M 3 for l ≤ L.) Moreover, for l = 0, • • • , L, ∇ b l f (x) = βα l+1 , so ∇ b l f (x) 2 ≤ βM 3 . Thus, if ( Next, we consider the difference in ∇ θ f (x) between θ and θ. Let f , W , b, x, h, α be the function and the values corresponding to θ. There is h 1 -h1 2 = 1 √ d 0 (W 0 -W 0 )x + β(b 0 -b0 ) 2 ≤ 1 √ d 0 W 0 -W 0 2 x 2 + β b 0 -b0 2 ≤ 1 √ d 0 + β θ -θ 2 h l+1 -hl+1 2 = 1 d W l (x l -xl ) + 1 d (W l -W l ) xl + β(b l -bl ) 2 ≤ 1 d W l 2 x l -xl 2 + 1 d W l -W l 2 xl 2 + β b l -bl 2 ≤ 3 x l -xl 2 + (M 2 + β) θ -θ 2 (∀l ≥ 1) x l -xl 2 = σ(h l ) -σ( hl ) 2 ≤ L 0 h l -hl 2 (∀l ≥ 1) By induction, there exists an M 5 > 0 such that x l -xl 2 ≤ M 5 θ -θ 2 for all l. For α l , we have α L+1 = αL+1 = 1, and for all l ≥ 1, α l -αl 2 = diag( σ(h l )) W l d α l+1 -diag( σ( hl )) W l d αl+1 2 ≤ diag( σ(h l )) W l d (α l+1 -αl+1 ) 2 + diag( σ(h l )) (W l -W l ) d αl+1 2 + diag(( σ(h l ) -σ( hl ))) W l d αl+1 2 ≤ 3L 0 α l+1 -αl+1 2 + M 3 L 0 d-1/2 + 3M 3 M 5 L 1 d-1/4 θ -θ 2 (90) where 90) is still true. By induction, there exists an L 1 is the Lipschitz constant of σ. Particularly, for l = L, though αL+1 = 1, since W L 2 ≤ 3 d1/4 , ( M 6 > 0 such that α l -αl 2 ≤ M6 4 √ d θ -θ 2 for all l ≥ 1 (note that this is also true for l = L + 1). Thus, if ( 86) and ( 87) are true, then for all θ, θ ∈ B(θ (0) , C 0 ), any x such that x 2 ≤ 1, we have ∇ W 0 f (x) -∇ W 0 f (x) 2 = 1 √ d 0 xα 1 -x α1 2 ≤ 1 √ d 0 α 1 -α1 2 ≤ 1 √ d 0 M 6 4 d θ -θ 2 and for l = 1, • • • , L, we have ∇ W l f (x) -∇ W l f (x) 2 = 1 d x l α l+1 -xl αl+1 2 ≤ 1 d x l 2 α l+1 -αl+1 2 + x l -xl 2 αl+1 2 ≤ M 2 M 6 4 d + M 5 M 3 d θ -θ 2 Moreover, for any l = 0, • • • , L, there is ∇ b l f (x) -∇ bl f (x) 2 = β α l+1 -αl+1 2 ≤ βM 6 4 d θ -θ 2 Overall, we can see that there exists a constant M 7 > 0 such that ∇ θ f (x) -∇ θ f (x) 2 ≤ M7 √ n• 4 √ d θ -θ 2 , so that J(θ) -J( θ) F ≤ M7 4 √ d θ -θ 2 . D.3.4 PROOF OF THEOREM 4 First of all, for a linearized neural network (11), if we view {∇ θ f (0) (x i )} n i=1 as the inputs and {y i -f (0) (x i ) + θ (0) , ∇ θ f (0) (x i ) } n i=1 as the targets, then the model becomes a linear model. So by Theorem 2 we have the following corollary: Corollary 15. If ∇ θ f (0) (x 1 ), • • • , ∇ θ f (0) (x n ) are linearly independent, then there exists η 0 > 0 such that for any GRW satisfying Assumption 1, and any η ≤ η 0 , θ (t) converges to the same interpolator θ * that does not depend on q i . Let η 1 = min{η 0 , η * }, where η 0 is defined in Corollary 15 and η * is defined in Lemma 5. Let f (t) lin (x) and f (t) linERM (x) be the linearized neural networks of f (t) (x) and f (t) ERM (x), respectively. By Lemma 5, for any δ > 0, there exists D > 0 and a constant C such that        sup t≥0 f (t) lin (x) -f (t) (x) ≤ C d-1/4 sup t≥0 f (t) linERM (x) -f (t) ERM (x) ≤ C d-1/4 (94) By Corollary 15, we have lim t→∞ f (t) lin (x) -f (t) linERM (x) = 0 Summing the above yields lim sup t→∞ f (t) (x) -f (t) ERM (x) ≤ 2C d-1/4 which is the result we want. D.4 PROOFS FOR SUBSECTION 4.3

D.4.1 A NEW APPROXIMATION THEOREM

Lemma 16 (Approximation Theorem for Regularized GRW). For a wide fully-connected neural network f , denote J(θ) = ∇ θ f (X; θ) ∈ R p×n and g(θ) = ∇ ŷ (f (X; θ), Y ) ∈ R n Given that the loss function satisfies: ∇ θ g(θ) = J(θ)U (θ) for any θ, and U (θ) is a positive semi-definite diagonal matrix whose elements are uniformly bounded, we have: for any GRW that minimizes the regularized weighted empirical risk (13) with a sufficiently small learning rate η, there is: for a sufficiently large d, with high probability over random initialization, on any test point x such that x 2 ≤ 1, sup t≥0 f (t) linreg (x) -f (t) reg (x) ≤ C d-1/4 where both f (t) linreg and f (t) reg are trained by the same regularized GRW and start from the same initial point. First of all, with some simple linear algebra analysis, we can prove the following proposition: Proposition 17. For any positive definite symmetric matrix H ∈ R n×n , denote its largest and smallest eigenvalues by λ max and λ min . Then, for any positive semi-definite diagonal matrix Q = diag(q 1 , • • • , q n ), HQ has n eigenvalues that all lie in [min i q i • λ min , max i q i • λ max ]. Proof. H is a positive definite symmetric matrix, so there exists A ∈ R n×n such that H = A A, and A is full-rank. First, any eigenvalue of AQA is also an eigenvalue of A AQ, because for any eigenvalue λ of AQA we have some v = 0 such that AQA v = λv. Multiplying both sides by A on the left yields A AQ(A v) = λ(A v) which implies that λ is also an eigenvalue of A AQ because A v = 0 as λv = 0. Second, by condition we know that the eigenvalues of A A are all in [λ min , λ max ] where λ min > 0, which implies for any unit vector v, v A Av ∈ [λ min , λ max ], which is equivalent to Av 2 ∈ [ √ λ min , √ λ max ] . Thus, we have v A QAv ∈ [λ min min i q i , λ max max i q i ], which implies that the eigenvalues of A QA are all in [λ min min i q i , λ max max i q i ]. Thus, the eigenvalues of HQ = A AQ are all in [λ min min i q i , λ max max i q i ]. Proof of Lemma 16 By the condition satisfies, without loss of generality, assume that the elements of U (θ) are in [0, 1] for all θ. Then, let η ≤ (µ + λ min + λ max ) -1 . (If the elements of U (θ) are bounded by [0, C], then we can let η ≤ (µ + Cλ min + Cλ max ) -1 and prove the result in the same way.) With L 2 penalty, the update rule of the GRW for the neural network is: θ (t+1) = θ (t) -ηJ(θ (t) )Q (t) g(θ (t) ) -ηµ(θ (t) -θ (0) ) And the update rule for the linearized neural network is: θ (t+1) lin = θ (t) lin -ηJ(θ (0) )Q (t) g(θ (t) lin ) -ηµ(θ (t) lin -θ (0) ) By Proposition 11, f (x; converges in probability to a zero-mean Gaussian process. Thus, for any δ > 0, there exists a constant R 0 > 0 such that with probability at least (1 -δ/3), g(θ (0) ) 2 < R 0 . Let M be as defined in Lemma 13. Denote A = ηM R 0 , and let C 0 = 4A ηµ in Lemma 13foot_5 . By Lemma 13, there exists D 1 such that for all d ≥ D 1 , with probability at least (1 -δ/3), (52) is true. Similar to the proof of Proposition 17, we can show that for arbitrary θ, all non-zero eigenvalues of J(θ (0) )Q (t) U ( θ)J(θ (0) ) are eigenvalues of J(θ (0) ) J(θ (0) )Q (t) U ( θ). This is because for any λ = 0, if J(θ (0) )Q (t) U ( θ)J(θ (0) ) v = λv, then J(θ (0) ) J(θ (0) )Q (t) U ( θ)(J(θ (0) ) v) = λ(J(θ (0) ) v) , and J(θ (0) ) v = 0 since λv = 0, so λ is also an eigenvalue of J(θ (0) ) J(θ (0) )Q (t) U ( θ). On the other hand, by Proposition 3, J(θ (0) ) J(θ (0) )Q (t) U ( θ) converges in probability to ΘQ (t) U ( θ) whose eigenvalues are all in [0, λ max ] by Proposition 17. So there exists D 2 such that for all d ≥ D 2 , with probability at least (1 -δ/3), the eigenvalues of J(θ (0) )Q (t) U ( θ)J(θ (0) ) are all in [0, λ max + λ min ] for all t. By union bound, with probability at least (1 -δ), all three above are true, which we will assume in the rest of this proof. First, we need to prove that there exists D 0 such that for all d ≥ D 0 , sup t≥0 θ (t) -θ (0) 2 is bounded with high probability. Denote a t = θ (t) -θ (0) . By (98) we have a t+1 =(1 -ηµ)a t -η[J(θ (t) ) -J(θ (0) )]Q (t) g(θ (t) ) -ηJ(θ (0) )Q (t) [g(θ (t) ) -g(θ (0) )] -ηJ(θ (0) )Q (t) g(θ (0) ) which implies a t+1 2 ≤ (1 -ηµ)I -ηJ(θ (0) )Q (t) U ( θ(t) )J( θ(t) ) 2 a t 2 + η J(θ (t) ) -J(θ (0) ) F g(θ (t) ) 2 + η J(θ (0) ) F g(θ (0) ) 2 where θ(t) is some linear interpolation between θ (t) and θ (0) . Our choice of η ensures that ηµ < 1. Now we prove by induction that a t 2 < C 0 . It is true for t = 0, so we need to prove that if a t 2 < C 0 , then a t+1 2 < C 0 . For the first term on the right-hand side of (101), we have (1 -ηµ)I -ηJ(θ (0) )Q (t) U ( θ(t) )J( θ(t) ) 2 ≤(1 -ηµ) I - η 1 -ηµ J(θ (0) )Q (t) U ( θ(t) )J(θ (0) ) 2 + η J(θ (0) ) F J( θ(t) ) -J(θ (0) ) F Since η/(1 -ηµ) ≤ (λ min + λ max ) -1 by our choice of η, we have I - η 1 -ηµ J(θ (0) )Q (t) U ( θ(t) )J(θ (0) ) 2 ≤ 1 (103) On the other hand, we can use (52) since a t 2 < C 0 , so J(θ (0) ) F J( θ(t) ) -J(θ (0) ) F ≤ M 2 4 √ d C 0 . Therefore, there exists D 3 such that for all d ≥ D 3 , (1 -ηµ)I -ηJ(θ (0) )Q (t) U ( θ(t) )J( θ(t) ) 2 ≤ 1 - ηµ 2 For the second term, we have g(θ (t) ) 2 ≤ g(θ (t) ) -g(θ (0) ) 2 + g(θ (0) ) 2 ≤ J( θ(t) ) 2 U ( θ(t) ) 2 θ (t) -θ (0) 2 + R 0 ≤ M C 0 + R 0 (105) And for the third term, we have η J(θ (0) ) F g(θ (0) ) 2 ≤ ηM R 0 = A (106) we have a t+1 2 ≤ 1 - ηµ 2 a t 2 + ηM (M C 0 + R 0 ) 4 d + A So there exists D 4 such that for all d ≥ D 4 , a t+1 2 ≤ 1 -ηµ 2 a t 2 + 2A. This shows that if a t 2 < C 0 is true, then a t+1 2 < C 0 will also be true.

In conclusion, for all

d ≥ D 0 = max{D 1 , D 2 , D 3 , D 4 }, θ (t) -θ (0) 2 < C 0 is true for all t. This also implies that for C 1 = M C 0 + R 0 , we have g(θ (t) ) 2 ≤ C 1 for all t by (105). Similarly, we can prove that θ (t) lin -θ (0) 2 < C 0 for all t. Second, let ∆ t = θ (t) lin -θ (t) . Then we have ∆ t+1 -∆ t = η(J(θ (t) )Q (t) g(θ (t) ) -J(θ (0) )Q (t) g(θ (t) lin ) -µ∆ t ) which implies ∆ t+1 = (1 -ηµ)I -ηJ(θ (0) )Q (t) U ( θ(t) )J( θ(t) ) ∆ t + η(J(θ (t) ) -J(θ (0) ))Q (t) g(θ (t) ) ) where θ(t) is some linear interpolation between θ (t) and θ (t) lin . By (104), with probability at least (1 -δ) for all d ≥ D 0 , we have ∆ t+1 2 ≤ (1 -ηµ)I -ηJ(θ (0) )Q (t) U ( θ(t) )J( θ(t) ) 2 ∆ t 2 + η J(θ (t) ) -J(θ (0) ) F g(θ (t) ) 2 ≤ 1 - ηµ 2 ∆ t 2 + η M 4 d C 0 C 1 Again, as ∆ 0 = 0, we can prove by induction that for all t, ∆ t 2 < 2M C 0 C 1 µ d-1/4 For any test point x such that x 2 ≤ 1, we have f (t) reg (x) -f (t) linreg (x) = f (x; θ (t) ) -f lin (x; θ (t) lin ) ≤ f (x; θ (t) ) -f lin (x; θ (t) ) + f lin (x; θ (t) ) -f lin (x; θ (t) lin ) ≤ f (x; θ (t) ) -f lin (x; θ (t) ) + ∇ θ f (x; θ (0) ) 2 θ (t) -θ (t) lin 2 ≤ f (x; θ (t) ) -f lin (x; θ (t) ) + M ∆ t 2 For the first term, note that f (x; θ (t) ) -f (x; θ (0) ) = ∇ θ f (x; θ(t) )(θ (t) -θ (0) ) f lin (x; θ (t) ) -f lin (x; θ (0) ) = ∇ θ f (x; θ (0) )(θ (t) -θ (0) ) where θ(t) is some linear interpolation between θ (t) and θ (0) . Since f (x; θ (0) ) = f lin (x; θ (0) ), f (x; θ (t) ) -f lin (x; θ (t) ) ≤ ∇ θ f (x; θ(t) ) -∇ θ f (x; θ (0) ) 2 θ (t) -θ (0) 2 ≤ M 4 d C 2 0 (114) Thus, we have shown that for all d D 0 , with probability at least (1 -δ) for all t and all x, f (t) reg (x) -f (t) linreg (x) ≤ M C 2 0 + 2M 2 C 0 C 1 µ d-1/4 = O( d-1/4 ) which is the result we need.

D.4.2 RESULT FOR LINEARIZED NEURAL NETWORKS

Lemma 18. Suppose there exists  M 0 > 0 such that ∇ θ f (0) (x) 2 ≤ M 0 for all test point x. If the gradients ∇ θ f (0) (x 1 ), • • • , ∇ θ f (0) (x n ) are (t) linreg ) < , for some > 0, then for x such that x 2 ≤ 1 we have lim sup t→∞ f (t) linreg (x) -f (t) linERM (x) = O( √ ). First, we can see that under the new weight update rule, θ (t) -θ (0) ∈ span{∇ θ f (0) (x 1 ), • • • , ∇ θ f (0) (x n )} is still true for all t. Let θ * be the interpolator in span(∇ θ f (0) (x 1 ), • • • , ∇ θ f (0) (x n )), then the empirical risk of θ is 1 2n n i=1 θ - θ * , ∇ θ f (0) (x i ) 2 = 1 2n ∇ θ f (0) (X) (θ -θ * ) 2 2 . Thus, there exists T > 0 such that for any t ≥ T , ∇ θ f (0) (X) (θ (t) -θ * ) 2 2 ≤ 2n Let the smallest singular value of 1 √ n ∇ θ f (0) (X) be s min , and we have s min > 0. Note that the column space of ∇ θ f (0) (X) is exactly span(∇ θ f (0) (x 1 ), • • • , ∇ θ f (0) (x n )) . Define H ∈ R p×n such that its columns form an orthonormal basis of this subspace, then there exists G ∈ R n×n such that ∇ θ f (0) (X) = HG, and the smallest singular value of 1 √ n G is also s min . Since θ (t) -θ (0) is also in this subspace, there exists v ∈ R n such that θ (t) -θ * = Hv. Then we have √ 2n ≥ G H Hv 2 = G v 2 . Thus, v 2 ≤ √ 2 s min , which implies θ (t) -θ * 2 ≤ √ 2 s min We have already proved in previous results that if we minimize the unregularized risk with ERM, then θ always converges to the interpolator θ * . So for any t ≥ T and any test point x such that x 2 ≤ 1, we have |f (t) linreg (x) -f (t) linERM (x)| = | θ (t) -θ * , ∇ θ f (0) (x) | ≤ M 0 √ 2 s min (120) which implies (117).

D.4.3 PROOF OF THEOREM 6

Given that R(f (t) linreg ) < for sufficiently large t, Lemma 16 implies that R(f (t) linreg ) -R(f (t) reg ) = O( d-1/4 √ + d-1/2 ) So for a fixed , there exists D > 0 such that for all d ≥ D, for sufficiently large t, R(f (t) reg ) < ⇒ R(f (t) linreg ) < 2 (122) By Lemma 5 and Lemma 16, we have        sup t≥0 f (t) linERM (x) -f (t) ERM (x) = O( d-1/4 ) sup t≥0 f (t) linreg (x) -f (t) reg (x) = O( d-1/4 ) Lemma 18 with (123) derives lim sup t→∞ f (t) reg (x) -f (t) ERM (x) = O( d-1/4 + √ ) Letting d → ∞ leads to the result we need. Remark. One might wonder whether ∇ θ f (0) (x) 2 will diverge as d → ∞. In fact, in Lemma 13, we have proved that there exists a constant M such that with high probability, for any d there is ∇ θ f (0) (x) 2 ≤ M for any x such that x 2 ≤ 1. Therefore, it is fine to suppose that there exists such an M 0 . D.5 PROOFS FOR SUBSECTION 5.1 D.5.1 PROOF OF THEOREM 7 First we need to show that θMM is unique. Suppose both θ 1 and θ 2 maximize min i=1,••• ,n y i • θ, x i and θ 1 = θ 2 , θ 1 2 = θ 2 2 = 1. Then consider θ 0 = θ/ θ 2 where θ = (θ 1 + θ 2 )/2. Obviously, θ 2 < 1, and for any i, y i • θ, x i = (y i • θ 1 , x i + y i • θ 2 , x i )/2, so y i • θ 0 , x i > min{y i • θ 1 , x i , y i • θ 2 , x i }, which implies that min i=1,••• ,n y i • θ 0 , x i > min{min i=1,••• ,n y i • θ 1 , x i , min i=1,••• ,n y i • θ 2 , x i }, contradiction! Now we start proving the result. Without loss of generality, let (x 1 , y 1 ), • • • , (x m , y m ) be the samples with the smallest margin to u, i.e. arg min 1≤i≤n y i • u, x i = {1, • • • , m} And denote y 1 • u, x 1 = • • • = y m • u, x m = γ u . Since the training error converges to 0, γ u > 0. Note that for the logistic loss, if y i • θ, x i < y j • θ, x j , then for any M > 0, there exists an R M > 0 such that for all R ≥ R M , ∇ θ ( Rθ, x i , y i ) ∇ θ ( Rθ, x j , y j ) > M which can be shown with some simple calculation. And because the training error converges to 0, we must have θ (t) → ∞. Then, by Assumption 3 this means that when t gets sufficiently large, the impact of (x j , y j ) to θ (t) where j > m is an infinitesimal compared to (x i , y i ) where i ≤ m (because there exists a positive constant δ such that q (t) i > δ for all sufficiently large t by Assumption 3). Thus, we must have u ∈ span{x 1 , • • • , x m }. Let u = α 1 y 1 x 1 + • • • + α m y m x m . Now we show that α i ≥ 0 for all i = 1, • • • , m . This is because when t is sufficiently large such that the impact of (x j , y j ) to θ (t) where j > m becomes infinitesimal, we have θ (t+1) -θ (t) ≈ η q (t) i exp(y i • θ (t) , x i ) 1 + exp(y i • θ (t) , x i ) y i x i and since θ (t) → ∞ as t → ∞, we have α i ∝ lim T →∞ T t=T0 q (t) i exp(y i • θ (t) , x i ) 1 + exp(y i • θ (t) , x i ) := lim T →∞ α i (T ) where T 0 is sufficiently large. Here the notion α i ∝ lim T →∞ α i (T ) means that lim T →∞ αi(T ) αj (T ) = αi αj for any pair of i, j and α j = 0. Note that each term in the sum is non-negative. This implies that all α 1 , • • • , α m have the same sign (or equal to 0). On the other hand, m i=1 α i u = m i=1 α i y i • u, x i = u, u > 0 Thus, α i ≥ 0 for all i and at least one of them is positive. Now suppose u = θMM , which means that γ u is smaller than the margin of θMM . Then, for all i = 1, • • • , m, there is y i • u, x i < y i • θMM , x i . This implies that u, u = m i=1 α i y i • u, x i < m i=1 α i y i • θMM , x i = θMM , u which is a contradiction. Thus, we must have u = θMM .

D.5.2 PROOF OF THEOREM 8

Denote the largest and smallest eigenvalues of X X by λ max and λ min , and by condition we have λ min > 0. Let = min{ q * 3 , (q * λ min ) 2 192λ max 2 }. Then similar to the proof in Appendix D.2.2, there exists t such that for all t ≥ t and all i, q (t) i ∈ (q i -, q i + ). Denote Q = diag(q 1 , • • • , q n ), then for all t ≥ t , Q (t) := Q (t) = √ Q Q (t) 3 , where we use the subscript to indicate that Q (t) -Q 2 < . First, we prove that F (θ) is L-smooth as long as x i 2 ≤ 1 for all i. The gradient of F is ∇F (θ) = n i=1 q i ∇ ŷ ( θ, x i , y i )x i Since (ŷ, y) is L-smooth in ŷ, we have for any θ 1 , θ 2 and any i, ( θ 2 , x i , y i ) -( θ 1 , x i , y i ) ≤ ∇ ŷ ( θ 1 , x i , y i ) • ( θ 2 , x i -θ 1 , x i ) + L 2 ( θ 2 , x i -θ 1 , x i ) 2 = ∇ ŷ ( θ 1 , x i , y i ) • x i , θ 2 -θ 1 + L 2 ( θ 2 -θ 1 , x i ) 2 ≤ ∇ ŷ ( θ 1 , x i , y i ) • x i , θ 2 -θ 1 + L 2 θ 2 -θ 1 2 2 Thus, we have F (θ 2 ) -F (θ 1 ) = n i=1 q i [ ( θ 2 , x i , y i ) -( θ 1 , x i , y i )] ≤ n i=1 q i ∇ ŷ ( θ 1 , x i , y i ) • x i , θ 2 -θ 1 + L 2 n i=1 q i θ 2 -θ 1 2 2 = ∇F (θ 1 ), θ 2 -θ 1 + L 2 θ 2 -θ 1 2 2 (133) which implies that F (θ) is L-smooth. Denote g(θ) = ∇ ŷ (f (X; θ), Y ) ∈ R n , then ∇F (θ (t) ) = XQg(θ (t) ), and the update rule is θ (t+1) = θ (t) -ηXQ (t) g(θ (t) ) So by the upper quadratic bound, we have F (θ (t+1) ) ≤ F (θ (t) ) -η XQg(θ (t) ), XQ (t) g(θ (t) ) + η 2 L 2 XQ (t) g(θ (t) ) 2 2 Published as a conference paper at ICLR 2023 Let η 1 = q * λ min 2L(1+3 )λ max . Similar to what we did in Appendix D.2.2 (Eqn. ( 40)), we can prove that for all η ≤ η 1 , (135) implies that for all t ≥ t , there is F (θ (t+1) ) ≤ F (θ (t) ) - ηq * λ min 2 Qg(θ (t) ) 2 2 + η 2 L 2 X Q (t) 3 2 2 Qg(θ (t) ) 2 2 ≤ F (θ ) - ηq * λ min 2 Qg(θ (t) ) 2 2 + η 2 L 2 X 2 2 (1 + 3 ) Qg(θ (t) ) 2 2 ≤ F (θ (t) ) - ηq * λ min 4 Qg(θ (t) ) 2 2 ≤ F (θ (t) ) - ηq * 2 λ min 4 g(θ (t) ) 2 2 This shows that F (θ (t) ) is monotonically non-increasing. Since F (θ) ≥ 0, F (θ (t) ) must converge as t → ∞, and we need to prove that it converges to 0. Suppose that F (θ (t) ) does not converge to 0, then there exists a constant C > 0 such that F (θ (t) ) ≥ 2C for all t. On the other hand, it is easy to see that there exists θ * such that ( θ * , x i , y i ) < C for all i. (136) also implies that g(θ (t) ) 2 → 0 as t → ∞ because we must have F (θ (t) ) -F (θ (t+1) ) → 0. Note that from (134) we have θ (t+1) -θ * 2 2 = θ (t) -θ * 2 2 + 2η XQ (t) g(θ (t) ), θ * -θ (t) + η 2 XQ (t) g(θ (t) ) 2 2 (137) Denote F t (θ) = n i=1 q (t) i ( θ, x i , y i ) Then F t is convex because is convex and q (t) i are non-negative, and ∇F t (θ (t) ) = XQ (t) g(θ (t) ). By the lower linear bound F t (y) ≥ F t (x) + ∇F t (x), yx , we have for all t, XQ (t) g(θ (t) ), θ * -θ (t) ≤ F t (θ * ) -F t (θ (t) ) ≤ F t (θ * ) - 2 3 F (θ (t) ) ≤ C - 4C 3 = - C 3 (139) because q (t) i ≥ q i -≥ 2 3 q i and n i=1 q (t) i = 1. Since g(θ (t) ) 2 → 0, there exists T > 0 such that for all t ≥ T and all η ≤ η 0 , θ (t+1) -θ * 2 2 ≤ θ (t) -θ * 2 2 - ηC 3 which means that θ (t) -θ * 2 2 → -∞ because ηC 3 is a positive constant. This is a contradiction! Thus, F (θ (t) ) must converge to 0, which is result (i). (i) immediately implies (ii) because is strictly decreasing to 0 by condition. Now let's prove (iii). First of all, the uniqueness of θ R can be easily proved from the convexity of F (θ). The condition implies that y i θ R , x i > 0, i.e. θ R must classify all training samples correctly. If there are two different minimizers θ R and θ R in whose norm is at most R, then consider θ R = 1 2 (θ R + θ R ). By the convexity of F , we know that θ R must also be a minimizer, and θ R 2 < R. Thus, F ( R θ R 2 θ R ) < F (θ R ) and R θ R 2 θ R 2 = R, which contradicts with the fact that θ R is a minimizer. To prove the rest of (iii), the key is to consider (135). On one hand, similar to (36) we can prove that for all t ≥ t , there is XQ (t) g(θ (t) ), X(Q (t) -Q)g(θ (t) ) ≤ λ max √ 3 Q (t) g(θ (t) ) 2 2 Since we choose = min{ q * 3 , (q * λ min ) 2 192λ max 2 }, this inequality implies that ∇F t (θ (t) ) 2 2 = XQ (t) g(θ (t) ) 2 2 ≥ λ min Q (t) g(θ (t) ) 2 2 ≥ λ min (q * -) Q (t) g(θ (t) ) 2 2 ≥ λ min q * 2 Q (t) g(θ (t) ) 2 2 ≥ 4 XQ (t) g(θ (t) ), X(Q (t) -Q)g(θ (t) ) On the other hand, if η ≤ η 2 = 1 2L , we will have η 2 L 2 XQ (t) g(θ (t) ) 2 2 ≤ η 4 ∇F t (θ (t) ) 2 2 (143) Combining all the above with (135) yields F (θ (t+1) ) -F (θ (t) ) ≤ - η 2 ∇F t (θ (t) ) Since α is arbitrary, we must have lim t→∞ θ (t) θ (t) 2 = u as long as η ≤ min{η 1 , η 2 }.

D.5.3 COROLLARY OF THEOREM 8

We can show that for the logistic loss, it satisfies all conditions of Theorem 8 and lim R→∞ θ R R = θMM . First of all, for the logistic loss we have ∇ 2 ŷ (ŷ, y) = y 2 e y ŷ +e -y ŷ +2 ≤ max i y 2 i 4 , so is smooth. On the other hand, we have q j log(1 + exp(-y j • R • θR , ∇ θ f (0) (x j ) -M )) ≤ F (R • θR ) ≤F (R • θMM ) ≤ log(1 + exp(-Rγ + M )) which implies that 158) and this leads to q * log 1 + exp -(1 - δ 2 4 )Rγ -M ≤ log(1 + exp(-Rγ + M 1+exp(-Rγ +M ) ≥ 1 + exp -(1 - δ 2 4 )Rγ -M q * ≥ 1+q * exp -(1 - δ 2 4 )Rγ -M (159) which is equivalent to -Rγ + M ≥ -(1 - δ 2 4 )Rγ -M + log(q * ) Thus, we have δ = O(R -1/2 ) = O((-log 2 ) -1/2 ) (161) So for any test point x, since ∇ θ f (0) (x) 2 ≤ M 0 , we have | θMM -θR , ∇ θ f (0) (x) | ≤ δM 0 = O((-log 2 ) -1/2 ) Combined with Theorem 16, we have: with high probability, lim sup t→∞ |R • f MM (x) -f (t) reg (x)| = O(R • (-log 2 ) -1/2 + d-1/4 ) So there exists a constant C > 0 such that: As d → ∞, with high probability, for all ∈ (0, 1 4 ), if |f MM (x)| > C • (-log 2 ) -1/2 , then f (t) reg (x) will have the same sign as f MM (x) for a sufficiently large t. Note that this C only depends on n, q * , γ, M and M 0 , so it is a constant independent of . Remark. Note that Theorem 9 requires Assumption 1 while Theorem 6 does not due to the fundamental difference between the classification and regression. In regression the model converges to a finite point. However, in classification, the training loss converging to zero implies that either (i) The direction of the weight is close to the max-margin classifier or (ii) The norm of the weight is very large. Assumption 1 is used to eliminate the possibility of (ii). If the regularization parameter µ is sufficiently large, then a small empirical risk could imply a small weight norm. However, in our theorem we do not assume anything on µ, so Assumption 1 is necessary.

E A NOTE ON THE PROOFS IN LEE ET AL. (2019)

We have mentioned that the proofs in Lee et al. (2019) , particularly the proofs of their Theorem 2.1 and Lemma 1 in their Appendix G, are flawed. In order to fix their proof, we change the network initialization to (9). In this section, we will demonstrate what goes wrong in the proofs in Lee et al. (2019) , and how we manage to fix the proof. For clarity, we are referring to the following version of the paper: https://arxiv.org/pdf/1902.06720v4.pdf. To avoid confusion, in this section we will still use the notations used in our paper. E.1 THEIR PROBLEMS Lee et al. (2019) claimed in their Theorem 2.1 that under the conditions of our Lemma 5, for any δ > 0, there exist D > 0 and a constant C such that for any d ≥ D, with probability at least (1 -δ), the gap between the output of a sufficiently wide fully-connected neural network and the output of its linearized neural network at any test point x can be uniformly bounded by sup t≥0 f (t) (x) -f (t) lin (x) ≤ C d-1/2 (claimed) where they used the original NTK formulation and initialization in Jacot et al. (2018) :      l+1 = W l √ d l x l + βb l x l+1 = σ(h l+1 ) and W l(0) i,j ∼ N (0, 1) b l(0) i ∼ N (0, 1) (∀l = 0, • • • , L) where x 0 = x and f (x) = h L+1 . However, in their proof in their Appendix G, they did not directly prove their result for the NTK formulation, but instead they proved another result for the following formulation which they called the standard formulation: h l+1 = W l x l + βb l x l+1 = σ(h l+1 ) and    W l(0) i,j ∼ N (0, 1 d l ) b l(0) i ∼ N (0, 1) (∀l = 0, • • • , L) See their Appendix F for the definition of their standard formulation. In the original formulation, they also included two constants σ w and σ b for standard deviations, and for simplicity we omit these constants here. Note that the outputs of the NTK formulation and the standard formulation at initialization are actually the same. The only difference is that the norm of the weight W l and the gradient of the model output with respect to W l are different for all l. In their Appendix G, they claimed that if a network with the standard formulation is trained by minimizing the squared loss with gradient descent and learning rate η = η/ d, where η is our learning rate in Lemma 5 and also their learning rate in their Theorem 2.1, then ( 164) is true for this network, so it is also true for a network with the NTK formulation because the two formulations have the same network output. And then they claimed in their equation (S37) that applying learning rate η to the standard formulation is equivalent to applying the following learning rates To avoid confusion, in the following discussions we will still use the NTK formulation and initialization if not stated otherwise. Problem 1. Claim (167) is true, but it leads to two problems. The first problem is that η l b = O(d -1 max ) since η = O(1), while their Theorem 2.1 needs the learning rate to be O(1). Nevertheless, this problem can be simply fixed by modifying their standard formulation as h l+1 = W l x l +β √ d l b l where b l(0) i ∼ N (0, d -1 l ). The real problem that is non-trivial to fix is that by (167), there is η 0 W = d0 dmax η. However, note that d 0 is a constant since it is the dimension of the input space, while d max goes to infinity. Consequently, in (167) they were essentially using a very small learning rate for the first layer W 0 but a normal learning rate for the rest of the layers, which definitely does not match with their claim in their Theorem 2.1. Problem 2. Another big problem is that the proof of their Lemma 1 in their Appendix G is erroneous, and consequently their Theorem 2.1 is unsound as it heavily depends on their Lemma 1. In their Lemma 1, they claimed that for some constant M > 0, for any two models with the parameters θ and θ such that θ, θ ∈ B(θ (0) , C 0 ) for some constant C 0 , there is J(θ) -J( θ) F ≤ M d θ -θ 2 (claimed) Note that the original claim in their paper was J(θ) -J( θ) F ≤ M d θ -θ 2 . This is because they were proving this result for their standard formulation. Compared to the standard formulation, the d in the denominator on the right-hand side. As a result, their Lemma 1 and Theorem 2.1 cannot be proved without this critical d-1/2 . Similarly, we can also construct a counterexample where θ and θ only differ in the first row of some W l .

E.2 OUR FIXES

Regarding Problem 1, we can still use an O(1) learning rate for the first layer in the NTK formulation given x 2 ≤ 1. This is because for the first layer, we have ∇ W 0 f (x) = 1 √ d 0 x 0 α 1 = 1 √ d 0 xα 1 For all l ≥ 1, we have x l 2 = O( d1/2 ). However, for l = 0, we instead have x 0 2 = O(1). Thus, we can prove that the norm of ∇ W 0 f (x) has the same order as the gradient with respect to any other layer, so there is no need to use a smaller learning rate for the first layer. Regarding Problem 2, in our formulation (8) and initialization (9), the initialization of the last layer of the NTK formulation is changed from the Gaussian initialization W for any r ∈ (0, 1/2), though we cannot really prove the O( d-1/2 ) bound as originally claimed in (164). So this is how we solve Problem 2. One caveat of changing the initialization to zero initialization is whether we can still safely assume that λ min > 0 where λ min is the smallest eigenvalue of Θ, the kernel matrix of our new formulation. The answer is yes. In fact, in our Proposition 3 we proved that Θ is non-degenerated (which means that Θ(x, x ) still depends on x and x ), and under the overparameterized setting where d L n, chances are high that Θ is full-rank. Hence, we can still assume that λ min > 0. As a final remark, one key reason why we need to initialize W L as zero is that the dimension of the output space (i.e. the dimension of h L+1 ) is finite, and in our case it is 1. Suppose we allow the dimension of h L+1 to be d which goes to infinity, then using the same proof techniques, for the NTK formulation we can prove that sup t h L+1(t) -h L+1(t) lin 2 ≤ C, i.e. the gap between two vectors of infinite dimension is always bounded by a finite constant. This is the approximation theorem we need for the infinite-dimensional output space. However, when the dimension of the output space is finite, sup t h L+1(t) -h L+1(t) lin 2 ≤ C no longer suffices, so we need to decrease the order of the norm of W L in order to obtain a smaller upper bound.



Our results can be easily extended to the multi-class scenario (see Appendix B). For distributions P and Q, Q is absolute continuous to P , or Q P , means that for any event A, P (A) = 0 implies Q(A) = 0. f is Lipschitz if there exists a constant L > 0 such that for any x1, x2, |f (x1) -f (x2)| ≤ L x1 -x2 2 . Non-degenerate means that Θ(x, x ) depends on x and x and is not a constant. For ease of understanding, later we will write this condition as "with a sufficiently small learning rate". Note that Lemma 13 only depends on the network structure and does not depend on the update rule, so we can use this lemma here.



Figure 1: Experimental results of ERM, importance weighting (IW) and Group DRO (GDRO) with the squared loss on six MNIST images with a linear model. All norms are L 2 norms.

Figure 2: Experimental results of ERM, importance weighting (IW) and Group DRO (GDRO) with L 2 regularization with the squared loss. Left two: µ = 0.1; Right two: µ = 10.

minimized the class-based cross entropy on the target domain while keeping the source and target classifiers close with a residual block, and Motiian et al. (2017) adopted a similarity penalty to keep samples from different classes away from each other. A.3 IMPLICIT BIAS UNDER THE OVERPARAMETERIZED SETTING For overparameterized models, there could be many model parameters which all minimize the training loss. In such cases, it is of interest to study the implicit bias of specific optimization algorithms such as gradient descent i.e. to what minimizer the model parameters will converge to Du et al. (2019); Allen-Zhu et al. (2019). Our results use the NTK formulation of wide neural networks Jacot et al. (

Figure3: Experimental results of ERM, importance weighting (IW) and Group DRO (GDRO) with the logistic loss and the polynomially-tailed loss. First row: Logistic loss; Second row: Polynomially-tailed loss. All norms are L 2 norms. θ is a unit vector which is the direction of θ.

and b l of the NTK formulation, where d max = max{d 0 , • • • , d L }.

Now we show how this modification solves Problem 2.The main consequence of changing the initialization of the last layer is that (86) becomes different:instead of W L 2 ≤ 3 d, we now have W L 2 ≤ C 0 ≤ 3 4 d.In fact, for any r ∈ (0, 1/2), we can prove that W L 2 ≤ 3 dr for sufficiently large d. In our proof we choose r = 1/4. Consequently, instead of α l 2 ≤ M 3 , we can now prove that α l 2 ≤ M 3 dr-1/2 for all l ≤ L by induction. So now we can prove α lαl2 = O dr-1/2 θ -θ 2 instead of O θ -θ 2 , because• For l < L, we now have α l+1 2 = O( dr-1/2 ) instead of O(1), so we can have the additional dr-1/2 factor in the bound.• For l = L, although α L+1 2 = 1, note that W L 2 now becomes O( dr ) instead of O( d1/2 ), so again we can decrease the bound by a factor of dr-1/2 . Then, with this critical dr-1/2 , we can prove the approximation theorem with the form sup t≥0 f (t) (x) -f (t) lin (x) ≤ C dr-1/2 (175)

studied another type of group fairness called Rawlsian max-min fairness Rawls (2001), which does not require equal performance but rather requires high performance on the worst-off group. The subpopulation shift problem we study in this paper is naturally connected to the Rawlsian max-min fairness. A large body of recent work have studied how to improve this worst-group performance Duchi & Namkoong (2018); Oren et al. (2019); Liu et al. (2021); Zhai et al. (2021a). Recent work however observe that these approaches, when used with modern overparameterized models, easily overfit Sagawa et al. (2020a;b). Apart from group fairness, there are also other notions of fairness, such as individual fairness Dwork et al. (2012); Zemel et al. (2013) and counterfactual fairness Kusner et al. (2017), which we do not study in this work.

linearly independent, and the empirical training risk of f

ACKNOWLEDGMENTS

We acknowledge the support of NSF via OAC-1934584, IIS-1909816, DARPA via HR00112020006, and ARL. 

annex

Published as a conference paper at ICLR 2023 Then, we prove that lim R→∞ θ R R exists and is equal to θMM . For the logistic loss, it is easy to show that for any θ = θMM , there exists an R( θ ) > 0 and an δ( θ ) > 0 such that F (R • θ) > F (R • θMM ) for all R ≥ R( θ and θ ∈ B( θ , δ( θ )).Let S = {θ : θ 2 = 1}. For any > 0, S -B( θMM , ) is a compact set. And for any θ ∈ S -B( θMM , ), there exist R(θ) and δ(θ) as defined above. Thus, there must existR exists and is equal to θMM . Therefore, by Theorem 8, any GRW satisfying Assumption 1 makes a linear model converge to the max-margin classifier under the logistic loss.

D.6 PROOF OF THEOREM 9

We first consider the regularized linearized neural network f (t) linreg . Since by Proposition 11 f (0) (x) is sampled from a zero-mean Gaussian process, there exists a constant M > 0 such that |f (0) (x i )| < M for all i with high probability. Definewhen the linearized neural network is trained by a GRW satisfying Assumption 1 with regularization, since this is convex optimization and the objective function is smooth, we can prove that with a sufficiently small learning rate, as t → ∞,First, we derive the lower bound of R. By Theorem 16, with a sufficiently large d, with high probability R(flinreg ) < 2 . By the convexity of , we haveBy the definition of θMM , there exists j such that y j • θ , ∇ θ f (0) (x j ) ≤ γ, which impliesThus, we havePublished as a conference paper at ICLR 2023 in the NTK formulation θ is d times larger, while the Jacobian J(θ) is d times smaller. This is also why here we have θ, θ ∈ B(θ (0) , C 0 ) instead of θ, θ ∈ B(θ (0) , C 0 d-1/2 ) for the NTK formulation. Therefore, equivalently they were claiming (168) for the NTK formulation.However, their proof of (168) in incorrect. Specifically, the right-hand side their inequality (S86) is incorrect. Using the notations in our Appendix D.3.3, their (S86) essentially claimed thatfor any θ, θ ∈ B(θ (0) , C 0 ), where α l = ∇ h l h L+1 and αl is the same gradient for the second model. Note that their (S86) does not have the d in the denominator which appears in ( 169). This is because for their standard formulation, θ is d times smaller than the original NTK formulation, while α l 2 has the same order in the two formulations because all h l are the same. However, it is actually impossible to prove (169). Consider the following counterexample: Since θ and θ are arbitrarily chosen, we can choose them such that they only differ in b l 1 for someWe can see that h l+1 and hl+1 only differ in the first element, and. Moreover, we have W l+1 = W l+1 , so there isThen we can lower bound α l+1 -αl+1 2 byThe first term on the right-hand side is equal to σ(h l+1 1 ) -σ( hl+1is the first row of W l+1 . We know that W l+1 1 2 = Θ d with high probability as its elements are sampled from N (0, 1), and in their (S85) they claimed that α l+2 2 = O(1), which is true. In addition, they assumed that σ is Lipschitz. Hence, we can see thatOn the other hand, suppose that claim (169) is true, then α l+2 -αl+2Then we can see that the second term on the right-hand side is O d-1/2 θ -θ 2 because W l+1 2 = O( d) and σ(x) is bounded by a constant as σ is Lipschitz. Thus, for a very large d, the second-term is an infinitesimal compared to the first term, so we can only prove thatwhich is different from (169) because it lacks a critical d-1/2 and thus leads to a contradiction. Hence, we cannot prove (169) with the d-1/2 factor, and consequently we cannot prove (168) with

