PLATEAU IN MONOTONIC LINEAR INTERPOLATION -A "BIASED" VIEW OF LOSS LANDSCAPE FOR DEEP NETWORKS

Abstract

Monotonic linear interpolation (MLI) -on the line connecting a random initialization with the minimizer it converges to, the loss and accuracy are monotonic -is a phenomenon that is commonly observed in the training of neural networks. Such a phenomenon may seem to suggest that optimization of neural networks is easy. In this paper, we show that the MLI property is not necessarily related to the hardness of optimization problems, and empirical observations on MLI for deep neural networks depend heavily on the biases. In particular, we show that interpolating both weights and biases linearly leads to very different influences on the final output, and when different classes have different last-layer biases on a deep network, there will be a long plateau in both the loss and accuracy interpolation (which existing theory of MLI cannot explain). We also show how the last-layer biases for different classes can be different even on a perfectly balanced dataset using a simple model. Empirically we demonstrate that similar intuitions hold on practical networks and realistic datasets. We first formally define the linear interpolation between the network at initialization and the network after training. Then we describe the notations that we will use in the paper. Linear interpolation: Consider a network with parameters θ ∈ R p . Suppose the network is initialized with parameters θ (0) and it converges to θ (T ) . A linear interpolation is constructed by setting the parameters θ [α] = (1 -α)θ (0) + αθ (T ) for α ∈ [0, 1]. The loss interpolation curve is defined as γ loss (α) : [0, 1] → R such that γ loss (α) is the training loss of the network at θ [α] . Similarly, the error interpolation curve is defined as γ error (α) : [0, 1] → [0, 1] with γ error (α) as the training error of the network at θ [α] . Here, the training error is simply the ratio of training samples that get classified incorrectly by the network. We use [k] to denote the set {1, 2, • • • , k}. We use N (0, δ 2 ) to denote the Gaussian distribution with mean zero and variance δ 2 . We use ∥•∥ to denote the ℓ 2 norm for a vector or the spectral norm for a matrix. For any non-zero vector v, we use v to denote v/ ∥v∥ . We use O(•), Θ(•), Ω(•) to hide the dependency on constant factors and use O(•), Θ(•), Ω(•) to hide the dependency on poly-logarithmic factors. For any time t, we use θ (t) , f (t) to denote the parameters and the network at time t. For any α ∈ [0, 1], we use θ [α] , f [α] to denote the α interpolation point, which means θ [α] := (1 -α)θ [0] + αθ [T ] and f [α] is the network with parameters θ [α] .

1. INTRODUCTION

Deep neural networks can often be optimized using simple gradient-based methods, despite the objectives being highly nonconvex. Intuitively, this suggests that the loss landscape must have nice properties that allow efficient optimization. To understand the properties of loss landscape, Goodfellow et al. (2014) studied the linear interpolation between a random initialization and the local minimum found after training. They observed that the loss interpolation curve is monotonic and approximately convex (see the MNIST curve in Figure 1 ) and concluded that these tasks are easy to optimize. However, other recent empirical observations, such as Frankle (2020) observed that for deep neural networks on more complicated datasets, both the loss and the error curves have a long plateau along the interpolation path, i.e., the loss and error remain high until close to the optimum (see the CIFAR-10 curve in Figure 1 ). Does the long plateau along the linear interpolation suggest these tasks are harder to optimize? Not necessarily, since the hardness of optimization problems does not need to be related to the shape of interpolation curves (see examples in Appendix A). In this paper we give the first theory that explains the plateau in both loss and error interpolations. We attribute the plateau to simple reasons as the bias terms, the network initialization scale and the network depth, which may not necessarily be related to the difficulty of optimization. Note that there are many different theories for the optimization of overparametrized neural networks, in particular the neural tangent kernel (NTK) analysis (Jacot et al., 2018; Du et al., 2018; Allen-Zhu et al., 2019; Arora et al., 2019) and mean-field analysis (Chizat & Bach, 2018; Mei et al., 2018) . However they don't explain the plateau in both loss and error interpolations. For NTK regime, the network output is nearly linear in the parameters and the loss interpolation curve is monotonically decreasing and convex -no plateau in the loss interpolation. Mean-field regime often uses a smaller initialization on a homogeneous neural network (as considered in Chizat & Bach (2018) ; Mei et al. (2018) ). In this case, the interpolated network output is basically a scaled version of the network output at the minimum and has same label predictions -no plateau in the error interpolation curve. 1.1 OUR RESULTS Our theoretical results consist of two parts. In the first part (see Section 3), we give a plausible explanation for the plateau in both the loss and error curves. Claim 1 (informal). If a deep network has a relatively small initialization, and its last-layer biases are significantly different for different classes, then both the loss and error curves will have a plateau. The length of the plateau is longer for a deeper network. We formalize this claim in Theorem 1. For intuition, consider an r-layer neural network that only has bias on the last layer, and consider Xavier initialization (Glorot & Bengio, 2010) which typically gives small output and zero bias. If we consider the α-interpolation point (with coefficient α for the minimum and (1 -α) for the initialization), then the weight "signal" from the minimum scales as α r (as it is the product of r layers) while the bias scales as α. As illustrated in Figure 2 (right), when r is large and there is a difference in biases, the bias will dominate, which creates a plateau in error. For the loss, one can also show that the weight signal is near 0 for small α, so the network output is dominated by the biases and the loss cannot beat the random guessing at initialization. Note that this explanation for the plateau does not have any implication on the hardness of optimization. However, why would the last-layer biases be different for different classes, especially in cases when the biases are initialized as zeros and all classes are balanced? In the second part (see Section 4), we focus on a simple model that we call r-homogeneous-weight network. This is a two-layer network whose i-th output is ⟨W i,: , x⟩ r + b i , where x ∈ R d is the network input, W i,: ∈ R d is the weight vector and b i ∈ R is the bias (see Figure 2 (left)). Our simple model simulates a depth-r ReLU/linear network with bias on the output layer, in the sense that the signal is r-homogeneous while the bias is 1-homogeneous in the parameters. Under this model we can show that: Claim 2 (informal). For the r-homogeneous-weight network on a simple balanced dataset, the class that is learned last has the largest bias. Here, a class is learned when all the samples in this class get classified correctly with good confidence. We basically show that once a class gets learned, the bias associated with this class starts decreasing and eventually the class that is learned last has the largest bias. We formalize this claim in Theorem 2. In Section 5, we verify these ideas empirically on fully-connected networks for MNIST (Deng, 2012) , Fashion-MNIST (Xiao et al., 2017) and on VGG-16 (Simonyan & Zisserman, 2014) for CIFAR-10, CIFAR-100 (Krizhevsky et al., 2009) . We first show that if we train a neural network without using any bias, then the error curve has much shorter plateau or no plateau at all. Even for networks that are trained normally with biases, we design a homogeneous interpolation scheme for biases to make sure that both biases and weights are r-homogeneous. Such an interpolation indeed significantly shortens the plateau for the error. We also show that decreasing the initialization scale or increasing the network depth can produce a longer plateau in both the error and loss curves. Finally, we show that the bias is correlated with the ordering in which the classes are being learned for small datasets, which suggests that even though the model we consider in the convergence analysis is simple, it captures some of the behavior in practice. 

1.2. RELATED WORKS

There are two major lines of work studying interpolation between different points for neural networks, one on monotonic linear interpolation that interpolates the initial network and the learned network, and the other on mode connectivity that connects two learned networks. Monotonic linear interpolation. Goodfellow et al. (2014) first studied the linear interpolation between the network at initialization and the network after training on MNIST. Frankle (2020) extended the experiments to modern networks on CIFAR-10 and ImageNet and found that though the loss/error is still monotonically non-increasing along the path, it remains high until close to the optimum. Lucas et al. (2021) showed that MLI holds when the network output curve along the interpolation path is close to linear (measured by Gaussian length). However, the Gaussian length can only be formally controlled in the NTK regime. Mode connectivity. Mode connectivity considers the interpolation between two learned networks (modes) found by SGD. In general, a linear interpolation between two different local minima crosses regions of high loss (Goodfellow et al., 2014) . Surprisingly, Draxler et al. (2018) and Garipov et al. (2018) observed that local minima found by SGD from different initializations can be connected via a piece-wise linear path of low loss. Frankle et al. (2020) and Fort et al. (2020) observed that local minima trained from the same initialization can also be connected using a linear path. Freeman & Bruna (2016); Venturi et al. (2018) ; Nguyen (2019; 2021) ; Kuditipudi et al. (2019) ; Shevchenko & Mondelli (2020) ; Nguyen et al. (2021) gave several theoretical explanations for this phenomenon.

3. PLATEAU FOR LOSS AND ERROR INTERPOLATIONS

We prove that the long plateau exists in the loss and error curves when the initialization is small and the network is deep on fully-connected networks. The detailed proof can be found in Appendix B.3. We consider an r-layer fully-connected neural network with r at least three. Given input x ∈ R n0 , the network output is g(x) := V r σ (V r-1 • • • σ(V 1 x) • • •) + b, where For the biases, we initialize them as zeros and assume after training there exists a gap between the largest bias and the second largest, which also holds empirically (see Figure 8 ). Note this bias gap is essential for the plateau in the error interpolation. If all the biases are equal in the trained network, the logits for different classes only differ by the weights signal and the interpolated network has same label predictions as the trained network. V i ∈ R nr×nr-1 for each layer i ∈ [ Assumption 1 (Bias Gap). Choosing i * ∈ arg max i∈[k] b (T ) i , we have b (T ) i * -max i̸ =i * b (T ) i > 0. Without loss of generality, we assume that b (T ) k > max i∈[k-1] b (T ) i . We denote ∆ min := b (T ) k - max i∈[k-1] b (T ) i and ∆ max := b (T ) k -min i∈[k-1] b (T ) i . Then, we show both the loss and error interpolation curves have a long plateau in Theorem 1. Theorem 1. Suppose the network is defined as in Equation (1) and suppose the weights satisfy V (0) i ≤ δ, V (T ) i ≤ V max for all layers i ∈ [r]. On a k-class balanced dataset whose inputs have ℓ 2 norm at most 1, if Assumption 1 holds, for any ϵ > 0, as long as δ < min ϵ 1/r r , 1 r 2 , 1 2e 2 r-2 , there exist α 1 = δ ∆min , α 2 = 1 1+ √ δ r r-1 ∆min 2V r max 1 r-1 and α 3 = ϵ 1/r Vmax such that 1. for all α ∈ [α 1 , α 2 ], the error is 1 -1/k; 2. for all α ∈ [0, α 3 ], we have log k -2eϵ ≤ 1 N L V [α] i , b [α] ≤ log k + α∆ max + 2eϵ, where N is the number of training examples. The above theorem shows that for all α ∈ [α 1 , α 2 ], the error remains at 1 -1/k that is the same as random guessing. We skip the very short initial region [0, δ ∆min ] since the bias is very small and the error can be unpredictable due to the randomness in initial weights. When initialization scale δ is small, this error plateau region is roughly [0, ( ∆min Intuitively, the plateau in error curve is there because for a small initialization, the output is close to (T ) . When α is not large enough α r is much smaller than α, so for every class i ̸ = k, the first term (signal part) cannot overcome the bias gap α(b α r V (T ) r σ V (T ) r-1 • • • σ(V (T ) 1 x) • • • + αb (T ) k -b (T ) i ). This implies that all samples are predicted as class k and the error is 1 -1/k. We also show that the average loss cannot be lower than log k -2eϵ when α ≤ ϵ 1/r Vmax . Note a small random initialization can achieve a loss of approximately log k. Usually the bias gap ∆ max in practice is not very large, so the loss curve remains nearly flat during this interpolation region. Again, the loss plateau is becoming longer when depth r increases. This is because the weights signal remains near 0 for a larger range of α.

4. TRAINING DYNAMICS FOR CREATING A BIAS GAP

In this section, we explain how the gradient flow dynamics generates a bias gap on a balanced dataset by analyzing a simple model. Below, we first define the network model, training dataset and optimization procedure for our analysis. r-homogeneous-weight network: We consider a two-layer and k-output neural network with activation function σ(z) := z r , where r is a positive constant that is at least three. As illustrated in Figure 2 (left), under input x ∈ R d , the i-th output f i (x) is ⟨W i,: , x⟩ r + b i , where the weight vector W i,: ∈ R d is the i-th row of weight matrix W ∈ R k×d and b i ∈ R is the i-th entry of vector b ∈ R k . In output f i (x), we call ⟨W i,: , x⟩ r the signal since it is input-dependent and call b i the bias.

Dataset:

We consider a k-class balanced dataset, with k as a constant. We denote the whole dataset as S and denote the subset for each class i ∈ [k] as S i . Each subset S i has exactly N/k samples and each sample x ∈ R d is independently sampled as v i + ξ, where the noise ξ ∼ N (0, σfoot_1 d I). To differentiate the noise terms among different samples, we denote the noise associated with sample x as ξ x . We assume all v i 's are orthonormal; without loss of generality, we assume v i = e i for each class i. Here, we assume the orthogonal features to facilitate the convergence analysis beyond the NTK regime, following previous works (Allen-Zhu & Li, 2020; Ge et al., 2021) . Optimization: We initialize each entry in weight matrix W by independently sampling from Gaussian distribution N (0, δ 2 ) and then taking the absolute valuefoot_0 . Our analysis can be trivially generalized to standard Gaussian initialization (without taking absolute value) when r is an even integer. We initialize all bias terms as zeros. We use cross-entropy loss L(W, b) = i∈[k] x∈Si -log exp(fi(x)) j∈[k] exp(fj (x)) , and run gradient flow on k N L(W, b) for time T. Our analysis can also be extended to gradient descent with a small step size. Next we show that running gradient flow from a small initialization can converge to a model with zero error and constant bias gap. Theorem 2. Suppose the neural network, dataset and optimization procedure are as defined in Section 4. Suppose initialization scale δ ≤ Θ(1), noise level σ ≤ Θ(1), dimension d ≥ Θ(1/δ 2r-2 ) and number of samples N ≥ Θ(1/δ r-1 ), with probability at least 0.99 in the initialization, there exists time T = Θ(log(1/δ)/δ r-2 ) such that we have 1. zero error: for all different i, j ∈ [k] and for all x ∈ S i , f (T ) i (x) ≥ f (T ) j (x) + Ω(1); 2. bias gap: b (T ) i * -max i̸ =i * b (T ) i ≥ Ω(1) with i * = arg max i∈[k] b (T ) i . Due to space limit, we only give a proof sketch here and leave the detailed proof in Appendix C. Since our dataset is perfectly balanced, it might seem surprising that gradient flow learns diverse biases. We can compute the time derivative on the bias, ḃi = 1 -k N x∈S u i (x), where u i (x) is the softmax output for class i, that is exp(fi(x)) i ′ ∈[k] exp(f i ′ (x)) . At the beginning, all logits are small, we have u i (x) ≈ 1/k and ḃi ≈ 0. If all the samples are learned at the same time, we have u i (x) ≈ 1, u i (x ′ ) ≈ 0 for x ∈ S i , x ′ ∈ S \ S i , which again leads to ḃi ≈ 0. On the other hand, we can consider what happens if all samples in one class (e.g., class i) are learned before any sample in any other class (e.g., class j) is learned 2 . In this case we have ḃi = 1 - k N x∈Si u i (x) - k N x∈S\Si u i (x) ≈ 1 - k N • N k • 1 - k N • N (k -1) k • 1 k = - k -1 k , ḃj = 1 - k N x∈Si u j (x) - k N x∈S\Si u j (x) ≈ 1 - k N • N k • 0 - k N • N (k -1) k • 1 k = 1 k , where for any learned sample x ∈ S i , we have u i (x) ≈ 1, u j (x) ≈ 0; for any not yet learned sample x ∈ S \ S i , we have u i (x), u j (x) ≈ 1/k. The above calculation shows that b i starts to decrease and all the other bias terms increase. Generalizing this intuition, we show that b i ′ starts to decrease whenever class i ′ is learned, and the class that is learned last will have the largest bias. As the weights are initialized randomly, by standard anti-concentration, one can argue that there is a gap between W (0) i,i 's. Without loss of generality, we assume 0 𝑡 𝑏 ( 𝑏 % 𝑏 ) 𝑏 * 𝑡 % 𝑠 ( 𝑡 ( 𝑠 * 𝑡 * 𝑠 ) 𝑠 % 0 𝑡 𝑡 % 𝑠 ( 𝑡 ( 𝑠 * 𝑡 * 𝑠 ) 𝑠 % 𝑊 ),) 𝑊 %,% 𝑊 *,* 𝑊 (,( W (0) 1,1 > W (0) 2,2 > • • • > W (0) k,k . The initial difference in the weights will lead to different classes being learned at different time. We show that by doing induction on the following hypothesis through training: Proposition 1 (Induction Hypothesis). In the same setting of Theorem 2, with probability at least 0.99 in initialization, there exist time points 0 =: s 1 < t 1 < s 2 < t 2 < • • • < s k-1 < t k-1 < s k := T with t i -s i = Θ(log(1/δ)/δ r-2 ) and s i+1 -t i = Θ(1) for i ∈ [k-1] such that for any t ∈ [s i , s i+1 ], 1. (classes not yet learned) for any class j, j ′ ≥ i+1, we have (1) b (t) j ≥ max i ′ ∈[k] b (t) i ′ -O(δ r ), (2) b (t) j -b (t) j ′ ≤ O(δ r ) and (3) W (t) j,j ≤ O(δ); 2. (classes already learned) for any class j ≤ i -1, we have (1) b (t) j ≤ max i ′ ∈[k] b (t) i ′ -Ω(1), (2) f (t) j (x) ≥ f (t) i ′ (x) + Ω(1) for i ′ ̸ = j, x ∈ S j and (3) W (t) j,j ≥ Ω(1); 3. (parameters movement) (1) for any j ∈ [k], Θ(δ) = W (0) j,j < W (t) j,j , (2) for any distinct j, j ′ ∈ [k], 0 < W (t) j,j ′ ≤ O(δ) and (3) for any j, j ′ ∈ [k] and any x ∈ S j ′ , W (t) j,: , ξ x ≤ min O(δ), W (t) j,j ′ . This proposition shows that gradient flow learns k classes one by one, from class 1 to class k. More precisely, each class i is learned during time [s i , s i+1 ]. All the not yet learned classes j ≥ i + 1 have close to maximum biases and their weights W (t) j,j 's are small. All the already learned classes j ≤ i -1 have small biases and large weights W (t) j,j 's. For the parameters movement, we know that all the diagonal entries W (t) j,j 's are larger than the initialization and all the off-diagonal entries W (t) j,j ′ 's are only O(δ). The correlation between the weights and noise terms also remains small. is the largest bias. See an illustration of this learning process in Figure 3 . Although we consider a simple neural network and data distribution, the analysis for the training dynamics is still non-trivial. There are three major challenges in our proof: (1) How to ensure that class i + 1 is learned much later than class i? (2) For any class j that has not been learned, how to maintain that its bias is close to the maximum? (3) For any learned class j, how to maintain the large bias gap from the top bias? Next, we give the proof ideas for these questions. Since all the off-diagonal entries and correlations with noise terms in W (t) are negligible, in our proof we can essentially focus on the movement of W Lower bounding s i+1 -s i . During time [s i , t i ], the dynamics of W (t) i,i is similar as in the tensor power method (Allen-Zhu & Li, 2020; Ge et al., 2021) . The initial gap between W (0) i,i and W (0) j,j ensures that when W (ti) i,i rises to a small constant, W (ti) j,j is still O(δ) for all j ≥ i + 1. Then after constant time s i+1 -t i , W si+1 j,j is still O(δ) since the increasing rate of W (t) j,j is merely O(δ r-1 ). Bias for classes that are not learned. For j ≥ i + 1, we maintain that b (t) j ≥ max i ′ ∈[k] b (t) i ′ -O(δ r ). First, we use the below lemma to show biases for any two classes j, j ′ ≥ i + 1 remain close. Lemma 1 (Coupling Biases). Assuming W j ′ ,j ′ , W j,j ≤ O(δ) and b j ′ , b j ≥ max i ′ ∈[k] b i ′ -O(δ r ), we have ḃj ′ -ḃj > 0 if b j ′ -b j ≤ -µδ r , and ḃj ′ -ḃj < 0 if b j ′ -b j ≥ +µδ r for some positive constant µ. Second we show that any already learned or being learned class j ′ ≤ i cannot have bias much larger than any class j ≥ i + 1 not yet learned. Lemma 2 (Bias Gap Control I). For any different j ′ , j ∈ [k], if W j ′ ,j ′ ≥ W j,j , W j,j ≤ O(δ) and b j ′ -b j ≥ O(δ r ), b j ≥ max i ′ ∈[k] b i ′ -O(δ r ), we have ḃj ′ -ḃj < 0. Bias for learned classes. At time s j+1 , we can prove that 1 -u (sj+1) j (x) ≤ C 1 for all x ∈ S j and b (sj+1) j -b (sj+1) k ≤ -C 2 . According to the below lemma, we can ensure that b (t) j -b (t) k ≤ -C 2 for any t ≥ s j+1 . Lemma 3 (Bias Gap Control II). There exist small positive constants C 1 , C 2 such that for any j ∈ [k -1] and any x ∈ S j , if 1 -u j (x) ≤ C 1 , W k,k ≤ O(δ) and b j -b k ≥ -C 2 , we have ḃj -ḃk < -Ω(1).

4.1. PLATEAU AND MONOTONICITY FOR r-HOMOGENEOUS-WEIGHT NETWORK

Now assuming the network at initialization and after training satisfies the properties described in Theorem 2 and Proposition 1, we can prove a tighter bound on the plateau region and also show the monotonicity in error and loss curve. See the complete proofs in Appendix B.1 and Appendix B.2. Same as in Assumption 1, we use ∆ i to denote the bias gap b (T ) k -b (T ) i for i ∈ [k -1] and de- note ∆ min := min i∈[k-1] ∆ i and ∆ max = max i∈[k-1] ∆ i . For the weights, we denote W min = min i∈[k-1] W (T ) i,i and W max = max i∈[k] W (T ) i,i . We denote R min = min i∈[k-1] ∆ i /[W (T ) i,i ] r , R max = max i∈[k-1] ∆ i /[W (T ) i,i ] r . Below, we show the plateau and monotonicity of loss and error interpolations in Theorem 3. Theorem 3. Suppose the neural network, dataset and optimization procedure are as defined in Section 4. Suppose the network at initialization and after training satisfies the properties described in Theorem 2 and Proposition 1. For any ϵ ∈ (0, 1), sup- pose δ ≤ min(Θ(ϵ 1/r ), Θ(R 1 r-1 min ∆ 1/r min ), Θ(( Wmin Wmax ) 2r r-2 )). There exist α 1 = δ ∆min , α 2 = ( 1 1+O( √ δ) ) r r-1 R 1 r-1 min , α 3 = ϵ 1/r Wmax and α 4 = (1 + O(δ)) 1 r-1 Rmax r 1 r-1 such that 1. for all α ∈ [α 1 , α 2 ], the error is 1 -1/k; for all α ∈ [α 1 , 1], the error is non-increasing; 2. for all α ∈ [0, α 3 ], we have log k -eϵ ≤ 1 N L(W [α] , b [α] ) ≤ log k + α∆ max + eϵ; for all α ∈ [α 4 , 1], the loss is strictly monotonically decreasing. The analysis for the plateau is very similar as in Theorem 1 since a r-homogeneous-weight network is similar to a depth-r fully connected neural network with only last-layer biases in the sense that the weights are r-homogeneous while the (last-layer) bias is 1-homogeneous. For the error plateau, we prove a tighter bound on the right boundary α 2 than in Theorem 1. We also show the error is nonincreasing for α ∈ [δ/∆ min , 1] by arguing that once a sample is correctly classified at interpolated point α ′ ≥ δ/∆ min , it will remain so for any α ≥ α ′ . Similar as in Theorem 1, we can show that the loss is no smaller than log k -eϵ when α ≤ ϵ 1/r Wmax . To show the monotonicity of loss after α 4 , we show that f [α] i (x) -f [α] j (x) is increasing in α for i ̸ = j and x ∈ S i . In summary, for α ∈ [α 1 , α 2 ], the signal is smaller than the bias gap and the error remains at 1 -1/k. Before α 3 , the signal is very small and the loss remains large; after α 4 , the signal starts to overcome the bias gap and the loss is decreasing. See an illustration in Figure 2 (right).

5. EXPERIMENTS

In this section we empirically show that intuitions from our simple theoretical model can also be applied to more realistic datasets and architectures. First, we show that bias plays an important role in creating the plateau in the error interpolation, as predicted by Theorem 1. We then demonstrate the influence of initialization size and network depth (also see Theorem 1). Finally, we show that similar to Proposition 1 the class that is learned last often has larger bias. Due to space constraint, we only show the results on MNIST and CIFAR-100 in this section, while similar results also hold on Fashion-MNIST and CIFAR-10 (see Appendix D). Unless specified otherwise, we use a depth-10 and width-1024 fully-connected ReLU neural network (FCN10) for MNIST and use VGG-16 (without batch normalization) for CIFAR-100. We use Kaiming initialization (He et al., 2015) for the weights and set all bias terms as zeros. For FCN10 on MNIST, we use a small initialization by scaling the weights of each layer by (0.001) 1/10 so the output is scaled by 0.001. We train each network using SGD for 100 epochs. See more experiment settings in Appendix D. We linearly interpolate using 50 evenly spaced points between the network at initialization and the network at the end of training. We evaluate error and loss on the train set. For each setting, we repeat the experiments three times from different random seeds and plot the mean and deviation. Role of bias in creating plateau. We demonstrate the importance of bias using two experiments. In the first experiment, we compare the loss/error interpolation curves between networks equipped with bias for all the layers (all bias), with bias only for the output layer (last bias), and with no bias at all (no bias). Figure 4 shows that networks with all bias and last bias have a much longer error plateau than networks without bias. Three bias settings have similar loss interpolation curves. By our theory, the bias dominates the signal at the beginning of the interpolation because the bias term scales as α while the signal scales as α r . In the second experiment, to correct this discrepancy, we interpolate the bias at the h-th layer (input is at the 0-th layer) as b [α] h = (1 -α) h b (0) h + α h b (T ) h = α h b (T ) h . We call this the homogeneous interpolation as now terms involving bias and weights all have α r coefficients. We compare this with the normal interpolation that linearly interpolates the bias terms. Figure 5 shows that for networks with all bias or last bias, using homogeneous interpolation can significantly reduce the plateau in the error interpolation, but does not affect the loss interpolation. Role of initialization scale and network depth. Our theory suggests that with a smaller initialization, the signal magnitude at the initial interpolation is smaller, which can create longer plateau in both loss interpolation and error interpolation. We compare networks under initialization scales 1, 0.1, 0.01 and 0.001, where scale 1 corresponds to the standard Kaiming initialization. For other initialization β, we rescale each layer by the same factor so the output is rescaled by β. According to Figure 6 , smaller initialization does create longer plateau in loss and error interpolation. With a deeper network, the signal grows slower at the initial interpolation phase, which can potentially create a longer plateau in both loss interpolation and error interpolation. We compare FCN4, FCN6, FCN8, FCN10 on MNIST and compare VGG11, VGG13, VGG16, VGG19 on CIFAR-100. According to Figure 7 , deeper networks do have longer plateau in loss and error interpolation. 

6. CONCLUSION

Our theory suggests that the plateau in loss/error interpolation curves may be attributed to simple reasons, and it's unclear if these reasons are related to the difficulty/easiness of optimization. In our experiments although the training succeeds in all the settings, the loss and error interpolation curves can be easily manipulated by changing the initialization size, network depth and bias terms. Therefore, we believe one needs to look at structures more complicated than linear interpolation to understand why optimization succeeds for deep neural networks. Though our theory requires a small initialization, we also observe plateau in CIFAR-100 with standard initialization, which suggests that the useful signal is still a high order term in α. We also observe that sometimes the ordering of the biases does not exactly follow the ordering of the learning. We believe this is partially due to the correlation between different-class features and offer a preliminary explanation in Appendix D.5. We leave the thorough study of these problems in the future work.

ACKNOWLEDGEMENT

This work is supported by NSF Award DMS-2031849, CCF-1845171 (CAREER), CCF-1934964 (Tripods) and a Sloan Research Fellowship.

REPRODUCIBILITY STATEMENT

For our theoretical results, we listed all the assumptions and stated the theorems in the main text and we left the complete proof for all the claims in the Appendix. For our experimental results, we defined the detailed experiment settings in the Appendix and also uploaded the source code as supplementary material.

A EXAMPLES FOR THE DISCONNECTION BETWEEN LINEAR INTERPOLATION SHAPE AND OPTIMIZATION DIFFICULTY

We give two examples that illustrate the disconnection between the linear interpolation shape and the optimization difficulty. In Section A.1, we show a function that is NP-hard to optimize, but has a convex and monotonically decreasing loss interpolation. Then in Section A.2, we give a function that is easy to optimize, but has a non-monotonic loss interpolation.

A.1 HARD FUNCTION WITH CONVEX LOSS INTERPOLATION

For any symmetric third-order tensor T ∈ R d×d×d , our goal is to minimize f (x, z) = -T (x, x, x) + ∥x∥ 4 + z 4 (2) where x ∈ R d and z ∈ R. It's known that finding the spectral norm of a symmetric third-order tensor (that is, Hillar & Lim, 2013) . We prove that minimizing f (x, z) is also NP-hard by reducing the tensor spectral norm problem to it. Proposition 2. Minimizing f (x, z) as defined in Eqn. 2 is NP-hard. max v∈R d ,∥v∥=1 T (v, v, v)) is NP-hard ( Proof. For any non-zero tensor T, let (x * , z * ) be one minimizer of f (x, z), it's easy to verify that T (x * , x * , x * ) > 0. We show that x * := x * / ∥x * ∥ must be a solution to max v∈R d ,∥v∥=1 T (v, v, v) . For the sake of contradiction, assume there exists v * with unit norm such that T (v * , v * , v * ) > T (x * , x * , x * ). It's easy to verify that f (∥x * ∥ v * , z * ) < f (x * , z * ), which however contradicts the optimality of (x * , z * ). Next, we prove that start from certain initialization, the loss along the linear interpolation path is convex and monotonically decreasing. Note that assuming the unit Frobenius norm of T does not hurt the NP-hardness of the problem. And our initialization is oblivious of the tensor T. Proposition 3. Assume ∥T ∥ F = 1. Suppose we start from initialization (x 0 , z 0 ) with x 0 = 0 and |z 0 | > 3 √ 2 4 . Let (x * , z * ) be a minimizer of f (x, z) as defined in Eqn. 2. We know the loss interpolation curve γ(α) := f ((1 -α)x 0 + αx * , (1 -α)z 0 + αz * ) is convex and monotonically decreasing for α ∈ [0, 1]. Proof. We first prove that at any minimizer (x * , z * ), we must have z * = 0. Otherwise, we can set z as zero to further decrease the loss. Starting from an initialization (x (0) , z (0) ) with x (0) = 0, we know at each interpolation point x [α] = αx * , z [α] = (1 -α)z (0) . Therefore, we have γ(α) = f (x [α] , z [α] ) = -T (x [α] , x [α] , x [α] ) + x [α] 4 + z [α] 4 = -T (αx * , αx * , αx * ) + ∥αx * ∥ 4 + (1 -α)z (0) 4 = -α 3 T (x * , x * , x * ) + α 4 ∥x * ∥ 4 + (1 -α) 4 z (0) 4 . To prove the convexity of γ(α) for α ∈ [0, 1], we only need to prove γ ′′ (α) > 0 for α ∈ [0, 1]. We have γ ′′ (α) = -6αT (x * , x * , x * ) + 12α 2 ∥x * ∥ 4 + 12(1 -α) 2 z (0) 4 = -6α ∥x * ∥ 3 T (x * , x * , x * ) + 12α 2 ∥x * ∥ 4 + 12(1 -α) 2 z (0) 4 . Since the formula for γ ′′ (α) involves both T (x * , x * , x * ) and ∥x * ∥ , we first figure out the relation between these two quantities. Suppose T (x * , x * , x * ) = p > 0, it's not hard to find ∥x * ∥ must be equal to 3p 4 . This is because -∥x * ∥ 3 p + ∥x * ∥ 4 is minimized when ∥x * ∥ = 3p 4 . Next, we prove γ ′′ (α) > 0 for α ∈ (2/3, 1] and α ∈ [0, 2/3] separetely. When α ∈ (2/3, 1], we have 12α 2 ∥x * ∥ 4 > 6pα ∥x * ∥ 3 = 6α ∥x * ∥ 3 T (x * , x * , x * ). Therefore, we know γ ′′ (α) > 0. When α ∈ [0, 2/3], we know 12(1 -α) 2 z (0) 4 ≥ 4 3 z (0) 4 . Since ∥T ∥ F = 1, we know T (x * , x * , x * ) ≤ 1 and ∥x * ∥ ≤ 3/4. Therefore, we have 6α ∥x * ∥ 3 T (x * , x * , x * ) ≤ 6 • 2 3 • 3 4 3 • 1 = 27 16 . Then, we know that if z (0) >foot_2 √ 2 4 , we have γ ′′ (α) > 0.

A.2 EASY FUNCTION WITH NON-MONOTONIC LOSS INTERPOLATION

In this section, we give an easy-to-optimize function that however has a non-monotonic loss interpolation curve. We consider the following loss function f (x, y) =    0 if x = y = 0 1 - y 3 √ x 2 +y 2 x 2 + y 2 2 -2 x 2 + y 2 otherwise, where x, y ∈ R. We can also re-parameterize f (x, y) using angle θ ∈ [0, 2π) and length r ∈ [0, ∞) as h(θ, r) = 1 -sin(θ) 3 r 4 -2r 2 . Next, we prove that starting from any non-zero point, gradient flow converges to the global minimizer. Proposition 4. Starting from any non-zero initialization, gradient flow on f (x, y) as defined in Eqn. 3 converges to the global minimizer (0, -1). Proof. We know the unique minimizer of f (x, y) is (0, -1) by considering its equivalent form h(θ, r). For h(θ, r) = 1 -sin(θ) 3 r 4 -2r 2 , we know r 4 -2r 2 is minimized at r = 1 and 1 -sin(θ) 3 is maximized at θ = 3π 2 . Besides the minimizer (0, -1), the other stationary point is at (0, 0). For any point (x, y) different from (0, -1) and (0, 0), if x 2 + y 2 ̸ = 1, the gradient along the radial direction is non-zero; if y √ x 2 +y 2 ̸ = -1, the gradient along the tangent direction is non-zero. It's also easy to verify that starting from a non-zero point, gradient flow does not converge to (0, 0), so it must converge to (0, -1) It's also very easy to prove that gradient descent with appropriate step size converges to an ϵneighborhood of the global minimizer within poly(1/ϵ) number of iterations. This is because the gradient is at least poly(ϵ) for any non-zero point outside of the ϵ-neighborhood of the global minimizer. Starting from an initialization (x, y) with x 2 + y 2 = Θ(1), the smoothness along the training is also bounded by a constant. Next, we prove that starting from certain initialization 3 , the loss interpolation between the initialization and the global minimizer is non-monotonic. We prove this by identifying two points along the interpolation path such that the point closer to minimizer has a higher loss compared with the point further to the minimizer. Published as a conference paper at ICLR 2023 Proposition 5. Suppose we start from an initialization (x 0 , y 0 ) = (r sin(β), r cos(β)) with r ≥ 1 and β ∈ [-π/3, π/3]. Consider the loss interpolation curve γ(α) = f ((1-α)x 0 +αx * , (1-α)y 0 +αy * ) with (x * , y * ) = (0, -1) and f (•, •) defined in Eqn. 3. We know there exist 0 ≤ α 1 < α 2 ≤ 1 such that γ(α 2 ) -γ(α 1 ) ≥ 5 32 . Proof. We prove for any β ∈ [-π/3, π/3] and any r ≥ 1, the loss interpolation between (r sin(β), r cos(β)) to (0, -1) is non-monotonic. In particular, we show there are two points along the linear interpolation satisfying f (sin (β/2) cos (β/2) , -sin (β/2) sin (β/2)) -f (sin(β), cos(β)) ≥ 1/12, where (sin (β/2) cos (β/2) , -sin (β/2) sin (β/2)) is the middle point between (sin(β), cos(β)) and (0, -1). Next, we separately upper bound f (sin(β), cos(β)) and lower bound f (sin (β/2) cos (β/2) , -sin (β/2) sin (β/2)). We have max β∈[-π/3,π/3] f (sin(β), cos(β)) ≤ f (0, 1) = - 2 3 and min β∈[-π/3,π/3] f (sin (β/2) cos (β/2) , -sin (β/2) sin (β/2)) ≥f (sin (π/6) cos (π/6) , -sin (π/6) sin (π/6)) = 1 + 1 2 • 1 3 1 2 4 -2 1 2 2 = - 49 96 . Therefore, we have f (sin (β/2) cos (β/2) , -sin (β/2) sin (β/2)) -f (sin(β), cos(β)) ≥ 5 32 .

B PROOF FOR PLATEAU AND MONOTONICITY

We first consider the r-homogeneous-weight model. We prove the plateau and monotonicity properties for the error interpolation (Theorem 4) in Section B.1. We then prove the plateau and monotonicity properties for the loss interpolation (Theorem 5) in Section B.2. Theorem 3 is a simple combination of Theorem 4 and Theorem 5. Finally, we give the plateau analysis for the fully-connected neural networks (Theorem 1) in Section B.3.

B.1 ERROR INTERPOLATION FOR r-HOMOGENEOUS-WEIGHT MODEL

Theorem 4 (Error Interpolation). Suppose the network at initialization and after training satisfy the properties described in Theorem 2 and Induction Hypothesis 1. Suppose δ ≤ min(O(1), O(R 1 r-1 min ∆ 1/r min ), O(( Wmin Wmax ) 2r r-2 )). There exist α 1 = δ ∆min and α 2 = ( 1 1+O( √ δ) ) r r-1 R 1 r-1 min , such that 1. for all α ∈ [α 1 , α 2 ], the error is 1 -1/k; 2. for all α ∈ [α 1 , 1], the error is non-increasing. Proof of Theorem 4. This theorem directly follows from Lemma 4 and Lemma 5. □ Next, we separately prove the initial plateau in Lemma 4 and the monotonicity in Lemma 5. Lemma 4 (Error Plateau). In the same setting as in Theorem 4, there exists α 1 = δ ∆min and α 2 = 1 1+O( √ δ) r r-1 R 1 r-1 min , such that for any interpolation point with α ∈ [α 1 , α 2 ], the error is 1 -1/k. Moreover, we have f [α] i (e j ) < f [α] k (e j ) for all j ∈ [k] and all i ̸ = k. In the proof of Lemma 4, we show that for interpolation point α ∈ [α 1 , α 2 ], the bias term dominates and all samples are classified as class k that has the largest bias. Proof of Lemma 4. We only need to show that for all α ∈ [α 1 , α 2 ], we have f [α] i (x) < f [α] k (x) for all x ∈ S and all i ̸ = k, which immediately implies the error is 1 -1/k. Without loss of generality, assume x ∈ S j where j may equal i or k. For α ∈ α 1 , √ δ Wmin . If α 1 = δ ∆min ≥ √ δ Wmin , we only need to consider the case when α ∈ √ δ Wmin , α 2 . So here we assume δ ∆min < √ δ Wmin . We can lower bound f [α] k (x) -f [α] i (x) as f [α] k (x) -f [α] i (x) = W [α] k,: , x r + b [α] k -W [α] i,: , x r -b [α] i = W [α] k,: , x r + b [α] k -W [α] i,j ± O(δ) r -b [α] i ≥α∆ i -W (0) i,j + αW (T ) i,j + O(δ) r , where the second equality uses W  [α] i,: , ξ x ≤ O(δ) and the inequality uses W [α] k,: , x ≥ 0. To prove f [α] k (x) -f [α] i (x) > 0 for α ∈ δ ∆min , √ δ Wmin , we only need to prove δ ∆min ∆ i - W (0) i,j + √ δ Wmin W (T ) i,j + O(δ) r > 0. Since ∆ i ≥ ∆ min , we know δ ∆min ∆ i ≥ δ. Due to full ac- curacy, we know W (T ) i,: , x ≥ ∆ 1/r i for x ∈ S i , which then implies W (T ) i,i ≥ Ω ∆ 1/r i because ∆ i ≥ Ω(1) and W (T ) i,: , ξ x ≤ O(δ) ≤ O(1). Since W (T ) i,i ≥ Ω ∆ i,j + √ δ W min W (T ) i,j + O(δ) r ≤ O(δ) + √ δW max W min r ≤ √ δW max rW min + √ δW max W min r ≤e √ δW max W min r , where the second inequality assumes δ ≤ O Wmin , α 2 . Similar as above, we only need to show that α∆ i - W 2 max W 2 min . Therefore, to prove δ ∆min ∆ i - W (0) i,j + √ δ Wmin W (T ) i,j + O(δ) W (0) i,j + αW (T ) i,j + O(δ) r > 0 for i ̸ = k and j ∈ [k]. Since W (0) i,j ≤ O(δ) and α ≥ √ δ/W min , we have W (0) i,j ≤ O( √ δαW min ). Therefore, we have W (0) i,j + αW (T ) i,j + O(δ) ≤ 1 + O( √ δ) αW (T ) i,i . Therefore, we have α∆ i -W (0) i,j + αW (T ) i,j + O(δ) r ≥ α∆ i -1 + O( √ δ) r α r W (T ) i,i r > 0, where the last inequality assumes α ≤ α 2 := 1 1+O( √ δ) r r-1 R 1 r-1 min where R min = min i∈[k-1] ∆ i /[W (T ) i,i ] r . □ Next, we show that the error is non-increasing for α ∈ [α 1 , 1] by proving that once a sample is classified correctly it will remain so. Lemma 5 (Error Monotonicity). In the same setting as in Theorem 4, there exists δ 1 = δ ∆min such that the error is non-increasing for α ∈ [α 1 , 1]. Proof of Lemma 5. We first show that sample e k is correctly classified for the whole range [α 1 , 1]. Second, we show for any other sample once it become classified right it will remain so. Combining these two cases, we prove the monotonicity of the error rate. Class k. We first show that every x ∈ S k is classified correctly for any α ∈ [α 1 , 1]. According to Lemma 4, we know that f [α1] k (x) > f [α1] i (x) for any i ̸ = k. We only need to prove that f > Ω(1). [α] k (x) -f [α] i (x) is increasing for α ∈ [α 1 , 1]. Expanding f [α] k (x) -f [α] i (x), we have f [α] k (x) -f [α] i (x) = (1 -α) W (0) k,: , x + α W (T ) k,: , x r -(1 -α) W (0) i,: , x + α W (T ) i,: , x r + α b (T ) k -b (T ) i Other classes. For any class i ̸ = k, from Lemma 4, we know that it is classified incorrectly for α ∈ [α 1 , α 2 ]. We prove that once it become classified correctly at some α ′ ∈ (α 2 , 1], it remains so for α ∈ [α ′ , 1]. We show that at α, for any x ∈ S i , if f [α] i (x) > f [α] j (x) for all j ̸ = i, we have ∂ ∂α f [α] i (x) -f [α] j (x) > 0. Expanding f [α] i (x) -f [α] j (x), we have f [α] i (x) -f [α] j (x) = W [α] i,: , x r + b [α] i -W [α] j,: , x r -b [α] j = W [α] i,: , x r -W [α] j,: , x r -α b (T ) j -b (T ) i . Since f [α] i (x) -f [α] j (x) > 0, we have W [α] i,: , x r > α b (T ) j -b (T ) i , where we use W [α] j,: , x ≥ 0. Computing ∂ ∂α f [α] i (x) -f [α] j (x) , we have ∂ ∂α f [α] i (x) -f [α] j (x) = ∂ ∂α (1 -α) W (0) i,: , x + α W (T ) i,: , x r -(1 -α) W (0) j,: , x + α W (T ) j,: , x r + α b (T ) i -b (T ) j ≥r W (0) i,: , x + α W (T ) i,: , x -W (0) i,: , x r-1 W (T ) i,: , x -W (0) i,: , x -b (T ) j -b (T ) i -O(δ r ), where the inequality uses W (0) j,: , x , W (T ) j,: , x ≤ O(δ). If b (T ) j -b (T ) i ≤ 0, we only need to prove r W (0) i,: , x + α W (T ) i,: , x -W (0) i,: , x r-1 W (T ) i,: , x -W (0) i,: , x -O(δ r ) > 0, which holds since W (T ) i,: , x -W (0) i,: , x , W [α] i,: , x ≥ Ω(1). If b (T ) j -b (T ) i > 0, we have ∂ ∂α f [α] i (x) -f [α] j (x) =r W (0) i,: , x + α W (T ) i,: , x -W (0) i,: , x r-1 W (T ) i,: , x -W (0) i,: , x -b (T ) j -b (T ) i -O(δ r ) > (1 -O(δ r )) r W (T ) i,: , x -W (0) i,: , x W (0) i,: , x + α W (T ) i,: , x -W (0) i,: , x • α b (T ) j -b (T ) i -b (T ) j -b (T ) i , where the last inequality uses W [α] i,: , x r > α b (T ) j -b (T ) i . Therefore, to prove ∂ ∂α f [α] i (e i ) -f [α] j (e i ) > 0, we only need to prove (1-O(δ r ))r W (T ) i,: ,x -W (0) i,: ,x W (0) i,: ,x +α W (T ) i,: ,x -W (0) i,: ,x ≥ 1 α . We have (1 -O(δ r )) r W (T ) i,: , x -W (0) i,: , x W (0) i,: , x + α W (T ) i,: , x -W (0) i,: , x ≥ (1 -O(δ r )) r W (T ) i,: , x -W (0) i,: , x 2α W (T ) i,: , x -W (0) i,: , x ≥ 1 α . The first inequality requires W  (0) i,: , x ≤ α 1+α W (T ) i,: , x . Since α ≥ α 2 = 1 1+O( √ δ) r r-1 R 1 r-1 min , we can lower bound α 1+α as follows, α 1 + α ≥ 1 2 1 1 + O( √ δ) r r-1 R 1 r-1 min ≥ 1 8 R 1 r-1 min , where the first inequality uses α ≤ 1 and the second inequality uses 1 + O( √ δ) ≤ 2, r ≥ 2. So we have α 1+α W (T ) i,: , x ≥ 1 8 R 1 r-1 min ∆ 1/r min . Therefore, we only need δ ≤ O R 1 r-1 min ∆ 1/r min to ensure that W (0) i,: , x ≤ α W (T ) i,: , x -W (0) i,: , x . □ B.2 LOSS INTERPOLATION FOR r-HOMOGENEOUS-WEIGHT MODEL In this section, we give a proof of Theorem 5. Theorem 5 (Loss Interpolation). Suppose the network at initialization and after training satisfy the properties described in Theorem 2 and Induction Hypothesis 1. For any ϵ ∈ (0, 1), suppose δ ≤ O(ϵ 1/r ), there exist α 3 = ϵ 1/r Wmax and α] , b [α] ) ≤ log k + α∆ max + eϵ; 2. for all α ∈ [α 4 , 1], the loss is monotonically decreasing. α 4 = (1 + O(δ)) 1 r-1 Rmax r 1 r-1 such that 1. for all α ∈ [0, α 3 ], we have log k -eϵ ≤ 1 N L(W [ Proof of Theorem 5. This theorem directly follows from Lemma 6 and Lemma 7. □ Next, we prove the initial loss plateau in Lemma 6 and the monotonicity in Lemma 7. Lemma 6 (Loss Plateau). In the same setting as in Theorem 5, for any ϵ > 0, there exists α 3 = ϵ 1/r Wmax such that for all α ∈ [0, α 3 ] N (log k -eϵ) ≤ L(W [α] , b [α] ) ≤ N (log k + α∆ max + eϵ) . We show that for α ∈ [0, α 3 ], the weights W [α] is negligible and the bias dominates, which then gives a lower bound and an upper bound of the loss. Proof of Lemma 6. Since α ≤ α 3 = ϵ 1/r Wmax and δ ≤ O ϵ 1/r , we have W [α] i,: , x r = W (0) i,: , x + α( W (T ) i,: , x -W (0) i,: , x ) r ≤ 1 + 1 r ϵ 1/r r ≤ eϵ, for all i ∈ [k], x ∈ S. We can divide the dataset S into N/k disjoint subsets {P l } N/k l=1 where each P l contains exactly one sample from each class. Next, we bound the total loss of each subset P l . Without loss of generality, let's consider subset P 1 and suppose x (i) is the i-th class sample in this subset. For convenience, we denote the total loss of samples in P 1 as L 1 (W [α] , b [α] ). Lower bounding L 1 (W [α] , b [α] ). We have L 1 (W [α] , b [α] ) = i∈[k] log   j∈[k] exp f [α] j (x (i) ) exp f [α] i (x (i) )   = log   i∈[k] j∈[k] exp f [α] j (x (i) ) exp f [α] i (x (i) )   ≥ log           k i∈[k] exp f [α] i (x (i) ) j∈[k] exp f [α] j (x (i) )      k      , where the last inequality uses the HM-GM inequality. We can then upper bound i∈[k] exp f [α] i (x (i) ) j∈[k] exp f [α] j (x (i) ) as follows, i∈[k] exp f [α] i (x (i) ) j∈[k] exp f [α] j (x (i) ) = i∈[k] exp W [α] i,: , x (i) r + αb i j∈[k] exp W [α] j,: , x (i) r + αb j ≤ i∈[k] exp (αb i + eϵ) j∈[k] exp (αb j ) = i∈[k] exp (αb i ) j∈[k] exp (αb j ) exp (eϵ) = exp (eϵ) . Plugging back to the lower bound of L 1 (W [α] , b [α] ), we have L 1 (W [α] , b [α] ) ≥ k log k exp (eϵ) = k (log k -eϵ) . Upper bounding L 1 (W [α] , b [α] ). We have L 1 (W [α] , b [α] ) = i∈[k] log   j∈[k] exp f [α] j (x (i) ) exp f [α] i (x (i) )   ≤ i∈[k] log j∈[k] exp (αb j + eϵ) exp (αb i ) ≤k log (k exp (α∆ max + eϵ)) ≤k (log k + α∆ max + eϵ) The above analysis applies for every subset P l , so we have N (log k -eϵ) ≤ L(W [α] , b [α] ) ≤ N (log k + α∆ max + eϵ) .

□

Next we show that when α is reasonably large, we have f [α] i (e i ) -f [α] j (e i ) increasing for all i ̸ = j, which then implies that the loss is decreasing. Lemma 7 (Loss Monotonicity). In the same setting as in Theorem 5, there exists α 4 = (1 + O(δ)) 1 r-1 Rmax r 1 r-1 such that the loss is monotonically decreasing for α ∈ [α 4 , 1]. Proof of Lemma 7. To prove that the loss is monotonically decreasing, we only need to show that for any i ∈ [k] and any x ∈ S i , f [α] i (x) -f [α] j (x) is monotonically increasing for j ̸ = i. Same as in Lemma 5, it's easy to prove that for x ∈ S k , f k (x) -f j (x) with j ̸ = k monotonically increases for α ∈ [0, 1]. So we focus on other classes. For i ̸ = k, we show that ∂ ∂α f [α] i (x) -f [α] j (x) > 0 for x ∈ S i when α ≥ α 4 , ∂ ∂α f [α] i (e i ) -f [α] j (e i ) = ∂ ∂α (1 -α) W (0) i,: , x + α W (T ) i,: , x r -(1 -α) W (0) j,: , x + α W (T ) j,: , x r + α b (T ) i -b (T ) j ≥r (1 -α) W (0) i,: , x + α W (T ) i,: , x r-1 W (T ) i,: , x -W (0) i,: , x -b (T ) k -b (T ) i -O(δ r ) ≥rα r-1 W (T ) i,: , x r-1 W (T ) i,: , x -W (0) i,: , x -b (T ) k -b (T ) i -O(δ r ) ≥rα r-1 1 -O δ ∆ 1/r min W (T ) i,: , x r -b (T ) k -b (T ) i (1 + O(δ r )) >0, where the second last inequality uses W (0) i,: , x / W (T ) i,: , x ≤ O δ/∆ 1/r min . The last inequality requires rα r-1 ≥ (1 + O(δ)) b (T ) k -b (T ) i W (T ) i,i r which is satisfied as long as α ≥ (1 + O(δ)) 1 r-1 Rmax r 1 r-1 where R max = max i∈[k-1] ∆ i /[W (T ) i,i ] r .. □ B.3 PLATEAU FOR DEEP FULLY-CONNECTED NETWORKS In this section, we consider fully-connected neural networks as defined in Section 3 and prove that both the error and loss curves have plateau. We restate Theorem 1 as follows. Theorem 1. Suppose the network is defined as in Equation (1) and suppose the weights satisfy V (0) i ≤ δ, V (T ) i ≤ V max for all layers i ∈ [r] . On a k-class balanced dataset whose inputs have ℓ 2 norm at most 1, if Assumption 1 holds, for any ϵ > 0, as long as δ < min ϵ 1/r r , 1 r 2 , 1 2e 2 r-2 , there exist α 1 = δ ∆min , α 2 = 1 1+ √ δ r r-1 ∆min 2V r max 1 r-1 and α 3 = ϵ 1/r Vmax such that 1. for all α ∈ [α 1 , α 2 ], the error is 1 -1/k; 2. for all α ∈ [0, α 3 ], we have log k -2eϵ ≤ 1 N L V [α] i , b [α] ≤ log k + α∆ max + 2eϵ, where N is the number of training examples. Proof of Theorem 1. This theorem directly follows from Lemma 8 and Lemma 9. □ We separately prove the plateau of error interpolation in Lemma 8 and the plateau of loss interpolation in Lemma 9. Then, Theorem 1 is simply a combination of Lemma 8 and Lemma 9. For convenience, we denote h(x) := V r σ (V r-1 • • • σ(V 1 x) • • •) in the proof. Lemma 8. In the setting of Theorem 1, there exist α 1 = δ ∆min and α 2 = 1 1+ √ δ r r-1 ∆min 2V r max 1 r-1 such that the error is 1 -1/k for any interpolation point α ∈ [α 1 , α 2 ]. Proof of Lemma 8. Recall that the network output under input x is g(x) := V r σ (V r-1 • • • σ(V 1 x) • • •) + b. Similar as in the proof of Lemma 4, we only need to show that for all α ∈ [α 1 , α 2 ], we have g [α] i (x) < g [α] k (x) for all i ̸ = k and all samples x, which immediately implies the error is 1 -1/k. For α ∈ α 1 , √ δ Vmax . If α 1 = δ ∆min ≥ √ δ Vmax , we only need to consider the case when α ∈ √ δ Vmax , α 2 . So here we assume δ ∆min <

√ δ

Vmax . We can lower bound g [α] k (x) -g [α] i (x) as g [α] k (x) -g [α] i (x) = h [α] k (x) + b [α] k -h [α] i (x) -b [α] i ≥ α∆ min -2 j∈[r] (1 -α)V (0) j + αV (T ) j ≥ α∆ min -2 (δ + αV max ) r , where the first inequality holds because b [α] k -b [α] i ≥ α∆ min and h [α] k (x) , h [α] i (x) ≤ j∈[r] (1 -α)V (0) j + αV (T ) j . The second inequality uses (1 -α)V (0) j + αV (T ) j ≤ (1 - α) V (0) j + α V (T ) j ≤ δ + αV max . Since α ∈ δ ∆min , √ Vmax , we have g [α] k (x) -g [α] i (x) ≥ δ ∆ min ∆ min -2 δ + √ δ V max V max r , ≥δ -2 1 + 1 r √ δ r , ≥δ -2eδ r/2 >0, where the second inequality assumes δ ≤ 1/r 2 and the last inequality assumes δ < 1 2e 2 r-2 . For α ∈ √ δ Vmax , α 2 . Similar as above, we only need to show that α∆ min -2 (δ + αV max ) r > 0. Since α ≥ √ δ Vmax , we have δ ≤ √ δαV max . Therefore, we have α∆ min -2 (δ + αV max ) r ≥ α∆ min -2 1 + √ δ αV max r > 0, where the second inequality holds as long as α ≤ α 2 := 1 2 1 r-1 1 1+ √ δ r r-1 ∆min V r max 1 r-1 . □ Next, we show that for α ∈ [0, ϵ 1/r Vmax ], the loss cannot decrease by much. Similar as in Lemma 6, we prove that the signal is very small and the logit is dominated by the bias term. This then gives a lower and upper bounds for the loss. Lemma 9. In the setting of Theorem 1, there exists α 3 = ϵ 1/r Vmax such that for all α ∈ [0, α 3 ] log k -2eϵ ≤ 1 N L V [α] i , b [α] ≤ log k + α∆ max + 2eϵ, where N is the number of samples. Proof of Lemma 9. Since α ≤ α 3 = ϵ 1/r Vmax and δ ≤ ϵ 1/r r , we have h [α] (x) ≤ (δ + αV max ) r ≤ eϵ for all input x. Similar as in the proof of Lemma 6, we can show that log k -2eϵ ≤ 1 N L V [α] i , b [α] ≤ log k + α∆ max + 2eϵ, where we have an additional factor of 2 before eϵ because now the signal can be positive or negative. Here N is the number of samples. □

C PROOF OF TRAINING DYNAMICS

In this section, we give the complete proof of Theorem 2. Theorem 2. Suppose the neural network, dataset and optimization procedure are as defined in Section 4. Suppose initialization scale δ ≤ Θ(1), noise level σ ≤ Θ(1), dimension d ≥ Θ(1/δ 2r-2 ) and number of samples N ≥ Θ(1/δ r-1 ), with probability at least 0.99 in the initialization, there exists time T = Θ(log(1/δ)/δ r-2 ) such that we have 1. zero error: for all different i, j ∈ [k] and for all x ∈ S i , f (T ) i (x) ≥ f (T ) j (x) + Ω(1); 2. bias gap: b (T ) i * -max i̸ =i * b (T ) i ≥ Ω(1) with i * = arg max i∈[k] b (T ) i . Proof of Theorem 2. This theorem directly follows from Proposition 1. □ We consider the r-homogeneous-weight network as defined in Section 4. Our simple model simulates a depth-r ReLU/linear network with bias on the output layer, in the sense that the weights signal is r-homogeneous while the bias is 1-homogeneous in the parameters. Next, we prove Proposition 1 while leaving the proof of supporting lemmas into Section C.1. Through the proof of Proposition 1, we restate the lemmas when we use it for the convenience of readers. Proposition 1 (Induction Hypothesis). In the same setting of Theorem 2, with probability at least 0.99 in initialization, there exist time points 0 =: s 1 < t 1 < s 2 < t 2 < • • • < s k-1 < t k-1 < s k := T with t i -s i = Θ(log(1/δ)/δ r-2 ) and s i+1 -t i = Θ(1) for i ∈ [k-1] such that for any t ∈ [s i , s i+1 ], 1. (classes not yet learned) for any class j, j ′ ≥ i+1, we have (1) b (t) j ≥ max i ′ ∈[k] b (t) i ′ -O(δ r ), (2) b (t) j -b (t) j ′ ≤ O(δ r ) and (3) W (t) j,j ≤ O(δ);

2.. (classes already learned) for any class j

≤ i -1, we have (1) b (t) j ≤ max i ′ ∈[k] b (t) i ′ -Ω(1), (2) f (t) j (x) ≥ f (t) i ′ (x) + Ω(1) for i ′ ̸ = j, x ∈ S j and (3) W (t) j,j ≥ Ω(1); 3. (parameters movement) (1) for any j ∈ [k], Θ(δ) = W (0) j,j < W (t) j,j , (2) for any distinct j, j ′ ∈ [k], 0 < W (t) j,j ′ ≤ O(δ) and (3) for any j, j ′ ∈ [k] and any x ∈ S j ′ , W (t) j,: , ξ x ≤ min O(δ), W (t) j,j ′ . Proof of Proposition 1. Through the proof, we assume the conditions in Theorem 2 hold in all the lemmas without explicitly stated. At the initialization, we have the following properties with probability at least 0.99. Lemma 10 (Initialization). With probability at least 0.99 in the initialization, we have 1. for all j, j ′ ∈ [k], W (0) j,j ′ = Θ(δ); 2. for all distinct j, j ′ ∈ [k], W (0) j,j -W (0) j ′ ,j ′ = Θ(δ); 3. for all x ∈ S, ∥ξ x ∥ ≤ O(σ); 4. for all distinct x, x ′ ∈ S, ξx , ξx ′ ≤ O √ log(N ) √ d . 5. for all j ∈ [k] and all x ∈ S, ξx , e j , ξx , W ≤ O √ log(N ) √ d . Without loss of generality, we assume W (0) 1,1 > W (0) 2,2 > • • • > W (0) k,k . It's not hard to verify that the induction hypothesis holds at the initialization 4 . For any i ∈ [k -1], assuming the induction hypothesis holds for time [0, s i ], now we prove that it continues to hold in [s i , s i+1 ]. Next, we first prove the first two properties in the Proposition 1 and leave the last one at the end. The learning of x ∈ S i can be divided into four stages: 1. Stage 0 for t ∈ [s i , t i ] with t i -s i = O(log(1/δ)/δ r-2 ). During this stage, W i,i grows to a small constant µ 0 .

2.. Stage 1 for

t ∈ [t i , t (w) i ] with t (w) i -t i = O(1). In this stage, W (t) i,i grows from µ 0 to a large constant µ 1 . 3. Stage 2 for t ∈ [t (w) i , t (u) i ] with t (u) i -t (w) i = O(1). At the end of this stage, we have min x∈Si u (t) i (x) ≥ 1 -µ 2 for a small constant µ 2 . 4. Stage 3 for t ∈ [t (u) i , t (b) i ] with t (b) i -t (u) i = O(1), where t (b) i = s i+1 . During this stage, we have b (t) i -b (t) k decreases to -µ 3 with µ 3 a positive constant. Next, we consider these four stages one by one. Stage 0. We show that W (t) i,i increases faster than W (t) i+1,i+1 so that W (t) i,i reaches a constant while W (t) i+1,i+1 is still O(δ). We use the following lemma to characterize the increasing rate of W (t) i+1,i+1 and W (t) i,i . 4 We will maintain a stronger bound on W (t) j,: , ξx by proving W (t) j,: , ξx ≤ O( √ log N σδ), which implies W (t) j,: , ξx ≤ O(δ) as long as σ ≤ O(1/ √ log N ). Lemma 11. For any j ∈ [k], we have ∂W (t) j,j ∂t ≥ -O δ r-1 √ log N σ √ d . If min x∈Sj (1 -u j (x)) ≥ Ω(1), we further have 1 -O( log N σ) k N x∈Sj (1 -u j (x)) rW r-1 j,j ≤ ∂W (t) j,j ∂t ≤ 1 + O( log N σ) k N x∈Sj (1 -u j (x)) rW r-1 j,j . It's not hard to verify that min x∈Si 1 -u 1), so we have (t) i (x) , min x∈Si+1 1 -u (t) i+1 (x) ≥ Ω( ∂W (t) i,i ∂t ≥ 1 -O( log N σ) k N x∈Si 1 -u (t) i (x) r W (t) i,i r-1 , ∂W (t) i+1,i+1 ∂t ≤ 1 + O( log N σ) k N x∈Si+1 1 -u (t) i+1 (x) r W (t) i+1,i+1 r-1 . We can upper bound 1 -u (t) i+1 (x) for any x ∈ S i+1 as follows, 1 -u (t) i+1 (x) = i ′ ∈[k] exp f (t) i ′ (x) -exp f (t) i+1 (x) i ′ ∈[k] exp f (t) i ′ (x) ≤ i ′ ∈[k] exp b (t) i ′ -exp b (t) i+1 i ′ ∈[k] exp b (t) i ′ (1 + O(δ r )) , where the inequality uses W (t) i ′ ,: , x ≤ O(δ) for every i ′ ∈ [k]. We can lower bound 1 -u (t) i (x) for x ∈ S i as follows, 1 -u (t) i (x) = i ′ ∈[k] exp f (t) i ′ (x) -exp f (t) i (x) i ′ ∈[k] exp f (t) i ′ (x) ≥ i ′ ∈[k] exp b (t) i ′ -exp b (t) i i ′ ∈[k] exp b (t) i ′ (1 -O(µ r 0 )) ≥ i ′ ∈[k] exp b (t) i ′ -exp b (t) i+1 i ′ ∈[k] exp b (t) i ′ (1 -O(µ r 0 ) -O(δ r )) . The first inequality uses W (t) i ′ ,: , x ≤ O(µ 0 ) for every i ′ ∈ [k]. The second inequality uses b (t) i ≤ b (t) i+1 + O(δ r ), which is guaranteed by the following lemma. Lemma 2 (Bias Gap Control I). For any different j ′ , j ∈ [k], if W j ′ ,j ′ ≥ W j,j , W j,j ≤ O(δ) and b j ′ -b j ≥ O(δ r ), b j ≥ max i ′ ∈[k] b i ′ -O(δ r ), we have ḃj ′ -ḃj < 0. According to Lemma 10, we know there exists constant C > 1 such that W (0) i,i ≥ CW (0) i+1,i+1 . Choose constant S such that S 1 r-2 = √ C and W (0) i,i ≥ S 1 r-2 √ CW (0) i+1,i+1 . Choosing µ 0 as small constants and σ ≤ O(1/ √ log N ), δ ≤ O(1), we have 1 + O( log N σ) max x∈Si+1 1 -u (t) i+1 (x) ≤ S 1 -O( log N σ) min x∈Si 1 -u (t) i (x) . Published as a conference paper at ICLR 2023 We can also lower bound 1 -u (t) i (x) for x ∈ S i by a constant, 1 -u (t) i (x) ≥ i ′ ∈[k] exp b (t) i ′ -exp b (t) i i ′ ∈[k] exp b (t) i ′ (1 -O(µ r 0 )) ≥ exp b (t) i+1 i ′ ∈[k] exp b (t) i ′ (1 -O(µ r 0 )) ≥Ω(1), where the last inequality holds because µ 0 is a small constant and b Allen-Zhu & Li (2020) ). Let r ≥ 3 be a constant and let {W (t) i,i , W (t) j,j } t≥0 be two positive sequences updated as (t) i+1 ≥ max i ′ ∈[k] b (t) i ′ -O(δ r ). Lemma 12 (Adapted from Lemma C.19 in ∂W (t) i,i ∂t ≥ C t W (t) i,i r-1 for some C t = Θ(1), ∂W (t) j,j ∂t ≤ SC t W (t) j,j r-1 for some S = Θ(1). 1)) , then we must have for every A = O(1), let t i be the first time such that W Suppose W (0) i,i ≥ W (0) j,j S 1 r-2 (1 + Ω( (ti) i,i ≥ A, then W (ti) j,j ≤ O(W (0) j,j ). Then, according to Lemma 12, we know that there exists t i = O(log(1/δ)/δ r-2 ) such that W (ti) i,i = µ 0 and W (ti) i+1,i+1 ≤ O(δ) . By similar argument, we also know W (ti) j,j ≤ O(δ) for any j ≥ i + 1. Stage 1. In this stage, we show that W (t) i,i grows to a large constant µ 1 within constant time. Since W (t) i,i ≤ µ 1 and b (t) i,i ≤ b (t) i+1,i+1 + O(δ r ), we have 1 -u (t) i (x) ≥ Ω(1), for all x ∈ S i . This further implies, ∂W (t) i,i ∂t ≥ 1 -O( log N σ) k N x∈Si 1 -u (t) i (x) r W (t) i,i r-1 ≥Ω(1), where the inequality also uses W (t) i,i ≥ µ 0 . Since the increasing rate is at least a constant, we know W (t) i,i grows to µ 1 in constant time. For any j ≥ i + 1, since the increasing rate of W (t) j,j is merely O(δ r-1 ), we know W (t) j,j remains as O(δ) through Stage 1. Stage 2. In this stage, we prove that u (t) i (x) for any x ∈ S i grows to 1 -µ 2 with µ 2 a small constant. We use the following lemma to characterize the increasing rate of f (t) i (x) -f (t) j (x). Lemma 13. For any x ∈ S i and any j ̸ = i, if 1 -u i (x) ≥ Ω(1), we have ∂ ∂t (f i (x) -f j (x)) ≥ Ω(W 2r-2 i,i ) -O(1). Since u (t) i (x) ≤ 1 -µ 2 , we know 1 -u (t) i (x) ≥ Ω(1). For any j ̸ = i, we have ∂ ∂t f (t) i (x) -f (t) j (x) ≥Ω W (t) i,i 2r-2 -O(1) ≥Ω(1), where the last inequality holds because W (t) i,i ≥ µ 1 with µ 1 a large enough constant. The next lemma guarantees that at the beginning of Stage 2, we have b (t) i -b (t) j ≥ -O(1), which then implies f (t) i (x) -f (t) j (x) ≥ -O(1). Lemma 14 (Bias Gap Control III). For any different i, j ∈ [k], if W i,i ≤ O(1), W j,j ≤ O(δ) and b i -b j ≤ -O(1), we have ḃi -ḃj > 0. Let C be a constant such that f (t) i (x) -f (t) j (x) ≥ C for every j ̸ = i implies u i (x) ≥ 1 -µ 2 . Since at the at the beginning of Stage 2, we have f (t) i (x) -f (t) j (x) ≥ -O(1), within constant time, we have f (t) i (x) -f (t) j (x) ≥ C for every j ̸ = i and u i (x) ≥ 1 -µ 2 . Lemma 15 (Accuracy Monotonicity). Given any positive constant C 2 , there exists positive constant C 1 such that for all different i, j ∈ [k], as long as W i,i ≥ C 1 and f i (x) -f j (x) ≤ C 2 for any x ∈ S i , we have ∂(fi(x)-fj (x)) ∂t > 0. According to Lemma 15, by choosing large enough µ 1 , we can ensure that f (t) i (e i ) -f (t) j (e i ) ≥ C and u i (e i ) ≥ 1 -µ 2 throughout the training. Note that once W (t) i,i rises to µ 1 , it will stay at least µ 1 -O(δ) throughout the training, according to the gradient lower bound in Lemma 11. Stage 3. In this stage, we prove that within constant time we have b (t) i -b (t) k ≤ -µ 3 . The following lemma shows that b (t) i -b (t) k decreases in at least a constant rate. Lemma 3 (Bias Gap Control II). There exist small positive constants C 1 , C 2 such that for any j ∈ [k -1] and any x ∈ S j , if 1 -u j (x) ≤ C 1 , W k,k ≤ O(δ) and b j -b k ≥ -C 2 , we have ḃj -ḃk < -Ω(1). Choosing µ 2 = C 1 , µ 3 = C 2 where C 1 , C 2 are from Lemma 3, so we know that b = -µ 3 . By Lemma 3, we also know that for any t ≥ s i+1 , we have b (t) i -b (t) k ≤ -µ 3 . The following lemma shows that b (t) k is close to the maximum bias. Lemma 1 (Coupling Biases). Assuming W j ′ ,j ′ , W j,j ≤ O(δ) and b j ′ , b j ≥ max i ′ ∈[k] b i ′ -O(δ r ), we have ḃj ′ -ḃj > 0 if b j ′ -b j ≤ -µδ r , and ḃj ′ -ḃj < 0 if b j ′ -b j ≥ +µδ r for some positive constant µ. Combining Lemma 1 and Lemma 2, we know that throughout the training b (t) k ≥ max i ′ ∈[k] b (t) i ′ - O(δ r ). Therefore, we have b (t) i -max i ′ ∈[k] b (t) i ′ ≤ -Ω(1) for t ≥ s i+1 . Finally, let's bound the movement of different parameters. Monotonicity of diagonal terms: For j ∈ [k -1], according to Lemma 11 we know W  (t) k -b (t) k-1 ≤ O(1). Bounding non-diagonal terms: We use the following lemma to prove that Ω(δ) < W (t) j,j ′ ≤ O(δ) for j ̸ = j ′ . Lemma 16. For any j ̸ = j ′ , we have T Ẇj,j ′ ≤ O(δ). Furthermore, there exists absolute constant µ > 0 such that if 0 < W j,j ′ < µδ log 1 r-2 (1/δ) , we have T Ẇj,j ′ ≥ - µδ 2 log 1 r-2 (1/δ) . The first property in Lemma 16 guarantees that the increasing rate is so small that that the total increase within T time is only O(δ), which then implies that W , its decreasing rate is so small that that the total decrease within T time is only . j,j ′ = Θ(δ); 2. for all distinct j, j ′ ∈ [k], W 5. for all j ∈ [k] and all x ∈ S, ξx , e j , ξx , W (0) j,: ≤ O √ log(N ) √ d . Without loss of generality, we assume W (0) 1,1 > W (0) 2,2 > • • • > W (0) k,k . Proof of Lemma 10. Recall that each W (0) j,j ′ is independently sampled from N (0, δ 2 ) before taking the absolute value. By standard Gaussian concentration inequality, we know for any j, j ′ ∈ [k], with probability at least 1 -1 1000k 2 , W j,j ′ ≤ O(δ). By anti-concentration inequality of Gaussian polynomials, we know for any j, j ′ ∈ [k], with probability at least 1 -1 1000k 2 , W j,j ′ ≥ Ω(δ). Also by anti-concentration inequality of Gaussian polynomials, we know for any distinct j, j ′ ∈ [k], with probability at least 1 -1 1000k 2 , W (0) j,j 2 -W (0) j ′ ,j ′ 2 ≥ Ω(δ 2 ), which implies W (0) j,j -W (0) j ′ ,j ′ ≥ Ω(δ) assuming W (0) j,j , W (0) j ′ ,j ′ = Θ(δ). By the norm concentration of random vectors with independent Gaussian entries, for each x ∈ S, we have with probability at least 1 - 1 1000N 2 , ∥ξ x ∥ ≤ O(σ) as long as d ≥ O(log N ). By the concentration of standard Gaussian variable, for any distinct x, x ′ ∈ S, we have with probability at least 1 - 1 1000N 2 , ξx , ξx ′ ≤ O √ log N √ d . Similarly, for any x and any e j , we have with probability at least 1 - 1 1000kN , ξx , e j ≤ O √ log N √ d ; for any x and any W (0) j,: , we have with probability at least 1 - 1 1000kN , ξx , W (0) j,: ≤ O √ log N √ d ; Taking a union bound over all these events, we know with probability at least 0.99 in the initialization, we have 1. for all j, j ′ ∈ [k], W i,j = Θ(δ); 2. for all distinct j, j Allen-Zhu & Li (2020) ). Let r ≥ 3 be a constant and let {W (t) i,i , W (t) j,j } t≥0 be two positive sequences updated as ′ ∈ [k], W ∂W (t) i,i ∂t ≥ C t W (t) i,i r-1 for some C t = Θ(1), ∂W (t) j,j ∂t ≤ SC t W (t) j,j r-1 for some S = Θ(1).

Suppose

W (0) i,i ≥ W (0) j,j S 1 r-2 (1 + Ω(1) ) , then we must have for every A = O(1), let t i be the first time such that W (ti) i,i ≥ A, then W (ti) j,j ≤ O(W (0) j,j ). Proof of Lemma 12. This lemma directly follows from Lemma C.19 in Allen-Zhu & Li (2020) by taking the continuous time limit and setting k as a constant. □ Lemma 1 (Coupling Biases). Assuming W j ′ ,j ′ , W j,j ≤ O(δ) and b j ′ , b j ≥ max i ′ ∈[k] b i ′ -O(δ r ), we have ḃj ′ -ḃj > 0 if b j ′ -b j ≤ -µδ r , and ḃj ′ -ḃj < 0 if b j ′ -b j ≥ +µδ r for some positive constant µ. Proof of Lemma 1. Let's first write down the time derivative on b j ′ , ḃj ′ =1 - k N x∈S u j ′ (x) =1 - k N x∈S exp (⟨W j ′ ,: , x⟩ r + b j ′ ) i ′ ∈[k] exp (⟨W i ′ ,: , x⟩ r + b i ′ ) For any x ∈ S, we can bound exp(⟨W j ′ ,: ,x⟩ r +b j ′ ) i ′ ∈[k] exp(⟨W i ′ ,: ,x⟩ r +b i ′ ) as follows, exp (⟨W j ′ ,: , x⟩ r + b j ′ ) i ′ ∈[k] exp (⟨W i ′ ,: , x⟩ r + b i ′ ) - exp (b j ′ ) i ′ ∈[k] exp (⟨W i ′ ,: , x⟩ r + b i ′ ) ≤ O(δ r ), where we uses |⟨W j ′ ,: , x⟩| ≤ O(δ) + O( √ log N σδ) ≤ O(δ) assuming σ ≤ 1/ √ log N . The similar bound also holds for exp(⟨Wj,:,x⟩ r +bj ) i ′ ∈[k] exp(⟨W i ′ ,: ,x⟩ r +b i ′ ) If b j ′ -b j ≥ µδ r , we can now upper bound ḃj ′ -ḃj as follows, ḃj ′ -ḃj ≤ k N x∈S exp (b j ) -exp (b j ′ ) i ′ ∈[k] exp (⟨W i ′ ,: , x⟩ r + b i ′ ) + O(δ r ) ≤ k N x∈Sj ∪S j ′ exp (b j ) -exp (b j ′ ) i ′ ∈[k] exp (⟨W i ′ ,: , x⟩ r + b i ′ ) + O(δ r ) ≤ -Ω(µδ r ) • k N x∈Sj ∪S j ′ exp (b j ) i ′ ∈[k] exp (⟨W i ′ ,: , x⟩ r + b i ′ ) + O(δ r ) When x ∈ S j ∪ S j ′ , we can lower bound exp(bj ) i ′ ∈[k] exp(⟨W i ′ ,: ,x⟩ r +b i ′ ) as follows, exp (b j ) i ′ ∈[k] exp (⟨W i ′ ,: , x⟩ r + b i ′ ) = exp (b j ) i ′ ∈[k] exp (b i ′ ) exp (⟨W i ′ ,: , x⟩ r ) ≥ exp (b j ) i ′ ∈[k] exp (b i ′ ) • 1 1 + O(δ r ) ≥Ω(1), where the first inequality uses |⟨W i ′ ,: , x⟩| ≤ δ and the second inequality assumes b j ≥ max i ′ ∈[k] b i ′ - O(δ r ) and δ is at most some small constant. Therefore, if b j ′ -b j ≥ µδ r , we have ḃj ′ -ḃj ≤ -Ω(µδ r ) + O(δ r ) < 0, where the second inequality chooses µ as a large enough constant. Similarly, we can prove that if b j ′ -b j ≤ -µδ r , we have ḃj ′ -ḃj ≥ Ω(µδ r ) -O(δ r ) > 0. □ Lemma 2 (Bias Gap Control I). For any different j ′ , j ∈ [k], if W j ′ ,j ′ ≥ W j,j , W j,j ≤ O(δ) and b j ′ -b j ≥ O(δ r ), b j ≥ max i ′ ∈[k] b i ′ -O(δ r ), we have ḃj ′ -ḃj < 0. Proof of Lemma 2. We can write down ḃj ′ -ḃj as follows, ḃj ′ -ḃj = 1 - k N x∈S u j ′ (x) -1 - k N x∈S u j (x) = k N x∈S j ′ (u j (x) -u j ′ (x)) + k N x∈S\S j ′ (u j (x) -u j ′ (x)) . We first prove that for any x ∈ S j ′ , we have u j (x) -u j ′ (x) ≤ 0. We can upper bound f j (x) and lower bound f j ′ (x) as follows, f j (x) = ⟨W j,: , x⟩ r + b j ≤ O(δ r ) + b j f j ′ (x) = ⟨W j ′ ,: , x⟩ r + b j ′ ≥ b j . The bound on f j (x) holds because ⟨W j,: , x⟩ = W j,j ′ + ⟨W j,: , ξ x ⟩ ≤ O(δ) + O( √ log N σδ) ≤ O(δ). The bound on f j ′ (x) holds because ⟨W j ′ ,: , x⟩ = W j ′ ,j ′ + ⟨W j ′ ,: , ξ x ⟩ ≥ Ω(δ) -O( √ log N σδ) > 0. With the above two bounds, we know that u j (x) -u j ′ (x) ≤ 0 as long as b j ′ -b j ≥ O(δ r ). Same as in the proof of Lemma 1, for each x ∈ S \ S j ′ , we can bound u j ′ (x), u j (x) as follows, exp (b j ′ ) i ′ ∈[k] exp (f i ′ (x)) -O(δ r ) ≤ u j ′ (x) ≤ exp (b j ′ ) i ′ ∈[k] exp (f i ′ (x)) + O(δ r ), exp (b j ) i ′ ∈[k] exp (f i ′ (x)) -O(δ r ) ≤ u j (x) ≤ exp (b j ) i ′ ∈[k] exp (f i ′ (x)) + O(δ r ). Therefore, if b j ′ -b j ≥ µδ r , we can further upper bound ḃj ′ -ḃj as follows, ḃj ′ -ḃj ≤ k N x∈S\S j ′ (u j (x) -u j ′ (x)) . ≤ k N x∈S\S j ′ exp (b j ) -exp (b j ′ ) i ′ ∈[k] exp (f i ′ (x)) + O(δ r ) ≤ k N x∈Sj exp (b j ) -exp (b j ′ ) i ′ ∈[k] exp (f i ′ (x)) + O(δ r ) ≤ -Ω(µδ r ) k N x∈Sj exp (b j ) i ′ ∈[k] exp (f i ′ (x)) + O(δ r ). Similar as in Lemma 1, we can show that exp(bj ) i ′ ∈[k] exp(f i ′ (x)) ≥ Ω(1) due to W j,j ≤ O(δ) and b j ≥ max i ′ ∈[k] b i ′ -O(δ r ). So, finally we have ḃj ′ -ḃj ≤ -Ω(µδ r ) + O(δ r ) < 0, where the last inequality chooses µ as a large enough constant. □ Lemma 15 (Accuracy Monotonicity). Given any positive constant C 2 , there exists positive constant C 1 such that for all different i, j ∈ [k], as long as W i,i ≥ C 1 and f i (x) -f j (x) ≤ C 2 for any x ∈ S i , we have ∂(fi(x)-fj (x)) ∂t > 0. Proof of Lemma 15. Since f i (x) -f j (x) ≤ C 2 , we know 1 -u i (x) ≥ Ω(1). This immediately implies min x ′ ∈Si (1 -u i (x ′ )) ≥ Ω(1) since |u i (x) -u i (x ′ )| ≤ O(δ). According to Lemma 13, we can bound ∂(fi(ei)-fj (ei)) ∂t as follows, ∂ (f i (x) -f j (x)) ∂t ≥ Ω(W 2r-2 i,i ) -O(1) > 0 where the second inequality holds because W i,i ≥ C 1 with C 1 a large enough constant. □ Lemma 3 (Bias Gap Control II). There exist small positive constants C 1 , C 2 such that for any j ∈ [k -1] and any x ∈ S j , if 1 -u j (x) ≤ C 1 , W k,k ≤ O(δ) and b j -b k ≥ -C 2 , we have ḃj -ḃk < -Ω(1). Proof of Lemma 3. Since 1 -u j (x) ≤ C 1 for some x ∈ S j , we know 1 -u j (x ′ ) ≤ C 1 + O(δ) for every x ′ ∈ S j . We can write down ḃjḃk as follows, (1 -u j (x ′ )) r ⟨W j,: , x ′ ⟩ r-1 x ′ , ξx - According to Lemma 18, we know that for x ′ ∈ S j , we have (1 -u j (x ′ )) ⟨W j,: , x ′ ⟩ r-1 ≤ O(1). ḃj -ḃk = 1 - k N x ′ ∈S u j (x ′ ) -1 - k N x ′ ∈S u k (x ′ ) = k N x ′ ∈Sj (u k (x ′ ) -u j (x ′ )) + k N x ′ ∈S\Sj (u k (x ′ ) -u j (x ′ )) . For x ′ ∈ S \ S i , we have u i (x ′ ) ⟨W j,: , x ′ ⟩ For the first term, we have Combining the bounds on both terms, as long as ⟨W i,: , x⟩ is larger than certain constant (which is guaranteed by (1 -u i (x)) ⟨W i,: , x⟩ r-1 ≥ Θ(1)), we know d dt (1 -u i (x)) ⟨W i,: , x⟩ r-1 < 0. Lemma 16. For any j ̸ = j ′ , we have T Ẇj,j ′ ≤ O(δ). Furthermore, there exists absolute constant µ > 0 such that if 0 < W j,j ′ < µδ log 1 r-2 (1/δ) , we have T Ẇj,j ′ ≥ - u j (x)r ⟨W j,: , x⟩ r-1 ⟨e j ′ , x⟩ = ± O σ √ log N √ d -O W j,j ′ ± O log N δσ r-1 ± O δ r-1 σ √ log N √ d . The bound on the first term relies on (1 -u j (x)) ⟨W j,: , x⟩ ). In Figure 18 , although class 9 is learned last, class 7 gets the largest bias after training. Let S be the set of all samples for number 7,8,9 and let S 7 , S 9 , S 9 be the set of samples for each class. For convenience, we use u i,j to denote 1 |S| x∈Sj u i (x), where u i (x) is the softmax output for class i under input x. Then, we can write down the derivative on three bias terms: ḃ7 = 1 3 -u 7,7 -u 7,8 -u 7,9 ḃ8 = 1 3 -u 8,7 -u 8,8 -u 8,9 ḃ9 = 1 3 -u 9,7 -u 9,8 -u 9,9 . According to the per-class loss, we know that x∈S7 -log (u 7 (x)) < x∈S9 -log (u 9 (x)) , which intuitively implies that x∈S7 u 7 (x) > x∈S9 u 9 (x) that is u 7,7 > u 9,9 . This tends to drive b 7 smaller than b 9 . However, because u 9,8 > u 7,8 , we actually have ḃ9 < ḃ7 . So eventually b 9 becomes smaller than b 7 . Intuitively, class 9 is more correlated with class 8, so u 9,8 > u 7,8 .



In the Xavier initialization, each entry in weight matrix W is sampled from N (0, 1/d), so we can think of δ = 1/d that is small when input dimension d is large. 2 This is indeed possible since all samples of one class only differ in the noise terms in our setting. In the analysis, we can show that the noise term has negligible contribution to the network output and all samples in one class are learned almost at the same time. Note the initialization condition in Prop. 5 is satisfied with constant probability for a reasonable initialization scheme. For example, if we uniformly sample (x, y) from the set S = {(x, y) ∈ R 2 |x 2 + y 2 ≤ R} with R ≥ 2, the condition is satisfied with constant probability.



Figure 1: Loss interpolation curve and error interpolation curve for a four-layer fully-connected network (FCN4) on MNIST and for VGG16 on CIFAR-10.

Figure 2: (Left) Our r-homogeneous-weight model with f i (x) = ⟨W i,: , x⟩ r + b i . (Right) The comparison between interpolated bias α(b (T ) k -b (T ) i ) and interpolated weights signal W [α] i,: , x

r] and b ∈ R nr . Here the activation function σ (•) can be either identity function or ReLU function. The output layer width equals to the number of classes, i.e., n r = k. We use L ({V i } , b) to denote the sum of cross entropy loss over all samples.

1 and does not change much when depth increases. So the plateau becomes longer in a deeper network.

Figure 3: The training dynamics of W and b in a four-class example.

Figure 4: Loss and error curves across networks with all bias, last bias and no bias.

Figure 5: Loss and error curves across networks with normal and homogeneous interpolation on bias.

Figure 6: Loss and error curves across networks with different initialization scales.

Figure 7: Loss and error curves across networks with different depths.

Figure 8: Train loss for each class and bias term dynamics on 2-class MNIST and 3-class MNIST.Bias learning dynamics. Our dynamics analysis in Section 4 shows that gradient descent can learn diverse biases on a balanced dataset by learning different classes at different time points. In particular, the last learned class should have the highest bias term. We verify this theory by studying FCN10 with only output bias on balanced 2-class or 3-class MNIST. To separate the learning of different classes, we compute the per-class loss by only considering the examples in that particular class. According to Figure8, in the 2-class MNIST, number 1 is learned last and its bias is larger, which fits our theory. Also in the 3-class MNIST, class 2 is learned first, class 3 the second and class 1 the last; for the learned bias, class 2 bias is smallest, class 3 bias in the middle and class 1 bias the highest.

≤ O(δ) for i ̸ = j, so we have W ≤ W max as long as δ ≤ O(∆

, x and the second inequality uses r ≥ 3, (1 -O(δ r )) ≥ 2/3. To prove W

can only start decreasing when it exceeds a large constant and can only decrease by at most O(δ) through the algorithm. By choosing δ ≤ O(1), we can ensure that W for any t. For W (t) k,k , we know it monotonically increases since we always have 1 -u(t) k (x) ≥ Ω(1) for x ∈ S k . This is because W (t) k,kremains as small as O(δ) through the algorithm and b

′ ≤ O(δ) through the training. The second property in Lemma 16 guarantees that once W

′ > Ω(δ) through the training.Bounding noise correlations: The following lemma shows that the total change of W (t) j,: , ξ x within T time is only O( √ log N σδ). Since at initialization, we know W , ξ x ≤ O( √ log N σδ) throughout the training. Since W (t)j,j ′ ≥ Ω(δ), as long as σ ≤ O(1), we also have W (t) j,: , ξ x ≤ W j,j ′ for x ∈ S j ′ . Lemma 17. For every j ∈ [k] and every x ∈ S, we haveẆj,: , ξ x • T ≤ O log N σδ □ C.1 PROOF OF LEMMASLemma 10 (Initialization). With probability at least 0.99 in the initialization, we have 1. for all j, j ′ ∈ [k], W

all x ∈ S, ∥ξ x ∥ ≤ O(σ); 4. for all distinct x, x ′ ∈ S, ξx , ξx ′ ≤ O √ log(N ) √ d

all x ∈ S, ∥ξ x ∥ ≤ O(σ); 4. for all distinct x, x ′ ∈ S, ξx , ξx ′ ≤ O all j ∈ [k] and all x ∈ S, ξx , e j , ξx , Lemma 12 (Adapted from Lemma C.19 in

x ′ ∈S\Sj u j (x ′ )r ⟨W j,: , x ′ ⟩ r-1 x ′ , ξx  We know that x, ξx ≤ O(σ + √ log N / √ d).For any x ′ ̸ = x, we have x ′ , ξx ≤ Olong as σ ≤ 1.

O(δ r-1 ) since |⟨W j,: , x ′ ⟩| ≤ O(δ) + O( √ log N δσ) ≤ O(δ) assuming σ ≤ 1/ √ log N .Therefore, we can bound Ẇj,: , ξx as follows,Ẇj,: , ξx ≤ O σ N + √ log N √ d Since T ≤ O(log(1/δ)/δ r-2 ), N ≥ log(1/δ)/δ r-1 and d ≥ log 2 (1/δ)/δ 2r-2 , we know Ẇj,: , ξx • T ≤ O( log N δ). □ Lemma 18. For any i ∈ [k] and x ∈ S i , if (1 -u i (x)) ⟨W i,: , x⟩ r-1 ≥ Θ(1), we have d dt (1 -u i (x)) ⟨W i,: , x⟩ r-1 < 0.Proof of Lemma 18. We can write 1 -u i (x) as j∈[k],j̸ =i exp(fj (x))j∈[k],j̸ =i exp(fj (x))+exp(fi(x)) . Next, we prove that for any j ′ ̸ = i, we haved dt exp (f j ′ (x)) j∈[k],j̸ =i exp (f j (x)) + exp (f i (x)) ⟨W i,: , x⟩ r-1 < 0.This derivative can be written the sum of two terms:d dt exp (f j ′ (x)) j∈[k],j̸ =i exp (f j (x)) + exp (f i (x)) ],j̸ =i exp (f j (x) -f j ′ (x)) + exp (f i (x) -f j ′ (x)) ],j̸ =i exp (f j (x) -f j ′ (x)) + exp (f i (x) -f j ′ (x))⟨W i,: , x⟩ r-1 .

j∈[k],j̸ =i exp (f j (x) -f j ′ (x)) + exp (f i (x) -f j ′ (x)) ],j̸ =i exp (f j (x) -f j ′ (x)) + exp (f i (x) -f j ′ (x)) (r -1) ⟨W i,: , x⟩ r-2 Ẇi,: , x ≤ 1 exp (f i (x) -f j ′ (x)) (r -1) ⟨W i,: , x⟩ r-2Ẇi,: , x .For the second term, we haved dt 1 j∈[k],j̸ =i exp (f (x) -f j ′ (x)) + exp (f i (x) -f j ′ (x)) ⟨W i,: , x⟩ r-1 = -j∈[k],j̸ =i exp (f j (x) -f j ′ (x)) ḟj (x) -ḟj ′ (x) + exp (f i (x) -f j ′ (x)) ḟi (x) -ḟj ′ (x) j∈[k],j̸ =i exp (f j (x) -f j ′ (x)) + exp (f i (x) -f j ′ (x)) i (x) -f j ′ (x)) ⟨W i,: , x⟩ r-1 ,where the last inequality uses f i (x) -f j (x) ≥ Ω(1), ḟj (x) -ḟj ′ (x) ≤ O(1) and ḟi (x) -ḟj ′ (x) ≥ r ⟨W i,: , x⟩ r-1Ẇi,: , x -O(1) ≥ Ω(1).

Proof of Lemma 16. We can write down the derivative of W j,j ′ as follows, u j (x)) r ⟨W j,: , x⟩ r-1 ⟨e j ′ , x⟩k N x∈S j ′ u j (x)r ⟨W j,: , x⟩ r-1 ⟨e j ′ , x⟩ -k N x∈S\(Sj ∪S j ′ )

r-1 ≤ O(1) and ⟨e j ′ , x⟩ = ±O σ √ log N √ d for x ∈ S j , where (1 -u j (x)) ⟨W j,: , x⟩ r-1 ≤ O(1) is guaranteed by Lemma 18.The bound on the second term uses ⟨W j,:, x⟩ = W j,j ′ ± O √ log N δσ and ⟨e j ′ , x⟩ = 1 ± O σ √ log N √ d for x ∈ S j ′ .The bound on the third term uses ⟨W j,:, x⟩ = O(δ) and ⟨e j ′ , x⟩ = ±O σ √ log N √ d for x ∈ S \ (S j ∪ S j ′ ).To prove the upper bound of the derivative, we haveẆj,j ′ ≤O( σ √ log N √ d )where we useW j,j ′ ± O √ log N δσ ≥ 0. Since T = O(log(1/δ)/δ r-2 ), we have T Ẇj,j ′ ≤ O(δ),as long as d ≥ O( log N log 2 (1/δ) δ 2r-2

Figure 17: Train loss for each class and bias term dynamics on MNIST{1, 2} and MNIST{2, 3}.

When learning class i during time [s i , s i+1 ], the weight W , t i ] and then quickly grows large in [t i , s i+1 .] As a result, all x ∈ S i become classified correctly. During the same time, b

annex

First, we upper bound u k (x ′ ) -u j (x ′ ) for every x ′ ∈ S j as follows,Same as in the proof of Lemma 1, for each x ′ ∈ S \ S j , we can bound u j (x ′ ), u k (x ′ ) as follows,+ O(δ r ).Therefore, we can upper bound u k (x ′ ) -u j (x ′ ) as follows,where the last inequality uses b k -b j ≤ C 2 .Above all, we can upper bound ḃjḃk as follows,where the second inequality holds as long as C 1 , C 2 , δ are at most some small constants.Proof of Lemma 14. We can write down ḃiḃj as follows,Next, we lower bound u j (x) -u i (x) for every x ∈ S,□ Lemma 17. For every j ∈ [k] and every x ∈ S, we haveProof of Lemma 17. For each j ∈ [k], we haveWe show that there exists absolute constant µ > such that if, we have, which holds as long as Ẇj,j ′ ≥ -O. We have. The second inequality assumes d ≥). The third inequality chooses µ as a small enough constant. □ Lemma 11. For any j ∈ [k], we haveLemma 13. For any x ∈ S i and any j ̸ = i, if 1 -u i (x) ≥ Ω(1), we haveProof of Lemma 13. Recall that f i (x) = ⟨W i,: , x⟩ r + b i , so we have ḟi (x) =r ⟨W i,: , x⟩where in the last inequality we uses Ẇi,i ≥ Ω(W r-1 i,i ) ≥ Ω(δ r ) and Ẇi,:We also havewhere we usesTherefore, we have) -O(1).

D ADDITIONAL EXPERIMENTS

In this section, we describe the detailed setting of our experiments and also include additional experiment results.MNIST & Fashion-MNIST. Unless specified otherwise, we use a depth-10 and width-1024 fully-connected ReLU neural network (FCN10) for MNIST and Fashion-MNIST. We use Kaiming initialization for the weights and set all bias terms as zero. We use a small initialization by scaling the weights of each layer by (0.001) 1/h so the output is scaled by 0.001, where h is the network depth. We train the network using SGD with learning rate 0.01 and momentum 0.9 for 100 epochs.

CIFAR-10 & CIFAR-100

We use VGG-16 (without batch normalization) for CIFAR-10 and CIFAR-100. We use Kaiming initialization for the weights and set all bias terms as zero. We run SGD with momentum 0.9 and weight decay 1e-4 for 100 epochs. For the learning rate, we start from 0.01 and reduce it by a factor of 0.1 at the 60-th epoch and 90-th epoch.We linearly interpolate using 50 evenly spaced points between the network at initialization and the network at the end of training. We evaluate error and loss on the train set. For each setting, we repeat the experiments three times from different random seeds and plot the mean and deviation.Note in Figure 1 , to contrast the convex curve and plateau curve, we have used FCN4 with standard initialization on MNIST, and VGG-16 with 0.001 initialization on CIFAR-10.Our code is based on the implementation from Lucas et al. (2021) . Each trial of our experiment can be finished on an Nvidia Tesla P100 within one hour. Figure 9 shows that on both Fashion-MNIST and CIFAR-10, having bias on the last layer or on all layers can create longer plateau in error curve, while does not significantly affect the loss curve. Deeper networks create longer plateau both error and loss curves. See Figure 14 for MNIST, CIFAR-100 with last bias; see Figure 15 for Fashion-MNIST, CIFAR-10 with all bias; see Figure 16 for Fashion-MNIST, CIFAR-10 with last bias.

