FEATURE SELECTION AND LOW TEST ERROR IN SHALLOW LOW-ROTATION RELU NETWORKS

Abstract

This work establishes low test error of gradient flow (GF) and stochastic gradient descent (SGD) on two-layer ReLU networks with standard initialization scale, in three regimes where key sets of weights rotate little (either naturally due to GF and SGD, or due to an artificial constraint), and making use of margins as the core analysis technique. The first regime is near initialization, specifically until the weights have moved by O( √ m), where m denotes the network width, which is in sharp contrast to the O(1) weight motion allowed by the Neural Tangent Kernel (NTK); here it is shown that GF and SGD only need a network width and number of samples inversely proportional to the NTK margin, and moreover that GF attains at least the NTK margin itself and in particular escapes bad KKT points of the margin objective, whereas prior work could only establish nondecreasing but arbitrarily small margins. The second regime is the Neural Collapse (NC) setting, where data lies in well-separated groups, and the sample complexity scales with the number of groups; here the contribution over prior work is an analysis of the entire GF trajectory from initialization. Lastly, if the inner layer weights are constrained to change in norm only and can not rotate, then GF with large widths achieves globally maximal margins, and its sample complexity scales with their inverse; this is in contrast to prior work, which required infinite width and a tricky dual convergence assumption. Damian et al., 2022), and typically not beating the within-kernel d 2 /ϵ sample complexity on the 2-sparse parity problem (cf. Table 1 ). Contributions. There are four high-level contributions of this work. The first two consider networks of reasonable width (e.g., O(d 2 ) for 2-sparse parity), and are the more tractable of the four. In these results, the network parameters can move up to O( √ m), where m is the width of the network; this is in sharp contrast to the NTK, where weights can only move by O(1). The performance of these first two results is measured in terms of the NTK margin γ ntk , a quantity formally defined in Assumption 1.2. These first two contributions are as follows.

1. INTRODUCTION

A key promise of deep learning is automatic feature learning: standard gradient methods are able to adjust network parameters so that lower layers become meaningful feature extractors, which in turn implies low sample complexity. As a running illustrative (albeit technical) example throughout this work, in the 2-sparse parity problem (cf. Figure 1 ), networks near initialization require d 2 /ϵ samples to achieve ϵ test error, whereas powerful optimization techniques are able to learn more compact networks which need only d/ϵ samples (Wei et al., 2018) . It is not clear how to establish this improved feature learning ability with a standard gradient-based optimization method; for example, despite the incredible success of the Neural Tangent Kernel (NTK) in proving various training and test error guarantees (Jacot et al., 2018; Du et al., 2018b; Allen-Zhu et al., 2018; Zou et al., 2018; Arora et al., 2019; Li & Liang, 2018; Ji & Telgarsky, 2020b; Oymak & Soltanolkotabi, 2019) , ultimately the NTK corresponds to learning with frozen initial random features. The goal of this work is to establish low test error from random initialization in an intermediate regime where parameters of individual nodes do not rotate much, however their change in norm leads to selection of certain pre-existing features. This perspective is sufficient to establish the best known sample complexities from random initialization in a variety of scenarios, for instance matching the d 2 /ϵ within-kernel sample complexity with a computationally-efficient stochastic gradient descent (SGD) method, and the beyond-kernel d/ϵ sample complexity with an inefficient gradient flow (GF) method. The different results are tied together through their analyses, which establish not merely low training error but large margins, a classical approach to low sample complexity within overparameterized models (Bartlett, 1996) . These results will use standard gradient methods from standard initialization, which is in contrast to existing works in feature learning, which adjusts the optimization method in some way (Shi et al., 2022; Wei et al., 2018) , most commonly by training the inner layer for only one iteration (Daniely & Malach, 2020; Abbe et al., 2022; Barak et al., 2022;  1. Non-trivial margin KKT points. Prior work established that features converge in a strong sense: features and parameters converge to a KKT point of a natural margin objection (cf. Section 1.1, (Lyu & Li, 2019; Ji & Telgarsky, 2020a) ). Those works, however, left open the possibility that the limiting KKT point is arbitrarily bad; instead, Theorem 2.1 guarantees that the limiting GF margin is at least γ ntk /4096, where γ ntk is a distribution-dependent constant. 2. Simultaneous low test error and low computational complexity. Replacing GF with SGD in the preceding approach leads to a computationally efficient method. Applying the resulting guarantees in Theorem 2.3 to the 2-sparse parity problem yields, as detailed in Table 1 , a method which saves a factor d 8 against prior work with sample complexity d 2 /ϵ, and a factor 1/ϵ in computation against work with sample complexity d 4 /ϵ 2 . Moreover, Theorem 2.3 guarantees that the first gradient step moves parameters by √ m and formally exits the NTK. The second two high-level contributions require intractable widths (e.g., 2 d ), but are able to achieve much better global margins γ gl , which as detailed in Sections 1.1 and 1.2, were previously only possible under strong assumptions or unrealistic algorithmic modifications. 3. Neural collapse. Theorem 3.2 establishes low sample complexity in the neural collapse (NC) regime (Papyan et al., 2020) , where data are organized in well-separated clusters of common label. By contrast, prior work did not analyze gradient methods from initialization, but instead the relationship between various optimality conditions (Papyan et al., 2020; Yaras et al., 2022; Thrampoulidis et al., 2022) . The method of proof is to establish global margin maximization of GF; by contrast, for any type of data, this was only proved in the literature with strong assumptions and modified algorithms (Wei et al., 2018; Chizat & Bach, 2020; Lyu et al., 2021) .

4.. Global margin maximization for rotation-free networks.

To investigate what could be possible, Theorem 3.3 establishes global margin maximization with GF under a restriction that the inner weights can only change in norm, and can not rotate; this analysis suffices to achieve d/ϵ sample complexity on 2-sparse parity, as in Table 1 , and the low-rotation assumption is backed by preliminary empirical evidence in Figure 2 . This introduction concludes with notation and related work, Section 2 collects the KKT point and low computation guarantees, Section 3 collects the global margin guarantees, Section 4 provides concluding remarks and open problems, and the appendices contain full proofs and additional technical discussion.

1.1. NOTATION

Architecture and initialization. With the exception of Theorem 3.3, the architecture will be a 2-layer ReLU network of the form x → F (x; W ) = j a j σ(v T j x) = a T σ(V x), where σ(z) = max{0, z} is the ReLU, and where a ∈ R m and V ∈ R m×d have initialization scale roughly matching pytorch defaults: a ∼ N m /m 1/4 (m iid Gaussians with variance 1/ √ m) and V ∼ N m×d / d √ m (m × d iid Gaussians with variance 1/(d √ m)); in contrast with pytorch, the layers are approximately balanced. These parameters (a, V ) will be collected into a tuple W = (a, V ) ∈ R m ×R m×d ≡ R m×(d+1) , and for convenience per-node tuples w j = (a j , v j ) ∈ R×R d ≡ R d+1 will often be used as well. Given a pair (x, y) with x ∈ R d and y ∈ {±1}, the prediction or unnormalized margin mapping is p(x, y; W ) = yF (x; W ) = ya T σ(V x); when examples ((x i , y i )) n i=1 are available, a simplified notation p i (W ) := p(x i , y i ; W ) is often used, and moreover define a single-node variant p i (w j ) := y i a j σ(v T j x i ). Throughout this work, ∥x∥ ≤ 1, and unmarked norms are Frobenius norms. SGD and GF. The loss function ℓ will be either the exponential loss ℓ exp (z) := exp(-z), or the logistic loss ℓ log (z) := ln(1 + exp(-z)); the corresponding empirical risk R is R(p(W )) := 1 n n i=1 ℓ(p i (W )), which used p(W ) := (p 1 (W ), . . . , p n (W )) ∈ R n . The two descent methods are W i+1 := W i -η ∂W ℓ(p i (W i )) , stochastic gradient descent (SGD), (1.1) Ẇt := d dt W t = -∂W R(p(W t )), gradient flow (GF), (1.2) where ∂ and ∂ are appropriate generalizations of subgradients for the present nonsmooth nonconvex setting, detailed as follows. For SGD, ∂ will denote any valid element of the Clarke differential (i.e., a measurable selection); for example, ∂F (x; W ) = σ(V x), j a j σ ′ (v T j x)e j x T , where e j denotes the jth standard basis vector, and σ ′ (v T j x i ) ∈ [0, 1] is chosen in some consistent and measurable way, for instance as chosen by pytorch. For GF, ∂ will denote the unique minimum norm element of the Clarke differential; typically, GF is defined as a differential inclusion, which agrees with this minimum norm Clarke flow almost everywhere, but here the minimum norm element is equivalently used to define the flow. Details of Clarke differentials and corresponding chain rules are differed to the now-extensive literature for their use in margin analyses (Lyu & Li, 2019; Ji & Telgarsky, 2020a; Lyu et al., 2021) . Due to clutter, the time indices t (written as W t or W (t)) will often be dropped.

Margins.

To develop the margin notion, first note that that F and p i are 2-homogeneous in W , meaning F (x; cW ) = ca T σ(cV x) = c 2 F (x; W ) for any x ∈ R d and c ≥ 0 (and p i (cW ) = c 2 p i (W )). It follows that F (x; W ) = ∥W ∥ 2 F (x; W/∥W ∥), and thus F and p i scale quadratically in ∥W ∥, and it makes sense to define a normalized prediction mapping p i and margin γ as p i (W ) := p i (W ) ∥W ∥ 2 , γ(W ) := min i p i (W ) ∥W ∥ 2 = min i p i (W ). Due to nonsmoothness, γ can be hard to work with, thus, following Lyu & Li (2019) , define the smoothed margin γ and the normalized smoothed margin γ as γ(W ) := ℓ -1 n R(W ) = ℓ -1   i ℓ(p i (W ))   , γ(W ) := γ(W ) ∥W ∥ 2 , where a key result is that γ is eventually nondecreasing (Lyu & Li, 2019) . These quantities may look complicated and abstract, but note for ℓ exp that γ(W ) := -ln i exp(-p i (W )) . An interesting technical consideration is that normalization by ∥W ∥ 2 can be replaced by ∥a∥ • ∥V ∥ or j ∥a j v j ∥, as appears throughout the proofs of Theorems 2.1, 3.2 and 3.3. Corresponding to these definitions, the global max margin assumption is that some shallow network can achieve a good margin almost surely over the distribution. Assumption 1.1. There exists γ gl > 0 and parameters ((α k , β k )) r k=1 with ∥α∥ 1 ≤ 1 and ∥β k ∥ 2 = 1 so that almost surely over the draw of any pair (x, y), then y k α k σ(β T k x) ≥ γ gl . ♢ The ℓ 1 norm on α is due to 2-homogeneity: for any 2-homogeneous generalized activation ϕ and parameters (w j ) m j=1 organized into matrix W ∈ R m×(d+1) , then j ϕ(w T j x)/∥W ∥ 2 = 3), with red and blue circles respectively denoting negative and positive examples. Red paths correspond to trajectories |a j |v j across time with a j < 0, whereas blue paths have a j > 0. j (∥w j ∥/∥W ∥) 2 ϕ((w j /∥w j ∥) T x), and in particular j ∥w j ∥ 2 /∥W ∥ 2 = 1. This use of the ℓ 1 norm is standard in works studying global margin maximization, see for instance (Chizat & Bach, 2020, Proposition 12, optimality conditions) . Near initialization, the features have not changed much, and it is therefore reasonable to consider a second margin definition as linear predictors on top of the random initial features. Assumption 1.2. There exists γ ntk > 0 and a weight mapping θ : R d+1 → R d+1 with θ(w) = 0 whenever ∥w∥ ≥ 2 and ∥θ(w)∥ ≤ 2 otherwise, so that almost surely over the draw of (x, y), then E w∼N θ θ(w), ∂w p(x; w) ≥ γ ntk , where w = (a, v) ∼ N θ means a ∼ N and v ∼ N d / √ d, and p(x, y; w) = p(x, y; (a, v)) = yaσ(v T x) as before. ♢ Assumption 1.2 may seem overly technical, but by taking an expectation over initial weights, it can not only be seen as an infinite-width linear predictor over the initial features, but moreover it is a deterministic condition and not a random variable depending on the sampled weights. This assumption was originally presented by Nitanda & Suzuki (2019) and later used in (Ji & Telgarsky, 2020b) ; a follow-up work with similar proof techniques made the choice of using a finite-width assumption which is a random variable (Chen et al., 2019) . The 2-sparse parity problem, finally formally defined in Proposition 1.3 as follows and depicted in Figure 1 , allows significantly different estimates for Assumption 1. Lastly, following (Lyu & Li, 2019) but simplified for the 2-homogeneous case, given examples ((x i , y i )) n i=1 , then parameters W = (a, V ) are a KKT point if there exist Lagrange multipliers (λ 1 , . . . , λ n ) with λ i ≥ 0 and λ i > 0 only if p i (W ) = 1, and moreover, for every j, then a j = i λ i y i σ(v T j x i ) and v j ∈ a j i λ i y i ∂ v σ(v T j x i ) , where ∂ v denotes the Clarke differential with respect to v j . Call W a KKT direction if there exists a scalar r > 0 so that rW is a KKT point. Lastly, the margin of a KKT point W is 1/∥W ∥ 2 , and the margin of a KKT direction is the margin of the corresponding KKT point; further details on these notions are deferred to Appendix B.1.

1.2. FURTHER RELATED WORK

Margin maximization. The concept and analytical use of margins in machine learning originated in the classical perceptron convergence analysis of Novikoff (1962) . The SGD analysis in Theorem 2. of the perceptron proof; similar perceptron-based proofs appeared before (Ji & Telgarsky, 2020b; Chen et al., 2019) , however they required width 1/γ 8 ntk , unlike the 1/γ 2 ntk here, and moreover the proofs themselves were in the NTK regime, whereas the proof here is not. Works focusing on the implicit margin maximization or implicit bias of descent methods are more recent. Early works on the coordinate descent side are (Schapire et al., 1997; Zhang & Yu, 2005; Telgarsky, 2013) ; the proof here of Lemma C.4 uses roughly the proof scheme in (Telgarsky, 2013) . More recently, margin maximization properties of gradient descent were established, first showing global margin maximization in linear models (Soudry et al., 2017; Ji & Telgarsky, 2018b) , then showing nondecreasing smoothed margins of general homogeneous networks (including multi-layer ReLU networks) (Lyu & Li, 2019) , and the aforementioned global margin maximization result for 2-layer networks under dual convergence and infinite width (Chizat & Bach, 2020) . The potential functions used here in Theorems 3.2 and 3.3 use ideas from (Soudry et al., 2017; Lyu & Li, 2019; Chizat & Bach, 2020) , but also the shallow linear and deep linear proofs of Ji & Telgarsky (2019; 2018a) . Feature learning. There are many works in feature learning, a few also carrying explicit guarantees on 2-sparse parity are summarized in Table 1 . An early work with high relevance to the present work is (Wei et al., 2018) , which in addition to establishing that the NTK requires Ω(d 2 /ϵ) samples whereas O(d/ϵ) suffice for the global maximum margin solution, also provided a noisy Wasserstein Flow (WF) analysis which achieved the maximum margin solution, albeit using noise, infinite width, and continuous time to aid in local search. The global maximum margin work of Chizat & Bach (2020) was mentioned before, and will be discussed in Section 3. The work of Barak et al. (2022) uses a two phase algorithm: the first step has a large minibatch and effectively learns the support of the parity in an unsupervised manner, and thereafter only the second layer is trained, a convex problem which is able to identify the signs within the parity; as in Table 1 , this work stands alone in terms of the narrow width it can handle. The work of (Abbe et al., 2022) uses a similar two-phase approach, and while it can not learn precisely the parity, it can learn an interesting class of "staircase" functions, and presents many valuable proof techniques. Another work which operates in two phases and can learn an interesting class of functions is the recent work of (Damian et al., 2022) ; while it can not handle 2-sparse parity explicitly, it can handle the Hermite polynomial analog (product of two hermite polynomials). Other interesting feature learning works are (Shi et al., 2022; Bai & Lee, 2019; Allen-Zhu & Li, 2020) .

2. LOW TEST ERROR WITH MODEST-WIDTH NETWORKS

This section states the aforementioned results for networks of width Ω(1/γ 2 ntk ), which can be small: as provided by Proposition 1.3, this width is Ω(d 2 ) for the 2-sparse parity problem. This section will first give guarantees for GF, establishing via Theorem 2.1 and Corollary 2.2 that non-trivial KKT points are achieved. Similar ideas will then be used to give a fully-tractable SGD approach in Theorem 2.3. To start, here is the low test error and large margin guarantee for GF. Theorem 2.1. Suppose the data distribution satisfies Assumption 1.2 for some γ ntk > 0, and the GF curve (W s ) s≥0 uses ℓ ∈ {ℓ exp , ℓ log } on an architecture of width m ≥ 640 ln(n/δ) γntk 2 . Then, with probability at least 1 -15δ, there exists t with ∥W t -W 0 ∥ = γ ntk √ m/32 so that, for all s ≥ t, γ(W s ) ∥a s ∥ • ∥V s ∥ ≥ γ ntk 2048 and Pr[p(x, y; W s ) ≤ 0] ≤ O ln(n) 3 nγ 2 ntk + ln 1 δ n , and moreover lim inf s→∞ γ(W s ) = lim inf s→∞ mini pi(Ws) 2∥as∥•∥Vs∥ ≥ γntk 4096 . Before sketching the proof, one interesting comparison is to a leaky ReLU convergence analysis on a restricted form of linearly separable data due to Lyu et al. (2021) . That work, through an extremely technical and impressive analysis, establishes convergence to a solution which is equivalent to the best linear predictor. By contrast, while the work here does not recover that analysis, since γ ntk = Ω(γ 0 ) where γ 0 is the linear separability margin (cf. Proposition B.1), then the margin and sample complexity achieved here are within a constant factor of those in (Lyu et al., 2021) , but via a simpler and more general analysis (dropping the additional data and initialization conditions). The proof of Theorem 2.1 is provided in full in the appendices, but has the following key components. The main tool powering all results in this section, Lemma B.4, can be roughly stated as follows: gradients at initialization are aligned with a fixed good parameter direction θ ∈ R m×(d+1) with ∥θ∥ ≤ 2, meaning θ, ∂p i (W 0 ) ≥ γ ntk √ m/2, and moreover nearly the same inequality holds with W 0 replaced by any W ∈ R m×(d+1) with ∥W -W 0 ∥ ≤ γ ntk √ m. This is a form of Polyak-Łojasiewicz inequality, and guides the gradient flow in a good direction, and is used in a strengthened form to obtain an empirical risk guarantee for GF (large margins and low test error will be discussed shortly). While a version of this inequality has appeared in prior work (Ji & Telgarsky, 2020b) , despite adaptations to multi-layer cases (Chen et al., 2019) , all prior work had a width dependence of 1/γ 8 ntk ; many careful refinements here lead to the smaller width 1/γ 2 ntk . Overall, as in (Ji & Telgarsky, 2020b) , the proof technique is based on the classical perceptron analysis, and the width requirement here matches the width needed by perceptron with frozen initial features. The proof then continues by establishing large margins, and then by applying a large-margin generalization bound. The margin analysis, surprisingly, is a 2-homogeneous adaptation of a large-margin proof technique for coordinate descent (Telgarsky, 2013) , and uses the preceding empirical risk guarantee for a warm start. The generalization analysis follows a new proof technique and may be of independent interest, and appears in full in Lemma C.6. An interesting detail in these proofs is that the margins behave better when normalized with the nonstandard choice ∥a∥ • ∥V ∥. As discussed above, Theorem 2.1 is complemented by Corollary 2.2, which establishes that GF can sometimes escape bad KKT points. Corollary 2.2. Let γ 0 ∈ (0, 1/4) be given, and consider the uniform distribution on the two points z 1 = (γ 0 , + 1 -γ 2 0 ) and z 2 = (γ 0 , -1 -γ 2 0 ) with common label y = +1. With probability at least 1 -2 1-n over an iid draw from this distribution, for any width m, the choice a j = 1 and v j = (1, 0) for all j is a KKT direction with margin γ 0 /2. On the other hand, with probability at least 1 -15δ, GF with loss ℓ exp or ℓ log on an iid sample of size n from this data distribution using width at least m ≥ 2 50 ln(n/δ) 2 converges to a KKT direction with margin at least 2 -27 . Summarizing, GF achieves at least constant margin 2 -27 , whereas the provided KKT point achieves the arbitrarily small margin γ 0 /2; as such, choosing any γ 0 < 2 -26 guarantees that GF converges to a nontrivial KKT point. While this construction may seem artificial, it is a simplified instance of the neural collapse constructions in Section 3, and by contrast with the results there, is achievable with reasonably small widths. To close this section, the corresponding SGD guarantee is as follows. This result gives a fully tractable method, and appears in Table 1 . Notably, this proof can not handle the exponential loss, since gradient norms do not seem to concentrate. A notable characteristic is exiting the NTK: choosing the largest allowed step size η := γ 2 ntk /6400, it follows that ∥W 1 -W 0 ∥ ≥ δ 4 γ 3 ntk √ m/51200, whereas the NTK regime only permits ∥W t -W 0 ∥ = O(1). Of course, despite exiting the NTK, this sample complexity is still measured in terms of γ ntk , suggesting many opportunities for future work. A few remarks on the proof are as follows. Interestingly, it is much shorter than the GF proof, as it mainly needs to replicate GF's empirical risk guarantee, and then apply a short martingale concentration argument. A key issue, however, is the large squared gradient norm term, which is the source of the large lower bound on ∥W 1 -W 0 ∥. A typical optimization analysis technique is to swallow this term by scaling the step size with 1/ √ t or 1/ √ m, but here a constant step size is allowed. Instead, controlling the term is possible again using nuances of the perceptron proof technique, which controls the term i<t |ℓ ′ (p i (W i ))|, which appears when these squared gradients are accumulated.

3. LOWER TEST ERROR WITH LARGE-WIDTH NETWORKS

This section provides bounds which are more ambitious in terms of test error, but pay a big price: the network widths will be exponentially large, and either the data or the network architecture will have further conditions. Still, these settings will both be able to achieve globally maximal margins, and for instance lead to the improved d/ϵ sample complexity in Table 1 .

3.1. NEURAL COLLAPSE (NC)

The Neural Collapse (NC) setting partitions data into well-separated groups (Papyan et al., 2020); these groups form narrow cones which meet at obtuse angles. Assumption 3.1. There exist ((α k , β k )) r k=1 with ∥β k ∥ = 1 and α k ∈ {±1/r} and γ nc > 0 and ϵ ∈ (0, γ nc ) so that almost surely for any (x, y), for each (α k , β k ) exactly one of the following hold: • either x lies in a cone around β k , meaning sgn(α k )β T k xy ≥ γ nc and ∥(I-β k β T k )x∥ |β T k x| ≤ ϵ/2; • or x is bounded away from the cone around β k , meaning sgn(α k )β T k xy ≤ -ϵ. ♢ It follows that Assumption 3.1 implies Assumption 1.1 with margin γ gl ≥ γ nc /r, but the condition is quite a bit stronger. The corresponding GF result is as follows. Theorem 3.2. Suppose the data distribution satisfies Assumption 3.1 for some (r, γ nc , ϵ), and consider GF curve (W t ) t≥0 with ℓ ∈ {ℓ exp , ℓ log } and width m ≥ 4 2 ϵ d-1 ln r δ . then, with probability at least 1 -δ, it holds for all large t that γ(W t ) ≥ γ nc -ϵ 2r and Pr p(x, y; W t ) ≤ 0 = O r 2 ln(n) 3 n(γ nc -ϵ) 2 + ln 1 δ n . Before discussing the proof, here are a few remarks. Firstly, the standard NC literature primarily compares different optimality conditions and how they induce NC (Papyan et al., 2020; Yaras et al., 2022; Thrampoulidis et al., 2022) ; by contrast, Theorem 3.2 analyzes the behavior of a standard descent method on data following NC. Furthermore, Theorem 3.2 does not establish that GF necessarily converges to the NC solution, since Assumption 3.1 allows for scenarios where the globally maximal margin solution disagrees with NC. One such example is to take two points on the surface of the sphere and which form an angle just beyond π/2; in this case, the globally maximal margin solution is equivalent to a single linear predictor, but Assumption 3.1 and Theorem 3.2 still apply. The similar construction in Corollary 2.2 took those two points and pushed their angle to be just below π, which made NC the globally maximal margin solution. Overall, the relationship between NC and the behavior of GF is quite delicate. The proof of Theorem 3.2 hinges on a potential function Φ(W t ) := 1 4 k |α k | ln j ϕ k,j (w j )∥a j v j ∥ , where ϕ k,j (w j ) is near 1 when v j /∥v j ∥ ≈ β k , and 0 otherwise. This strange-looking potential has derivative scaling with γ nc /r and a certain technical factor Q; meanwhile, the derivative of ln ∥W t ∥ 2 scales roughly with γ(W t ) and that same factor Q. Together, it follows by considering their difference that either γ(W t ) must exceed (γ nc -ϵ)/r, or mass must concentrate on the NC directions. This suffices to complete the proof, however there are many technical details; for instance, without the NC condition, data can pull gradients in bad directions and this particular potential function can become negative; in other words, the NC condition reduces rotation. The use of ln(•) may seem bizarre, but it causes the gradient to be self-normalizing; similar self-normalizing ideas were used throughout earlier works on margins outside the NTK (Lyu & Li, 2019; Chizat & Bach, 2020; Ji & Telgarsky, 2020a) . This discussion of Φ will resume after the proof of Theorem 3.3, which uses a similar construction. One technical point of potentially independent interest is once again the use of ∥a j v j ∥ as a surrogate for ∥w j ∥ 2 (where ∥w j ∥ 2 ≥ 2∥a j v j ∥); this seems crucial in the proofs, and was also used in the proofs of Theorem 2.1, and also partially motivated the use of |a j |v j when plotting the trajectories in Figure 1 . While it is true that these quantities asymptotically balance (Du et al., 2018a) , it takes quite a long time, and this more refined norm-like quantity is useful in early phases.

3.2. GLOBAL MARGIN MAXIMIZATION WITHOUT ROTATION

The final theorem will be on stylized networks where the inner layer is forced to not rotate. Specifically, the networks are of the form x → j a j σ(b j v T j x), where ((a j , b j )) m j=1 are trained, but v j are fixed at initialization; the new scalar parameter b j is effectively the norm of v j (though it is allowed to be negative). As a further simplification, a j and b j are initialized to have the same norm; this initial balancing is common in many implicit bias proofs, but is impractical and constitutes a limitation to improve in future work. While these are clearly significant technical assumptions, we note firstly as in Figure 2 that low rotation seems to hold empirically, and moreover that the only other works establishing global margin maximization used either a significantly different algorithm with added gradient noise (Wei et al., 2018) , or in the case of (Chizat & Bach, 2020) , heavily relied upon infinite width (requiring weights to cover the sphere for all times t), and also a dual convergence assumption detailed in the appendices and circumvented here. Theorem 3.3. Suppose the data distribution satisfies Assumption 1.1 for some γ gl > 0 with reference architecture ((α k , β k )) r k=1 . Consider the architecture x → j a j σ(b j v T j x i ) where ((a j (0), b j (0))) m j=1 are sampled uniformly from the two choices ±1/m 1/4 , and v j (0) is sampled from the unit sphere (e.g., first v ′ j ∼ N d , then v j (0) := v ′ j /∥v ′ j ∥), and the width m satisfies m ≥ 4 4 γgl d-1 ln r δ . Then, with probability at least 1 -δ, for all large t, GF on ((a j , b j )) m j=1 with loss ℓ ∈ {ℓ exp , ℓ log } satisfies γ (a(t), b(t)) ≥ γ gl 2 and Pr p(x, y; (a(t), b(t))) ≤ 0 = O ln(n) 3 nγ 2 gl + ln 1 δ n . The proof strategy of Theorem 3.3 follows a simplification of the scheme from Theorem 3.2. First, due to the large width, for each k there must exist a weight j with sgn(b j )v j ≈ β k . Second, since this inner layer can not rotate, we can reorder the weights so that simply sgn(b k )v k ≈ β k , and define a simplified potential  Φ(W t ) := 1 4 k |α k | ln a 2 k + b 2 k . ∥vj (0)∥ , vj (t) ∥vj (t)∥ are first calculated, and then treated as an empirical distribution, and their CDF is plotted. The overall trend is that as m increases, rotation decreases. While this trend is consistent with the NTK, the rotations are still too large to allow an NTK-style analysis; further experimental details are in Appendix B.2. As mentioned after the proof of Theorem 3.2, it's possible without Assumption 3.1 for data to pull weights in bad directions; that is ruled out here via the removal of rotation, and spiritually this situation is ruled out in the proof by Chizat & Bach (2020) via their dual convergence assumption.

4. CONCLUDING REMARKS AND OPEN PROBLEMS

Stated technically, this work provides a variety of settings where GF can achieve margins γ ntk and γ gl (and SGD, in one case, can achieve sample complexity and computation scaling nicely with γ ntk ), and whose behavior can be interpreted as GF and SGD selecting good features and achieving low test error. There are many directions for future work. Figure 2 demonstrated low rotation with 2-sparse parity and mnist; can this be proved, thereby establishing Theorem 3.3 without forcing nodes to not rotate? Looking to Table 1 for 2-sparse parity, the approaches here fail to achieve the lowest width; is there some way to achieve this with SGD and GF, perhaps even via margin analyses? Theorem 2.3 and Theorem 2.1 achieve the same sample complexity for SGD and GF, but via drastically different proofs, the GF proof being weirdly complicated; is there a way to make the two more similar? The approaches here are overly concerned with reaching a constant factor of the optimal margins; is there some way to achieve slightly worse margins with the benefit of reduced width and computation? More generally, what is the Pareto frontier of width, samples, and computation in Table 1? The margin analysis here for the logistic loss, namely Theorem 2.1, requires a long warm start phase. Does this reflect practical regimes? Specifically, does good margin maximization and feature learning occur with the logistic loss in this early phase? This issue also appears in prior linear max margin works with the logistic loss. The analyses here work best with j ∥a j v j ∥ in place of ∥W ∥ 2 ; are there more natural choices, and can this choice be used in other aspects of deep learning analysis? Christos Thrampoulidis, Ganesh R. Kini, Vala Vakilian, and Firstly, in many of the proofs, it is useful to normalized parameters: define v j := v j /∥v j ∥ and a j := sgn(a j ) := a j /|a j |. Furthermore, write ℓ i (W ) = ℓ(p i (W )) and ℓ ′ i (W ) := ℓ ′ (p i (W )) are used; since ℓ ′ i is negative, often |ℓ ′ i | is written. It is annoying to write |ℓ ′ i | over and over, however, interestingly, these nonnegative derivatives can be transformed into a notion of dual variable, which will be used throughout the proofs. Concretely, define dual variables (q i ) n i=1 q := ∇ p ℓ -1   i ℓ(p i )   = ∇ p i ℓ(p i ) ℓ ′ (ℓ -1 ( i ℓ(p i ))) = ∇ p i ℓ(p) ℓ ′ ( γ(p)) , which made use of the inverse function theorem. Correspondingly define Q := -ℓ ′ ( γ(p)), whereby -ℓ ′ i = q i Q; for the exponential loss, Q = i exp(-p i ) and i q i = 1, and while these quantities are more complicated for the logistic loss, they eventually satisfy i q i ≥ 1 (Ji & Telgarsky, 2019, Lemma 5.4, first part, which does not depend on linear predictors). Overall, these dual variables match the usual interpretation in margin problems of corresponding to examples of high error, and also relate to the Lagrange multipliers used in the definition of KKT point. On the topic of KKT points, further detail on the formalism is as follows. Firstly, (Lyu & Li, 2019) provided a definition for general L-homogeneous models, and the version here is equivalent for the simplified choice of 2-homogeneous models of the form x → j a j σ(v T j x). Given a KKT point W , the complementary slackness conditions imply min i p i (W ) ≥ 1, whereby γ(W ) = min i p i (W )/∥W ∥ 2 ≥ 1/∥W ∥ 2 , justifying the choice of 1/∥W ∥ 2 as the margin. Lastly, given any arbitrary W (not necessarily a KKT point), the optimality conditions on a j and v j hold iff they hold for any rescaling rW (since the term r appears on both sides), and thus r can be adjusted to make the complementary slackness conditions tight, justifying the definition of a KKT direction's margin. To close, this section will collect various estimates of γ ntk and γ gl . Firstly, both function classes are universal approximators, and thus the assumption can be made to work for any prediction problem with pure conditional probabilities (Ji et al., 2020) . Next, as a warmup, note the following estimates of γ ntk and γ gl , for linear predictors, with an added estimate of showing the value of working with both layers in the definition of γ ntk . Proposition B.1. . Suppose the data distribution is almost surely linearly separable: there exists ∥ū∥ = 1 and γ > 0 with yx T ū ≥ γ almost surely. 1. Choosing θ(a, v) := 0, sgn(a)ū • 1[∥(a, v)∥ ≤ 2], then Assumption 1.2 holds with γ ntk ≥ γ 32 . 2. Choosing θ(a, v) := sgn(ū T v), 0 • 1[∥(a, v)∥ ≤ 2], then Assumption 1.2 holds with γ ntk ≥ γ 16 √ d . 3. Choosing α = (1/2, -1/2) and β = (ū, -ū), then Assumption 1.1 holds with γ gl ≥ γ 2 . Proof. The proof considers the three settings separately; in each, let (x, y) be a random draw, which almost surely satisfies ūT xy ≥ γ. 1. To start, E w∼N θ θ(w), ∂w p(x, y; w) = E (a,v)∼N θ |a|ū T xyσ ′ (v T x)1[∥(a, v)∥ ≤ 2] ≥ γE (a,v)∼N θ |a|σ ′ (v T x)1[∥(a, v)∥ ≤ 2]. To control the expectation, note that with probability at least 1/2, then 1/4 ≤ |a| ≤ √ 2, and thus by rotational invariance E (a,v)∼N θ |a|σ ′ (v T x)1[∥(a, v)∥ ≤ 2] ≥ 1 8 E (a,v)∼N θ σ ′ (v T x)1[∥v∥ ≤ √ 2] ≥ 1 8 E (a,v)∼N θ σ ′ (v 1 )1[∥v∥ ≤ √ 2] ≥ 1 32 . 2. For convenience, write (a, v) = w, whereby w ∼ N w means a ∼ N a and v ∼ N v . With this out of the way, define orthonormal matrix M ∈ R d×d where the first column is ū, the second column is (Iūū T )x/∥(Iūū T )x∥, and the remaining columns are arbitrary so long as M is orthonormal, and note that M u = e 1 and M x = e 1 ūT x + e 2 r 2 where r 2 := ∥x∥ 2 -(ū T x) 2 . Then, using rotational invariance of the Gaussian, E w θ(w), ∂p(x, y; w) = yE w=(a,v) sgn(ū T v)σ(v T x)1[∥w∥ ≤ 2] = yE ∥(a,M v)∥≤2 α(M v)σ(v T M T x) = E ∥(a,v)∥≤2 ysgn(v 1 )σ(v 1 ūT xy 2 + v 2 r 2 ) = E ∥(a,v)∥≤2 ysgn(v 1 )σ(ysgn(v 1 )|v 1 |ū T xy + v 2 r 2 ) = E ∥(a,v)∥≤2 ysgn(v1)=1 v2≥0 σ(|v 1 |ū T xy + v 2 r 2 ) -σ(-|v 1 |ū T xy + v 2 r 2 ) + σ(|v 1 |ū T xy -v 2 r 2 ) -σ(-|v 1 |ū T xy -v 2 r 2 ) . Considering cases, the first ReLU argument is always positive, exactly one of the second and third is positive, and the fourth is negative, whereby yE ∥(a,v)∥≤2 α(v)σ(v T x) = E ∥(a,v)∥≤2 ysgn(v1)=1 v2≥0 |v 1 |ū T xy + v 2 r 2 + |v 1 |ū T xy -v 2 r 2 = 2E ∥(a,v)∥≤2 ysgn(v1)=1 |v 1 |ū T xy ≥ 2 γE ∥v∥≤1 ysgn(v1)=1 |v 1 | = γPr[∥(a, v)∥ ≤ 2]E |v 1 | ∥(a, v)∥ ≤ 2 , where Pr[∥(a, v)∥ ≤ 2] ≥ 1/4 since (for example) the χ 2 random variables corresponding to |a| 2 and ∥v∥ 2 have median less than one, and the expectation term is at least 1/(4 √ d) by standard Gaussian computations (Blum et al., 2017, Theorem 2.8 ).

3.. It suffices to note that

2y 2 j=1 α j σ(β T j x) = yσ(ū T x) -yσ(-ū T x) = 1[y = 1]σ(yū T x) + 1[y = -1]σ(yū T x) = yū T x ≥ γ. Next, estimates for γ gl and γ ntk on 2-sparse parity were stated in the body in Proposition 1.3. The key is that γ ntk scales with 1/d whereas γ gl scales with 1/ √ d, which suffices to yield the separations in Table 1 . The bound on γ ntk is also necessarily an upper bound, since otherwise it would be possible to beat the NTK lower bound (Wei et al., 2018) . Proof of Proposition 1.3. This proof shares ideas with (Wei et al., 2018; Ji & Telgarsky, 2020b) , though with some adjustments to exactly fit the standard 2-sparse parity setting, and to shorten the proofs. Without loss of generality, due to the symmetry of the data distribution about the origin, suppose a = 1 and b = 2, meaning for any x ∈ H d , the correct label is dx 1 x 2 , the product of the first two coordinates. Both proofs will use the global margin construction (the parameters for γ gl ), given as follows: p(x, y; (α, β)) = y 4 j=1 α j σ(β T j x), where α = (1/4, -1/4, -1/4, 1/4) and β 1 := 1 √ 2 , 1 √ 2 , 0, . . . , 0 ∈ R d , β 2 := 1 √ 2 , -1 √ 2 , 0, . . . , 0 ∈ R d , β 3 := -1 √ 2 , 1 √ 2 , 0, . . . , 0 ∈ R d , β 4 := -1 √ 2 , -1 √ 2 , 0, . . . , 0 ∈ R d . Note moreover that for any x ∈ H d , then β T j x > 0 for exactly one j, which will be used for both γ ntk and γ gl . The proof now splits into the two different settings, and will heavily use symmetry within H d and also within (α, β).

1.. Consider the transport mapping

θ (a, v) =   0, sgn(a) 2 4 j=1 β j 1[β T j v ≥ 0]   ; note that this satisfies the condition ∥θ(w)∥ ≤ 1 thanks to the factor 1/2, since each β j gets a hemisphere, and (β 1 , β 4 ) together partition the sphere once, and (β 2 , β 3 ) similarly together partition the sphere once. Now let any x be given, which as above has label y = x 1 x 2 . By rotational symmetry of the data and also the transport mapping, suppose suppose β 1 is the unique choice with β T 1 x > 0, which implies y = 1, and also β T 2 x = 0 = β T 3 x = 0, however β T 4 x = -β T 4 x. Using these observations, and also rotational invariance of the Gaussian, E a,v θ(a, v), ∂p(x, y; w) = E a,v |a| 2 4 j=1 β T j x1[β T j v ≥ 0] • 1[v T x ≥ 0] = β T 1 x E a |a| 2 • E v 1[β T 1 v ≥ 0] • 1[v T x ≥ 0] -E v 1[-β T 1 v ≥ 0] • 1[v T x ≥ 0] . Now consider E v 1[β T 1 v ≥ 0] • 1[v T x ≥ 0]. A standard Gaussian computation is to introduce a rotation matrix M whose first column is β 1 , whose second column is (I -β 1 β T 1 )x/∥(I - β 1 β T 1 )x∥, and the rest are orthogonal, which by rotational invariance and the calculation β T 1 x = 2/d gives E v 1[β T 1 v ≥ 0] • 1[v T x ≥ 0] = E v 1[β T 1 M v ≥ 0] • 1[v T M x ≥ 0] = E v 1[v 1 ≥ 0] • 1[v 1 β T 1 x + v 2 1 -(β T 1 x) 2 = E v 1[v 1 ≥ 0] • 1[v 1 + v 2 d/2 -1 ≥ 0]. Performing a similar calculation for the other term (arising from β T 4 x) and plugging all of this back in, E a,v θ(a, v), ∂p(x, y; w) = 2 d E a |a| 2 • E v 1[v 1 ≥ 0] 1[v 1 + v 2 d/2 -1 ≥ 0] -1[-v 1 + v 2 d/2 -1 ≥ 0] . To finish, a few observations suffice. Whenever v 1 ≥ 0 (which is enforced by the common first term), then -v 1 + v 2 τ ≤ v 1 + v 2 d/2 -1, so the first indicator is 1 whenever the second indicator is 1, thus their difference is nonnegative, and to lower bound the overall quantity, it suffices to asses the probability that v 1 + v 2 d/2 -1 ≥ 0 whereas -v 1 + v 2 d/2 -1 ≤ 0. To lower bound this event, it suffices to lower bound Pr[v 1 ≥ 0 ∧ v 2 ≥ 0 ∧ v 1 ≥ v 2 d/2 -1] ≥ Pr[v 1 ≥ 1/2] • Pr[0 ≤ v 2 ≤ 1/d. The first term is at least 1/5, and the second can be calculated via brute force: Pr[v 2 ≥ 1/ √ d] = 1 √ 2π 1/ √ d 0 exp(-x 2 ) dx ≥ 1 √ 2π 1/ √ d 0 exp(-1/d) dx ≥ 1 √ 2π 1 √ d 1 e , which completes the proof after similarly using E a |a| ≥ 1, and simplifying the various constants.

2.. Let any

x ∈ H d be given, and as above note that β T j x > 0 for exactly one j. By symmetry, suppose it is β 1 , whereby y = x 1 x 2 = 1, and γ gl ≥ p(x, y; (α, β)) = y j α j σ(β T j x) = |α 1 | • β T 1 x = 1 4 2 √ 2d = 1 √ 8d . To close, consider the k-sparse parity problem, the natural k-bit analog of the 2-sparse parity problem: now the target label depends on the product of k unknown input bits, but otherwise the problem is the same, meaning the data distribution is again supported on H d := {±1/ √ d} d , and only the support of the distribution is used in the margin analysis. Proof. Let P(S) range over the 2 k possible vectors which are ±1/ √ k on elements of S, and 0 otherwise, whereby v ∈ P(S) has ∥v∥ = 1. Moreover, for convenience, define a shorthand sgn(x) = sgn( i∈S x i ). With this in hand, define a target mapping h(x) := (-1) k/2-1 2 k v∈P(S) sgn(v)σ(v T x), which is of the desired form for Assumption 1.1 with  α v = sgn(v)(-1) k/2-1 /2 k and β v = v, (-1) k/2-1 sgn(v) 2 k σ(v T x) = (-1) k/2-1 k j=0 v∈P(S)[x;j] sgn(v)sgn(x) 2 k σ k -2j √ kd = (-1) k/2-1 k/2 j=0 v∈P(S)[x;j] (-1) j 2 k k -2j √ kd = (-1) k/2-1 2 k √ kd k/2 j=0 (-1) j k j (k -2j). The inner sum can now be handled via a few binomial tricks from (Graham et al., 1994, Chapter 5 ): Graham et al., 1994, eq. (5.6)) Graham et al., 1994, eq. (5.14)) Graham et al., 1994, eq. (5.9)) k/2 j=0 (-1) j k j (k -2j) = k + k/2 j=1 (-1) j k j (k -2j) = k + k k/2 j=1 (-1) j k -1 j -1 k -2j j ( = k + k k/2 j=1 (-1) j-1 k -1 j -1 2j -k j = k + k k/2 j=1 j -1 -(k -1) -1 j -1 2j -k j ( = k + k k/2 j=1 j -k -1 j -1 j -k j + j -k -1 j -1 = k k/2 j=0 j -k j + k k/2-1 j=0 j -k j = k k/2 + 1 -k k/2 + k/2 -k k/2 -1 ( = k (-1) k/2 k -2 k/2 + (-1) k/2-1 k -2 k/2 -1 . (Graham et al., 1994, eq. (5.14)) As an elementary simplification, k -2 k/2 -1 - k -2 k/2 = (k -2)! (k/2 -1)!(k/2 -1)! - (k -2)! (k/2 -2)!(k/2)! = (k -3)! (k/2 -2)!(k/2 -1)! k -2 k/2 -1 - k -2 k/2 = k -3 k/2 -1 4 k , which combines with the preceding to give k/2 j=0 (-1) j k j (k -2j) = k (-1) k/2 k -2 k/2 + (-1) k/2-1 k -2 k/2 -1 = k(-1) k/2-1 k -3 k/2 -1 4 k , which after combining with the original simplification gives Figure 3 : CDFs of norm growth on 2-sparse parity and mnist digits 3 vs 5, with three choices of width, and all other experimental details as in Figure 1 , Figure 2 , and Appendix B.2. Here, for each width m, the norm growth of node j is interpreted as √ m∥v j (t) -v j (0)∥; the √ m factor is due to the gradient of v j scaling with a j , which initially has magnitude roughly 1/ √ m, and results in overlapping CDFs. One measure of exiting the NTK, though a bit weak, is that most rescaled norms are far beyond 1; an experiment to finer accuracy could be an interesting direction for future work. yg(x) = (-1) k/2-1 2 k √ kd k/2 j=0 (-1) j k j (k -2j) = (-1) k/2-1 2 k √ kd k(-1) k/2-1 k -3 k/2 -1 4 k = 4 2 k √ kd k -3 k/2 -1 . To close, for the final estimate, if k = 4 then 4 2 k √ kd k -3 k/2 -1 = 4 2 4 √ 4d 1 1 = 1 2k √ d , and otherwise, if k > 4, note firstly via standard lower bounds on the central binomial coefficient that k -3 k/2 -1 = k -3 k/2 -2 k -2 k/2 -1 ≥ k -3 k/2 -2 2 k-2 √ k -1 ≥ 2 2 k-2 √ k -1 , and thus 4 2 k √ kd k -3 k/2 -1 ≥ 2 (k -1) √ kd ≥ 2 k √ d .

B.2 EXPERIMENTAL DETAILS

This brief section summarizes various choices used in the experiments behind Figure 1 and Figure 2 , and provides an additional companion figure Figure 3 . The mnist data was limited to classes 3 and 5 to give a binary classification problem which is not linearly separable, and otherwise unmodified. The 2-sparse parity data was uniform over H d , the corners of the rescaled hypercube as defined in Proposition 1.3, and further described in Figure 1 (e.g., n = 64 samples and d = 20 dimensions). To simulate gradient flow, full-batch gradient descent was used together with the logistic loss. Initially the step size was 0.01, but eventually the mnist plots switched to step size 1, which did not lead to any discernible change. All experiments were run until the empirical logistic risk was approximately 1/n. A companion figure to Figure 2 from the paper body is to plot the CDFs of norm changes of the inner layer; this is presented here in Figure 3 , and as detailed in the caption, also indicates an exit to the NTK, though it is unclear if it is quite to the significant level ∥W t -W 0 ∥ = O( √ m) allowed by the theoretical guarantees.

B.3 GAUSSIAN CONCENTRATION

The first concentration inequalities are purely about the initialization. Lemma B.3. Suppose a ∼ N m / √ m and V ∼ N m×d / √ d. 1. With probability at least 1 -δ, then ∥a∥ ≤ 1 + 2 ln(1/δ)/m; similarly, with probability at least 1 -δ, then ∥V ∥ ≤ √ m + 2 ln(1/δ)/d.

2.

Let examples (x 1 , . . . , x n ) be given with ∥x i ∥ ≤ 1. With probability at least 1 -4δ, max i j a j σ(v T j x i ) ≤ 4 ln(n/δ). Proof. 1. Rewrite ã := a √ m, so that ã ∼ N m . Since ã → ∥ã∥/ √ m = ∥a∥ is (1/ √ m)- Lipschitz, then by Gaussian concentration, (Wainwright, 2019, Theorem 2.26), ∥a∥ = ∥ã∥/ √ m ≤ E∥ã∥/ √ m + 2 ln(1/δ)/m ≤ E∥ã∥ 2 / √ m + 2 ln(1/δ)/m = 1 + 2 ln(1/δ)/m. Similarly for V , defining Ṽ := V √ d whereby Ṽ ∼ N m×d , Gaussian concentration grants ∥V ∥ = ∥ Ṽ ∥/ √ d ≤ √ m + 2 ln(1/δ)/d.

2.. Fix any example

x i , and constants ϵ 1 > 0 and ϵ 2 > 0 to be optimized at the end of the proof, and define d i := d/∥x i ∥foot_0 for convenience. By rotational invariance of Gaussians and since x i is fixed, then σ(V x i ) is equivalent in distribution to ∥x i ∥σ(g)/ √ d = σ(g)/ √ d i where g ∼ N m . Meanwhile, g → ∥σ(g)∥/ √ d i is (1/ √ d i )-Lipschitz with E∥σ(g)∥ ≤ √ m, and so, by Gaussian concentration (Wainwright, 2019, Theorem 2.26), Pr[∥σ(V x i )∥ ≥ ϵ 1 + √ m] = Pr[∥σ(g)∥/ d i ≥ ϵ 1 + √ m] ≤ exp -d i ϵ 2 1 2 . Next consider the original expression a T σ(V x i ). To simplify handling of the 1/m variance of the coordinates of a, define another Gaussian h := a √ m, and a new constant c i := md i for convenience, whereby a T σ(V i ) is equivalent in distribution to equivalent in distribution to h T σ(g)/ √ c i since a and V are independent (and thus h and V are independent). Conditioned on g, since Eh = 0, then E[h T σ(g)|g] = 0. As such, applying Gaussian concentration to this conditioned random variable, since h → h T σ(g)/ √ c i is (∥σ(g)∥/ √ c i )-Lipschitz, then Pr[h T σ(g)/ √ c i ≥ ϵ 2 g] ≤ exp -c i ϵ 2 Returning to the original expression, it can now be controlled via the two preceding bounds, conditioning, and the tower property of conditional expectation: Pr[h T σ(g)/ √ c i ≥ ϵ 2 ] ≤ Pr h T σ(g)/ √ c i ≥ ϵ 2 ∥σ(g)∥/ d i ≤ ϵ 1 + √ m • Pr ∥σ(g)∥/ d i ≤ ϵ 1 + √ m + Pr h T σ(g)/ √ c i ≥ ϵ 2 ∥σ(g)∥/ d i > ϵ 1 + √ m • Pr ∥σ(g)∥/ d i > ϵ 1 + √ m = E Pr[h T σ(g)/ √ c i ≥ ϵ 2 | g] ∥σ(g)∥/ d i ≤ ϵ 1 + √ m Pr[∥σ(g)∥/ d i ≤ ϵ 1 + √ m] + Pr[h T σ(g)/ √ c i ≥ ϵ 2 | ∥σ(g)∥/ d i > ϵ 1 + √ m]Pr[∥σ(g)∥/ d i > ϵ 1 + √ m] ≤ E   exp -c i ϵ 2 2 2∥σ(g)∥ 2 ∥σ(g)∥/ d i ≤ ϵ 1 + √ m   + exp -d i ϵ 2 1 /2 ≤ exp -c i ϵ 2 2 4d i ϵ 2 1 + 4d i m + exp -d i ϵ 2 1 /2 . As such, choosing ϵ 2 := 4 ln(n/δ) md i /c i = 4 ln(n/δ) and ϵ 1 := 2 ln(n/δ)/d i gives Pr[a T σ(V x i ) ≥ ϵ 2 ] = Pr[h T σ(g)/ √ c i ≥ ϵ 2 ] ≤ δ n + δ n , which is a sub-exponential concentration bound. Union bounding over the reverse inequality and over all n examples and using max i ∥x i ∥ ≤ 1 gives the final bound. Next comes a key tool in all the proofs using γ ntk : guarantees that the infinite-width margin assumptions imply the existence of good finite-width networks. Lemma B.4. Suppose the data distribution satisfies Assumption 1.2 with corresponding θ : R d+1 → R d+1 and γ ntk > 0, and let ((x i , y i )) n i=1 be an iid draw. 1. With probability at least 1 -δ over the draw of (w j ) m j=1 , defining θ j := θ(w j )/ √ m, then min i j θ j , ∂p i (w j ) ≥ γ ntk √ m -32 ln(n/δ). 2. With probability at least 1 -7δ over the draw of W with rows (w j ) m j=1 with m ≥ 2 ln(1/δ), defining rows θ j := θ(w j )/ √ m of θ ∈ R m×(d+1) , then for any W ′ and any R ≥ ∥W -W ′ ∥ and any r θ ≥ 0 and r w ≥ 0, r θ θ + r w W, ∂p i (W ′ ) -r w p i (W ′ ) ≥ γ ntk r θ √ m -r θ 32 ln(n/δ) + 8R + 4 -r w 4 ln(n/δ) + 2R + 2R √ m + 4 √ m , and moreover, writing W = (a, V ), then ∥a∥ ≤ 2 and ∥V ∥ ≤ 2 √ m. For the particular choice r θ := R/8 and r w = 1, if R ≥ 8 and m ≥ (64 ln(n/δ)/γ ntk ) 2 , then r θ θ + W, ∂p i (W ′ ) -p i (W ′ ) ≥ γ ntk r θ √ m 2 -160r 2 θ . Proof. 1. Fix any example (x i , y i ), and define µ := E w θ(w), ∂p i (w) , where µ ≥ γ ntk by assumption. By the various conditions on θ, it holds for any (a, v) := w ∈ R d+1 and corresponding (ā, v) := θ(w) ∈ R d+1 that θ(w), ∂p i (w) ≤ āσ(v T x i ) + v, ax i σ ′ (v T x i ) ≤ |ā| • 1[∥v∥ ≤ 2] • ∥v∥ • ∥x i ∥ + ∥v∥ • |a| • 1[|a| ≤ 2] • ∥x i ∥ ≤ 4. and therefore, by Hoeffding's inequality, with probability at least 1 -δ/n over the draw of m iid copies of this random variable, j θ(w j ), ∂p i (w j ) ≥ mµ -32m ln(n/δ) ≥ mγ ntk -32m ln(n/δ), which gives the desired bound after dividing by √ m, recalling θ j := θ(w j )/ √ m, and union bounding over all n examples. 2. First, suppose with probability at least 1 -7δ that the consequences of Lemma B.3 and the preceding part of the current lemma hold, whereby simultaneously ∥a∥ ≤ 2, and ∥V ∥ ≤ 2 √ m, and min i p i (W ) ≥ -4 ln(n/δ), min i j θ j , ∂p i (w j ) ≥ γ ntk √ m -32 ln(n/δ). The remainder of the proof proceeds by separately lower bounding the two right hand terms in r θ θ + r w W, ∂p i (W ′ ) -r w p i (W ′ ) = r θ θ, ∂p i (W ) + θ, ∂p i (W ′ ) -∂p i (W ) + r w W, ∂p i (W ′ ) -r w p i (W ′ ) . For the first term, writing (ā, V) = θ and noting ∥ā∥ ≤ 2 and ∥V∥ ≤ 2, then for any W ′ = (a ′ , V ′ ), θ, ∂p i (W ′ ) -∂p i (W ) ≤ j āj σ(x T i v ′ j ) -σ(v T j x i ) + j x T i vj a ′ j σ ′ (x T v ′ j ) -a j σ ′ (x T v j ) ≤ j ā2 j j σ(x T i v ′ j ) -σ(v T j x i ) 2 + j |x T i vj | • a ′ j σ ′ (x T v ′ j ) -a j σ ′ (x T v ′ j ) + j |x T i vj | • a j σ ′ (x T v ′ j ) -a j σ ′ (x T v ′ j ) ≤ ∥ā∥ • ∥V ′ -V ∥ + ∥a ′ -a∥ • ∥V∥ + ∥a∥ • ∥V∥ ≤ 4R + 4. For the second term, W, ∂p i (W ′ ) -p i (W ′ ) = a, ∂a p i (W ′ ) + V, ∂V p i (W ′ ) -V ′ , ∂V p i (W ′ ) ≤ j a j σ(x T i v ′ j ) + j a ′ j v j -v ′ j , x i σ ′ (x T i v j ) ≤ p i (w) + y i j a j σ(x T i v ′ j ) -σ(x T i v j ) + j a ′ j • ∥v j -v ′ j ∥ ≤ 4 ln(nδ) + ∥a∥ • ∥V -V ′ ∥ + ∥a ′ -a + a∥ • ∥V -V ′ ∥ ≤ 4 ln(nδ) + 4R + R 2 . Multiplying through by r θ and r and combining these inequalities gives, for every i, r θ θ + r w W, ∂p i (W ′ ) -r w p i (W ′ ) ≥ γ ntk r θ √ m -r θ 32 ln(n/δ) + 4R + 4 -r w 4 ln(n/δ) + 4R + R 2 , which establishes the first inequality. For the particular choice r θ := R/8 with R ≥ 8 and r w = 1, and using m ≥ (64 ln(n/δ)/γ ntk ) 2 , the preceding bound simplifies to r θ θ + r w W, ∂p i (W ′ ) -r w p i (W ′ ) ≥ γ ntk r θ √ m -r θ γ ntk √ m 8 + 32r θ + 32r θ - γ ntk √ m 16 + 32r θ + 64r 2 θ ≥ γ ntk r θ √ m 2 -160r 2 θ .

B.4 BASIC PROPERTIES OF L-HOMOGENEOUS PREDICTORS

This subsection collects a few properties of arbitrary L-homogeneous predictors in a setup more general than the rest of the work, and used in all large margin calculations. Specifically, suppose general parameters u t with some unspecified initial condition u 0 , and thereafter given by the differential equation ut = -∂u R(p(u t )), (B.1) where p(u) := (p 1 (u), . . . , p n (u)) ∈ R n , p i (u) := y i F (x i ; u), F (x i ; cu) = c L F (x i ; u) ∀c ≥ 0. The first property is that norms increase once there is a positive margin. Lemma B.5 (Restatement of (Lyu & Li, 2019, Lemma B.1) ). Suppose the setting of eq. (B.1) and also ℓ ∈ {ℓ exp , ℓ log }. If R(u τ ) < ℓ(0)/n, then, for every t ≥ τ , d dt ∥u t ∥ > 0 and ⟨u t , ut ⟩ > 0, and moreover lim t ∥u t ∥ = ∞. Proof. Since R is nonincreasing during gradient flow, it suffices to consider any u s with R(u s ) < ℓ(0)/n. To apply (Lyu & Li, 2019, Lemma B.1) , first note that both the exponential and logistic losses can be handled, e.g., via the discussion of the assumptions at the beginning of (Lyu & Li, 2019, Appendix A.1) . Next, the statement of that lemma is (Lyu & Li, 2019, Lemma B.1) imply the main part of the statement; all that remains to show is ∥u s ∥ → ∞, but this is given by (Lyu & Li, 2019, Lemma B.6 ). d ds ln ∥u s ∥ > 0, but note that ∥u s ∥ > 0 (otherwise R(u s ) < ℓ(0)/n is Next, even without the assumption R(u s ) < ℓ(0)/n (which at a minimum requires a two-phase proof, and certain other annoyances), note that once ∥u s ∥ is large, then the gradient can be related to margins, even if they are negative, which will be useful in circumventing the need for dual convergence and other assumptions present in prior work (e.g., as in (Chizat & Bach, 2020) ). We note that while the closest inequalities in the literature require the condition R(u s ) < ℓ(0)/n (Ji & Telgarsky, 2020a, Lemma C.5), those results aim for a more stringent goal, replacing n in the bound below with ln(n); this simpler goal is sufficient in the present work. Lemma B.6 (See also (Ji & Telgarsky, 2020a , Proof of Lemma C.5)). Suppose the setting of eq. (B.1) and also ℓ ∈ {ℓ exp , ℓ log }. Then, for any u and any ((x i , y i )) n i=1 (and corresponding R), u, -n ∂u R(u) L∥u∥ L ≤ Q γ(u) + n ∥u∥ L ≤ Q γ(u) + n ∥u∥ L . Proof. Define v := p(u) for convenience, as well as π(v) = ℓ -1 ( i ℓ(v i )) = γ(u), whereby q = ∇ p π(v), and π is (unconditionally) concave (Ji & Telgarsky, 2020a, Lemma C.8) . Combining these facts, u, -n ∂u R(u) = i -ℓ ′ (v i ) u, ∂u p i (u) = LQ i q i v i = LQ ∇ v π(v), v = LQ ∇ v π(v), v -0 ≤ LQ π(v) -π(0) . Simplifying -π(0) now proceeds separately for ℓ exp and ℓ log : for ℓ exp , then -π(0) = ln( i exp(0)) = ln(n), whereas for ℓ log , then ℓ -1 log (r) = -ln(e r -1), thus -π(0) = ln exp i ln(1 + exp(-0)) -1 = ln exp n ln 2 -1 = ln 2 n -1 ≤ n ln 2. As such, in either case, u, -n ∂u R(u) ≤ LQ π(v) -π(0) ≤ LQ [ γ + n] . Next, since ℓ is strictly decreasing in both cases, then ℓ -1 is strictly decreasing as well, whereby letting k denote the index of any example with v k = min i v i , then additionally using the positivity of ℓ gives γ = ℓ -1   i ℓ(v i )   ≤ ℓ -1 ℓ(v k ) = v k = γ(u)∥u∥ L . Combining these inequalities and dividing by L∥u∥ L gives the desired bounds. Lastly, a key abstract potential function lemma: this potential function is a proxy for mass accumulating on certain weights with good margin, and once it satisfies a few conditions, large margins are implied directly. This is the second component needed to remove dual convergence from (Chizat & Bach, 2020) . Lemma B.7. Suppose the setting of eq. (B.1) with L = 2, and ℓ ∈ {ℓ exp , ℓ log }. Then, unconditionally, lim t t 0 Q s ds = ∞. Moreover, if there exists a constant γ > 0, a time τ , and a potential function Φ(u) so that Φ(u τ ) > -∞, and for all t ≥ τ , Φ(u) ≤ 1 L ln ∥u∥ and d dt Φ(u) ≥ 1 n Q(u) γ, then it follows that R(u) → 0, and ∥u∥ → ∞, and lim inf t γ(u t ) ≥ γ. Proof. The unconditional claim t 0 Q s ds → ∞ is shown by considering two cases: either inf s R(u s ) = 0, or R(u s ) > 0 (the case inf s R(u s ) < 0 is not possible since ℓ is nonnegative). 1. First suppose inf s R(u s ) > 0. For both losses, it will be argued that inf s Q s > 0, whereby ∞ 0 Q s ds ≥ ∞ 0 inf r Q r ds = ∞. In the case of ℓ exp , then R(u s ) = 1 n Q s , and inf s Q s > 0 directly. In the case of ℓ log , note ℓ -1 log (r) = -ln(e r -1) and ℓ ′ (z) = -(1 + e z ) -1 , whereby (ℓ ′ log • ℓ -1 log )(r) = -1 1 + (e r -1) -1 = 1 -e r e r -1 + 1 = exp(-r) -1. As such, inf s Q s = inf s -(ℓ ′ log • ℓ -1 log )(n R(u s )) = inf s 1 -exp(-n R(u s )) = 1 -exp(-n inf s R(u s )) > 1 -exp(-0) = 0, meaning inf s Q s > 0 as desired. 2. If inf s R(u s ) = 0, then we can choose a time τ with R(u τ ) < ℓ(0), and it follows that margins increase monotonically and R decreases monotonically (Lyu et al., 2021) , and moreover by Lemma B.5 that ∥u s ∥ is increasing and lim s ∥u s ∥ = ∞. Furthermore, for ℓ exp , then i q i = 1, whereas for ℓ log , R(u s ) < ℓ(0) (which holds for all s ≥ τ Telgarsky, 2019, Lemma 5.4 ). As such, for any s ≥ τ , ) implies i q i (s) ∈ [1, 2] (Ji & d ds ln ∥u s ∥ 2 = 2 i |ℓ ′ i | u s , ∂p i (u s ) ∥u s ∥ 2 = 4Q s i q i (s)p i (u s ) ∥u s ∥ 2 ≤ 8Q s , whereby it follows that ∞ = lim t ln ∥u t ∥ 2 = ∞ τ ∥u s ∥ 2 ds ≤ 8 ∞ τ Q s ds, meaning ∞ τ Q s ds = ∞, whereas τ 0 Q s ds > 0 via the analysis in the preceding case, and together ∞ 0 Q s ds = ∞. Combining these two cases, then ∞ 0 Q s ds = ∞ unconditionally. Now consider the second statement, with Φ, τ, γ given. If lim inf t γ t ≥ γ > 0, then lim t γ t is well-defined and positive via nondecreasing margins, and moreover ∥u∥ → ∞ via Lemma B.5, and 0 ≤ lim sup t R(u t ) ≤ lim sup t ℓ(-γ t ∥w t ∥ L ) = 0. Alternatively, suppose contradictorily that lim inf t γ t < γ, and choose any ϵ ∈ (0, γ/4) so that lim inf t γ t < γ -3ϵ; noting that γ t is monotone once there exists some γ s > 0, choose t 1 ≥ τ large enough so that γ s ≤ γ -3ϵ for all s ≥ t 1 . Next, note that ∥u∥ → ∞ even in this situation (which may violate the conditions of Lemma B.5), since the assumptions Φ and the unconditional property ∞ 0 Q s ds = ∞ imply lim inf t 1 L ln ∥u t ∥ ≥ lim inf t Φ(u t ) -Φ(u τ ) + Φ(u τ ) = Φ(u τ ) + lim inf t t τ d ds Φ(u s ) ds ≥ Φ(u τ ) + 1 n lim inf t t τ γQ s ds = ∞, meaning ∥u s ∥ → ∞; henceforth, choose t 2 ≥ t 1 so that so that ∥u s ∥ 2 ≥ n/ϵ for all s ≥ t 2 . It follows by Lemma B.6 and the assumption L = 2 that 0 ≤ lim inf t 1 L ln ∥u t ∥ -Φ(u t ) ≤ 1 L ln ∥u t3 ∥ -Φ(u t3 ) + lim inf t t t3 d ds 1 4 ln ∥u s ∥ 2 -Φ(u s ) ds = 1 L ln ∥u t3 ∥ -Φ(u t3 ) + lim inf t t t3 d ds ⟨u s , us ⟩ 2∥u s ∥ 2 -Φ(u s ) ds ≤ 1 L ln ∥u t3 ∥ -Φ(u t3 ) + lim inf t t t3 Q s (γ s + n) n∥u s ∥ 2 - 1 n Q s γ ds ≤ 1 L ln ∥u t3 ∥ -Φ(u t3 ) + 1 n lim inf t t t3 [-ϵQ] ds = -∞, a contradiction, and since ϵ ∈ (0, γ/4) was arbitrary, it follows that lim inf γ t ≥ γ.

C PROOFS FOR SECTION 2

This section contains proofs for Section 2, all of which have a dependence on γ ntk rather than γ gl . The SGD proofs will come first, as they are easier and serve as a warmup. C.1 SGD PROOFS Before proceeding with the proof of Theorem 2.3, the following technical lemma (little more than an application of Freedman's inequality) will be a sufficient martingale concentration inequality for the test error bound. Lemma C.1 (Nearly identical to (Ji & Telgarsky, 2020b, Lemma 4  .3)). Define Q(W ) := E x,y |ℓ ′ (p(x, y; W ))| and Q i (W ) := |ℓ ′ (p(x i , y i ; W ))|. Then i<t Q(W i ) -Q i (W i ) is a martingale difference sequence, and with probability at least 1 -δ, i<t Q(W i ) ≤ 4 i<t Q i (W i ) + 4 ln(1/δ), Proof. This proof is essentially a copy of one due to Ji & Telgarsky (2020b, Lemma 4. 3); that one is stated for the analog of p i used there, and thus needs to be re-checked. Let F i := {((x j , y j )) : j < i} denote the σ-field of all information until time i, whereby x i is independent of F i , whereas w i deterministic after conditioning on F i . Consequently, E Q(W i ) -Q i (W i )|F i = 0, whereby i<t Q(W i ) -Q i (W i ) is a martingale difference se- quence. The high probability bound will now follow via a version of Freedman's inequality (Agarwal et al., 2014, Lemma 9) . To apply this bound, the conditional variances must be controlled: noting that |ℓ ′ (z)| ∈ [0, 1], then Q(W i ) -Q i (W i ) ≤ 1, and since Q i (W i ) ∈ [0, 1], then Q i (W i ) 2 ≤ Q i (W i ), and thus E Q(W i ) -Q i (W i ) 2 F i = E Q i (W i ) 2 F i -Q(W i ) 2 ≤ E Q i (W i ) F i -0 = Q(W i ). As such, by the aforementioned version of Freedman's inequality (Agarwal et al., 2014, Lemma 9)  , i<t Q(W i ) -Q i (W i ) ≤ (e -2) i<t E Q(W i ) -Q i (W i ) 2 F i + ln(1/δ) ≤ (e -2) i<t Q(W i ) + ln(1/δ), which rearranges to give (3 -e) i<t Q(W i ) ≤ i<t Q i (W i ) + ln(1/δ), which gives the result after multiplying by 4 and noting 4(3 -e) ≥ 1. With Lemma C.1 and the Gaussian concentration inequalities from Appendix B.3 in hand, the proof of Theorem 2.3 is as follows. Proof of Theorem 2.3. Let (w j ) m j=1 be given with corresponding (ā j , vj ) := θ j := θ(w j )/ √ m (whereby ∥θ j ∥ ≤ 2 by construction), and define r := 10η √ m γ ntk ≤ γ ntk √ m 640 , R := 8r = 80η √ m γ ntk ≤ γ ntk √ m 80 , W := rθ + W 0 , which implies r ≥ 1, and R ≥ 1, and η ≤ R/16. For the remainder of the proof, rule out the 7δ failure probability associated the second part of Lemma B.4, whereby simultaneously for every ∥W ′ -W 0 ∥ ≤ R, min i W , ∂p i (W ′ ) ≥ rγ ntk √ m 2 -160r 2 ≥ rγ ntk √ m 4 = γ 2 ntk m 2560 ≥ ln(t), (C.1) min i θ, ∂p i (W ′ ) ≥ γ ntk √ m -32 ln(n/δ) -4R -4 ≥ γ ntk √ m - γ ntk √ m 8 - γ ntk √ m 10 ≥ γ ntk √ m 2 , (C.2) and also ∥a 0 ∥ ≤ 2 and ∥V 0 ∥ ≤ 2 √ m. The proof now proceeds as follows. Let τ denote the first iteration where ∥W τ -W 0 ∥ ≥ R, whereby τ > 0 and max s<τ ∥W s -W 0 ∥ ≤ R. Assume contradictorily that τ ≤ t; it will be shown that this implies ∥W τ -W 0 ∥ ≤ R. Consider any iteration s < τ . Expanding the square, ∥W s+1 -W ∥ 2 = ∥W s -η ∂ℓ s (W s ) -W ∥ 2 = ∥W s -W ∥ 2 -2η ∂ℓ s (W s ), W s -W + η 2 ∂ℓ s (W s ) 2 = ∥W s -W ∥ 2 + 2ηℓ ′ s (W s ) ∂p s (W s ), W -W s + η 2 ℓ ′ s (W s ) 2 ∂p s (W s ) 2 . By convexity, ∥W s -W 0 ∥ ≤ R, and eq. (C.1), ℓ ′ s (W s ) ∂p s (W s ), W -W s = ℓ ′ s (W s ) ∂p s (W s ), W -p s (W s ) -p s (W s ) ≤ ℓ s ∂p s (W s ), W -p s (W s ) -ℓ s (W s ) ≤ ln(1 + exp(-ln(t))) -ℓ s (W s ), ≤ 1 t -ℓ s (W s ), which combined with the preceding display gives ∥W s+1 -W ∥ 2 ≤ ∥W s -W ∥ 2 + 2η 1 t -ℓ s (W s ) + η 2 ℓ ′ s (W s ) 2 ∂p s (W s ) 2 . Since this inequality holds for any s < τ , then applying the summation s<τ and rearranging gives ∥W τ -W ∥ 2 + 2η s<τ ℓ s (W s ) ≤ ∥W 0 -W ∥ 2 + 2η + s<τ η 2 ℓ ′ s (W s ) 2 ∂p s (W s ) 2 . To simplify the last term, using ∥V 0 ∥ ≤ 2 √ m and ∥a 0 ∥ ≤ 2 and ∥W s - W 0 ∥ ≤ R gives ∂p s (W ) 2 = ∥σ(V s x s )∥ 2 + j e j a i,j σ ′ (v T i,j x s )x s 2 ≤ σ(V s x s ) 2 +∥a s ∥ 2 ≤ 2∥V s -V 0 ∥ 2 + 2∥V 0 ∥ 2 + 2∥a s -a 0 ∥ 2 + 2∥a 0 ∥ 2 ≤ 2R 2 + 8m + 8, ≤ 10m, and moreover the first term can be simplified via ∥W τ -W ∥ 2 = ∥W τ -W 0 ∥ 2 -2 W τ -W 0 , W -W 0 + ∥W -W 0 ∥ 2 ≥ ∥W τ -W 0 ∥ 2 -2r∥W τ -W 0 ∥ + ∥W -W 0 ∥ 2 , whereby combining these all gives ∥W τ -W 0 ∥ 2 -2r∥W τ -W 0 ∥ + ∥W -W 0 ∥ 2 + 2η s<τ ℓ s (W s ) ≤ ∥W τ -W ∥ 2 + 2η s<τ ℓ s (W s ) ≤ ∥W 0 -W ∥ 2 + 2η + s<τ η 2 ℓ ′ s (W s ) 2 ∂p s (W s ) 2 ≤ ∥W 0 -W ∥ 2 + 2η + 10η 2 m s<τ |ℓ ′ s (W s )|, which after canceling and rearranging gives ∥W τ -W 0 ∥ 2 + 2η s<τ ℓ s (W s ) ≤ 2r∥W τ -W 0 ∥ + 2η + 10η 2 m s<τ |ℓ ′ s (W s )|. To simplify the last term, note by eq. (C.2) that ∥W τ -W 0 ∥ = sup ∥W ∥≤1 ⟨W, W τ -W 0 ⟩ ≥ 1 2 -θ, W τ -W 0 = η 2 s<τ -θ, ∂ℓ s (W s ) = η 2 s<τ |ℓ ′ s (W s )| θ, ∂p i (W s ) ≥ η 2 s<τ |ℓ ′ s (W s )| γ ntk √ m 2 , (C.3) and thus, by the choice of R, and since ∥W τ -W 0 ∥ ≥ 1 and η ≤ R/16, ∥W τ -W 0 ∥ 2 + 2η s<t ℓ s (W s ) ≤ 2r∥W τ -W 0 ∥ + 2η + 40η √ m∥W τ -W 0 ∥ γ ntk ≤ R 4 + R 8 + R 2 ∥W τ -W 0 ∥. Dropping the term 2η s<t ℓ s (W s ) ≥ 0 and dividing both sides by ∥W τ -W 0 ∥ ≥ R > 0 gives ∥W τ -W 0 ∥ ≤ R 4 + R 8 + R 2 < R, the desired contradiction, thus τ > t and all above derivations hold for all s ≤ t. To finish the proof, combining eq. (C.3) with ∥W t -W 0 ∥ ≤ R = 80η √ m/γ ntk gives s<t |ℓ ′ s (W s )| ≤ 4∥W t -W 0 ∥ ηγ ntk √ m ≤ 320 γ 2 ntk . Lastly, for the generalization bound, defining Q(W ) := E x,y |ℓ ′ (p(x, y; W ))|, discarding an additional δ failure probability, by Lemma C.1, s<t Q(W s ) ≤ 4 ln(1/δ) + 4 s<t |ℓ ′ s (W s )| ≤ 4 ln(1/δ) + 1280 γ 2 ntk . Since 1[p s (W s ) ≤ 0] ≤ 2|ℓ ′ s (W s )|, the result follows. It remains to argue that ∥W 1 -W 0 ∥ is large; to this end, it already holds by instantiating eq. (C.3) with τ = 1 that ∥W 1 -W 0 ∥ ≥ ηγ ntk |ℓ ′ 0 (W 0 )| √ m 4 , so it only remains to show that |ℓ ′ 0 (W 0 )| is not too small. By Lemma B.3, discarding an additional failure probability δ, it holds that |F (x 0 ; W 0 )| ≤ 4 ln(1/δ), and therefore |ℓ ′ 0 (W 0 )| = 1 1 + exp(p 0 (W 0 )) ≥ 1 1 + 1/δ 4 ≥ δ 4 , which combines to give ∥W 1 -W 0 ∥ ≥ ηγntkδ 4 √ m 8 as desired.

C.2 GF PROOFS

This section culminates in the proof of Theorem 2.1, which is broken into a few main lemmas: Lemma C.3 first controls the empirical risk R similarly to the proof of Theorem 2.3, then Lemma C.4 establishes large margins, whereby Lemma C.6 develops a suitable Rademacher complexity bound, which combine to quickly give the proof of Theorem 2.1. Before proceeding with the main proofs, the following technical lemma is used to convert a bound on ℓ ′ to a bound on ℓ. Lemma C.2. For ℓ ∈ {ℓ log , ℓ exp }, then |ℓ ′ (z)| ≤ 1/8 implies ℓ(z) ≤ 2|ℓ ′ (z)|. Proof. If ℓ = ℓ exp , then ℓ ′ = -ℓ, and thus ℓ(z) ≤ 2|ℓ ′ (z)| automatically. If ℓ(z) = ℓ log , the logistic loss, then |ℓ ′ (z)| ≤ 1/8 implies z ≥ 2. By the concavity of ln(•), for any z ≥ 2, since 1 + e -z ≤ 7/6, then ℓ(z) = ln(1 + e -z ) ≤ e -z ≤ (7/6)e -z 1 + e -z ≤ 2|ℓ ′ (z)|, thus completing the proof. Next comes the proof of Lemma C.3, which follows the same proof plan as Theorem 2.3. Lemma C.3. Suppose the data distribution satisfies Assumption 1.2 for some γ ntk > 0, let time t be given, and suppose width m satisfies m ≥ 640 ln(t/δ) γ ntk

2

. Then, with probability at least 1 -7δ, the GF curve (W s ) s∈[0,t] on empirical risk R with loss ℓ ∈ {ℓ log , ℓ exp } satisfies R(W t ) ≤ 1 5t , (training error bound), sup s<t ∥W s -W 0 ∥ ≤ γ ntk √ m 80 , (norm bound). Note that this bound is morally equivalent to the SGD bound in Theorem 2.3 after accounting for the γ 2 ntk "units" arising from the step size. Proof of Lemma C.3. This proof is basically identical to the SGD in Theorem 2.3. Despite this, proceeding with amnesia, let rows (w j ) m j=1 of W 0 be given with corresponding (ā j , vj ) := θ j := θ(w j )/ √ m (whereby ∥θ j ∥ ≤ 2 by construction), and define r := γ ntk √ m 640 , R := 8r = γ ntk √ m 80 , W := rθ + W 0 , with immediate consequences that r ≥ 1 and R ≥ 8. For the remainder of the proof, rule out the 7δ failure probability associated with the second part of Lemma B.4, whereby simultaneously for every ∥W ′ -W 0 ∥ ≤ R, min i W , ∂p i (W ′ ) ≥ rγ ntk √ m 2 -160r 2 ≥ rγ ntk √ m 4 = γ 2 ntk m 2560 ≥ ln(t), (C.4) min i θ, ∂p i (W ′ ) ≥ γ ntk √ m -32 ln(n/δ) -4R -4 ≥ γ ntk √ m - γ ntk √ m 8 - γ ntk √ m 10 ≥ γ ntk √ m 2 . (C.5) The proof now proceeds as follows. Let τ denote the earliest time such that ∥W τ -W 0 ∥ = R; since W s traces out a continuous curve and since R > 0 = ∥W 0 -W 0 ∥, this quantity is well-defined. As a consequence of the definition, sup s<τ ∥W s -W 0 ∥ ≤ R. Assume contradictorily that τ ≤ t; it will be shown that this implies ∥W τ -W 0 ∥ < R. By the fundamental theorem of calculus (and the chain rule for Clarke differentials), convexity of ℓ, and since ∥W s -W 0 ∥ ≤ R holds for s ∈ [0, τ ), which implies eq. (C.4) holds, ∥W τ -W ∥ 2 -∥W 0 -W ∥ 2 = τ 0 d ds ∥W s -W ∥ 2 ds = τ 0 2 Ẇs , W s -W ds = 2 n τ 0 i ℓ ′ i (W s ) ∂p i (W s ), W s -W ds = 2 n τ 0 i ℓ ′ i (W s ) ∂p i (W s ), W -p i (W s ) -p i (W s ) ds ≤ 2 n τ 0 i ℓ i ∂p i (W s ), W -p i (W s ) -ℓ i (W s ) ds ≤ 2 n τ 0 i 1 t -ℓ i (W s ) ds ≤ 2 -2 τ 0 R(W s ) ds. To simplify the left hand side, ∥W τ -W ∥ 2 -∥W 0 -W ∥ 2 = ∥W τ -W 0 ∥ 2 -2 W τ -W 0 , W -W 0 ≥ ∥W τ -W 0 ∥ 2 -2r∥W τ -W 0 ∥, which after combining, rearranging, and using r ≥ 1 and ∥W τ - W 0 ∥ ≥ R ≥ 1 gives ∥W τ -W 0 ∥ 2 + 2 τ 0 R(W s ) ds ≤ 2 + 2r∥W τ -W 0 ∥ ≤ 4r∥W τ -W 0 ∥, which implies ∥W τ -W 0 ∥ ≤ 2r = R 2 < R, a contradiction since W τ is well-defined as the earliest time with ∥W τ -W 0 ∥ = R, which thus contradicts τ ≤ t. As such, τ ≥ t, and all of the preceding inequalities follows with τ replaced by t. To obtain an error bound, similarly to the key perceptron argument before, using eq. (C.5), ∥W t -W 0 ∥ = sup ∥W ∥≤1 ⟨W, W t -W 0 ⟩ ≥ 1 2 -θ, W t -W 0 = 1 2 -θ, t 0 ȧs ds = 1 2n t 0 i |ℓ ′ i (W s )| θ, ∂p i (W s ) ds ≥ γ ntk √ m 4n t 0 i |ℓ ′ i (W s )| ds, which implies 1 n t 0 i |ℓ ′ i (W s )| ds ≤ 4∥W t -W 0 ∥ γ ntk √ m ≤ 1 20 , and in particular inf s∈[0,t] 1 n i |ℓ ′ i (W s )| ≤ 1 tn t 0 i |ℓ ′ i (W s )| ds ≤ 4∥W t -W 0 ∥ tγ ntk √ m ≤ 1 20t and so there exists k ∈ [0, t] with 1 n i |ℓ ′ i (W k )| ≤ 1 10t . Since this also implies max i |ℓ ′ i (W k )| ≤ n/(10t) ≤ 1/10, it follows by Lemma C.2 that R(W k ) ≤ 1/(5t) , and the claim also holds for t ′ ≥ t since the empirical risk is nonincreasing with gradient flow. Next is the explicit maximum margin guarantee, which was missing from the SGD analysis. Lemma C.4. Suppose the data distribution satisfies Assumption 1.2 with margin γ ntk > 0 and parameter mapping θ, and let ((x i , y i ))  ′ -W 0 ∥ ≤ R, then min i θ, ∂p i (W ′ ) ≥ γ ntk √ m - 32 ln(n/δ) + 8R + 4 ≥ γ ntk √ m 2 , where θ j := θ(w j )/ √ m as usual, and ∥θ∥ ≤ 2; for the remainder of the proof, suppose these bounds, and discard the corresponding 7δ failure probability. Moreover, for any W ′ with R(W ′ ) < ℓ(0)/n and ∥W ′ -W 0 ∥ ≤ R, as a consequence of the preceding lower bound and also the property predictors), ∂ γ(W ′ ) = sup ∥W ∥≤1 W, ∂ γ(W ′ ) ≥ 1 2 θ, i q i ∂p i (W ′ ) = 1 2 i q i θ, ∂p i (W ′ ) ≥ γ ntk √ m 4 i q i ≥ γ ntk √ m 4 . Now consider the given W τ with R(W τ ) < ℓ(0)/n and ∥W τ -W 0 ∥ ≤ R/2. Since s → W s traces out a continuous curve and since norms grow monotonically and unboundedly after time τ (cf. Lemma B.5), then there exists a unique time r with ∥W t -W 0 ∥ = R. Furthermore, since R is nonincreasing throughout gradient flow, then R(W s ) < ℓ(0)/n holds for all s ∈ [τ, t]. Then γ(W t ) -γ(W τ ) = t τ ∂ γ(W s ), Ẇs ds = t τ ∥ ∂ γ(W s )∥ • ∥ Ẇs ∥ ds ≥ γ ntk √ m 4 t τ ∥ Ẇs ∥ ds ≥ γ ntk √ m 4 t τ Ẇs ds = γ ntk √ m 4 ∥W t -W τ ∥ ds ≥ γ ntk R √ m 8 ≥ γ 2 ntk m 256 . Since ∥W 0 ∥ ≤ 3 √ m, thus ∥W t ∥ ≤ 3 √ m+γ ntk √ m/32 ≤ 4 √ m, and the normalized margin satisfies γ(W t ) ≥ γ(W τ ) ∥W t ∥ 2 + 1 ∥W t ∥ 2 t τ d ds γ(W s ) ds ≥ 0 + γ 2 ntk m/256 16m = γ 2 ntk 4096 . Furthermore, it holds that γ(W s ) ≥ γ(W t ) for all s ≥ t (Lyu & Li, 2019) , which completes the proof for W t under the standard parameterization.

Now consider the rebalanced parameters W

t := (a t / √ γ ntk , V t √ γ ntk ); since m ≥ 256/γ 2 ntk , which means 16 ≤ γ ntk √ m, then ∥a t ∥ ≤ ∥a 0 ∥ + ∥a t -a 0 ∥ ≤ 2 + R ≤ γ ntk √ m 8 + γ ntk √ m 32 ≤ γ ntk √ m 4 , ∥V t ∥ ≤ ∥V 0 ∥ + ∥V t -V 0 ∥ ≤ 2 √ m + R ≤ 3 √ m, then the rebalanced parameters satisfy ∥ W t ∥ ≤ ∥a t / √ γ ntk ∥ + ∥V t √ γ ntk ∥ ≤ √ γ ntk m 4 + 3 √ γ ntk m ≤ 4 √ γ ntk m, and thus, for any (x, y), since p(x, y; W t ) = j a j (t)σ(v j (t) T x) = j a j (t) √ γ ntk σ( √ γ ntk v j (t) T x) = p(x, y; W t ), then γ( W t ) = min i p i ( W t ) ∥ W t ∥ 2 = min i p i (W t ) ∥ W t ∥ 2 ≥ γ 2 m/256 16γ ntk m ≥ γ ntk 4096 , and lastly to complete the proof note by AM-GM that ∥ W t ∥ 2 = 1 γ ntk ∥a t ∥ 2 + γ ntk ∥V t ∥ 2 ≥ 2∥a t ∥ • ∥V t ∥, whereby γ( W t ) = γ(Wt) ∥ Wt∥ 2 ≤ γ(Wt) 2∥at∥•∥Vt∥ . Next comes a margin-based Rademacher complexity bound; as Rademacher complexity has not been used or defined within this work, here is a brief description of the main definition, with further detail deferred tp standard references (Shalev-Shwartz & Ben-David, 2014) . First, for a given set of vectors V ⊆ R n , the Rademacher complexity Rad(V ) is Rad(V ) = 1 n E ϵ sup u∈V ⟨ϵ, u⟩ , where ϵ ∈ {±1} n has iid Rademacher coordinates, meaning Pr[ ϵ i = +1] = 1 2 = Pr[ϵ i = -1]. The set V will typically be the set of outputs of some class of predictors G on a finite sample X = (x i ) n i=1 of size n, using the notation G |X = g(x 1 ), . . . , g(x n ) : g ∈ G ⊆ R n . Our bound below will replace G with a variety of bounded two-layer networks of any width. This bound can be viewed as a strengthening of the proofs of (Vardi et al., 2022) , where the bound here holds for all widths simultaneously (with no dependence on width), and is normalized by the tighter quantity j ∥a j v j ∥. Lemma C.5. For any B ≥ 0 and any X = (x i ) n i=1 with ∥x i ∥ ≤ 1, Rad x → F (x; W ) : m ≥ 0, W ∈ R m×(d+1) , ∥W ∥ 2 ≤ 2B |X ≤ Rad x → F (x; W ) : m ≥ 0, (a, V ) = W ∈ R m×(d+1) , ∥a∥ • ∥V ∥ ≤ B |X ≤ Rad       x → F (x; W ) : m ≥ 0, (a, V ) = W ∈ R m×(d+1) , j ∥a j v j ∥ ≤ B    |X    ≤ 2B √ n . Proof. The first two inequalities are easier, and follow by set inclusion. In detail, note for any fixed W = (a, V ), by Cauchy-Schwarz and AM-GM, that j ∥a j v j ∥ = j |a j | • ∥v j ∥ ≤ ∥a∥ • ∥V ∥ ≤ 1 2 ∥a∥ 2 + ∥V ∥ 2 = ∥W ∥ 2 2 , which implies for each m the inclusions W ∈ R m×(d+1) : ∥W ∥ 2 ≤ 2B ⊆ (a, V ) = W ∈ R m×(d+1) : ∥a∥ • ∥V ∥ ≤ B ⊆    (a, V ) = W ∈ R m×(d+1) : j ∥a j v j ∥ ≤ B    , which in turn implies the first two Rademacher inequalities in the statement since Rademacher complexity can not decrease with the growth of sets. For the final inequality, recall the definition of symmetric convex hull sconv(•) as used throughout Rademacher complexity (Shalev-Shwartz & Ben-David, 2014) : sconv(S) :=    m j=1 p j u j : m ≥ 0, p ∈ R m , ∥p∥ 1 ≤ 1, u j ∈ S    . Then, recalling the notation a j and v j for normalized counterparts to a j and v j , the set of vectors U in the final Rademacher term can be written U :=    x → j a j σ(v T j x) : m ≥ 0, W ∈ R m×(d+1) , j ∥a j v j ∥ ≤ B    |X =    x → j ∥a j v j ∥ aσ( v T j x) : m ≥ 0, W ∈ R m×(d+1)    |X =    x →   k ∥a k v k ∥   j ∥a j v j ∥ k ∥a k v k ∥ aσ( v T j x) : m ≥ 0, W ∈ R m×(d+1) , j ∥a j v j ∥ ≤ B    |X = B    x → j p j σ( v T j x) : m ≥ 0, W ∈ R m×(d+1) , p ∈ R m , ∥p∥ 1 ≤ 1    |X = B • sconv x → σ(v T x) : ∥v∥ 2 = 1 |X , and by standard rules of Rademacher complexity (Shalev-Shwartz & Ben-David, 2014) , Rad(U ) ≤ B • Rad sconv x → σ(v T x) : ∥v∥ 2 = 1 |X ≤ 2B • Rad x → σ(v T x) : ∥v∥ 2 = 1 |X ≤ 2B √ n . The large margin generalization bound is now an immediate consequence of Lemma C.5 and a refined margin-based Rademacher bound due to Srebro et al. (2010, Theorem 5) . This bound will use the refined normalized margin Proof. Combining a refined Rademacher-based margin bound due to (Srebro et al., 2010, Theorem 5) with the 2-layer Rademacher complexity estimate from Lemma C.5 gives, with probability at least 1 -δ, for every margin level γ 2 > 0, for every W ∈ R m×(d+1) with γ 1 (W ) ≥ γ 2 , defining Then, for the point z 1 , since 1 -γ 2 0 > γ 0 , it follows that E θ((a, v)), ∂p(z 1 , +1; (a, v)) = E1[v 2 ≥ |v 1 | ∧ ∥(a, v)∥ ≤ 2] 0 + |a|σ ′ (γ 0 v 1 + 1 -γ 2 0 v 2 ) 1 -γ 2 0 + E1[v 2 < 0 ∧ ∥(a, v)∥ ≤ 2] 0 -|a|σ ′ (γ 0 v 1 + 1 -γ 2 0 v 2 ) 1 -γ 2 0 ≥ E|a|1[v 2 ≥ |v 1 | ∧ ∥(a, v)∥ ≤ 2] 1 -γ 2 0 ≥ 1 -γ 2 0 16 , where the last step follows from standard Gaussian computations. The case for z 2 is analogous, which establishes that Assumption 1.1 holds with γ ntk ≥ 1 -γ 2 0 /16, which implies the result by applying Theorem 2.1, using the simplification 1 -γ 2 0 ≥ 1/2 since γ 0 ≤ 1/4, and lastly obtaining convergence to KKT directions via (Lyu & Li, 2019; Ji & Telgarsky, 2020a) .

D PROOFS FOR SECTION 3

This section develops the proofs of Theorem 3.2 and Theorem 3.3. Before proceeding, here is a quick sampling bound which implies there exist ReLUs pointing in good directions at initialization, which is the source of the exponentially large widths in the two statements. then with probability at least 1 -δ, for every (α k , β k ) there exists (a j , v j ) with sgn(α k ) = sgn(a j ) and ∥β k -v j ∥ ≤ ϵ (equivalently, v T j β k ≥ 1 -ϵ 2 /2). Proof. By standard sampling estimates (Ball, 1997, Lemma 2.3), for any fixed k and j, then Pr[∥ v j -β k ∥ ≤ ϵ] ≥ 1 2 ϵ 2 d-1 , and since all ((a j , v j )) m j=1 are iid, Pr[∃j sgn(α k ) = sgn(a j ) ∧ ∥ v j -β k ∥ ≤ ϵ] = 1 -Pr[∀j sgn(α k ) ̸ = sgn(a j ) ∨ ∥ v j -β k ∥ > ϵ] = 1 -Pr[sgn(α k ) ̸ = sgn(a 1 ) ∨ ∥ v 1 -β k ∥ > ϵ] m = 1 -1/2 + (1/2) • (1 -Pr[∥ v 1 -β k ∥ ≤ ϵ]) m ≥ 1 -1 -(ϵ/2) d-1 /4 m ≥ 1 -exp - m 4 (ϵ/2) d-1 ≥ 1 - δ r , and union bounding over all (β k ) r k=1 gives the first claim; for the alternative form, it suffices to note ∥ v j -β k ∥ 2 = 2 -2 v T j β k and to rearrange. First comes the proof of Theorem 3.2, whose entirety is the construction of a potential Φ and a verification that it satisfies the conditions in Lemma B.7. Proof of Theorem 3.2. The method of proof is to define a potential Φ as Φ(W ) := 1 4 k |α k | ln j ϕ k,j ∥a j v j ∥, where (heavily dropping time indices and even the argument w j to reduce clutter) ϕ k,j (w j ) := ϕ k,j := ϕ α k a j σ(v T j β k ) -(1 -ϵ)∥a j v j ∥ , ϕ(z) := max{0, min{1, z}}, and to then verify the conditions of Lemma B.7 with τ = 0 and γ := γnc-ϵ 2r , where the test error bound follows by Lemma C.6. By the lower bound on m and Lemma D.1, it follows that Φ(W 0 ) > -∞, and moreover, for any t ≥ 0, by AM-GM, = 2 a j v j , ȧj v j + a j vj 2 a j v j , a j v j 1/2 = a j v j ,i ℓ ′ i y i v j σ(v T j x i ) + a 2 j σ ′ (v T j x i )x i n∥a j v j ∥ Φ(W t ) ≤ 1 2 k |α k | ln j ϕ k,j ∥w j ∥ 2 ≤ 1 2 k |α k | ln ∥W t ∥ 2 =

=

i ℓ ′ i p i (w j )∥w j ∥ 2 n∥a j v j ∥ = -1 n i ℓ ′ i a j σ( v T j x i )∥w j ∥ 2 , then (d/ dt)Φ(W t ) can be lower bounded as d dt Φ(w) = 1 4 k |α k | j ϕ k,j d dt ∥a j v j ∥ + ∥a j v j ∥ d dt ϕ k,j j ϕ k,j ∥a j v j ∥ ≥ 1 4 k |α k | -i ℓ ′ i y i j ϕ k,j a j σ( v T j x i )∥w j ∥ 2 j ϕ k,j ∥a j v j ∥ = 1 4n Q k |α k | i∈S k q i y i j ϕ k,j α k σ ( v j -β k + β k ) T x i ∥w j ∥ 2 j ϕ k,j ∥a j v j ∥ ≥ 1 4n Q k α k i∈S k q i j ϕ k,j (γ nc -ϵ) 2∥a j v j ∥ j ϕ k,j ∥a j v j ∥ ≥ 1 n Q γ nc -ϵ 2r ; the rest of the proof will establish eq. (D.1). Note moreover that eq. (D.1) has an explicit interpretation as nodes getting trapped in good directions. The first property is that a 2 k = b 2 k for all times t; this follows directly, from the initial condition a k (0) 2 = b k (0) 2 = 1/ √ m, since at any later time t it holds that a k (t) 2 -b k (t) 2 = a k (t) 2 -a k (0) 2 -b k (t) 2 + b k (0) 2 = t 0 a k ȧk -b k ḃk ds = 1 n t 0 i |ℓ ′ i | a k σ(b k v T k x i ) -a k σ ′ (b k v T k x i )v T k x i ds = 0. This also implies that a 2 k + b 2 k = 2a 2 k = 2|a k | • |b k | throughout. Next, for each β k , choose j so that ∥ b j (0) v j -β k ∥ ≤ ϵ = γ gl /2 and a j (0) = sgn(α k ); this holds with probability at least 1 -δ via Lemma D.1 since v j (0) v j is equivalent in distribution to sampling v j alone. For the rest of the proof, reorder the weights ((a j , b j , v j )) m j=1 so that each (α k , β k ) is associated with (a k , b k , v k ). Moreover, it will be shown later in the proof that ∥ b j (t) v j -β k ∥ ≤ ϵ and a j (t) = sgn(α k ) in fact hold for all t.  |ℓ ′ i |y i |α k | a k σ(b k v T k x i ) a 2 k + b 2 k = 1 n Q k i q i y i α k |a k |∥b k v k ∥σ( b k v T k x i ) 2|a k | • |b k | = 1 n Q k i q i y i α k σ ( b k v k -β k + β k ) T x i ≥ 1 n Q k i q i y i α k σ β T k x i - 1 n Q k i q i |α k | b k v k -β k ≥ 1 n Qγ gl i q i - ϵ n Q k i q i |α k | = 1 n Q i q i γ gl -ϵ = Qγ gl 2n > 0, which establishes the desired lower bound on (d/ dt)Φ for t ∈ [0, T ), but moreover establishes (after integrating along [0, T )) that Φ(W T ) ≥ Φ(W 0 ) > -∞, which means there can not exist k with a k (T ) = 0, since that would mean b k (T ) = 0 as well (as above) and thereby Φ(W T ) = -∞. Consequently, T = ∞ and (d/ dt)Φ ≥ Qγ gl /(2n) holds for all t, and the proof is complete.



2∥σ(g)∥ 2 . i q i (W ′ ) ≥ 1 Ji & Telgarsky (2019, Lemma 5.4, first part, which does not depend on linear



As purely technical contributions, this work provides new tools to analyze low-width networks near initialization (cf. Lemmas B.4 and C.4), a new versatile generalization bound technique (cf. Lemma C.5), and a new potential function technique for global margin maximization far from initialization (cf. Lemma B.7 and applications thereof).

(a) Trajectories (|aj|vj) m j=1 with m = 16. (b) Trajectories (|aj|vj) m j=1 with m = 256.

Figure 1: Two-dimensional projection of n = 64 samples drawn from the 2-sparse parity distribution in d = 20 dimensions (cf. Proposition 1.3), with red and blue circles respectively denoting negative and positive examples. Red paths correspond to trajectories |a j |v j across time with a j < 0, whereas blue paths have a j > 0.

Rotation for mnist, m ∈ {256, 1024, 4096}.

Figure 2: Cumulative distribution functions (CDFs) for rotation on 2-sparse parity and mnist digits 3 vs 5, with three choices of width; both are run with small step size and full batch gradient descent until the empirical logistic risk is 1/n, and the 2-sparse parity plots match the same invocation which gave Figure 1. To measure rotation, for any given width m, per-node rotations vj (0)

Let k ≥ 4 be an even integer, and consider any k-sparse parity data distribution: inputs are supported on H d := {±1/ √ d} d (as in Proposition 1.3), and for any x ∈ H d , the label is the product of k fixed coordinates: y := d k/2 i∈S x i with |S| = k. Then Assumption 1.1 holds with γ gl ≥ 1 2k √ d .

and moreover v |α v | = 1 and ∥β v ∥ 2 = 1. Now let any x ∈ H d with corresponding label y = sgn(x) be given. To develop a first simplification of the margin yg(x), let P(S)[x; j] denote v ∈ P(S) where the signs of x and v have j disagreements within the support of S, whereby |P(S)[x; j]| = k j and yg(x) = sgn(x) v∈P(S)

for mnist, m ∈ {256, 1024, 4096}.

and AM-GM imply γ 1 (W ) ≥ mini pi(W ) ∥a∥•∥V ∥ ≥ 2γ(W ) as in the proof of Lemma C.5.Lemma C.6. With probability at least 1 -δ over the draw of ((x i , y i )) n i=1 , for every width m, every choice of weights (a, V ) = W ∈ R m×(d+1) with γ 1 (W ) > 0 (cf. eq. (C.6)) satisfiesPr[p(x, y; W ) ≤ 0] ≤ O ln(n) 3 nγ 1 (W ) 2 + ln 1 δ n .

Let ϵ > 0 and ((α k , β k )) r k=1 with (α k , β k ) ∈ R × R d be given with ∥β k ∥ = 1, and suppose ((a j , v j )) mj=1 are sampled iid so that with Pr[sgn(a j ) = +1] = Pr[sgn(a j ) = -1] = 1/2 and v j is distributed uniformly on the surface of the unit sphere in R d (e.g., sample v j ∼ N d and choose v j := v j /∥v j ∥). If

ln a 2 k + b 2 k .Note directly that Φ(W 0 ) > -∞ by the above application of Lemma D.1 and choice of ((a k , b k )) r k=1 , and moreover that remains to verify (d/ dt)Φ(W t ) ≥ Qγ gl /(2n).To this end, let T denote the earliest time such that a k (T ) = 0 for some k ∈ {1, . . . , r}, which also means b k (T ) = 0 for that k and moreover T is the earliest time b k ′ (T ) = 0 for any k ′ ∈ {1, . . . , r} since a 2 k = b 2 k unconditionally for all t. Then, for any t ∈ [0,

3, as well as the training error analysis in Lemma C.3 were both established with a variant Performance on 2-sparse parity by a variety of works, loosely organized by technique; see Section 1.2 for details. Briefly, m denotes width, n denotes total number of samples (across all iterations), and t denotes the number of algorithm iterations. Overall, the table exhibits many tradeoffs, and there is no single best method.

Theorem 2.3. Suppose the data distribution satisfies Assumption 1.2 for some γ ntk > 0, let time t be given, and suppose width m and step size η satisfy Then, with probability at least 1 -8δ, the SGD iterates (W s ) s≤t with logistic loss ℓ = ℓ log satisfy

Tina Behnia. Imbalance trouble: Revisiting neural-collapse geometry, 2022. URL https://arxiv.org/abs/2208.05512. Gal Vardi, Ohad Shamir, and Nathan Srebro. The sample complexity of one-hidden-layer neural networks, 2022. URL https://arxiv.org/abs/2202.06233.

n i=1 be an iid draw. Let (W s ) s≥0 denote the GF curve resulting from loss ℓ ∈ {ℓ log , ℓ exp }. Suppose the width m satisfies Before discussing the proof, a few remarks are in order. Firstly, the final large margin iterate W t is stated as explicitly achieving some distance from initialization; needing such a claim is unsurprising, as the margin definition requires a lot of motion in a good direction to clear the noise in W 0 . In particular, it is unsurprising that moving O( √ m) is needed to achieve a good margin, given that the initial weight norm is O( √ m); analogously, it is not surprising that Lemma C.3 can not be used to produce a meaningful lower bound on γ(W τ ) directly. Lastly, while these comments seem natural when normalizing by ∥W t ∥ 2 , the normalization ∥a t ∥ • ∥V t ∥ does not obviously have these deficiencies.Another thing to highlight is that the use of Lemma C.3 for warm start is only needed for ℓ log , and not for ℓ exp ; it is unclear how this discrepancy translates to practice, where ℓ log dominates.

1 2 ln ∥W t ∥, whereby it only remains to show dΦ/ dt ≥ Q γ/n. To this end, note that if we could show

ACKNOWLEDGMENTS

MT thanks Fanny Yang, the Simons Institute, and the NSF (grant IIS-1750051).

annex

which implies the desired statement.Thanks to Lemmas C.3 and C.4, the proof of Theorem 2.1 is now immediate.Proof of Theorem 2.1. As in the statement, define R := γ ntk √ m/32; the analysis now uses two stages. The first stage is handled by Lemma C.3 run until time τ := n, whereby, with probability at least 1 -7δ,The second stage now follows from Lemma C.4: since W τ as above satisfies all the conditions of Lemma C.4, there exists W t with ∥W t -W 0 ∥ = R, and γ(W t )/(∥a t ∥ • ∥V t ∥) ≥ γ ntk /2048, and since even these mixed-norm margins are nondecreasing (Lyu & Li, 2019, Section H) , the claim also holds for all W s with s ≥ t, and the generalization bound follows from Lemma C.6, using j ∥a j v j ∥ ≤ ∥a∥•∥V ∥ and min i p i (W ) ≥ γ(W ), which holds for both {ℓ log , ℓ exp }. Lastly, since ∥W s ∥ → ∞ and since lim s→∞ ∥a s ∥/∥V s ∥ = 1 (Du et al., 2018a) , it follows that lim s→∞ ∥W s ∥ 2 /(∥a s ∥ • ∥V s ∥) = 2, which gives the final claim.Lastly, the proof of Corollary 2.2, giving a simple construction where GF escapes bad KKT points.Proof of Corollary 2.2. First it is shown that the provided choice of (a, V ) with a j = 1 and v j = (1, 0) is a KKT direction. With probability 1 -2 1-n , both elements of the support of the distribution are sampled, and for convenience reorder the sampled data so that x 1 = z 1 , and x 2 = z 2 , and the other data are arbitrary (though y i = +1 for all examples). It will be shown that the choice λ 1 = λ 2 = 1/(2γ 0 ) and λ i = 0 for i ≥ 3 are a valid choice of Lagrange multipliers, certifying that (a, V ) is a KKT direction. Firstly, derivative condition is easy to check since the Clarke differential is evaluated where the ReLU is differentiable, and it holds directly thatLastly, note that p 1 (W ) = mγ 0 = p 2 (W ), therefore the rescaling W := W/ √ mγ 0 correctly satisfies p 1 ( W ) = 1 = p 2 ( W ), whereby W is a KKT point with margin mγ 0 /∥W ∥ 2 = γ 0 /2, and W is a KKT direction with margin γ 0 /2.For the GF guarantee, it suffices to provide a quick estimate for γ ntk and invoke Lemma C.3. Specifically, consider the rather loose but convenient weight mappingFix any j, k, t, and first notewhereby (d/ dt)ϕ k,j = 0 when the argument to ϕ is not in [0, 1], and otherwiseAnalyzing the two bracketed terms separately, the first (the coefficient to ∥v∥ 2 ) is nonnegative since the term in parentheses is a rescaling of the argument to ϕ within ϕ k,j , which was assumed in [0, 1],The second bracketed term (the coefficient of a 2 j ) is more complicated. To start, fix any example (x i , y i ), define z i := x i y i for convenience, and define and orthogonal decomposition z1, but more importantly c ⊥ /c ≤ ϵ/2 by Assumption 3.1. Similarly, define u j := a j v j for convenience, and additionally an orthogonal decomposition u j = qβ k + 1 -q 2 u ⊥ , which made use of ∥ u j ∥ = ∥ v j ∥ = 1, and note 1 -With this notation in hand, the second term becomesas desired: dϕ k,j / dt ≥ 0 for every pair (k, j), meaning eq. (D.1) has been established, and the proof is complete.To close, the proof of Theorem 3.3.Proof of Theorem 3.3. As in the proof of Theorem 3.2, the method of proof will be to construct a potential function Φ and then verify the conditions on Lemma B.7 with the choices τ = 0 and γ := γ gl /2 = γ gl -ϵ where ϵ := γ gl /2 throughout the proof, and to then apply Lemma C.6 to obtain the test error bound. Throughout the proof, use W = ((a j , b k )) m j=1 to denote the full collection of parameters, even in this scalar parameter setting, and define b j := sgn(b j ) in mimicry of a j and v j . To develop Φ, a few other properties must first be checked.

