DEPTH SEPARATION WITH MULTILAYER MEAN-FIELD NETWORKS

Abstract

Depth separation-why a deeper network is more powerful than a shallower onehas been a major problem in deep learning theory. Previous results often focus on representation power. For example, Safran et al. ( 2019) constructed a function that is easy to approximate using a 3-layer network but not approximable by any 2-layer network. In this paper, we show that this separation is in fact algorithmic: one can learn the function constructed by Safran et al. ( 2019) using an overparameterized network with polynomially many neurons efficiently. Our result relies on a new way of extending the mean-field limit to multilayer networks, and a decomposition of loss that factors out the error introduced by the discretization of infinite-width mean-field networks.

1. INTRODUCTION

One of the mysteries in deep learning theory is why we need deeper networks. In the early attempts, researchers showed that deeper networks can represent functions that are hard for shallow networks to approximate (Eldan & Shamir, 2016; Telgarsky, 2016; Poole et al., 2016; Daniely, 2017; Yarotsky, 2017; Liang & Srikant, 2017; Safran & Shamir, 2017; Poggio et al., 2017; Safran et al., 2019; Malach & Shalev-Shwartz, 2019; Vardi & Shamir, 2020; Venturi et al., 2022; Malach et al., 2021) . In particular, seminal works of Eldan & Shamir (2016) ; Safran et al. (2019) constructed a simple function (f * (x) = ReLU(1 -∥x∥)) which can be computed by a 3-layer neural network but cannot be approximated by a 2-layer network. However, these results are only about the representation power of neural networks and do not guarantee that training a deep neural network from reasonable initialization can indeed learn such functions. In this paper, we prove that one can train a neural network that approximates f * (x) = ReLU(1 -∥x∥) to any desired accuracy -this gives an algorithmic separation between the power of 2-layer and 3-layer networks. To analyze the training dynamics, we develop a new framework to generalize mean-field analysis of neural networks (Chizat & Bach, 2018; Mei et al., 2018) to multiple layers. As a result, all the layer weights can change significantly during the training process (unlike many previous works on neural tangent kernel or fixing lower-layer representations). Our analysis also gives a decomposition of loss that allows us to decouple the training of multiple layers. In the remainder of the paper, we first introduce our new framework for multilayer mean-field analysis, then give our main result and techniques. We discuss several related works in the algorithmic aspect for depth separation in Section 1.3. Similar to standard mean-field analysis, we first consider the infinite-width dynamics in Section 3, then we discuss our new ideas in discretizing the result to a polynomial-size network (see Section 4).

1.1. MULTI-LAYER MEAN-FIELD FRAMEWORK

We propose a new way to extend the mean-field analysis to multiple layers. For simplicity, we state it for 3-layer networks here. See Appendix A for the general framework. In short, we break the middle layer into two linear layers and restrict the size of the layer in between. More precisely, we  f (x) = 1 m 2 a ⊤ 2 σ(W 2 F (x)), F (x) = 1 m 1 A 1 σ(W 1 x), where W 1 ∈ R m1×d , A 1 ∈ R D×m1 , W 2 ∈ R m2×D a 2 ∈ R m2 are the parameters, and F (x) ∈ R D represents the hidden feature. See Figure 1 for an illustration. Later we will refer to the step of x → F (x) as the first layer and F (x) → f (x) as the second layer, even though both of them actually are two-layer networks. In the infinite-width limit, we will fix hidden feature dimension D and let the number of neurons m 1 , m 2 go to infinity. Then, we get the infinite-width network f (x) = E (a2,w2)∼µ2 a 2 σ(w 2 • F (x)), F i (x) = E (a1,w1)∼µ1,i a 1 σ(w 1 • x), ∀i ∈ [D], where (µ 1,i ) i∈ [D] are distributions over R 1+d with a shared marginal distribution over w 1 , and µ 2 is a distribution over R 1+D . Note that, unlike the formulation in Nguyen & Pham (2020) , here the hidden layers are described using distributions of neurons, whence are automatically invariant under permutation of neurons, which is one of the most important properties of mean-field networks. One can choose µ 1 , µ 2 to be empirical distributions over finitely many neurons to recover a finite-width network. In fact, we will do so in most parts of the paper so that our results apply to finite-width networks of polynomially many neurons. The network can be viewed as a 3-layer network with intermediate layer W 2 A, which is low rank. This is reminiscent of the bottleneck structure used in ResNet (He et al. (2016) ) and has also been used in previous theoretical analyses such as Allen-Zhu & Li (2020) for other purposes. Learner network Now we are ready to introduce the specific network that we use to learn the target function. We set D = 1 and couple a 1 with w 1 .    F (x) = F (x; µ 1 ) := E w∼µ1 {∥w∥ σ(w • x)} , f (x) = f (x; µ 2 , µ 1 ) := E (w2,b2)∼µ2 σ(w 2 F (x; µ 1 ) + b 2 ). (1) Here, σ is the ReLU activation, and µ 1 ∈ P(R d ) and µ 2 ∈ P(R 2 ) are distributions encoding the weights of the first and second hidden layers, respectively. We multiply each first layer neuron by ∥w∥ to make F more regular. This 2-homogeneous parameterization is also used in Li et al. (2020) and Wang et al. (2020) . In most parts of the paper, µ 1 and µ 2 are empirical distributions over polynomially many neurons. We use µ 1 , µ 2 to unify the notations in discussions on infinite-and finite-width networks. Restricting the intermediate layer to have only one dimension (D = 1) is sufficient as one can learn x → α ∥x∥ for some α ∈ R with the first layer F (x) and α ∥x∥ → σ(1 -∥x∥) with the second layer. For the network that computes F (x), we do not need a bias term as the intended function is homogeneous in x. Though we restrict the first layer to be positive, it does not restrict the representation power of the network as the second layer can be either positive or negative. For the second layer, even though a single neuron is sufficient, we follow the framework and overparameterize the network.

1.2. MAIN RESULT AND OUR TECHNIQUES

Our main result applies the framework in the previous section to the function constructed in Safran et al. (2019) (see details in Section 2). Informally, we prove:foot_0 Theorem 1.1 (Main result, Informal). Given the learner network defined in (1) with input dimension d, for any ϵ > 0, we can choose layer widths as m 1 = poly(d, 1/ϵ), m 2 = Θ(1) so that, with probability at least 1 -1/ poly(d, 1/ε) over random initialization, running a simple variant of gradient flowfoot_1 reduces the loss L := E x (f * (x) -f (x)) 2 /2 to ε within T = poly(d, 1/ϵ) time. This result shows that one can train a multilayer neural network to learn the function ReLU(1-∥x∥) that cannot be approximated by any 2-layer network. There are some technical details caused by the choice of a heavy-tail input distribution in Safran et al. (2019) which we discuss in Section 2. To prove such a result, we first characterize the infinite-width dynamics (see Section 3). In particular, we show that in the infinite-width dynamics, the first layer will always compute a multiple of ∥x∥, while the second layer will behave like a single neuron. However, it is often difficult to discretize such an infinite-width analysis to a polynomial-width network. The main difficulty is in the potential amplification of error in the network: if at the beginning, the first layer is δ-close to computing a multiple of ∥x∥, this δ value can potentially increase exponentially during the training process (Mei et al. (2018) ). Given the large polynomial training time for our dynamics, this exponential increase would not be acceptable. To fix this issue, we partition the analysis into two phases, and for the time-consuming second phase, we rely on a decomposition of the loss function: L := 1 2 E x∼D (f * (x) -f (x)) 2 ≈ 1 2 E x (f * (x) -f (x)) 2 + w2 2 2 E x ( F (x) -F (x)) 2 . (2) Here F (x) is a multiple of ∥x∥ that is close to the actual first-layer output F (x), f (x) is the output of the network if the first layer is replaced by F (x) -that is, if the first layer actually computes a multiple of ∥x∥ (see (5) for precise definition). The first term therefore characterizes the loss conditioned on a perfect first-layer; while the second term characterizes the difference between the first-layer output and a multiple of ∥x∥. We show that the gradients of these two terms do not affect each other, at least approximately. Therefore, we can view the training process as simultaneously doing two things: minimizing the loss given a good first-layer representation (reducing first term), and making first-layer output closer to a multiple of ∥x∥ (reducing second term). We believe such a decomposition highlights how the lower-layer in the neural network receives useful gradient information to learn good representation for this particular objective.

Algorithmic aspect of depth separation

There have been other works that add algorithmic insights into depth separation. Allen-Zhu & Li (2020) showed that multi-layer quadratic networks can learn certain target functions in a hierarchical way, which cannot be learned by any kernel methods or shallow neural networks. Our work deals with more standard neural network architectures and target functions. A concurrent work Safran & Lee (2021) considers a similar problem as ours, where they show that GD with a certain three-layer network can learn the ball indicator which is not approximable by any two-layer network. Conceptually the main difference between our results lies in the training dynamics -the first layer of Safran & Lee (2021) is fixed while we train both layers. This leads to very different training dynamics and proof techniques. Overparametrized Neural Networks One line of works studied the optimization of overparameterized neural network which couples the training dynamics to kernel regression with neural tangent kernel (NTK) (e.g., Jacot et al., 2018; Allen-Zhu et al., 2018b; Du et al., 2018) . However, it is shown that neural network behaves like kernel methods in NTK regime, and several lower bounds have been developed (Yehudai & Shamir, 2019; Wei et al., 2019; Ghorbani et al., 2019; 2020) . Our training dynamics is not in the NTK regime as all the weights change significantly. Another line of works studied the optimization of overparameterized neural network in the mean-field limit (Mei et al., 2018; Chizat & Bach, 2018; Nitanda & Suzuki, 2017; Wei et al., 2019; Rotskoff & Vanden-Eijnden, 2018; Sirignano & Spiliopoulos, 2020) . Chizat et al. (2019) showed that the parameters can move away from its initialization in mean-field regime and learn useful features, which is different from NTK regime. However, most of the existing works require exponential/infinite number of neurons and do not provide a polynomial convergence rate. See more discussions in Appendix A. Multi-layer mean-field Although mean-field analysis has been successful for the optimization of two-layer overparameterized network, it is not easy to extend it to multiple-layer network since the width of intermediate layer goes to infinity. Mildly overparameterized neural networks Recently there are many works that consider the problem of learning certain target function with mildly overparameterized (polynomial size) network (Allen-Zhu et al., 2018a; Allen-Zhu & Li, 2019; Bai & Lee, 2019; Dyer & Gur-Ari, 2019; Woodworth et al., 2020; Bai et al., 2020; Huang & Yau, 2020; Chen et al., 2020; Li et al., 2020; Wang et al., 2020; Zhou et al., 2021) . In particular, these works are different from the typical meanfield analysis where usually the infinite-width network are considered, or the typical NTK analysis where neural network behaves like kernel method. Our work is in a similar direction, but we need new insights to extend the discretization to our new multilayer framework.

2. PRELIMINARIES

In this section, we discuss the additional technical conditions for the input distributions in Safran et al. (2019) , and how we deal with this in the training process. Notations For a vector x, we let ∥x∥ denote its Euclidean norm. We use a = b ± c as a shorthand for the condition a ∈ [b -|c|, b + |c|]. For a distribution µ, we write v ∈ µ for the condition v is in the support of µ. Other notations we use are mostly standard. We usually use v 1 and w 1 to denote a first layer neuron, and (v 2 , r 2 ) and (w 2 , b 2 ) to denote a second layer neuron. Keeping two sets of notations for neurons is intentional. When we are taking expectations over neurons, we use w 1 and (w 2 , b 2 ). When considering a single neuron, we use v 1 and (v 2 , r 2 ). For vectors, we write v := v/ ∥v∥. We will use E x as a shorthand for E x∼D when it is clear from the context. We also use v ∈ µ as a shorthand for v ∈ supp(µ).

Target Function and Input Distribution

The target function we consider is f * (x) = σ(1-∥x∥), where σ : R → R is the ReLU activation. To describe the input distribution, first, we define φ(x) := R d ∥x∥ d/2 J d/2 (2πR d ∥x∥), where R d = 1 √ π (Γ(d/2+1 )) 1/d and J ν is the Bessel function of the first kind of order ν. Let α, β > 0 be the universal constants from Safran et al. (2019) (cf. the proof of Theorem 5). We assume the inputs x ∈ R d are sampled from the distribution D whose density is given by x → ( √ dβα) d φ 2 ( √ dβαx). It has been verified in Eldan & Shamir (2016) and Safran et al. (2019) that this is indeed a valid probability distribution. Also, note that D is a spherically symmetric distribution. For more properties of D, see Appendix B.2. By Theorem 5 of Safran et al. (2019) , no two-layer networks of width poly(d, 1/ε) can approximate f * to accuracy ε in L 2 (D).foot_2 This distribution is heavy-tailed in the sense that E x∼D [∥x∥ 2 ] is undefined. The choice of such heavy-tailed distribution is mostly required for proving the lower bound. Our training result holds for most reasonable spherically symmetric distributions.

Training Algorithm and Main Result

We use gradient flow with clipping over MSE loss to train a polynomial-size network. We write the loss as L = L(µ 1 , µ 2 ) = 1 2 E x∼D (f * (x) -f (x)) 2 =: E x L(x), (3) Define S(x) = (f * (x) -f (x)) E w2,b2 {σ ′ (w 2 F (x) + b 2 )w 2 }. One can verify that the dynamics of the neurons are given by          v1 = E x∼D Π Rv 1 [S(x) (v 1 σ(v 1 • x) + ∥v 1 ∥ σ ′ (v 1 • x)x)] , v2 = E x∼D Π Rv 2 [(f * (x) -f (x))σ ′ (v 2 F (x) + r 2 )F (x)] , ṙ2 = E x∼D Π Rr 2 [(f * (x) -f (x))σ ′ (v 2 F (x) + r 2 )] , where Π R stands for the projection to the ball of radius R, and 1) are the projection threshold. We add these additional gradient clipping because without them the gradients are not well-defined due to the heavy-tailed property of the distribution D. Note that gradient clipping is indeed widely used in practice to avoid exploding gradients (Pascanu et al., 2013; Zhang et al., 2020) . In fact, we believe our optimization result without using gradient clipping would still be true for a general spherically symmetric distribution D as long as it is more regular. R v1 = Θ(d), R v2 = Θ(d 3 ), R r2 = Θ( To initialize the learner network, we use Unif(σ 1 S d-1 ) to initialize the first layer weights w 1 , N (0, σ 2 2 ) for the second layer weights w 2 , and choose all second layer bias b 2 to be σ r , where σ 1 , σ 2 , σ r are some small positive real numbers. We initialize w 1 on the sphere instead using a Gaussian only for technical convenience. We initialize the bias term to be a small positive value so that all second layer neurons are activated at initialization to avoid zero gradient. Now we are ready to give our main result. It shows that gradient flow with a polynomial-sized learner network (1) defined in our mean-field framework can learn f * (x) = σ(1 -∥x∥) efficiently, which is not approximable by any two-layer network (Safran et al., 2019) . Theorem 2.1 (Main result). Given the learner network defined in (1) with initialization described above and suppose we run gradient flow, assuming it exists, on this finite-width network with clipping (4) on loss (3). Then, for any ϵ > 0, we can choose m 1 = poly m1 (d, 1/ϵ), m 2 = Θ(1), σ 1 = 1/ poly σ1 (d, 1/ϵ), σ 2 = 1/ poly σ2 (d, 1/ϵ), σ r = Θ(1), R v1 = Θ(d), R v2 = Θ(d 3 ) and R r2 = Θ(1) so that with probability at least 1 -1/ poly(d, 1/ε) over the random initialization, we have loss L ≤ ε within T = poly(d, 1/ϵ) time.

3. THE INFINITE-WIDTH DYNAMICS

Our proof consists of analyzing the dynamics of the infinite-width mean-field network and controlling the discretization error. In this section, we characterized the infinite-width dynamics. For ease of presentation, we pretend there is no projection and the gradients are well-defined in this subsection and defer the discussion on handling the projections to Section 4. First, note that both the input distribution D and the infinite-width network are spherically symmetric. That is, for any x, x ′ ∈ R d with ∥x∥ = ∥x ′ ∥, the density/function value are the same. Any spherically symmetric g : R d → R can be characterized by a function h : [0, ∞) → R which satisfies h(∥x∥) = g(x). For convenience, we will abuse notation to also use g : R → R to denote this function h. Assuming that the distribution µ 1 of the first layer neurons is spherically symmetric, which is true at least at initialization, we can approximate the first layer with a simple function using the following lemma. The proof of it can be found in Appendix B.3. Lemma 3.1. Let µ be a spherically symmetric distribution. We have E w∼µ ∥w∥ σ(w • x) = C Γ E w∼µ ∥w∥ 2 √ d ∥x∥ where C Γ := Γ(d/2) √ d 2 √ πΓ((d + 1)/2) . Note that, as d → ∞, we have C Γ → 1/ √ 2π, so C Γ is universally bounded for all d. This lemma implies that, in the infinite-width limit, we have F (x) = α ∥x∥ for some real α > 0, at least at initialization. This suggests defining the infinite-width approximation as: α := C Γ √ d E w1∼µ1 ∥w 1 ∥ 2 , F (x) := α ∥x∥ , f (x) := E (w2,b2)∼µ2 σ(w 2 F (x) + r 2 ). (5) Note that (5) is well-defined no matter µ 1 is infinite-width or not, though only in the infinite-width case will one have F = F . Later in Section 4 we will show that F ≈ F throughout the entire process in the discretization part of the proof. For the infinite-width network, one can imagine that, thanks to the symmetry, as long as µ 1 is spherically symmetric at time t, then no first layer neuron will change its direction and the change in norm is also uniform, i.e., it does not depend on the direction v1 . (See Appendix B.4 for the proof.) As a result, µ 1 will remain spherically symmetric. Formally, one can show that, for any spherically symmetric g : R d → R, we have E x {g(x)σ(v • x)} = C Γ √ d E x {g(x) ∥x∥} ∥v∥ and E x {g(x)σ ′ (v • x)x} = C Γ √ d E x {g(x) ∥x∥} v, where v = v/ ∥v∥. Again, the proof of these two identities can be found in Appendix B.3. Apply these identities to v1 with g ≡ S and one can obtain v1 = 2C Γ √ d E x {S(x) ∥x∥} v 1 . As a result, µ 1 is always a uniform distribution over some sphere. Moreover, we havefoot_3  α = E w1 ∂α ∂w 1 dw 1 dt = 4C 2 Γ d E x {S(x) ∥x∥} E w1 ∥w 1 ∥ 2 = 4C Γ √ d E x {S(x) ∥x∥} α. This implies that the dynamics of the first layer can also be characterized by α alone. This reduces the dynamics of the first layer to a single real number α. That is, the outputs of the first layer depend only on α and x, and the dynamics of α also depend only on α instead of every single neuron w 1 . In other words, we do not need to look at the actual dynamics of w 1 in this infinite-width case. We will later show that the spread of the second layer is always small, hence the second layer can be approximated by α ∥x∥ → σ( w2 α ∥x∥ + b2 ) where ( w2 , b2 ) = E(w 2 , b 2 ). Combining these observations, one can characterize the dynamics of the entire network using three quantities: α, w2 and b2 . We close this section with another interpretation of F , which is going to be handy in Section 4.2. Since we know that, in the idealized case, F should be spherically symmetric. Hence, it makes sense to define the "idealized" F to be the average over the sphere, that is, F (x) = E x ′ ∈∥x∥S d-1 F (x ′ ). Note that in Lemma 3.1, the expectation is taken over the neurons while here it is over the inputs. However, similar to the proof of Lemma 3.1, one can still show that E x ′ ∈∥x∥S d-1 F (x ′ ) = E w∼µ1 E x ′ ∈∥x∥S d-1 ∥w∥ 2 σ( w • x) = C Γ E w∼µ1 ∥w∥ 2 √ d ∥x∥ = α ∥x∥ . In other words, these two derivations are equivalent. In some sense, this means that the infinite-width network can be interpreted as a symmetrization of the actual finite-width network.

4. DISCRETIZING THE DYNAMICS WITH POLYNOMIAL-SIZE NETWORK

In this section, we show how to discretize the infinite-width dynamics to get our main results. See Fig. 2 for simulation results. As we can see, even though the network has a finite width, at any time step, the function f (x) is close to a function of the form x → σ( b2 -w2 α∥x∥), and throughout the training the second layer weights are well-concentrated. Let δ 2 := max (v2,r2),(v ′ 2 ,r ′ 2 ) ∥(v 2 , r 2 ) -(v ′ 2 , r 2 )∥ be the spread of the second layer, we will split the training procedure into two stages. Recall that ( w2 , b2 ) := E (w2,b2)∼µ2 (w 2 , b 2 ). In Stage 1, w2 will decrease to -poly(d)δ 2 . We show that after this condition is true, the projection operators in (4) can be ignored (that is, the corresponding terms never exceed the thresholds, see Lemma 4.1). In Stage 2, we show that the network can fit the target function in polynomial time. One can observe that f ≈ f indeed holds, and the second layer neurons are concentrated around ( w2 , b2 ), which matches our theoretical analysis. Simulation is performed on a finite-width network with widths m 1 = 512, m 2 = 128 and input dimension d = 100.

4.1. STAGE 1: REMOVING THE PROJECTIONS

Our first step shows that after a short amount of time in training, it is OK to ignore the projection operators in (4). To see why the projections can be ignored in certain circumstances, first note that if f ≈ f , second layer neurons concentrate around their mean, b2 = Θ(1) and w2 < 0, then f ≈ σ( w2 α ∥x∥ + b2 ) vanishes outside {∥x∥ ≤ Θ(1/| w2 α|)}, whence the gradients also vanish for those large x. Meanwhile, by upper bounding the norm of the gradients, one can show that in order for the projections to be triggered, it is necessary for ∥x∥ to be large. As a result, when f decreases sufficiently fast, f (x) will reach 0 before ∥x∥ becomes too large. Formally, we have the following lemma, whose proof can be found in Appendix C. Lemma 4.1. Choose the projection threshold R v1 = Θ(d), R v2 = Θ(d 3 ) and R r2 = Θ(1) in (4). Suppose that α = Θ(1/ √ d). Then, the projection operators in ṙ2 , v1 and v2 will no longer be activated if all second layer weights are nonpositive, -w2 > Θ(1)δ 2 for some large constant, and -w2 ≥ Θ(1)/R v2 for some large constant, respectively. Based on this lemma, we further split Stage 1 into three substages. We define T 1.1 to be the first time all second layer weights become negative, and T 1.2 and T 1.3 the first time | w2 | becomes Θ(d)δ 2 and Θ(1/R v2 ), respectively. They represent the end time of Stage 1.1, 1.2, and 1.3, respectively. We require | w2 | to be Θ(d)δ 2 instead of Θ(1)δ 2 at the end of Stage 1.2 so that the starting state of Stage 1.3 is more regular. By definition and Lemma 4.1, after each substage, one more projection can be ignored, and all of them can be ignored after Stage 1. The main lemma of Stage 1 is as follows. Recall that R v1 , R v2 , R r2 are the clipping thresholds. Lemma 4.2 (Stage 1, informal). Define the end time of Stage 1 as T 1 := inf{t ≥ 0 : -w2 (t) = C 1 /R v2 } for some large constant C 1 . Under the assumptions of Theorem 2.1, we have T 1 ≤ poly(d, 1/ε) and the following conditions hold throughout Stage 1. (a) Approximation error of the first layer. For each v 1 ∈ µ 1 , both the tangent movement and the radial spread can be controlled as ∥v 1 (t) -v1 (0)∥ ≤ δ (1) 1,T (t) and ∥v 1 ∥ 2 = (1 ± δ (1) 1,R (t)) E ∥w 1 ∥ 2 , where δ (1) 1,T and δ (1) 1,R are two processes which are always small. (b) Spread of the second layer. For any (v 2 , r 2 ), (v ′ 2 , r ′ 2 ) ∈ µ 2 , ∥(v 2 , r 2 ) -(v ′ 2 , r ′ 2 )∥ is small. (c) Regularity conditions. r 2 = Θ(1) for all (v 2 , r 2 ) ∈ µ 2 , | w2 | = O(1/R v2 ) = O(1/d 3 ) and α = Θ( √ d/R v1 ) = Θ(1/d 1.5 ). The first two conditions mean the approximation f (x) ≈ σ( w2 α ∥x∥ + b2 ) is valid throughout Stage 1 and the third condition describes the shape of f in Stage 1. To maintain these conditions, we use the so-called continuity argument, which can be viewed as a continuous version of mathematical induction. See Appendix B.1 for explanations of this technique.

Published as a conference paper at ICLR 2023

With the approximation F (x) ≈ α ∥x∥ and the fact f (x)σ ′ (v 2 F (x) + r 2 ) = f (x) for most x, we can rewrite the dynamics of v 2 as v2 ≈ E x Π Rv 1 [(f * (x) -f (x))α ∥x∥] . Since f is much flatter than f * , f is still Ω(1) when f * vanishes because of ∥x∥ ≥ 1. As a result, the RHS is always negative. In fact, we show that it is -Θ(α log d). Recall that T 1.2 is the time | w2 | reaches Θ(dδ 2 ). If δ 2 roughly remains constant, the time needed for Stage 1.1 and Stage 1.2 is proportional to the initial δ 2 . Then, we can make the initial δ 2 small by selecting a small enough σ 2 . This also helps control the movement of v 1 and r 2 in Stage 1.1 and Stage 1.2 as their dynamics depend on |w 2 |. One also needs to show that δ 2 cannot increase too much during Stages 1.1 and 1.2 to maintain the approximation f (x) ≈ σ( w2 F (x) + b2 ). Intuitively, this is because for inputs with small ∥x∥, the gradient ∇ v2 L(x) does not depend on (v 2 , r 2 ) itself; for the inputs with a large norm, they cannot contribute too much to the gradient due to gradient clipping. As a result, the dynamics of v 2 are approximately uniform in Stage 1.1 and Stage 1.2, whence the distance between different (v 2 , r 2 ), (v ′ 2 , r ′ 2 ) stays small. The same method does not work in Stage 1.3 as now the target value of w2 no longer depends on δ 2 , and we need a finer analysis for the first layer. Recall that, after Stage 1.2, the projection in v1 can be ignored. Therefore, we can decompose v1 along the radial and tangent direction as v1 = Rad( v1 ) + Tan( v1 ) = ⟨ v1 , v1 ⟩ v1 + (I -v1 v⊤ 1 ) v1 = 2 E x {S(x)σ(v 1 • x)} + ∥v 1 ∥ E x S(x)σ ′ (v 1 • x)(I -v1 v⊤ 1 )x . Then, we write S(x) ≈ (f * (x) -f (x)) w2 = (f * (x) -f (x)) w2 + ( f (x) -f (x)) w2 . The terms related to f * -f is essentially what one should expect to have in the infinite-width dynamics. For those terms, the radial movement is uniform and tangent movement is 0. Then, we bound terms related to f -f using the radial spread and tangent movement of the first layer to obtain d dt δ (1) 1,R + δ (1) 1,T ⪅ O(1) d 2.5 δ (1) 1,R + δ (1) 1,T (cf. Lemma C.16). Though, with this bound, the error can grow exponentially fast (exp(t/d 2.5 )), this is sufficient since Stage 1.3 only takes O(d 1.5 ) time.

4.2. STAGE 2: FITTING THE TARGET FUNCTION

The goal of Stage 2 is for the gradient flow to converge to a point with loss at most ε in polynomial time. The main difficulty in this stage is that we need to bound the approximation error of the first layer more carefully, as Stage 2 is potentially long and the brute-force estimations used in Stage 1 is too loose towards the end of training. We write F := F/α and measure the approximation error using F | S d-1 -1 and F -∥•∥ 2 L 2 . Strictly speaking, for the L 2 error, we only consider those x with ∥x∥ ≤ Θ(1/| wα|) = poly(d) since otherwise it can be ill-defined. This is valid because, as we have discussed earlier, f vanishes for large x. In Stage 2, E x always means E ∥x∥≤Θ(1/| w2α|) and, for the simplicity of presentation, we usually do not explicitly state this. The main result of Stage 2 is as follows.  Both F -∥•∥ L 2 and F | S d-1 -1 L ∞ are small. (b) Spread of the second layer. max (v2,r2),(v ′ 2 ,r ′ 2 ) ∥(v 2 , r 2 ) -(v ′ 2 , r 2 )∥ does not grow. (c) Regularity conditions. The shape of f is similar to the one shown in Figure 2 . As we mentioned, the main technical challenge is to bound the approximation error of the first layer. The overall strategy is to first show that, in Stage 2, the L 2 error barely grows and then show that, as long as the L 2 error is small, the L ∞ error can also be controlled. Unlike Stage 1, | w2 α| is fairly large in Stage 2 and, as a result, the first layer can receive some signal from the loss function. Intuitively, this signal should push the first layer to become closer to a multiple of ∥x∥ as that is what the global optimal solution would do. Formally, we first show the following approximation: L ≈ 1 2 E x (f * (x) -f (x)) 2 + w2 2 2 E x ( F (x) -F (x)) 2 , in the sense that the gradients ∇ v1 of both sides are approximately the same, where f (x) is defined as E (w2,b2)∼µ2 σ(w 2 F (x) + b 2 ). The first term of ( 6) measures the distance between the target function and the infinite-width network and the second term measures the approximation error of the first layer. In some sense, one can view this formula as a bias-variance decomposition for discretizing mean-field networks. With this approximation in hand, we then show that, thanks to the 2-homogeneity of F , the first term, after certain normalization, does not affect the approximation error of the first layer. Meanwhile, since we are following the gradient flow, the second term can only decrease the approximation error. To establish (6), we first decompose the loss function as L = 1 2 E x (f * (x) -f (x)) 2 + 1 2 E x ( f (x) -f (x)) 2 + E x (f * (x) -f (x))( f (x) -f (x)) =: L 1 + L 2 + L 3 . We claim that L 2 is approximately the second term of ( 6) and the third term is approximately 0foot_4 . Let X 1 be the largest spherically symmetric set on which v 2 F (x) + r 2 > 0 for all (v 2 , r 2 ) ∈ µ 2 . We show that those x outside X 1 contribute a little. Therefore, we can rewrite L 2 as L 2 ≈ 1 2 E X1 E w2,b2 (w 2 F (x) + b 2 ) -E w2,b2 (w 2 F (x) + b 2 ) 2 = w2 2 2 E X1 ( F (x) -F (x)) 2 ≈ w2 2 2 E x ( F (x) -F (x)) 2 . Similarly, we can rewrite L 3 as L 3 ≈ w2 E x (f * (x) -f (x))( F (x) -F (x)) . Recall from Section 3 that F (x) = E x ′ ∈∥x∥S d-1 F (x). With this in mind, one can easily verify that, for any spherically symmetric function g : R d → R, E x {g(x)F (x)} = E x g(x) F (x) . Setting g = f * (x) -f (x) gives L 3 ≈ 0. Combine these two estimations together and we obtain (6). Provided that the L 2 error is always small, we show that, up to some higher order terms, d dt F ( x) ≲ O(d 3 ) F -∥•∥ 2 L 2 , ∀ x ∈ S d-1 . In words, the change of d dt F (x) can be bounded by the L 2 error. Hence, F | S d-1 -1 L ∞ is always small as long as we choose a sufficiently large m 1 so that F (x)| x∈S d-1 is close to 1 at initialization. This should not be a surprise since, after all, in the infinite-width dynamics F (x)| x∈S d-1 = 1. The formal proof of the above argument can be found in Section D.2. Given that the approximation error can be controlled, one can then derive a convergence rate using the infinite-width dynamics. See Section D.3 for details.

5. CONCLUSION

In this paper we give a new framework for extending mean-field limit to multilayer networks, and use this framework to show that three-layer networks can learn a function that is not approximable by two-layer networks. There are still many open problems: for the current objective the loss is spherically symmetric so the first-layer neurons don't move much tangentially, what if the function is instead σ(1 -∥P S x∥) where P S is projection to some unknown subspace? How about functions that require an intermediate layer of size more than 1? Can one generalize the saddle point analysis to deeper networks? We hope this work will be a starting point for understanding how deep neural networks can learn useful features.

A MULTI-LAYER MEAN-FIELD NETWORKS

In this section, we first briefly review existing theories of two-layer mean-field networks, and then introduce our framework for multi-layer mean-field networks.

A.1 TWO-LAYER NETWORKS AND PERMUTATION INVARIANCE

A two-layer network f of width m can usually be represented byfoot_5  f (x; W , a) = 1 m a ⊤ σ(W x) = 1 m m i=1 a i σ(w i • x). ( ) where W ∈ R m×d is the weight matrix of the hidden layer and a ∈ R m the output weights. Let µ be the empirical distribution of {(a i , w i )} m i=1 ⊂ R d+1 . Then, we can write f (x; µ) = E (a,w)∼µ {aσ(w • x)} . ( ) By allowing µ to be an arbitrary sufficiently regular distribution over R d , we obtain a neural network, represented by ( 8), that can contain infinitely many neurons. To describe the gradient flow of this infinite-width network, it suffices to assign a vector field to R d+1 that describes how each neuron (a, w) ∈ R d+1 should move at time t. One simple heuristic way to do so is to first compute the gradient in the finite-width case and then replace all summations with expectations as in ( 8) and treat the gradient as a vector field. We now illustrate the idea under realizable setting and with the MSE loss L = 1 2 E x (f * (x) -f (x)) 2 . The theory can be generalized to much more general settings and can be formally justified using the theory of Wasserstein gradient flow. Readers can refer to, for example, Chizat & Bach (2018) and Mei et al. (2018) for details. For a finite-width network (7), the gradient of L w.r.t. a neuron (a k , w k ) is -m∇ a k L = E x {(f * (x) -f (x; W , a))σ(w k • x)} , -m∇ w k L = E x {(f * (x) -f (x; W , a))a k σ ′ (w k • x)x} . Replace f (x; W , a) with f (x; µ), treat (a k , w k ) as a generic neuron, and we obtain a vector field ∇ : R d+1 → R d+1 -∇(a, w) := E x (f * (x) -f (x; µ)) σ(w • x) aσ ′ (w • x)x . At each time t, we update the neurons in µ according to -∇. One of the most important properties of this mean-field formulation is that it factors out the permutation invariance of neurons. That is, we can permute (a 1 , w 1 ), . . . , (a m , w m ) without changing the output of the network. However, when we treat training as an optimization problem over the space of (a, W ), i.e., R m × R m×d , permuting (a i , w i ) entirely changes (a, W ). On the other hand, if we describe the network using a distribution µ over R d+1 , then it is automatically permutation invariant. Note that this is not restricted to infinite-width networks. When we choose µ to be an empirical distribution over finitely many neurons, we recover a finite-width network without breaking the permutation invariance. Published as a conference paper at ICLR 2023 A.2 MULTI-LAYER MEAN-FIELD NETWORKS Unfortunately, the above strategy cannot be directly generalized to multi-layer networks. Consider the three-layer network f (x; a, W 2 , W 1 ) = 1 m 2 a ⊤ σ (W 2 h(x; W 1 )) , h(x; W 1 ) = 1 m 1 σ(W 1 x), where a ∈ R m2 , W 2 ∈ R m2×m1 , W 1 ∈ R m1×d . One can still write f (x; a, W 2 , W 1 ) = 1 m 2 m2 i=1 a i σ(w 2,i • h(x; W 1 )) = E (ai,w2)∼µ2 {aσ(w 2 • h(x; W 1 ))} . However, now µ 2 is a distribution over R m1 , and if m 1 → ∞, it will become a distribution over R ∞ , which is not readily defined. One way to resolve this issue is to view W 2 as a function from We now present a formulation that does factor out the permutation invariance of neurons, and it is built upon composing a sequence of vector-valued two-layer networks. As a first step, we consider a two-layer network with D-dimensional outputs: [m 2 ] × [m 1 ] to R f (x; A, W ) = 1 m Aσ(W x), where A ∈ R D×m and W ∈ R m×d . For each index i ∈ [D], we still have f i (x; A, W ) = 1 m m j=1 a i,j σ(w j • x) = E (a,w)∼µi {aσ(w • x)} , where µ i is the empirical distribution of {(a i,j , w j )} j∈[m] ⊂ R d+1 . Range over i and we obtain the output vector of this network. For two-layer networks with scalar outputs, in order to obtain its mean-field counterpart, it suffices to allow µ to take a general distribution over R × R d . This, however, is not the case for networks with vector outputs as the W parts of µ i are coupled. Hence, we need to additionally impose the constraint that all (µ i ) i∈ [D] share the same second margin, that is, π 2 #µ i = µ W for some distribution µ W over R d and all ∈ [D], where π 2 : R × R d → R d is the projection that takes (a, w) to w. Intuitively, this condition says that they share the same first layer weights W . We formalize this idea in the following definition. Definition A.1. Let (µ i ) i∈[D] be D sufficiently regular 7 distributions over R × R d . We call (µ i ) D i=1 an admissible configuration of dimension (D, d) if there exists a measure µ W over R d such that π 2 #µ i = µ W holds for all i ∈ [D]. Remark. Note that here, by a neuron, we mean a (D + d)-dimensional vector (a 1 , . . . , a D , w). In the finite-width network (9), this corresponds to a row in W and the corresponding column in A. This point of view is important when deriving the infinite-width gradient flow since, as in the twolayer case, the vector field at the position of a certain neuron can only depend on the other neurons as a whole. ♣ To complement the discussion, here we consider the problem that, given an admissible infinite-width configuration (µ i ) i∈ [D] , how to obtain a finite-width network with m neurons. For a scalar-valued mean-field network characterized by µ, it suffices to generate m samples from µ. For a vectorvalued network, the procedure is slightly different. We first sample a weight vector w from the shared margin µ W . Then, for each i ∈ [D], we generate a real number a i conditioning on w. This gives us a neuron (a 1 , . . . , a D , w) ∈ R D × R d . Repeat this procedure m times and we obtain a finite-width network with m neurons. We formally define two-layer vector-valued mean-field networks as follows. Definition A.2. Given an admissible (µ i ) i∈ [D] , the two-layer vector-valued network it defines is F (x; µ 1 , . . . , µ D ) = (F 1 (x; µ 1 ), . . . , F D (x; µ D )), where F i (x; µ i ) = E (a,w)∼µi {aσ(w • x)} , ∀i ∈ [D]. Now, we are ready to define a multi-layer mean-field network. Basically, a multi-layer mean-field network is a composition of a sequence of two-layer vector-valued networks (10). Definition A.3. Let L ≥ 1 be an integer. Let D (1) , . . . , D (L) be a sequence of positive integers and put D (0) = d. For each l ∈ [L], let (µ (l) i ) i∈[D l ] be an admissible configuration of dimension (D (l) , D (l-1) ). The L-layer mean-field network f defined by the configuration Θ := ((µ (l) i ) i∈[D l ] ) l∈[L] is defined recursively as f (x; Θ) = F (L) (x; Θ), F (l) (x; Θ) := F F (l-1) (x; Θ); µ (l) 1 , . . . , µ (l) D l , ∀l ≥ 1, F (0) (x; Θ) := x, ( ) where F is the two-layer mean-field network given by ( 10). Example As an example, we consider the case L = 3 here. In this case, the finite-width network corresponding to (11) is f (x; A 2 , W 2 , A 1 , W 1 ) = 1 m 2 A 2 σ W 2 1 m 1 A 1 σ(W 1 x) , which is exactly the usual multi-layer network used in practice except the normalizing terms 1/m 2 , 1/m 1 and an additional matrix A 1 ∈ R D1×m1 . This matrix compresses an m 1 dimensional feature vector to a D 1 dimensional one, where D 1 is an integer that does not go to ∞. It is a reminiscent of the bottleneck structure used in ResNet (He et al. (2016) ). Remark. Note that this formulation is indeed invariant under permutation of each layer's neurons. However, it does not factor out all permutation invariance of a deep network. For example, one can permute the columns of W 1 and adjusting A 1 , W 2 , A 2 accordingly without changing the output of the network. In some sense, this corresponds to permuting the entires of the hidden feature F (1) . We believe it is not necessary or useful to factor out this symmetry since, after all, even in the two-layer case, we do not permute the entries of the inputs x. ♣ Finally, we consider the problem of formulating mean-field gradient flow so that it matches the usual gradient flow. The idea is simple: We compute the gradient in the finite-width setting and then replace summations with integrals. For the ease of presentation, we consider a three-layer network and the MSE loss. Again, this framework can be easily generalized to deeper networks and other loss functions. We write f (x) = f (x; a, W 2 , V 1 , W 1 ) = 1 m 2 a ⊤ σ (W 2 F (x; V 1 , W 1 )) , F (x) = F (x; V 1 , W 1 ) = 1 m 1 V 1 σ(W 1 x), L = L(a, W 2 , V , W 1 ) = 1 2 E x (f * (x) -f (x; a, W 2 , V , W 1 )) 2 , where a ∈ R m2 , W 2 ∈ R m2×D , V 1 ∈ R D×m1 , W 1 ∈ R m1×d . We have -m 2 ∇ ai L = E x {(f * (x) -f (x))σ(w 2,i • F (x))} , ∀i ∈ [m 2 ], -m 2 ∇ w2,i L = E x {(f * (x) -f (x))a i σ ′ (w 2,i • F (x))F (x)} , ∀i ∈ [m 2 ], -m 1 ∇ v1,i L = E x    (f * (x) -f (x)) 1 m 2 m2 j=1 a j σ ′ (w 2,j • F (x))w 2,j σ(w 1,i • x)    , ∀i ∈ [m 1 ], -m 1 ∇ w1,i L = E x    (f * (x) -f (x)) 1 m 2 m2 j=1 a j σ ′ (w 2,j • F (x)) ⟨w 2,j , v 1,i ⟩ σ ′ (w 1,i • x)x    , ∀i ∈ [m 1 ]. Replace summations with integrals and we obtain -∇(a,w2) = E x (f * (x) -f (x)) σ(w 2 • F (x)) aσ ′ (w 2 • F (x))F (x) , -∇(v1,w1) = E x (f * (x) -f (x)) E (a,w2)∼µ2 aσ ′ (w 2 • F (x)) σ(w 1 • x)w 2 ⟨w 2 , v 1 ⟩ σ ′ (w 1 • x)x . Namely, at each step t, we update the second layer neurons (a, w 2 ) with -∇(a,w2) , and first layer neurons (v 1 , w 1 ) with -∇(v1,w1) . Note that, unlike many other multi-layer mean-field frameworks, we do not introduce any notion of paths. The dynamics of each first layer neuron depends on the second layer as a whole as we take expectation over µ 2 in ( 12). The same is also true for second layer neurons. In some sense, the additional matrix V 1 decouples the dynamics of the first and second layer neurons.

B PRELIMINARIES B.1 INDUCTION HYPOTHESIS AND CONTINUITY ARGUMENT

We extensively use the continuos-time version of mathematical induction in our proof, which is also called the continuity argument. We briefly discuss this technique in this subsection and explain some conventions we employ in the writing of the proof. One may refer to, for example, Chapter 1.3 of Tao (2006) for details. Similar to the discrete-time induction argument, the goal is to maintain a collection of conditions, which we call the Induction Hypothesis, throughout a period of time (cf. Induction Hypothesis C.2 and Induction Hypothesis D.1). There are mainly two types of conditions. The first type has the form "certain process A t is bounded by another process B t ". In the proof, A t is usually the error we want to control and B t an non-decreasing process representing the corresponding upper bound. To maintain this type of condition, it suffices to show that A t ≤ B t at initialization and Ȧt ≤ Ḃt as long as the Induction Hypothesis is true. For this type of condition, usually we also have an upper bound for B t , say, B t ≤ B ∞ . The most rigorous way to maintain these bounds is to argue by contradiction. Let T be the minimum between the time T 1 the process ends and the time T 2 this bound first get violated. By definition, the Induction Hypothesis holds for any t ≤ T . Using the Induction Hypothesis, one can then derive an upper bound T ′ on T 1 , which then leads to an upper bound on T . Then, all we need to show is that B T ′ is smaller than B ∞ so that T is attained by T 1 instead of T 2 . For the ease of presentation, for this type of conditions, instead of arguing by contradiction explicitly, we will simply show that, provided that the Induction Hypothesis is true over [0, T 1 ], then B T1 ≤ B ∞ holds. The second type has the form "certain process C t is bounded some value D". Here, C t is usually some quantity related to the shape of the learner function such as w2 and α. In order to maintain, say, C t ≤ D, we show that when C t ∈ [D -ε, D], we have Ċt < 0. This implies that, as long as C t is continuous, this implies C t can never reach D.

B.2 PROPERTIES OF THE INPUT DISTRIBUTION

In this subsection, we derive some basic properties of the input distribution that will be useful in later analysis. The following lemma gives the distribution of ∥x∥ and its tail bound. Lemma B.1. Let x ∼ D and let ∥D∥ denote the distribution of ∥x∥. We have ∥D∥ (r) = d r J 2 d/2 (2πR d βα √ dr) = O 1 r 2 , ∀r > 0. As a result, we have the tail bound: for all R > 0, P[∥x∥ ≥ R] ≤ O (1/R). We now give some regularity conditions on the input distribution that will be used in our proof. Roughly speaking, it shows that the distribution is heavy-tailed and still has large enough mass for ∥x∥ ∈ [0, 1] Lemma B.2 (Regularity conditions on input distribution). For the input distribution D, we have (a) E ∥x∥≤0.99 ∥x∥ = Θ(1). (b) E x∼D f * (x) = Ω(1). (c) E ∥x∥≤Ω(d) ∥x∥ ≥ Θ(log d) and E ∥x∥≤poly(d) ∥x∥ ≤ Θ(log(d)). Proof of Lemma B.1. Recall that the input distribution of x is βα √ d d φ 2 (βα √ dx), where α, β > 0 are the universal constants from Safran et al. (2019) (cf. the proof of Theorem 5), Eldan & Shamir (2016) ) and J ν is the Bessel function of the first kind of order. Note that since φ only depends on ∥x∥, we can abuse the notation to use φ(r) to denote φ(x) with ∥x∥ = r. φ(x) = R d ∥x∥ d/2 J d/2 (2πR d ∥x∥), x ∈ R d , R d = 1 √ π (Γ(d/2 + 1)) 1/d = Θ( √ d) (Lemma 5 in For any test function g : R → R, we have E x∼D [g(∥x∥)] = R d g(∥x∥) βα √ d d φ 2 (βα √ dx)dx = βα √ d d S d-1 ∞ 0 g(r)φ 2 (βα √ dr)r d-1 dr, where S d-1 = 2π d/2 /Γ(d/2 ) is the surface of unit ball S d-1 . Therefore, we have the density of ∥x∥ with ∥x∥ = r is βα √ d d S d-1 φ 2 (βα √ dr)r d-1 = 2π d/2 βα √ d d Γ(d/2) R d d βα √ dr d J 2 d/2 (2πR d βα √ dr)r d-1 = d r J 2 d/2 (2πR d βα √ dr) =O 1 r 2 , where we use the fact that J ν (z) = O(1/ √ z) (Krasikov (2006) ). Then, it is easy to see that P(∥x∥ ≥ R) = O(1/R). Proof of Lemma B.2. (a) It is easy to see the upper bound E ∥x∥≤0.99 ∥x∥ ≤ 0.99. For lower bound, note that E ∥x∥≤0.99 ∥x∥ ≥ 0.1 P(0.1 ≤ ∥x∥ ≤ 0.99). Hence, it suffices to lower bound P(0.1 ≤ ∥x∥ ≤ 0.99). We have P(0.1 ≤ ∥x∥ ≤ 0.99) = 0.99 0.1 d r J 2 d/2 (2πR d βα √ dr)dr ≥Ω(1) 1.98πR d βα √ d 0.2πR d βα √ d J 2 d/2 (r)dr =Ω(1), where in the last line we use Lemma 23 in Eldan & Shamir (2016) . This implies that E ∥x∥≤0.99 ∥x∥ = Ω(1). Together with the upper bound, we have E ∥x∥≤0.99 ∥x∥ = Θ(1). (b) We have E x∼D f * (x) = E ∥x∥≤1 [1 -∥x∥] ≥ E ∥x∥≤0.99 [1 -∥x∥] ≥ 0.01 P(∥x∥ ≤ 0.99) ≥ 0.01 P(0.1 ≤ ∥x∥ ≤ 0.99) = Ω(1), where the last inequality we use the calculation in (a). (c) The upper bound follows directly from the tail bound ∥D∥ (r) ≤ O(1/r 2 ). For the lower bound, recall the density of ∥x∥ when ∥x∥ = r is d r J 2 d/2 (2πR d βα √ dr). For notational simplicity, put R D = Θ(d). We have E ∥x∥≤R D ∥x∥ = R D 0 dJ 2 d/2 (2πR d βα √ dr)dr = d 2πR d βα √ d 2πR d R D βα √ d 0 J 2 d/2 (r)dr ≥Ω(1) cd 2 cd J 2 d/2 (r)dr, where c is a large enough constant. To lower bound E ∥x∥, it suffices to lower bound cd 2 cd J 2 d/2 (r)dr. In the following, we will lower bound it by following a similar calculation in Lemma 23 in Eldan & Shamir (2016) . From the proof of Lemma 23 in Eldan & Shamir (2016) , we have for x ≥ d ≥ 2 J 2 d/2 (x) ≥ 2 πx cos 2 - (d + 1)π 4 + f d,x x -3x -2 , where f d,x is a quantity that depends on d and x, and satisfies 1.3 ≥ f d,x ≥ 0.85. Then, we have cd 2 cd J 2 d/2 (x)dx ≥ cd 2 cd 2 πx cos 2 - (d + 1)π 4 + f d,x x dx - cd 2 cd 3x -2 dx = 2 π cd 2 cd 1 x cos 2 - (d + 1)π 4 + f d,x x dx - 3(d -1) cd 2 Note that in the proof of Lemma 23 in Eldan & Shamir (2016) , it is shown that ∂ ∂x (f d,x x) = 1 - d 2 -1 4x 2 ≤ 1. Then, since 1.3 ≥ f d,x ≥ 0.85 we have 2 π cd 2 cd 1 x cos 2 - (d + 1)π 4 + f d,x x dx ≥ 2 π cd 2 cd 0.85 f d,x x cos 2 - (d + 1)π 4 + f d,x x ∂ ∂x (f d,x x)dx = 2 π f d,cd 2 cd 2 f d,cd cd 0.85 z cos 2 - (d + 1)π 4 + z dz ≥ 1.7 π 0.85cd 2 1.3cd 1 z cos 2 - (d + 1)π 4 + z dz. Then, using integration by parts and the fact that cos 2 (z-(d+1)π/4) = ∂ ∂z (z/2+sin(2z-(d + 1)π/2)/4), we have 0.85cd 2 1.3cd 1 z cos 2 - (d + 1)π 4 + z dz = ( z 2 + 1 4 sin(2z -(d+1)π 2 )) z 0.85cd 2 1.3cd + 0.85cd 2 1.3cd ( z 2 + 1 4 sin(2z -(d+1)π 2 )) z 2 dz ≥ - 1 4 1 0.85cd 2 + 1 1.3cd + 0.85cd 2 1.3cd 1 4z dz = - 1 4 1 0.85cd 2 + 1 1.3cd + 1 4 ln 0.85cd 2 1.3cd = Ω(log d). Therefore, we have cd 2 cd J 2 d/2 (x)dx = Ω(log d), which implies E ∥x∥≤Θ(d) ∥x∥ = Ω(log d).

B.3 PROPERTIES OF SPHERICALLY SYMMETRIC FUNCTIONS AND DISTRIBUTIONS

In this subsection, we give some useful proprieties of spherically symmetric functions and distributions. These will be useful tools in our later analysis. Basically, these lemmas allow us to disentangle input x and neuron v when considering integration against spherically symmetric function. Lemma 3.1. Let µ be a spherically symmetric distribution. We have E w∼µ ∥w∥ σ(w • x) = C Γ E w∼µ ∥w∥ 2 √ d ∥x∥ where C Γ := Γ(d/2) √ d 2 √ πΓ((d + 1)/2) . Note that, as d → ∞, we have C Γ → 1/ √ 2π, so C Γ is universally bounded for all d. Lemma B.3. For any spherically symmetric g : R d → R and v ∈ R d , we have E x {g(x)σ(v • x)} = C Γ √ d E x {g(x) ∥x∥} ∥v∥ . Corollary B.4. Let g : R d → R be a spherically symmetric function. We have E x {g(x)F (x)} = α E x {g(x) ∥x∥} . Lemma B.5. Let g : R d → R be a spherically symmetric function. Then, for any v ∈ R d , we have E x∼D {g(x)σ ′ (v • x)x} = E x∼D {g(x) ∥x∥} C Γ √ d v. Proof of Lemma 3.1. For simplicity, put g(x) = E w∼µ ∥w∥ σ(w • x). Since σ is 1-homogenous and µ is spherically symmetric, we have g(x) = R d ∥w∥ 2 σ( w • x)µ(w) dw = ∞ 0 S d-1 r 2 σ( w • x)µ(r w)r d-1 dσ d-1 ( w)dr = ∞ 0 r d+1 µ(r) dr S d-1 σ( w • x) dσ d-1 ( w). For the first term, note thatfoot_7  R d ∥w∥ 2 µ(w) dw = ∞ 0 S d-1 r 2 µ(r w) dσ d-1 ( w)dr = 2π d/2 Γ(d/2) ∞ 0 r d+1 µ(r) dr. Hence, ∞ 0 r d+1 µ(r) dr = Γ(d/2) 2π d/2 R d ∥w∥ 2 µ(w) dw = Γ(d/2) 2π d/2 E w∼µ ∥w∥ 2 . Then we compute the second term as follows. Since it is also spherically symmetric, we have S d-1 σ( w • x) dσ d-1 ( w) = ∥x∥ S d-1 σ( w1 ) dσ d-1 ( w) = ∥x∥ 2 S d-1 | w1 | dσ d-1 ( w). Define I = R d |w 1 |e -∥w∥ 2 dw. We have I = R d |w 1 | d i=1 e -w 2 i dw = ∞ -∞ |w 1 |e -w 2 1 dw 1 d i=2 ∞ -∞ e -w 2 i dw i = π (d-1)/2 . We also have I = S d-1 ∞ 0 r| w1 |e -r 2 r d-1 drdσ d-1 ( w) = ∞ 0 e -r 2 r d dr S d-1 | w1 | dσ d-1 ( w) = Γ((d + 1)/2) 2 S d-1 | w1 | dσ d-1 ( w). Therefore, S d-1 σ( w • x) dσ d-1 ( w) = ∥x∥ 2 S d-1 | w1 | dσ d-1 ( w) = π (d-1)/2 Γ((d + 1)/2) ∥x∥ . Thus, g(x) = Γ(d/2) 2π d/2 E w∼µ ∥w∥ 2 π (d-1)/2 Γ((d + 1)/2) ∥x∥ = C Γ E w∼µ ∥w∥ 2 √ d ∥x∥ . Proof of Lemma B.3. We compute E x∼D {g(x)σ(v • x)} = R d g(x)σ(v • x)D(x) dx = ∞ 0 S d-1 g(r x)σ(v • (r x))D(r x)r d-1 dσ d-1 ( x)dr = ∞ 0 S d-1 g(r)σ(v • x)D(r)r d dσ d-1 ( x)dr = ∞ 0 g(r)D(r)r d dr S d-1 σ(v • x) dσ d-1 ( x) = ∞ 0 g(r)D(r)r d dr π (d-1)/2 Γ((d + 1)/2) ∥v∥ , where the last line comes from ( 13). (Note the integral is taken w.r.t. x instead of w here.) For the first term, note that E x∼D {g(x) ∥x∥} = R d g(x) ∥x∥ D(x) dx = ∞ 0 S d-1 g(r)D(x)r d dσ d-1 ( x)dr = ∞ 0 S d-1 g(r)D(x)r d dσ d-1 ( x)dr = 2π d/2 Γ(d/2) ∞ 0 g(r)D(x)r d dr. Thus, E x∼D {g(x)σ(v • x)} = E x∼D {g(x) ∥x∥} 2π d/2 Γ(d/2) -1 π (d-1)/2 Γ((d + 1)/2) ∥w∥ = E x∼D {g(x) ∥x∥} C Γ √ d ∥v∥ . Proof of Corollary B.4. By the previous Lemma, we have  E x {g(x)F (x)} = E x g(x) E w∼µ1 {∥w∥ σ(w • x)} = E w∼µ1 ∥w∥ E x {g(x)σ(w • x)} = E w∼µ1 ∥w∥ 2 C Γ √ d E x {g(x) ∥x∥} = α E x {g E x∼D {g(x)σ ′ (v • x)x} = 1 2 E x∼D {g(x)σ ′ (v • x)x} + E x∼R#D {g(x)σ ′ (v • x)x} = 1 2 E x∼D {g(x)σ ′ (v • x)x + g(Rx)σ ′ (v • Rx)Rx} = 1 2 E x∼D {g(x)σ ′ (v • x)x + g(Rx)σ ′ (Rv • x)Rx} = 1 2 E x∼D {g(x)σ ′ (v • x) (x + Rx)} . Note that x + Rx = 2v v⊤ x = 2 ⟨v, x⟩ v. Hence, E x∼D {g(x)σ ′ (v • x)x} = E x∼D {g(x)σ(v • x)} v = E x∼D {g(x) ∥x∥} C Γ √ d v, where the second identity comes from Lemma B.3.

B.4 THE INFINITE-WIDTH NETWORK REMAINS SPHERICALLY SYMMETRIC

In this subsection, we show that the infinite-width network remains spherically symmetric throughout the whole process. Clear that µ 1 is spherically symmetric at initialization. Now, assume that it is spherically symmetric at time t. We claim that v 1 does not move tangentially, and its radial speed does not depend on its direction v1 . That is, v1 = h(∥v 1 ∥)v 1 for some function h. By our induction hypothesis, S is also spherically symmetric at time t. Let T := 2v 1 v⊤ 1 -I d be the reflection w.r.t. v 1 . Clear that T v 1 = v 1 . Moreover, it does not change the norm and, as a result, S(T x) = S(x), T #D = D and Π •T = T • Π. Hence, we have v1 = E x∼D Π Rv 1 [S(x) (v 1 σ(v 1 • x) + ∥v 1 ∥ σ ′ (v 1 • x)x)] = 1 2 E x∼D Π Rv 1 [S(x) (v 1 σ(v 1 • x) + ∥v 1 ∥ σ ′ (v 1 • x)x)] + 1 2 E x∼T #D Π Rv 1 [S(x) (v 1 σ(v 1 • x) + ∥v 1 ∥ σ ′ (v 1 • x)x)] . For the second term, we have E x∼T #D Π Rv 1 [S(x) (v 1 σ(v 1 • x) + ∥v 1 ∥ σ ′ (v 1 • x)x)] = E x∼D Π Rv 1 [S(x) (v 1 σ(v 1 • T x) + ∥v 1 ∥ σ ′ (v 1 • T x)T x)] = E x∼D Π Rv 1 [S(x) (v 1 σ(v 1 • x) + ∥v 1 ∥ σ ′ (v 1 • x)T x)] = E x∼D Π Rv 1 [S(x)T (σ(v 1 • x)v 1 + ∥v 1 ∥ σ ′ (v 1 • x)x)] = E x∼D T Π Rv 1 [S(x) (σ(v 1 • x)v 1 + ∥v 1 ∥ σ ′ (v 1 • x)x)] . Thus, v1 = 1 2 (I + T ) E x∼D Π Rv 1 [S(x) (v 1 σ(v 1 • x) + ∥v 1 ∥ σ ′ (v 1 • x)x)] = 2 v1 , E x∼D Π Rv 1 [S(x) (v 1 σ(v 1 • x) + ∥v 1 ∥ σ ′ (v 1 • x)x)] v1 . Namely, v1 = h(v 1 )v 1 where h(v 1 ) = 2 v1 , E x∼D Π Rv 1 [S(x) (v 1 σ(v 1 • x) + ∥v 1 ∥ σ ′ (v 1 • x)x)] . Now, we show that h is spherically symmetric to complete the proof. Let R be an arbitrary rotation matrix. We have h(Rv 1 ) = 2 Rv 1 , E x∼D Π Rv 1 [S(x) (Rv 1 σ(Rv 1 • x) + ∥v 1 ∥ σ ′ (Rv 1 • x)x)] = 2 Rv 1 , E x∼D Π Rv 1 S(x) Rv 1 σ(v 1 • R ⊤ x) + ∥v 1 ∥ σ ′ (v 1 • R ⊤ x)RR ⊤ x = 2 Rv 1 , E x∼R ⊤ #D Π Rv 1 [S(x) (Rv 1 σ(v 1 • x) + ∥v 1 ∥ σ ′ (v 1 • x)Rx)] = 2 Rv 1 , R E x∼D Π Rv 1 [S(x) (v 1 σ(v 1 • x) + ∥v 1 ∥ σ ′ (v 1 • x)x)] = 2 v1 , E x∼D Π Rv 1 [S(x) (v 1 σ(v 1 • x) + ∥v 1 ∥ σ ′ (v 1 • x)x)] = h(v 1 ). Thus, h is spherically symmetric.

C STAGE 1

The goal of Stage 1 is for all v 2 to decrease to -Θ(1/R v2 ) so that we can ignore all projection operators in ṙ2 , v1 and v2 . We split Stage 1 into three substages, in which v 2 decreases to 0, -poly(d)δ 2 and -Θ(1/R v2 ), respectively. By Lemma C.3, at the end of each substage, one more projection operator can be ignored. We also show that, in Stage 1, the approximation error of the first layer and the spread of second layer cannot grow too much. First, for the initialization, by some standard concentration argument, we have the following lemma. Lemma C.1 (Initialization). We choose m 1 = poly(d, 1/ε), m 2 = Θ(1), σ 1 = 1/ √ d, σ 2 = 1/ poly(d, 1/ε), and σ r to be a small constant. We initialize w 1 ∼ Unif(σ 1 S d-1 ) for µ 1 , and w 2 ∼ N (0, σ 2 2 ) and b 2 = σ r for µ 2 . Given δ 1,I = 1/ poly 1 (d, 1/ε), we choose a sufficiently large m 1 so that, at initialization, with probability at least 1 -1/ poly(d), F | S d-1 -1 L ∞ ≤ δ 1,I . We also choose σ 2 = δ 1,I /d 7 . With probability at least 1 -1/ poly(d), we have max w2 |w 2 | ≤ O(log d)σ 2 . Then, we formally state the Induction Hypothesis we are going to maintain for Stage 1. Induction Hypothesis C.2 (Stage 1). We define T 1 := inf {t ≥ 0 : -w2 (t) = Θ(1)/R v2 } for some large constant. Define δ (1) 1,T , δ (1) 1,R , δ (1) 2 as 9                    δ (1) 1,T = max δ (1) 1,T (0). max v1∈µ1 ∥v 1 (t) -v1 (0)∥ , δ (1) 1,R = max δ (1) 1,R (0), max v1∈µ1 ∥v 1 ∥ 2 -E w1 ∥w 1 ∥ 2 E w1 ∥w 1 ∥ 2 , δ 2 = max δ (1) 2 (0), max (v2,r2),(v ′ 2 ,r ′ 2 ) ∥(v 2 , r 2 ) -(v ′ 2 , r ′ 2 )∥ , in Stage 1.1 and Stage 1.2, and                    d dt δ (1) 1,T = ReLU d dt max v1∈µ1 ∥v 1 (t) -v1 (0)∥ , d dt δ (1) 1,R = ReLU d dt max v1∈µ1 ∥v 1 ∥ 2 -E w1 ∥w 1 ∥ 2 E w1 ∥w 1 ∥ 2 , d dt δ (1) 2 = ReLU d dt max (v2,r2),(v ′ 2 ,r ′ 2 ) ∥(v 2 , r 2 ) -(v ′ 2 , r ′ 2 )∥ , in Stage 1.3, with initial value δ (1) 1,T (0) = δ (1) 1,R (0) = 0 and δ (1) 2 (0) = Θ(σ 2 log d). We say that this Induction Hypothesis is true at time t ∈ [0, T 1 ] if the following hold.foot_9 (a) Approximation error of the first layer. For each v 1 ∈ µ 1 , ∥v 1 (t) -v1 (0)∥ ≤ δ (1) 1,T and ∥v 1 ∥ 2 = 1 ± δ (2) 1,R E w1∼µ1 ∥w 1 ∥ 2 . (b) Spread of the second layer. For any (v 2 , r 2 ), (v ′ 2 , r ′ 2 ) ∈ µ 2 , ∥(v 2 , r 2 ) -(v ′ 2 , r ′ 2 )∥ ≤ δ (1) 2 . (c) The bias term. For any .5 ). (v 2 , r 2 ) ∈ µ 2 , r 2 = Θ(1). (d) Size of f . | w2 | = O(1/R v2 ) = O(1/d 3 ) and α = Θ( √ d/R v1 ) = Θ(1/d 1 (e) Bounds for the errors. δ (1) 2 ≤ O(d 1.5 (log d)σ 2 ) and δ (1) 1,R + δ (1) 1,T ≤ O(d 7 (log d)σ 2 + δ 1,I ) The next lemma describes when the projection operators can be ignored. Roughly speaking, we first bound the gradients to show that in order for a projection operator to be triggered, ∥x∥ must be larger than a certain quantity. Meanwhile, note that f , whence the gradients, vanishes for those x with ∥x∥ ≥ Θ(1/| w2 α|). Hence, as long as Θ(1/| w2 α|) is smaller than that quantity, we can ignore the projection. Lemma C.3. Suppose that Induction Hypothesis C.2 is true. The projection operators in ṙ2 , v1 and v2 will no longer be activated if all second layer weights are nonpositive, -w2 > Θ(1)δ (1) 2 for some large constant, and -w2 ≥ Θ(1)/R v2 for some large constant, respectively. Remark. Though we only need -w2 to be Θ(1)δ (1) 2 to ignore the projection operator in v1 , we will actually define the end of Stage 1.2 to be the time -w2 becomes poly(d)δ .5 ) and -w2 = Θ(1/d 3 ). For the errors, we have δ (2) 2 ≤ O(d 1.5 log dσ 2 ) and δ (1) 1,R + δ (1) 1,T ≤ O(δ 1,I ). Proof of Lemma C.3. First, note that when all v 2 are nonpositive, we have f = O(1). Since we choose R r2 to be a large constant, this implies the projection operator in ṙ2 will not be activated. When -w2 > Θ(1)δ (1) 2 , we have f (x) ≤ σ(c w2 α ∥x∥ + O( 1)) for some small constant c > 0. As a result, f vanishes on {∥x∥ ≥ (-c w2 α) -1 }. Then, for those x with ∥x∥ ≤ (-c w2 α) -1 , the gradient w.r.t. v 1 can be bounded as ∥∇ v1 L(x)∥ ≤ O(1)| w2 | ∥x∥ ∥v 1 ∥ ≤ O(1)| w2 | ∥v 1 ∥ 1 | w2 |α ≤ O(d). Since we choose R v1 = Θ(d) with a large constant, this implies the projection operator in v1 will not be triggered. Finally, for v2 , for those x with ∥x∥ ≤ (-c w2 α) -1 , we have |∇ v2 L(x)| ≤ O(1)α ∥x∥ ≤ O(1) | w2 | . By assumption, | w2 | = Θ(1)/R v2 for some large constant. Hence, this inequality implies the projection operator in v2 will not be triggered.

C.1 STAGE 1.1

The goal of Stage 1.1 is to make sure that all second layer weights v 2 become non-positive, that is, T 1.1 := inf{t ≥ 0 : ∀(v 2 , r 2 ) ∈ µ 2 , v 2 ≤ 0}. As a result, at the end of Stage 1.1, f is O(1) and, by Lemma C.3, the projection operator in ṙ2 can be ignored. Since this stage only takes a very small amount of time, we shall control the first layer error by directly bounding the movement of v 1 . For the second layer, we bound the movement of the bias term in the same brute-force way. For second layer weights, we show that those positive v 2 's decrease faster than the negative v 2 's, so the spread will not increase. Lemma C.5. Suppose that Induction Hypothesis C.2 is true at time t and t ≤ T 1.1 . Then the following hold. Remark. In fact, (c) holds whenever α = Ω(1/d 1.5 ) and v 2 F (x)+r 2 ≥ Θ(1) for any (v 2 , r 2 ) ∈ µ 2 and x ∈ {∥x∥ ≤ d 1.5 }, which is always true throughout Stage 1. This estimation will also be used in Stage 1.  v2 = E ∥x∥≤1 {(f * (x) -f (x))F (x)} -E ∥x∥≥1 Π Rv 2 [f (x)σ ′ (v 2 F (x) + r 2 )F (x)] . Note that the first term does not depend on v 2 , and, for the second term, σ ′ (v 2 F (x)+r 2 ) = 1 whenever v 2 ≥ 0. As a result, the speed of positive v 2 is uniform and more negative than those v 2 < 0. Thus, max w2 w 2 -min w2 w 2 is non-increasing. (c) Clear that E ∥x∥≤1 {(f * (x) -f (x))F (x)} = O(α). For the second term, first note that for any x with ∥x∥ ≤ d 1.5 , we have f (x)F (x) ≤ O 1 + max w2 w 2 α ∥x∥ α ∥x∥ ≤ R v2 and f (x) ≥ Θ(1)-max w2 |w 2 |α ∥x∥ = Θ(1). As a result, E ∥x∥≥1 Π Rv 2 [f (x)σ ′ (v 2 F (x) + r 2 )F (x)] ≥ Θ(α) E 1≤∥x∥≤d 1.5 ∥x∥ = Θ ((log d)α) . Thus, v2 ≤ -Θ(log d/d 1.5 ). Proof of Lemma C.6. By Lemma C.5, it takes at most O(d 1.5 δ (1) 2 (0)) amount of time for all v 2 to become nonpositive. Within this amount of time, r 2 at most changes O(d 1.5 δ (1) 2 (0)). Since the spread of w 2 does not increase, this implies δ 

C.2 STAGE 1.2

The goal of Stage 1.2 is to make sure -w2 ≥ dδ (1) 2 (T 1.1 ). Namely, T 1.2 := inf t ≥ T 1.1 : -w2 = dδ (1) 2 (T 1.1 ) . We will also show that δ (1) 2 (T 1.2 ) = O(δ (1) 2 (T 1.1 )) so δ (1) 2 (T 1.2 )/| w2 | = O(1/d) at the end of Stage 1.2. Moreover, by Lemma C.3, at the end of Stage 1.2, the projection operator in v1 will no longer be activated. We also show that r 2 remains Θ(1) throughout Stage 1 in this subsection. The first layer error is again controlled in a brute-force way. For the second layer spread, we show that since |v 2 | is small, σ ′ (v 2 F (x) + r 2 ) = 1 for most of x and, as a result, the change of (v 2 , r 2 ) is approximately uniform. Lemma C.7. Suppose that Induction Hypothesis C.2 is true at time t. Then, for any (v 2 , r 2 ) ∈ µ 2 , ṙ2 > 0 when r ≤ E f * /2 and ṙ2 < 0 when r ≥ 2 E f * . As a result, r 2 = Θ(1) throughout Stage 1. Lemma C.8 (Spread of the second layer). Suppose that Induction Hypothesis C.2 is true at time t and t ≤ T 1.2 . Then, for any (v 2 , r 2 ), (v ′ 2 , r ′ 2 ) ∈ µ 2 , we have d dt ∥(v 2 , r 2 ) -(v ′ 2 , r ′ 2 )∥ 2 ≤ O(d 2.5 ) δ (1) 2 2 . Though, by this Lemma, the error δ (1) 2 can grow exponentially fast and the growth rate is quite large, it will not blow up as v2 ≤ -Θ(log d/d 1.5 ), so the time needed for Stage 1.2 is much shorter than 1/d 2.5 . Lemma C.9 (Main lemma of Stage 1.2). Stage 1.2 takes at most O(d 2.5 δ (1) 2 (T 1.1 )) amount of time. At the end of Stage 1.2, we have, for any (v 2 , r 2 ) ∈ µ 2 , -v 2 ≥ Θ(d)δ (1) 2 (T 1.1 ). For the errors, the spread of the second layer is (1 + o(1))δ (1) 2 (T 1.1 ), and both δ (1) 1,R (T 1.2 ) and δ (1) 1,T (T 1.2 ) can be bounded by O(d 4 δ (1) 2 (T 1.1 )). Proof of Lemma C.7. We write 1), the RHS is a positive constant. In other word, r 2 will keep grow. Meanwhile, since the second term can be bounded as ṙ2 = E x {(f * (x) -f (x))σ ′ (v 2 F (x) + r 2 )} = E x f * (x) -E x {f (x)σ ′ (v 2 F (x) + r 2 )} Since the spread of b 2 is o(1), when r 2 ≤ E x f * (x)/2 = Θ( E x {f (x)σ ′ (v 2 F (x) + r 2 )} ≥ E ∥x∥≤d 2 {f (x)} ≥ (1 -o(1)) b2 , when r 2 ≥ 2 E f * (x) , ṙ2 will become a negative constant and r 2 will decrease. Combine this two cases together, and we complete the proof. Proof of Lemma C.8. Since |v 2 | ≤ dδ (1) 2 (T 1.1 ), F (x) = Θ(α) ∥x∥ and r 2 = Θ(1), v 2 F (x) + r 2 > 0 for all x with ∥x∥ ≤ Θ( √ d/δ (1) 2 (T 1.1 )). Hence, we can rewrite v2 as v2 = E ∥x∥≤Θ( √ d/δ (1) 2 (T1.1)) Π Rv 2 [(f * (x) -f (x))F (x)] - E ∥x∥≥Θ( √ d/δ (1) 2 (T1.1)) Π Rv 2 [f (x)σ ′ (v 2 F (x) + r 2 )F (x)] . The first term does not depend on v 2 and, by the tail bound, the second term can be bounded by O(R v2 δ (1) 2 (T 1.1 )/ √ d). Similarly, for ṙ2 , we have ṙ2 = E ∥x∥≤Θ( √ d/δ (1) 2 (T1.1)) {f * (x) -f (x)} ± O δ (1) 2 (T 1.1 )/ √ d . Hence, for any (v 2 , r 2 ), (v ′ 2 , r ′ 2 ) ∈ µ 2 , we have d dt ∥(v 2 , r 2 ) -(v ′ 2 , r ′ 2 )∥ 2 ≤ (v 2 -v ′ 2 )O R v2 δ (1) 2 (T 1.1 ) √ d +(r 2 -r ′ 2 )O δ (1) 2 (T 1.1 ) √ d ≤ O(d 2.5 ) δ (1) 2 2 . Proof of Lemma C.9. Recall from Lemma C.5 that v2 = -Θ(log d/d 1.5 ), whence Stage 1.2 takes at most O(d 2.5 δ (1) 2 (T 1.1 )) amount of time. By Lemma C.8, we have δ (1) 2 (T 1.2 ) 2 ≤ δ (1) 2 (T 1.1 ) 2 exp O(d 5 )δ (1) 2 (T 1.1 ) ≤ (1 + o(1)) δ (1) 2 (T 1.1 ) 2 . For v 1 , similar to the proof of Lemma C.6, both δ (1) 1,R (T 1.2 ) and δ (1) 1,T (T 1.2 ) can be bounded by O(d 4 δ (1) 2 (T 1.1 )). C.3 STAGE 1.3 The goal of Stage 1.3 is to make sure -w2 = Θ(1/R v2 ) for some large constant, so that, by Lemma C.3, the projection operator in v2 can be ignored. That is, we define T 1.3 := inf {t ≥ T 1.2 : -w2 (t) = Θ(1/R v2 )} . The time needed for this stage is longer than the time needed for previous stages, so we need less brute-force ways to control the errors. For the first layer, we show that the tangent movement is almost zero and the radial movement is approximately uniform. For the second layer, we show that the spread δ For the errors, the spread of the second layer is O δ (1) 2 (T 1.2 ) and the first layer errors are O δ (1) 1,R (T 1.2 ) + δ (1) 1,T (T 1.2 ) + δ 1,I + log(d)δ (1) 2 (T 1.2 ) . Proof. Since v2 = -Ω(log d/d 1.5 ) and R v2 = Θ(d 3 ), Stage 1.3 takes at most O(1/d 1.5 ) amount of time. Within this amount of time, by Lemma C.18, we have (δ (1) 2 (T 1.3 )) 2 ≤ (δ (1) 2 (T 1.2 )) 2 exp O(1) d 2.5 1 d 1.5 = (1 + o(1))(δ (1) 2 (T 1.2 )) 2 . For the first layer, by Lemma C.16, we have δ (1) 1,R (T 1.3 ) + δ (1) 1,T (T 1.3 ) ≤ δ (1) 1,R (T 1.2 ) + δ (1) 1,T (T 1.2 ) + O(1) d 3 δ 1,I + O log(d)δ (1) 2 exp O(1) d 2.5 1 d 1.5 = O δ (1) 1,R (T 1.2 ) + δ (1) 1,T (T 1.2 ) + δ 1,I d 3 + log(d)δ (1) 2 (T 1.2 ) . Finally, by Lemma C.17, we have α(T 1.3 ) = (1 + o(1))α(T 1.2 ). C.3.1 ESTIMATIONS RELATED TO σ ′ (v 2 F (x) + r 2 ) First, we need some helper results to handle σ ′ (v 2 F (x) + r 2 ). The conditions for them to hold are mild and are always true throughout the entire training procedure, and we will use these results in later stages, too. First, we show that when the value of σ ′ (v 2 F (x) + r 2 ) can change across different (v 2 , r 2 ), the function value must be small. Note that the error here depends on the ratio δ 2 /| w2 | and this is why we need | w2 | to be Θ(d)δ 2 instead of merely Θ(1)δ 2 at the end of Stage 1.2. Lemma C.11. Suppose that r 2 = Θ(1), -v 2 ≥ Ω(δ 2 ) for any (v 2 , r 2 ) ∈ µ 2 , where δ 2 is the spread of the second layer. If v 2 F (x) + r 2 = 0 for some (v 2 , r 2 ) ∈ µ 2 , then v ′ 2 F (x) + r ′ 2 ≤ O((| w2 | -1 + 1)δ 2 ) for all (v ′ 2 , r ′ 2 ) ∈ µ 2 . Remark. It is not necessary that there really exists a (v 2 , r 2 ) ∈ µ 2 with v 2 F (x) + r 2 = 0. As long as v ′ 2 F (x) + r ′ 2 ≤ 0 and v ′′ 2 F (x) + r ′′ 2 ≥ 0 for some (v ′ 2 , r ′ 2 ), (v ′′ 2 , r ′′ 2 ) ∈ µ 2 , by the continuity, there always exists some point (v 2 , r 2 ) between (v ′ 2 , r ′ 2 ) and (v ′′ 2 , r ′′ 2 ) such that v 2 F (x) + r 2 = 0. Moreover, this point is within the spread of the second layer, so this lemma still applies. ♣ Then, we show that we can absorb σ ′ into f * and f . Lemma C.12. Suppose that the hypothesis of Lemma C.11 is true, and all second layer neurons are activated on {∥x∥ ≤ 1}. Then, for any (v 2 , r 2 ) ∈ µ 2 and x ∈ R d , we have f * (x)σ ′ (v 2 F (x) + r 2 ) = f * (x) and f (x)σ ′ (v 2 F (x) + r 2 ) = f (x) ± O (| w2 | -1 + 1)δ 2 . As a corollary, we have f (x) = σ(v 2 F (x) + r 2 ) ± O (| w2 | -1 + 1)δ 2 , f (x) = σ( w2 F (x) + b2 ) ± O (| w2 | -1 + 1)δ 2 . As a corollary of Lemma C.11, the measure on which σ ′ (v 2 F (x) + r 2 ) can differ for different (v 2 , r 2 ) is also small. Here we also use the fact that those x are around Θ(1/| w2 α|) the tail bound ∥D∥ (r) ≤ O(1/r 2 ). Lemma C.13. Suppose that Induction Hypothesis C.2 is true at time t. For any (v 2 , r 2 ), (v ′ 2 , r ′ 2 ) ∈ µ 2 , we have E x {|σ ′ (v 2 F (x) + r 2 ) -σ ′ (v ′ 2 F (x) + r ′ 2 )|} ≤ O αδ . Proof of Lemma C.11. For any (v ′ 2 , r ′ 2 ) ∈ µ 2 , we can write v ′ 2 F (x) + r ′ 2 = v 2 F (x) + r 2 = 0 +(v ′ 2 -v 2 )F (x) + (r ′ 2 -r 2 ) = v ′ 2 -v 2 v 2 (v 2 F (x) + r 2 = 0 -r 2 ) + (r ′ 2 -r 2 ) = r 2 v 2 -v ′ 2 v 2 + (r ′ 2 -r 2 ). The last term can be bounded as O((| w2 | -1 + 1)δ 2 ). Proof of Lemma C.12. Since all second layer neurons are activated on {∥x∥ ≤ 1}, we always have f * (x)σ ′ (v 2 F (x) + r 2 ) = f * (x). Now we consider f (x)σ ′ (v 2 F (x) + r 2 ). If v 2 F (x) + r 2 > 0, then we are done. If v ′ 2 F (x) + r ′ 2 < 0 for all (v ′ 2 , r ′ 2 ) ∈ µ 2 , then both f (x)σ ′ (v 2 F (x) + r 2 ) and f (x) are 0. Therefore, it suffices to consider the case where v 2 F (x) + r 2 ≤ 0 while f (x) > 0. By Lemma D.6, in this case, we have f (x) ≤ O (| w2 | -1 + 1)δ 2 . Proof of Lemma C.13. Since the norm and direction of x are independent, it suffices to fix a direction x and consider E r∼∥D∥ {|σ ′ (v 2 rF ( x) + r 2 ) -σ ′ (v ′ 2 rF ( x) + r ′ 2 )|} . For notational simplicity, define h(v 2 , r 2 , r) = v 2 rF ( x) + r 2 . The integrand is nonzero iff the signs of h(v 2 , r 2 , r) and h(v ′ 2 , r ′ 2 , r) are different. To bound the length of the interval on which the signs can differ, we write h(v 2 , r 2 , r) = w2 rF ( x) + b2 + (v 2 -w2 )rF ( x) + (r 2 -b2 ) = w2 ± O δ (1) 2 rF ( x) + b2 ± O δ . Therefore, the length of this interval can be bounded by O(δ 2 /( w2 2 α)). Moreover, note that this interval is at Θ(1/| w2 α|), whence the density on it is O( w2 2 α 2 ). Thus, the measure of this interval is O(αδ (1) 2 ).

C.3.2 ESTIMATIONS FOR THE FIRST LAYER

Before we control the error growth, we need a lemma that relates the approximation error with the tangent movement and radial spread of the first layer. Lemma C.14. Suppose that the tangent movement and radial spread of the first layer neurons can be bounded as ∥v 1 (t) -v1 (0)∥ ≤ δ 1,T and ∥v 1 ∥ 2 = (1 ± δ 1,R ) E w1 ∥w 1 ∥ 2 . Then F (x; µ 1 ) = 1 + δ 1,I + √ dδ 1,R + √ dδ 1,T α ∥x∥ . As a simple corollary, we have the following. Corollary C.15. Suppose that Induction Hypothesis C.2 is true at time t. Then, we have |f (x) -f (x)| = δ 1,I + √ dδ (1) 1,R + √ dδ (1) 1,R w2 α ∥x∥ . As a result, we have (1) E x (f (x) -f (x)) ∥x∥ ≤ δ 1,I + √ dδ (1) 1,R + √ dδ (1) 1,R w2 α E ∥x∥ 2 ≤ δ 1,I + √ dδ (1) 1,R + √ dδ (1) 1,R . Now, 1,R + δ (1) 1,T ≤ O δ 1,I + √ dδ (1) 1,R + √ dδ (1) 1,R w2 + O log(d)δ (1) 2 ≤ O(1) d 2.5 δ (1) 1,R + δ (1) 1,T + O(1) d 3 δ 1,I + O log(d)δ . Finally, we estimate the radial speed of v 1 to provide an estimation for the magnitude of α at the end of Stage 1. Lemma C.17. Suppose that Induction Hypothesis C.2 is true at time t and t ∈ [T 1.2 , T 1.3 ]. Then we have d dt ∥v 1 ∥ 2 = Θ log d √ d w2 ∥v 1 ∥ 2 . Proof of Lemma C.14. Define N 2 = E w1 ∥w 1 ∥ 2 . Let µ ′ 1 be the distribution obtained by setting the norm of neurons in µ 1 to N . We have F (x; µ 1 ) = E w1∼µ1 (1 ± δ 1,R )N 2 σ( w1 • x) = F (x; µ ′ 1 ) ± O(δ 1,R N 2 ∥x∥). Let µ ′′ 1 be the distribution obtained by moving v1 (t) to v1 (0) in µ ′ 1 . Then, we have F (x; µ ′ 1 ) = N 2 E w1∼µ1(0) {σ( w1 • x)} ± O δ 1,T N 2 ∥x∥ = F (x; µ ′′ 1 ) ± O δ 1,T N 2 ∥x∥ . Finally, note that F (x; µ ′′ 1 ) = N 2 t N 2 0 F (x; µ 1 (0)) = N 2 t N 2 0 (1 ± δ 1,I )α 0 ∥x∥ = (1 ± δ 1,I )α t ∥x∥ . Combine these together and we complete the proof. Proof of Lemma C.16. First, we decompose v1 along the tangent and radial directions as follows: Rad( v1 ) := ⟨ v1 , v1 ⟩ v1 = 2 E x {S(x)σ(v 1 • x)} v1 , Tan( v1 ) := (I -v1 v⊤ 1 ) v1 = ∥v 1 ∥ E x S(x)σ ′ (v 1 • x)(I -v1 v⊤ 1 )x . Note that v1 = Rad( v1 ) + Tan( v1 ). By Lemma C.12, we have Rad( v1 ) = 2 w2 E x {(f * (x) -f (x))σ(v 1 • x)} v1 ± O log(d)δ (1) 2 ∥v 1 ∥ , Tan( v1 ) = ∥v 1 ∥ w2 E x (f * (x) -f (x))σ ′ (v 1 • x)(I -v1 v⊤ 1 )x ± O log(d)δ (1) 2 ∥v 1 ∥ . For the radial term, by Lemma B.3 and Lemma C.15, we have Rad( v1 ) = 2 w2 E x (f * (x) -f (x))σ(v 1 • x) v1 + 2 w2 E x ( f (x) -f (x))σ(v 1 • x) v1 ± O log(d)δ (1) 2 ∥v 1 ∥ = 2C Γ w2 √ d E x (f * (x) -f (x)) ∥x∥ v 1 ± O δ 1,I + √ dδ (1) 1,R + √ dδ (1) 1,R w2 ∥v 1 ∥ ± O log(d)δ (1) 2 ∥v 1 ∥ . Therefore, d dt ∥v 1 ∥ 2 = 2 ⟨v 1 , Rad( v1 )⟩ = 4C Γ w2 √ d E x (f * (x) -f (x)) ∥x∥ ∥v 1 ∥ 2 ± O δ 1,I + √ dδ (1) 1,R + √ dδ (1) 1,R w2 ∥v 1 ∥ 2 ± O log(d)δ (1) 2 ∥v 1 ∥ 2 . For any v 1 , v ′ 1 ∈ µ 1 with ∥v 1 ∥ ≥ ∥v ′ 1 ∥, we have d dt ∥v 1 ∥ 2 -∥v ′ 1 ∥ 2 ∥v ′ 1 ∥ 2 = d dt ∥v 1 ∥ 2 -∥v ′ 1 ∥ 2 ∥v ′ 1 ∥ 2 - ∥v 1 ∥ 2 -∥v ′ 1 ∥ 2 ∥v ′ 1 ∥ 2 d dt ∥v ′ 1 ∥ 2 ∥v ′ 1 ∥ 2 = 4C Γ w2 √ d E x (f * (x) -f (x)) ∥x∥ ∥v 1 ∥ 2 -∥v ′ 1 ∥ 2 ∥v ′ 1 ∥ 2 ± O δ 1,I + √ dδ (1) 1,R + √ dδ (1) 1,R w2 ± O log(d)δ (1) 2 - ∥v 1 ∥ 2 -∥v ′ 1 ∥ 2 ∥v ′ 1 ∥ 2 4C Γ w2 √ d E x (f * (x) -f (x)) ∥x∥ ± ∥v 1 ∥ 2 -∥v ′ 1 ∥ 2 ∥v ′ 1 ∥ 2 O δ 1,I + √ dδ (1) 1,R + √ dδ (1) 1,R w2 ± ∥v 1 ∥ 2 -∥v ′ 1 ∥ 2 ∥v ′ 1 ∥ 2 O log(d)δ (1) 2 = ±O δ 1,I + √ dδ (1) 1,R + √ dδ (1) 1,R w2 ± O log(d)δ (1) 2 . Now we consider the tangent movement. By Lemma B.5 and Lemma C.15, we have Tan( v1 ) = ∥v 1 ∥ w2 E x (f * (x) -f (x))σ ′ (v 1 • x)(I -v1 v⊤ 1 )x + ∥v 1 ∥ w2 E x ( f (x) -f (x))σ ′ (v 1 • x)(I -v1 v⊤ 1 )x ± O log(d)δ (1) 2 ∥v 1 ∥ = ±O δ 1,I + √ dδ (1) 1,R + √ dδ (1) 1,R w2 ∥v 1 ∥ ± O log(d)δ 2 ∥v 1 ∥ . As a result, d dt v1 = Tan( v1 ) ∥v 1 ∥ = ±O δ 1,I + √ dδ (1) 1,R + √ dδ (1) 1,R w2 ± O log(d)δ . Combine these two bounds together and we complete the proof. Proof of Lemma C.17. By the proof of Lemma C.16, we have Rad( v1 ) = 2C Γ w2 √ d E x (f * (x) -f (x)) ∥x∥ v 1 ± O δ 1,I + √ dδ (1) 1,R + √ dδ (1) 1,R w2 ∥v 1 ∥ ± O log(d)δ (1) 2 ∥v 1 ∥ = Θ log d √ d w2 v 1 ± O δ 1,I + √ dδ (1) 1,R + √ dδ (1) 1,R w2 ∥v 1 ∥ ± O log(d)δ 2 ∥v 1 ∥ .

Recall that δ

(1) 2 ≤ | w2 |/d. Hence, d dt ∥v 1 ∥ 2 = Θ log d √ d w2 ∥v 1 ∥ 2 ± O δ 1,I + √ dδ (1) 1,R + √ dδ (1) 1,R w2 ∥v 1 ∥ 2 ± O log(d)δ (1) 2 ∥v 1 ∥ 2 = Θ log d √ d w2 ∥v 1 ∥ 2 , C.3.3 ESTIMATIONS FOR THE SECOND LAYER Now, we bound the growth of the spread of the second layer. Readers may first check the proof of Lemma D.14, which is essentially a simpler case of this result where we do not need to deal with the projections. In Lemma D.14, we show that the spread will never grow. Here, the error comes from the projection. Lemma C.18. Suppose that Induction Hypothesis C.2 is true at time t. Then we have d dt (δ (1) 2 ) 2 ≤ O(1) d 2.5 (δ (1) 2 ) 2 . Proof. Let (v 2 , r 2 ), (v ′ 2 , r ′ 2 ) ∈ µ 2 and define h 2 (x) = v 2 F (x) + r 2 and h ′ 2 (x) = v ′ 2 F (x) + r ′ 2 . We write v2 = E ∥x∥≤1 {(f * (x) -f (x))F (x)} -E ∥x∥≥1 Π Rv 2 [f (x)σ ′ (h 2 (x))F (x)] =: T 1 ( v2 ) + T 2 ( v2 ). T 1 does not depend on v 2 . For T 2 , note that Π Rv 2 [f (x)σ ′ (h 2 (x))F (x)] = Π Rv 2 /F (x) [f (x)] σ ′ (h 2 (x))F (x). Similarly, for ṙ2 , we have d dt (r 2 -r ′ 2 ) 2 = -2 E ∥x∥≥1 {f (x)(σ ′ (h 2 (x)) -σ ′ (h ′ 2 (x)))(r 2 -r ′ 2 )} = -2 E ∥x∥≥1 Π Rv 2 /F (x) [f (x)](σ ′ (h 2 (x)) -σ ′ (h ′ 2 (x)))(r 2 -r ′ 2 ) -2 E ∥x∥≥1 f (x) -Π Rv 2 /F (x) [f (x)] (σ ′ (h 2 (x)) -σ ′ (h ′ 2 (x)))(r 2 -r ′ 2 ) . Combine these two equations together and we obtain d dt (v 2 -v ′ 2 ) 2 + (r 2 -r ′ 2 ) 2 = -2 E ∥x∥≥1 Π Rv 2 /F (x) [f (x)] (σ ′ (h 2 (x)) -σ ′ (h ′ 2 (x))) (h 2 (x) -h ′ 2 (x)) -2 E ∥x∥≥1 f (x) -Π Rv 2 /F (x) [f (x)] (σ ′ (h 2 (x)) -σ ′ (h ′ 2 (x)))(r 2 -r ′ 2 ) . Since σ ′ is non-decreasing, the first term is nonpositive. For the second term, by Lemma C.11 and Lemma C.13, it can be bounded as max x:sgn(h2(x))̸ =sgn(h ′ 2 (x)) f (x)× E ∥x∥ {|σ ′ (h 2 (x)) -σ ′ (h ′ 2 (x))|}×|r 2 -r ′ 2 | ≤ O α(δ (1) 2 ) 3 | w2 | ≤ O(1) d 2.5 (δ 2 ) 2 .

D STAGE 2

The goal of Stage 2 is for gradient flow to converge to a point with loss ε. Similar to Stage 1, we maintain a set of induction hypotheses. Induction Hypothesis D.1. Define T 2 := inf{t ≥ T 1 : L = ε}. Define δ (2) 1,L 2 , δ 1,L ∞ , δ 2 as d dt δ (2) 1,L 2 = ReLU d dt F -∥•∥ L 2 , d dt δ (2) 1,L ∞ = ReLU d dt F | S d-1 -1 L ∞ , d dt δ (2) 2 = 0, with initial value satisfying 11 Θ d 17 ε (δ (2) 1,L ∞ ) 2 ≤ δ (2) 1,L 2 ≤ Θ ε d 6 δ (2) 1,L ∞ , δ (2) 1,L 2 ≤ O ε 2 d 7 , δ (2) 1,L ∞ (T 1 ) ≤ O ε d 14 , δ (2) 2 ≤ O ε 2 d 10 . For any t ∈ [T 1 , T 2 ], we say that this Induction Hypothesis is true if the following hold. (a) Error of the first layer. F -∥•∥ L 2 ≤ δ (2) 1,L 2 and F | S d-1 -1 L ∞ ≤ δ (2) 1,L ∞ . (b) Spread of the second layer. ∥(v 2 , r 2 ) -(v ′ 2 , r ′ 2 )∥ ≤ δ (2) 2 for all (v 2 , r 2 ), (v ′ 2 , r ′ 2 ) ∈ µ 2 . (c) Regularity conditions. b2 ≤ 1-Θ( √ ε). w2 α ≥ -1+Θ( √ ε). | w2 | ≤ d. | w2 | ≥ Θ(1/d 3 ). α ≥ Θ(1/d 1.5 ). (d) Bounds for the errors. δ (2) 1,L ∞ = O(δ (2) 1,L ∞ (T 1 )) and δ (2) 1,L 2 = O(δ (2) 1,L 2 (T 1 )). The main lemma for Stage 2 is as follows. The rest of this section is organized as follows. In Section D.1, we collect some auxiliary results that will be used later. In Section D.2, we show that Induction Hypothesis D.1 is always true throughout Stage 2. (Also see Section B.1 for discussion on the techniques used and some conventions.) Then, we derive a lower bound on the convergence rate in Section D. Recall that, in Stage 2, we can ignore the projection operators, whence the dynamics of the neurons is given by v1 = E x {S(x) (v 1 σ(v 1 • x) + ∥v 1 ∥ σ ′ (v 1 • x)x)} , v2 = E x {(f * (x) -f (x))σ ′ (v 2 F (x) + r 2 )F (x)} , ṙ2 = E x {(f * (x) -f (x))σ ′ (v 2 F (x) + r 2 )} . Now, we derive the equations which describes the dynamics of α, F , and the loss L. Lemma D.3 (Dynamics of α). In Stage 2, we have α = 4C Γ √ d E x ′ {S(x ′ )F (x ′ )} . Lemma D.4 (Dynamics of F ). In Stage 2, for each fixed x, we have d dt F (x) = 4 E x ′ S(x ′ ) E w1 {σ(w 1 • x ′ )σ(w 1 • x)} + E x ′ S(x ′ ) E w1 ∥w 1 ∥ 2 σ ′ (v 1 • x ′ )σ ′ (v 1 • x) (I -v1 v⊤ 1 )x ′ , x . Note that in the above lemma, we decompose d dt F (x) into two terms where the first term corresponds to the radial movement of v 1 and the second term the tangent movement. Lemma D.5 (Dynamics of L). Define W2 (x) = E w2,b2 {σ ′ (w 2 F (x) + b 2 )w 2 }. In Stage 2, we have d dt L = -E w2,b2,w1 ∥∇ w2,b2,w1 ∥ 2 , where ∇ w2,b2,w1 := E x      (f * (x) -f (x))    σ ′ (w 2 F (x) + b 2 )F (x) σ ′ (w 2 F (x) + b 2 ) 2 W2 (x)σ(w 1 • x) ∥w 1 ∥ W2 (x)σ ′ (w 1 • x)(I -w1 w⊤ 1 )x         . The entries of ∇ w2,b2,w1 correspond to the movements of v 2 , r 2 , radial movement of v 1 and tangent movement of v 1 , respectively. The proofs of these three lemmas are as follows. Proof of Lemma D.3. Recall that α : = CΓ √ d E w1 ∥w 1 ∥ 2 . Hence, α = 2CΓ √ d E w1 ⟨w 1 , ẇ1 ⟩. We com- pute ⟨ v1 , v 1 ⟩ = E x {S(x) (σ(v 1 • x) ⟨v 1 , v 1 ⟩ + ∥v 1 ∥ σ ′ (v 1 • x) ⟨x, v 1 ⟩)} = 2 E x {S(x) ∥v 1 ∥ σ(v 1 • x)} . Hence, α = 4C Γ √ d E w1 E x {S(x) ∥w 1 ∥ σ(w 1 • x)} = 4C Γ √ d E x S(x) E w1 {∥w 1 ∥ σ(w 1 • x)} = 4C Γ √ d E x {S(x)F (x)} . Proof of Lemma D.4. First, we write d dt F (x) = d dt E w1 ∥w 1 ∥ 2 σ( w1 • x) = E w1 d dt ∥w 1 ∥ 2 σ( w1 • x) + E w1 ∥w 1 ∥ 2 d dt σ( w1 • x) . By the proof of Lemma D.3, the first term is 4 E x ′ {S(x ′ ) E w1 {σ(w 1 • x ′ )σ(w 1 • x)}} . For the second term, we compute d dt σ(v 1 • x) = σ ′ (v 1 • x) (I -v1 v⊤ 1 ) v1 ∥v 1 ∥ , x = σ ′ (v 1 • x) E x ′ S(x ′ )σ ′ (v 1 • x ′ )(I -v1 v⊤ 1 )x ′ , x = E x ′ S(x ′ )σ ′ (v 1 • x ′ )σ ′ (v 1 • x) (I -v1 v⊤ 1 )x ′ , x . Hence, the second term is E w1 ∥w 1 ∥ 2 d dt σ( w1 • x) = E x ′ S(x ′ ) E w1 ∥w 1 ∥ 2 σ ′ (v 1 • x ′ )σ ′ (v 1 • x) (I -v1 v⊤ 1 )x ′ , x Combine these together and we complete the proof. Proof of Lemma D.5. First, we write d dt f (x) = E w2,b2 {σ ′ (w 2 F (x) + b 2 ) ẇ2 F (x)} + E w2,b2 σ ′ (w 2 F (x) + b 2 ) ḃ2 + W2 (x) d dt F (x) =: T 1 d dt f (x) + T 2 d dt f (x) + T 3 d dt f (x) . Note that d dt L = - 3 i=1 E x (f * (x) -f (x))T i d dt f (x) . Now we compute each of these three terms separately. We have E x (f * (x) -f (x))T 1 d dt f (x) = E w2,b2 E x {(f * (x) -f (x))σ ′ (w 2 F (x) + b 2 )F (x)} ẇ2 = E w2,b2 E x {(f * (x) -f (x))σ ′ (w 2 F (x) + b 2 )F (x)} 2 , E x (f * (x) -f (x))T 2 d dt f (x) = E w2,b2 E x {(f * (x) -f (x))σ ′ (w 2 F (x) + b 2 )} ḃ2 = E w2,b2 E x {(f * (x) -f (x))σ ′ (w 2 F (x) + b 2 )} 2 . Meanwhile, for T 3 , by Lemma D.4, we have E x (f * (x) -f (x))T 3 d dt f (x) = E x (f * (x) -f (x)) W2 (x) d dt F (x) = 4 E x (f * (x) -f (x)) W2 (x) E x ′ S(x ′ ) E w1 {σ(w 1 • x ′ )σ(w 1 • x)} + E x (f * (x) -f (x)) W2 (x) E x ′ S(x ′ ) E w1 ∥w 1 ∥ 2 σ ′ (w 1 • x ′ )σ ′ (w 1 • x) (I -w1 w⊤ 1 )x ′ , x = 4 E w1 E x {S(x)σ(w 1 • x)} 2 + E w1 E x S(x) ∥w 1 ∥ σ ′ (w 1 • x)(I -w1 w⊤ 1 )x 2 . Combine these together and we complete the proof.

D.1.2 ERROR-RELATED ESTIMATIONS

We collect some error-related estimations here. Most of them have been proved in Stage 1 except that here we have used | w2 | ≥ Θ(1/d 3 ) to replace (| w2 | -1 +1) with O(d 3 ). We repeat the statement here for easier reference. Lemma D.6. Suppose that Induction Hypothesis D.1 is true at time t. If v 2 F (x) + r 2 = 0 for some (v 2 , r 2 ) ∈ µ 2 , then v ′ 2 F (x) + r ′ 2 ≤ O d 3 δ (2) 2 for all (v ′ 2 , r ′ 2 ) ∈ µ 2 . Proof. See Lemma C.11. Lemma D.7. Suppose that Induction Hypothesis D.1 is true at time t. Then, for any (v 2 , r 2 ) ∈ µ 2 and x ∈ R d , we have f * (x)σ ′ (v 2 F (x) + r 2 ) = f * (x) and f (x)σ ′ (v 2 F (x) + r 2 ) = f (x) ± O d 3 δ (2) 2 . As a corollary, we have f (x) = σ(v 2 F (x) + r 2 ) ± O d 3 δ (2) 2 , f (x) = σ( w2 F (x) + b2 ) ± O d 3 δ (2) 2 . Proof. See Lemma C.12. Lemma D.8. Suppose that Induction Hypothesis D.1 is true at time t. Then we have f -f L 2 ≤ O | w2 α|δ (2) 1,L 2 . Proof. Since σ is 1-Lipschitz, we have |f (x) -f (x)| = E w2,b2 σ(w 2 F (x) + b 2 ) -σ(w 2 F (x) + b 2 ) ≤ O | w2 ||F (x) -F (x)| . Thus, f -f 2 L 2 ≤ O w2 2 α 2 F -∥•∥ 2 L 2 ≤ O w2 2 α 2 (δ (2) 1,L 2 ) 2 .

D.2 MAINTAINING THE INDUCTION HYPOTHESIS

In this section, we show that Induction Hypothesis D.1 is true throughout Stage 2. See Section B.1 for discussion and conventions on the techniques used here.

D.2.1 ERROR OF THE FIRST LAYER

Recall that we can decompose the loss as L = 1 2 E x (f * (x) -f (x)) 2 + 1 2 E x ( f (x) -f (x)) 2 + E x (f * (x) -f (x))( f (x) -f (x)) =: L 1 + L 2 + L 3 . As we have discussed in the main text, the goal is to show that L 2 ≈ w2 2 2 E ( F (x) -F (x)) 2 and L 3 ≈ 0, so that L can be decomposed into two terms where the first term captures the difference between the target function f * and the infinite-width network f , and the second term measures the approximation error between F and F . We will show in Lemma D.11 that, as one may expect, L 1 does not affect F . Estimating the gradients of L 2 and L 3 is slightly more complicated. First we need to introduce the following partition on the input space. Proof of Lemma D.11. For fixed x ∈ R d , we write d dt F (x) = d dt F (x) α -F (x) α α = - 1 α E w1 ⟨∇ w1 F (x), ∇ w1 L⟩ + F (x) 1 α E w1 ⟨∇ w1 α, ∇ w1 L⟩ . First, we consider L 1 . For each v 1 ∈ µ 1 , we have ∇ v1 L 1 = -E x (f * (x) -f (x))∇ v1 f (x) = - 2C Γ √ d E x (f * (x) -f (x)) E w2,b2 {σ(w 2 α ∥x∥ + b 2 )w 2 } v 1 =: C Tmp,1 v 1 . Meanwhile, note that ⟨∇ v1 F (x), v 1 ⟩ = ∇ v1 (∥v 1 ∥ 2 σ(v 1 • x)), v 1 = ∇ v1 (∥v 1 ∥ 2 )σ(v 1 • x), v 1 = 2 ∥v 1 ∥ 2 σ(v 1 • x), ⟨∇ v1 α, v 1 ⟩ = C Γ √ d ∇ v1 ∥v 1 ∥ 2 , v 1 = 2C Γ √ d ∥v 1 ∥ 2 . Hence, d dt F (x) L1 := - 1 α E w1 ⟨∇ w1 F (x), ∇ w1 L 1 ⟩ + F (x) 1 α E w1 ⟨∇ w1 α, ∇ w1 L 1 ⟩ = -C Tmp,1 2 α E w1 ∥w 1 ∥ 2 σ( w1 • x) + C Tmp,1 F (x) 1 α 2C Γ √ d E w1 ∥w 1 ∥ 2 = -C Tmp,1 2 α F (x) + 2C Tmp,1 F (x) = 0. Namely, L 1 does not affect F . Now we consider L 2 . By Lemma D.10, we have d dt F (x) L2 := - 1 α E w1 ⟨∇ w1 F (x), ∇ w1 L 2 ⟩ + F (x) 1 α E w1 ⟨∇ w1 α, ∇ w1 L 2 ⟩ = - 1 α w2 2 2 E w1 ∇ w1 F (x), ∇ w1 E x ′ ∈X1 ( F (x ′ ) -F (x ′ )) 2 + 1 α w2 2 2 F (x) E w1 ∇ w1 α, E x ′ ∈X1 ( F (x ′ ) -F (x ′ )) 2 ± O √ d δ (2) X2 2 1 α ∥x∥ . Note that we can rewrite the ∇ w1 F (x) in the first term as (∇ w1 α) F (x) + α∇ w1 F (x) so that part of it cancel with the second term. Then, we get d dt F (x) L2 = - w2 2 2 E w1 ∇ w1 F (x), ∇ w1 E x ′ ∈X1 ( F (x ′ ) -F (x ′ )) 2 ±O √ d δ (2) X2 2 1 α ∥x∥ . For L 3 , we can simply merge it into the error term of d dt F (x)| L2 . Proof of Lemma D.12. By Lemma D.11, we have d dt F -∥•∥ 2 L 2 = E x ( F (x) -∥x∥) d dt F (x) = - w2 2 2 E x ( F (x) -∥x∥) E w1 ∇ w1 F (x), ∇ w1 E x ′ ∈X1 ( F (x ′ ) -F (x ′ )) 2 ± E x ( F (x) -∥x∥)O √ d δ (2) X2 2 1 α ∥x∥ . The second term can be bounded by O δ (2) 1,L 2 δ (2) X2 2 d 5 . The first term is equal to Tmp := - w2 2 4 E w1 ∇ w1 E x ( F (x) -∥x∥) 2 , ∇ w1 E x ′ ∈X1 (α ∥x ′ ∥ -F (x ′ )) 2 . To complete the proof, it suffices to show that this is negative. For each w 1 , we have ∇ w1 E x ′ ∈X1 (α ∥x ′ ∥ -F (x ′ )) 2 = E x ′ ∈X1 ( F (x ′ ) -∥x ′ ∥) 2 ∇ w1 α 2 + α 2 E x ′ ∈X1 ∇ w1 ( F (x ′ ) -∥x ′ ∥) 2 . Since the distribution of x is spherically symmetric, E x ′ ∈X1 ∇ w1 ( F (x ′ ) -∥x ′ ∥) 2 and E x ∇ w1 ( F (x) -∥x∥) 2 have the same direction. Hence, Tmp ≤ - w2 2 4 E w1 ∇ w1 E x ( F (x) -∥x∥) 2 , ∇ w1 α 2 E x ′ ∈X1 ( F (x ′ ) -∥x ′ ∥) 2 = - C Γ √ d w2 2 α E x ′ ∈X1 ( F (x ′ ) -∥x ′ ∥) 2 E x E w1 ∇ w1 ( F (x) -∥x∥) 2 , w 1 . Then, we compute ∇ w1 ( F (x) -∥x∥) 2 , w 1 = 2( F (x) -∥x∥) ∇ w1 F (x) α -F (x) ∇ w1 α α , w 1 = 2( F (x) -∥x∥) 2 ∥w 1 ∥ 2 σ( w1 • x) α -F (x) 1 α 2C Γ √ d ∥w 1 ∥ 2 . Take expectation over w 1 and one can see that this is 0. Thus, Tmp ≤ 0. Proof of Lemma D.13. Recall from Lemma D.11 that d dt F (x) = - w2 2 2 E w1 ∇ w1 F (x), ∇ w1 E x ′ ∈X1 ( F (x ′ ) -F (x ′ )) 2 ±O √ d δ (2) X2 2 1 α ∥x∥ . For the first term, we have ∇ w1 F (x) ≤ ∇ w1 F (x) α + F (x) α α ≤ O ∥w 1 ∥ ∥x∥ α , ∇ w1 E x ′ ∈X1 F (x ′ ) -F (x ′ ) 2 ≤ E x ′ ∈X1 F (x ′ ) -F (x ′ ) ∇ w1 F (x ′ ) + ∥∇ w1 F (x ′ )∥ ≤ O(1) E x ′ ∈X1 F (x ′ ) -F (x ′ ) ∥x ′ ∥ ∥w 1 ∥ ≤ O δ (2) 1,L 2 1 α| w2 | ∥w 1 ∥ . Thus, d dt F (x) ≤ O w2 2 E w1 ∥w 1 ∥ ∥x∥ α δ (2) 1,L 2 1 α| w2 | ∥w 1 ∥ + O √ d δ (2) X2 2 1 α ∥x∥ ≤ O √ d| w2 | 1.5 √ α δ (2) 1,L 2 ∥x∥ + O √ d δ (2) X2 2 1 α ∥x∥ ≤ O d 3 δ (2) 1,L 2 + d 2 δ (2) X2 2 ∥x∥ .

D.2.2 SPREAD OF THE SECOND LAYER

Lemma D.14. Suppose that Induction Hypothesis D.1 is true at time t. Then for any (v 2 , r 2 ), (v ′ 2 , r ′ 2 ) ∈ µ 2 , d dt ∥(v 2 , r 2 ) -(v ′ 2 , r ′ 2 )∥ 2 ≤ 0. In words, the spread of the second layer never grows. Therefore, f (x) = f * (x) ± |1 -b2 | ± O | -1 -w2 α| | w2 α| ± O δ (2) 1,L ∞ + d 3 δ (2) 2 . Thus, (2) 2 L = 1 2 E x (f * (x) -f (x)) 2 ≤ 1 2 |1 -b2 | + O | -1 -w2 α| | w2 α| + O δ (2) 1,L ∞ + d 3 δ (2) 2 2 ≤ O (1 -b2 ) 2 + (-1 -w2 α) 2 w2 2 α 2 + δ (2) 1,L ∞ + d 3 δ . Since L ≥ ε, by Lemma D.15, we have (-1 -w2 α) 2 w2 2 α 2 ≥ Ω(ε) -O(δ 2 ) -O δ (2) 1,L ∞ + d 3 δ (2) 2 2 ≥ Ω(ε). Since w2 α ≥ -1, this implies w2 α ≥ -1 + Ω (| w2 |α √ ε) . In fact, this implies w2 α ≥ -1 + Ω( √ ε) even when | w2 |α is o(1), as, in that case, w2 α ≥ -1 + Ω( √ ε) directly holds. Hence, σ( w2 α ∥x∥ + b2 ) ≥ σ -1 + Ω √ ε ∥x∥ + 1 -δ = σ 1 -∥x∥ + Ω √ ε ∥x∥ -δ . Thus, ḃ2 = E x f * (x) -σ 1 -∥x∥ + Ω √ ε ∥x∥ -δ ± O δ (2) 1,L ∞ + d 3 δ (2) 2 ≤ E ∥x∥≤ 1 -∥x∥ -1 -∥x∥ + Ω √ ε ∥x∥ -δ + O δ (2) 1,L ∞ + d 3 δ (2) 2 = -Ω √ ε + δ + O δ (2) 1,L ∞ + d 3 δ (2) 2 . As long as the constant in δ = Θ( √ ε) is sufficiently small, this implies ḃ2 < 0 when b2 = 1-δ. Proof of Lemma D.17 Now we estimate the coefficient of the first term. Suppose that w2 α = -1 + δ for some δ ≤ Θ( √ ε) with a sufficiently small constant. Then, by Lemma D.15, we have (1 -b2 ) 2 ≥ Ω(ε) -O(δ 2 ) = Ω(ε). Hence, b2 ≤ 1 -Θ( √ ε). Also note that w2 α = Θ(1) implies that it suffices to consider x with ∥x∥ = Θ(1). As a result, we have σ( w2 α ∥x∥ F ( x) + b2 ) = σ( w2 α ∥x∥ + b2 ) ± O δ (2) 1,L ∞ ≤ σ 1 -∥x∥ -Θ( √ ε) + O δ (2) 1,L ∞ . Then, we decompose the coefficient as  E x f * (x) -σ( w2 α ∥x∥ F ( x) + b2 ) F (x) = E f * (x) -σ( w2 α ∥x∥ F ( x) + b2 ) F (x) ≥ E ∥x∥≤1 Θ( √ ε) -O δ (2) 1,L ∞ F (x) ≥ Ω α √ ε . 2 w2 E x {(f * (x) -f (x))σ(w 1 • x)} = 2 w2 E x (f * (x) -f (x))σ(w 1 • x) + 2 w2 E x ( f (x) -f (x))σ(w 1 • x) = 2C Γ w2 √ d E x (f * (x) -f (x)) ∥x∥ ∥w 1 ∥ ± 2 w2 ∥w 1 ∥ f -f L 2 E x∈X2 ∥x∥ 2 = 2C Γ w2 √ d E x (f * (x) -f (x)) ∥x∥ ∥w 1 ∥ ± O | w2 | 1.5 α 0.5 ∥w 1 ∥ δ (2) 1,L 2 . Repeat the above procedure and we can replace the f in the first term with f . Therefore, ∇ w2,b2,w1 ] 1:3 = E (2) 2 [ . For [∇ w2,b2,w1 ] 3 , we have First, for those x ∈ {∥x∥ ≤ 1}, we have f * (x) = -∥x∥ + 1 and f (x) = w2 F (x) + b2 = w2 α ∥x∥ + b2 + w2 α( F (x) -∥x∥). E w1 [∇ w2,b2,w1 ] 3 2 = 4C 2 Γ w2 2 d E x {(f * (x) -f (x)) ∥x∥} 2 E w1 ∥w 1 ∥ 2 ± O δ (2) 1,L 2 + d 3 δ (2) 2 w2 2 E ∥w∥ 2 1 √ d log(d) = E x {(f * (x) -f (x)) ∥x∥} 2 4C Γ w2 2 √ d α ± O δ (2) 1,L 2 + d 3 δ (2) 2 Hence, we have E ∥x∥≤1 (f * (x) -f (x)) -∥x∥ + 1 -(α w2 ∥x∥ + b2 ) = E ∥x∥≤1 (f * (x) -f (x)) 2 + E ∥x∥≤1 (f * (x) -f (x)) w2 α( F (x) -∥x∥) = E ∥x∥≤1 (f * (x) -f (x)) 2 ± O | w2 α|δ (2) 1,L 2 . Then, for x ∈ {∥x∥ ≥ 1}, note that -∥x∥ + 1 ≤ 0 and f * (x) = 0. Therefore, we have where the second equality comes from Lemma D.7 and Induction Hypothesis D.1. Combine these two cases together and we obtain ∇, ∇ ≥ A E x (f * (x) -f (x)) 2 -O | w2 α|δ (2) 1,L 2 -O d 3 δ (2) 2 . Finally, note that A ≥ α. Thus, ∇ ≥ Ω(αL) -O δ Thus, it takes at most O(d 3 /ε) amount of time for L to reach ε.



We say some quantity a is poly(d, 1/ε) if it is bounded by C(d/ε) C for some universal constant C > 0 that may change across lines. Though gradient flow, strictly speaking, is not a proper algorithm, it is common to use it as a surrogate for gradient descent in theoretical analysis. See Appendix E for discussions on how to convert the argument to a gradient descent one. Strictly speaking, the result inSafran et al. (2019) requires ε = O(1/d 6 ). Even in that regime, our algorithm learns the function using poly(d) neurons, which is not achievable by any two-layer network, therefore it is still a valid separation. As in the standard mean-field arguments, we rescale the gradients by m so that it does not go to 0 as m → ∞. In most cases regarding gradient calculation, this is equivalent to using the formal rule ∂v Ew g(w) = ∂vg(v). For the ease of presentation, here we are talking about the function values instead of the gradients. Strictly speaking, this is incorrect as the function value being small does not necessarily imply the gradient is small. The ideas, however, are essentially the same. See Section D.2 for the actual proof. Here, wi ∈ R d means the i-th row of W . Later we will notations vi, ai to denote i-th row or column of the corresponding matrix. Whether it is a row or column can be easily inferred from the dimension. The general rule is that if V ∈ R D×m where m represents the number of neurons, then vi ∈ R D is i-th column, and if W ∈ R m×D , then wi ∈ R D is the i-th row. Our focus is on factoring out the permutation invariance and, in this paper, essentially all distributions are empirical distributions over finitely many neurons, with respect to which the integral is just summation and is always well-defined. We leave the work of figuring out specific regularity conditions to future works. Recall the surface area of the d-dimensional unit sphere is dσ d-1 = 2π d/2 Γ(d/2) . Note that we define these δ's to be upper bounds of the corresponding values instead the values themselves. The only reason we define these δ's in such a twisted way is to make the proof easier to write rigorously. See the footnote in Induction Hypothesis D.1, where this type of definitions plays more technically important role, for further discussions. The first two conditions actually follow directly from the definition of the δ's. We put repeat them here only for easier reference. The actual result we need to prove for these δ's is condition (e), which says that these δ's are always small. As we have mentioned in the footnote in Induction Hypothesis C.2, these δ's are defined as upper bounds for the corresponding errors. This gives certain degree of freedom in choosing their initial value. By Lemma C.4, we can choose the parameters so that the errors at the beginning of Stage 2 is arbitrarily small and these conditions can indeed be satisfied. The first condition, which requires the L 2 error to be left and right controlled by the L ∞ error, may seem strange at the first sight. The only reason we need it is to merge some second order error terms into first order ones.



Figure 1: Difference between previous Nguyen & Pham (2020) (Left) and our framework (Right).

define

Figure 2: Simulation results. The left figure shows the loss during training. Each vertical dashed line corresponds to a time point plotted in the other two figures. The center figure depicts the shape of f at certain steps. The right figure shows the values of the second-layer neurons at certain steps. One can observe that f ≈ f indeed holds, and the second layer neurons are concentrated around ( w2 , b2 ), which matches our theoretical analysis. Simulation is performed on a finite-width network with widths m 1 = 512, m 2 = 128 and input dimension d = 100.

Lemma 4.3 (Stage 2, informal). Define the end time of Stage 2 as T 2 := inf{t ≥ T 1 : L = ε}. Under the assumptions of Theorem 2.1, we have T 2 -T 1 ≤ poly(d, 1/ε) and the following conditions hold throughout Stage 2: (a) Approximation error of the first layer.

x) ∥x∥} . Proof of Lemma B.5. Define R = v v⊤ -(I d -v v⊤ ) = 2v v⊤ -I d . That is, R is the reflection matrix associated with v. Since D is spherically symmetric, we have R#D = D. For the same reason, g • R = g. Moreover, by construction, Rv = v. Hence,

present the main lemma of Stage 1. One can see that, by properly choosing the parameters, the errors can be made arbitrarily small without affecting the final value of α and w2 . To prove the main lemma, it suffices to combine Lemma C.6, Lemma C.9 and Lemma C.10 together. Lemma C.4 (Main lemma of Stage 1). Induction Hypothesis C.2 is true throughout Stage 1. Stage 1 takes at most O(d 4 σ 2 + 1/d 1.5 ) amount of time. At the end of Stage 1, we have α = Θ(1/d 1

(a) ∥ v1 ∥ ≤ R v1 and | ṙ2 | ≤ R r2 . (b) max w2 w 2 -min w2 w 2 is non-increasing. (c) For any positive second layer weight v 2 , we have v2 ≤ -Θ(log d/d 1.5 ).

2 and Stage 1.3. ♣ Lemma C.6 (Main lemma of Stage 1.1). Stage 1.1 takes at most O(d 1.5 δ (1) 2 (0)) amount of time. At the end of Stage 1.1, all second layer weights v 2 are non-positive. Hence, f = O(1) and, by Lemma C.3, the projection operator in ṙ2 can no longer be activated.For the errors, we have δ First, we decompose v 2 as

). Finally, the change of v 1 can be bounded by O(d 2.5 δ

too fast. Lemma C.10 (Main lemma of Stage 1.3). Stage 1.3 takes at most O(1/d 1.5 ) amount of time. At the end of Stage 1.3, we have -w2 = Θ(1/R v2 ) and α = Θ( √ d/R v1 ).

Lemma D.2 (Stage 2). Induction Hypothesis D.1 is true throughout Stage 2 and Stage 2 takes at most O(d 3 /ε) amount of time.

3. Finally, we prove Lemma D.2 in Section D.4. D.1 AUXILIARY LEMMAS D.1.1 THE DYNAMICS OF F , f AND L

Proof of Lemma D.16. By Lemma D.7, for any (v 2 , r 2 ) ∈ µ 2 , we haveṙ2 = E x f * (x) -σ( w2 F (x) + b2 ) ± O d 3 δ Induction Hypothesis D.1 and the Lipschitzness of σ, we haveσ( w2 F (x) + b2 ) = σ( w2 α ∥x∥ F ( x) + b2 ) = σ( w2 α ∥x∥ + b2 ) ± O δ x) -σ( w2 α ∥x∥ + b2 ) ± O δ (2) 1,L ∞ + d 3 δ

By Lemma D.3 and Lemma D.7, we haved dt ( w2 α) = E x f * (x) -σ( w2 α ∥x∥ F ( x) + b2 ) F (x)

Proof of Lemma D.18. By Lemma D.3 and Lemma D.7, we haveẇ2 = E x {(f * (x) -f (x))F (x)} ± d 3 log dδ * (x) -f (x))F (x)} w2 ± O d 2.5 (log d)δ Also recall that w2 2 ≪ α at T 1 . Thus, throughout Stage 2, we always have α -2CΓ √ d w2 2 ≪ 1/d. Since | w2 α| ≤ 1, this implies | w2 | ≤ O(d 1/6 ) ≤ d. Proof of Lemma D.19. Recall from the proof of Lemma D.18 that |α -2CΓ √ d w2 2 | ≪ 1/d. Hence, when α = Θ(1/d 1.5 ), we have | w2 | ≤ O(1/d).The estimations in Stage 1, mutatis mutandis, show that both α and | w2 | will grow in this case.D.3 CONVERGENCE RATERecall from Lemma D.5 that d dt L = -E w2,b2,w1 ∥∇ w2,b2,w1 ∥ 2 , where∇ w2,b2,w1 := E w 2 F (x) + b 2 )F (x) σ ′ (w 2 F (x) + b 2 ) 2 W2 (x)σ(w 1 • x) ∥w 1 ∥ W2 (x)σ ′ (w 1 • x)(I -20.Suppose that Induction Hypothesis D.1 is true at time t. Then we have Suppose that Induction Hypothesis D.1 is true at time t. Then we have∇ ≥ Ω(αL) -O δ 22 (Stage 2). Suppose that Induction Hypothesis D.1 is true throughout Stage 2. Then T 2 -T 1 ≤ O(d 3 /ε).Proof of Lemma D.20. Since it is the norm of ∇ w2,b2,w1 , we can safely ignore the last entry and only consider the first three entries. By Lemma D.7, we have[∇ w2,b2,w1 ] 1:3 = E x (f * (x) -f (x)) F (x) 1 2 w2 σ(w 1 • x) ± O d 3 δ (2) 2 α log(d) 1 w2 ∥w 1 ∥ log(d). Furthermore, we haveE x {(f * (x) -f (x))F (x)} = E x {(f * (x) -f (x))α ∥x∥} + E x (f * (x) -f (x))α( F (x) -∥x∥) = E x {(f * (x) -f (x))α ∥x∥} + O αδ (2) 1,L 2 .Meanwhile, for [∇ w2,b2,w1 ] 3 , by Lemma B.3 and Lemma D.8, we have

∥w 1 ∥ log(d). Now, we estimate the the expected norm of [∇ w2,b2,w1 ] 1:3 . First, we have[∇ w2,b2,w1 ]

Proof of Lemma D.21. For notational simplicity, put A := α 2 + 4CΓ √ D.1, ∇ ≤ O(1). Hence, in order to lower bound ∇ , it suffices to lower bound ∇, ∇ . We have∇, ∇ = A E x (f * (x) -f (x)) -∥x∥ + 1 -(α w2 ∥x∥ + b2 ) .

(x) -f (x)) -∥x∥ + 1 -(α w2 ∥x∥ + b2 ) = -E ∥x∥≥1 f (x) -∥x∥ + 1 -(α w2 ∥x∥ + b2 ) ≥ E ∥x∥≥1 f (x)(α w2 ∥x∥ + b2 ) .

Thus, for anyT ∈ [T 1 , T 2 ], L(T ) ≤ Ω d -3 (T -T 1 ) + 1 L(T 1 )

Many works have tried to address this issue to generalize mean-field analysis to deep networks. See e.g., Nguyen & Pham (2020); Pham & Nguyen (2021); Araújo et al. (2019); Sirignano & Spiliopoulos (2021); Fang et al. (2021); Lu et al. (2020); Ding et al. (2021) and references therein. Unlike most of the existing works, our multi-layer mean-field framework still has finite hidden feature dimension while the number of neurons can go to infinity to become a distribution of neurons. See Section 1.1 and Appendix A for more discussions.

and then generalize it to handle the infinite-width case by replacing the index sets [m 2 ], [m 1 ] by two general index sets I 2 , I 1 that can potentially be uncountable. For example, we can choose I 1 = I 2 = R. This is the strategy employed byNguyen & Pham (2020).(See Pham  & Nguyen (2021)  for a more accessible version of this paper.) The drawback of this formulation is that, with the introduction of index sets, the permutation invariance is no longer factored out. Though with this formulation, it is still possible to obtain global convergence results for infinitewidth networks, it become less useful when we want to analyze a finite-width network as it becomes essentially the same as the usual matrix formulation.

we are ready the control the error of the first layer. Lemma C.16. Suppose that Induction Hypothesis C.2 is true at time t and t ∈ [T 1.2 , T 1.3 ].

ACKNOWLEDGEMENT

This work is supported by NSF Award DMS-2031849, CCF-1845171 (CAREER), CCF-1934964 (Tripods) and a Sloan Research Fellowship.

annex

Lemma D.9. DefineThen, we partition the input space intoIn words, X 1 is the largest spherically symmetric set on which all second layer neurons are activated, and X 1 ∪ X 2 is the largest spherically symmetric set on which at least one second layer neuron is activated. Suppose that Induction Hypothesis D.1 is true at time t. Then the following hold.(a) f * vanishes on X 2 ∪ X 3 , i.e., R 1 ≥ 1, f vanishes on X 3 , andX2 . As a corollary, we haveX2 on X 2 .The above lemma implies thatWe formally establish this approximation in the following lemma. Lemma D.10 (Gradient of L 2 and L 3 ). Suppose that Induction Hypothesis D.1 is true at time t. Then, for each v 1 ∈ µ 1 , we haveNow, we are ready to derive the equation that governs the dynamics of F . Note that this Lemma implies that, at least approximately, the dynamics of F depends only on L 2 . Lemma D.11 (Dynamics of F ). Suppose that Induction Hypothesis D.1 is true at time t. Then, for each fixed x, we haveThen, we show that the signal term in d dt F (x) can only decrease the L 2 error, which is intuitively true as, after all, L 2 is the (rescaled) L 2 error. As a result, the L 2 error barely grows. Lemma D.12 (L 2 approximation error). Suppose that Induction Hypothesis D.1 is true at time t. Then we haveFinally, we show that the change F | S d-1 depends on the L 2 error. As a result, as long as the L 2 error is small, the L ∞ error cannot grow too fast. Lemma D.13 (L ∞ approximation error). Suppose that Induction Hypothesis D.1 is true at time t. Then, for any x ∈ S d-1 , we haveThe proofs of these lemmas are as follows.Proof of Lemma D.9.(a) This one follows directly from the construction of the partition and Induction Hypothesis D.1.(b) First, we writewhere the last equality comes from the fact f vanishes on {∥x∥ ≥ Ω(-b2 /(α| w2 |))}.Similarly, for any (v 2 , r 2 ) ∈ µ 2 , we have.Hence, for any R > 0 and x ∈ RS d-1 , we haveIn other words, R 1 ≥ R -δTmp -w2α and R 2 ≤ R + δTmp -w2α . Thus,To complete the proof, it suffices to invoke Lemma B.1.(c) Note that by the definition of R 2 , for any x 0 ∈ R 2 S d-1 , we have f (x 0 ) = 0. Hence, for any x ∈ X 2 , there exists some x 0 with f (x 0 ) = 0 and ∥x -Proof of Lemma D.10. Since both f * and f vanishes on X 3 , it suffices to consider X 1 and X 2 . Recall that that all second layer neurons are activated on X 1 . Hence,where the last equality comes from Corollary B.4. Now, we bound the influence of X 2 . Note that both ∇ v1 f (x) and ∇ v1 f (x) are bounded by O(| w2 | ∥v 1 ∥ ∥x∥). Recall from Lemma D.9 that f ≤ O(δX2 ). Therefore,The proof for ∇ v1 L 3 | X2 is the same.Proof. Let (v 2 , r 2 ), (v ′ 2 , r ′ 2 ) ∈ µ 2 be two second layer neurons. For notational convenience, we define h

D.2.3 REGULARITY CONDITIONS

As we have mentioned earlier, we will mainly use the continuity argument to maintain the regularity conditions, so the problem can be reduced into estimating the derivative on the boundary. As an example, suppose that b2 = 1 -δ for some small δ > 0. Then by Lemma D.15, which upper bounds the loss using 1 -b2 and -1 -w2 α, we know | -1 -w2 α| must be large, otherwise we would have L < ε. Then, we can use the fact that | -1 -w2 α| is large to estimate the derivative. The proof for the other regularity conditions is similar except the proof for | w2 |, which is in the same spirit with the ones for first layer errors. Lemma D.15. Suppose that Induction Hypothesis D.1 is true at time t. Then we haveLemma D.16. Suppose that Induction Hypothesis D.1 is true at time t and b2 = 1 -Θ( √ ε). Then, d dt b2 < 0. Lemma D.17. Suppose that Induction Hypothesis D.1 is true at time t andThe proofs of this subsubsection are gathered bellow.Proof of Lemma D.15. For any x ∈ R d , by Lemma D.7 and the Lipschitzness of σ, we have, for anyBy Induction Hypothesis D.1, for any x ∈ X 1 ∪ X 2 , we have Recall that δ X2 := O(1)d 4.5 (δ2 ). For simplicity, we choose δWe choose δ(2) 1,L 2 (T 1 ) and δNote that this is possible because δ1,L 2 (T 1 ) and δ(2)1,L ∞ (T 1 ) can be chosen to be arbitrarily polynomially small. When this is true, we haveThus, by induction, within O(d 3 /ε) amount of time, these two errors can at most O(δ1,L ∞ (T 1 )), respectively.

E FROM GRADIENT FLOW TO GRADIENT DESCENT

Converting the above gradient flow argument to a gradient descent one can be done in a standard one, provided that we can generate fresh samples at each iteration. First, by choosing a sufficiently small step size, one can make sure within each step, the difference between gradient descent and gradient flow is inverse polynomially small. Note that our argument is built upon the induction hypotheses. Hence, we do not need to worry about the accumulation of errors. Moreover, our estimations can tolerate an inverse polynomially large error. Then, at each step of gradient descent, we generate sufficiently (but still polynomially) many samples to ensure that with high probability, the difference between the population gradient and the finite-sample gradient is sufficiently small. Since it only takes polynomial iterations to finish the process, the total amount of samples needed is polynomial.

