UNDERSTANDING TRAIN-VALIDATION SPLIT IN META-LEARNING WITH NEURAL NETWORKS

Abstract

The goal of meta-learning is to learn a good prior model from a collection of tasks such that the learned prior is able to adapt quickly to new tasks without accessing many data from the new tasks. A common practice in meta-learning is to perform a train-validation split on each task, where the training set is used for adapting the model parameter to that specific task and the validation set is used for learning a prior model that is shared across all tasks. Despite its success and popularity in multitask learning and few-shot learning, the understanding of the train-validation split is still limited, especially when the neural network models are used. In this paper, we study the benefit of train-validation split for classification problems with neural network models trained by gradient descent. For first-order model-agnostic meta-learning (FOMAML), we prove that the train-validation split is necessary to learn a good prior model when the noise in the training sample is large, while the train-train method fails. We validate our theory by conducting experiment on both synthetic and real datasets. To the best of our knowledge, this is the first work towards the theoretical understanding of train-validation split in metalearning with neural networks. Optimization and generalization guarantees for meta-learning. A number of recent works studied the optimization guarantees for meta-learning algorithms. Finn & Levine (2017) proved the universality of gradient based meta-learning. Wang et al. (2020) studied the global optimality conditions for MAML with a nonconvex objective. Fallah et al. (2020) studied the convergence guarantee for MAML with nonconvex loss function and proposed Hessian-Free MAML with the same theoretical guarantee of MAML without accessing second order information. Finn et al. (2019); Balcan et al. (2019); Khodak et al. (2019); Denevi et al. (2019; 2018a) studied online meta-learning using online convex optimization. Another series of works have also studied the generalization error and sample complexities of meta-learning methods. Specifically, Amit & Meir (2018) extended the PAC-Bayes argument to the meta-learning setting and established a generalization error bound for

1. INTRODUCTION

In recent years, meta-learning has gained increasing popularity and been successfully applied to a wide range of problems including few-shot learning (Ren et al., 2018; Li et al., 2017; Rusu et al., 2018; Snell et al., 2017) , reinforcement learning (Gupta et al., 2018b; a) , neural machine translation (Gu et al., 2018) , and neural architecture search (NAS) (Liu et al., 2018; Real et al., 2019) . A popular meta-learning idea is to formulate it as a bi-level optimization problem, where the inner level computes the parameter adaptation to each task, while the outer level tries to minimize the meta-training loss. Such a bi-level optimization formulation is empirically proved effective to learn new tasks quickly using only a few examples with the aid of past experience. Following this idea, meta-learning algorithms such as model agnostic meta-learning (MAML) (Finn et al., 2017) have achieved remarkable success in many applications. Due to the nature of bi-level optimization, meta-learning algorithms can often take advantage of a train-validation split in the dataset, so that the inner and outer levels of the algorithm use different data points (Finn et al., 2017; Rajeswaran et al., 2019; Bai et al., 2021; Fallah et al., 2020) . It is believed that the train-validation split can help the meta-learning algorithm to achieve a better performance. There has been several attempts to understand the importance of train-validation split in meta-learning for linear models (Wang et al., 2021; Bai et al., 2021; Saunshi et al., 2021) . Specifically, Wang et al. (2021) showed that when learning linear models, the train-train method performs much worse than the train-validation method if the sample size is small and the noise is large. They also show that the train-train method is able to perform well on linear models when the sample size is large enough. Bai et al. (2021) considered the linear centroid model introduced in Denevi et al. (2018b) and showed that the train-validation method outperforms the train-train method in an agnostic setting, while in the realizable noiseless setting, the train-train method can asymptotically achieve a strictly smaller mean square error than the train-validation method as the sample size and dimension go to infinity at a fixed ratio. Saunshi et al. (2021) considered a representation learning perspective, and demonstrated that train-validation split encourages the learned representation to be low-rank while the train-train method encourages high-rank representations. However, all these works focus on the linear regression setting, while the advantage of train-validation split remains elusive for meta-learning with neural networks. Based on the above observation, we raise the following question: How does train-validation split affect the meta learning with neural networks? In this paper, we answer the above question via a case study of few-shot binary classification using a two-layer convolutional neural network. We consider a learning problem where the data model consists of large noises, and only a limited number of data are available. Under this setting, we theoretically compare the performance of first-order MAML (FOMAML (Finn et al., 2017) ), which is a simplification of MAML that ignores the hessian terms, with a train-validation split (the trainvalidation method) and without a train-validation split (the train-train method). We summarize our contributions as follows: 1. We show that under our setting, despite the complex bi-level structure of the FOMAML loss and its non-convex landscape, it is guaranteed that the train-validation method and the traintrain method can both train a two-layer CNN to a global minimum of the training loss with high probability. 2. We also demonstrate that there is a significant performance gap on new test data. Specifically, we show that the neural network trained by the train-validation method can achieve a test loss that decreases exponentially fast as the number of training tasks increase. On the other hand, we also show that the train-train method can at best achieve a constant level test loss. 3. Our study demonstrates the importance of train-validation split in learning neural networks. To the best of our knowledge, this is the first theoretical work studying train-validation split of metalearning with neural networks. Notably, the learning problem we consider is linearly realizable, for which Bai et al. (2021) showed that the train-train method can asymptotically achieve better MSE than the train-validation method under a linear model. However, our results give the opposite conclusion for learning CNNs -the train-validation method still outperforms the train-train method even for linearly realizable learning problems. Therefore, our results indicate that trainvalidation split may have a more significant advantage when learning complicated prediction models. 4. We perform experiments on both synthetic and real datasets with neural networks as backbone model to justify our theoretical results. In particular, even when the data and the neural network structure do not meet our theoretical assumptions, the experiment results still corroborate our theory to a certain extent. This demonstrates the practical value of our analysis. Notation. For an integer k, we denote [k] = {1, 2, . . . , k}. Given two sequences {x n } and {y n } with y n > 0, we denote x n = O(y n ) if |x n |/y n is upper bounded by a constant for all n. Similarly, we denote x n = Ω(y n ) if |x n |/y n is lower bounded by a positive constant. We denote x n = Θ(y n ) if x n = O(y n ) and x n = Ω(y n ). Finally, we use O(•), Ω(•), and Θ(•) to hide logarithmic factors. meta-learning. Saunshi et al. (2020) analyzed the sample complexity under the nonconvex setting using Reptile (Nichol et al., 2018) . There has also been a stream of work studying generalization in meta-learning from the representation learning perspective (Du et al., 2020; Tripuraneni et al., 2020; Denevi et al., 2018a) . These works do not employ a train-validation split, instead they assume some underlying low-rank constraints on the representation. Chen et al. (2022) ; Huang et al. (2022) ; Wang et al. (2022) also studied generalization error bounds for meta-learning in the overparametrized regime. Feature learning in over-parameterized neural networks. There is a series of recent work studying the feature learning dynamics of neural networks. For example, Allen-Zhu & Li (2020a) studied a sparse coding model and showed that adversarial training can help CNN filters to learn "pure" dictionary bases. Allen-Zhu & Li (2020b) studied the impact of ensemble and knowledge distillation on the feature learning process. Zou et al. (2021) demonstrated the generalization gap between Adam and stochastic gradient descent through the lens of feature learning. Cao et al. (2022) studied feature learning and noise memorization in learning two-layer CNNs and showed a phase transition between benign and harmful overfitting.

3. PROBLEM SETUP

In this section, we introduce our data model, neural network model, loss function, and the details about the FOMAML algorithm with train-train and train-validation methods. We first introduce our data model. In meta-learning, the goal is to train a model based on the data from K tasks, so that the trained model can learn a new task efficiently. Theoretical analysis of meta-learning thus requires careful modeling of (i) the relation among different tasks, and (ii) the data distribution for each specific task. To achieve this, we suppose that the data distribution D k for the k-th task is defined based on a vector ν k , i.e., D k = D(ν k ), and that the vectors ν 1 , . . . , ν K are independently drawn from a distribution Π. Following this framework, the data points of a new task can be generated by first sampling a vector ν from Π, then sampling data from the distribution D( ν). In the following, we present the detailed definitions of Π and D( ν). Definition 3.1 (Distribution of tasks). Let ν, z 1 , . . . , z M ∈ R d be fixed vectors, where z 1 , . . . , z M are orthogonal to ν. A vector ν is generated from Π by (i) randomly pick a vector z from {z 1 , . . . , z M }, and (ii) let ν = ν + z. Definition 3.1 is based on a set of fixed vectors ν, z 1 , . . . , z M ∈ R d . Here ν captures the feature shared by all tasks. And z 1 , . . . , z M give a dictionary of possible unique features of each specific task. It is important to note that we focus on the setting where the number of observed tasks K M . Under our setting, with high probability, different tasks will use different unique features. Definition 3.2 (Distribution of data). Given a vector ν ∈ R d , each data point (x, y) with x = [x (1) , x (2) ] ∈ R 2d and y ∈ {-1, 1} is generated from D( ν) as follows: 1. The label y is assigned as +1 or -1 with equal probability. 2. A noise vector ξ is generated from N (0, σ 2 ξ •(I-P)), where P ∈ R d×d is the projection operator onto span({ν, z 1 , . . . , z M }). 3. One of x (1) , x (2) is randomly selected and assigned as y • ν k ; the other is assigned as ξ. For k ∈ [K], we denote by S k = {(x k,i , y k,i )} n i=1 the set of independent samples from the kth observed task. We consider a specific type of data input that consists of two patches, x k,i = [x (1) k,i , x k,i ] , to meet our study of convolutional neural networks. Note also that Definition 3.2 requires the noises in the data input to be sampled from the Gaussian distribution N (0, σ 2 ξ • (I -P )) to ensure that the noise vector ξ is orthogonal to the features ν, z 1 , . . . , z M . It is then easy to check that our data model is linearly realizable: the linear predictor θ * = ν -2 2 • [ν , ν ] ∈ R 2d satisfy θ * , x = y for all (x, y) drawn from our data distribution. Our motivation to study this linearly realizable setting is that it has been proved in Bai et al. (2021) that the train-train method is strictly better than the train-validation method when learning linear models in the realizable setting. On the contrary, in this paper we aim to show that for the linearly realizable data model defined in Definitions 3.1 and 3.2, the train-validation method can still outperform the train-train method for FOMAML in learning two-layer CNNs. We study a two-layer CNN with m hidden layer neurons whose second layer weights are frozen as ±1's. We use the Huberized-ReLU (Chatterji et al., 2021) activation function, which is defined as σ(z) =    0, z < 0 z 2 /(2h), z ∈ [0, h] z -h/2, otherwise. where we set h = 1/2 in our analysis. Let W represent the collection of all weights of our network. For a data input x = [x (1) , x (2) ] , we consider the CNN f (W, x) = F +1 (W +1 , x) - F -1 (W -1 , x), where F j (W j , x) = m r=1 2 p=1 σ( w j,r , x (p) ) , j ∈ {-1, 1}. We use w j,r to denote the r-th convolution filter with second layer weight j, and use W j to denote the collection of w j,1 , . . . , w j,m . We consider cross-entropy loss. The loss of a data point (x, y) is given by L(W, x, y) = [y • f (W, x)], where (z) = log(1 + exp(-z)). The loss of a collection of data points S is defined by L(W, S) = 1 |S| (x,y)∈S L(W, x, y). Following the data model given in Definitions 3.1 and 3.2, we define the test loss achieved by a CNN with weights W as L test (W) := E ν∼Π,(x,y)∼D( ν) (y•f (W, x)). We implement the FOMAML algorithm (Finn et al., 2017) to train the neural network. The train-train and train-validation methods are given as follows. Train-train: for each task k, we use all of the samples for adapting the parameter in the inner-loop updates. Specifically, the meta objective is to minimize L tr-tr (W, {S k } K k=1 ) = 1 K K k=1 L( W(W, S k ), S k ) , where W(W, S k ) represents the weights of the network after J gradient descent steps (w.r.t. loss L(•, S k )) starting from W with step size γ. The FOMAML algorithm updates the CNN weights using the following gradient update rule with step size η:  W (t+1) = W (t) -η • 1 K K k=1 ∇ W L(W, S k )| W= W(W (t) ,S k ) . ( L tr-val (W, {S k } K k=1 ) = 1 K K k=1 L( W(W, S tr k ), S val k ) , where W(W, S tr k ) represents the weights of the network after J gradient descent steps (w.r.t. loss L(•, S tr k )) starting from W with step size γ. For the train-validation method, the FOMAML algorithm implements the following outer-loop update rule to train the network: W (t+1) = W (t) -η • 1 K K k=1 ∇ W L(W, S val k )| W= W(W (t) ,S tr k ) . (3.2)

4. MAIN RESULTS

In this section, we present our theoretical results for the train-train and train-validation methods. We first introduce the following condition. Condition 4.1. There exists σ s > 0 such that (1/2) • σ s √ d ≤ z i 2 ≤ (3/2) • σ s √ d for i ∈ [M ], and z i , z j ≤ O(σ 2 s • d log(d)) for all i = j. Condition 4.1 specifies the properties of the task-specific features in Definition 3.1. Here we require that the task-specific features to have weak correlations to each other, which is standard in dictionary learning. Moreover, it is easy to see that such z 1 , . . . , z M exist as long as M = O(poly(d)): with a simple Gaussian concentration argument, one can show that M independent zero-mean Gaussian random vectors with covariance matrix σ 2 s (I -νν / ν 2 2 ) are orthogonal to ν and satisfy Condition 4.1 with high probability. With Condition 4.1 and our data model in Definitions 3.1 and 3.2, it is necessary for a neural network to learn the shared feature ν to make accurate predictions on a new task. On the other hand, if a neural network only utilizes the task-specific features or the noises in the observed data to fit the labels, then it can only achieve a good training loss, but will have a poor prediction accuracy on new tasks. Following this intuition, we aim to construct a concrete setting under which the train-train and train-validation methods give different test losses while both achieving small training losses. The details of our constructed setting is summarized in the following condition. Condition 4.2. ν 2 = 1, σ ξ = d -1/2 • polylog(d), σ s = d -1/2 /polylog(d), n = Θ(1), K = polylog(d), m = polylog(d), Ω(d 1/2 ) ≤ M ≤ d/2. Condition 4.2 defines an over-parameterized setting where d Kn. Moreover, under Condition 4.2, the norms of the shared and task-specific features ν 2 , z 1 2 , . . . , z K 2 are all O(1). In comparison, with high probability, the norms of noise vectors in the data are Θ(polylog(d)). Therefore, Condition 4.2 defines a setting with relatively large noises. The FOMAML algorithm involves inner and outer-loop training of the neural network. We specify our detailed hyper-parameter configurations in the following condition. Condition 4.3. In the FOMAML algorithm defined in Eq. (3.1) and Eq. (3.2), we initialize the CNN weights W (0) by Gaussian random initialization with standard deviation σ 0 = d -1/2 . We set the inner-loop step size γ = polylog(d), and the outer-loop step size η = 1/polylog(d). We run T = poly(d) outer-loop iterations. Within each outer-loop iteration, we run J = 5 inner-loop gradient descent steps. In Condition 4.3, we use a slightly larger step size for inner-loop compared to that for the outer-loop. This is to ensure that the inner-loop updates can make a difference, and it matches the meta-learning practice. Moreover, we set J = 5 to demonstrate that our analysis applies to the case with multiple inner-loop iterations. The exact value of J is not essential -we can easily apply our analysis to other values of J = Θ(1). We are now in the position to state our main theoretical results. The theorem below gives the training and test loss guarantees for the train-train method. Theorem 4.4. Under Conditions 4.1, 4.2 and 4.3, suppose that one uses the train-train method to train the neural network. Then with probability at least 1 -(Kn) -10 , 1. the training loss is small: min t∈[T ] L tr-tr (W (t) , {S k } K k=1 ) ≤ O 1 poly(d) . 2. the test loss is large: min t∈[T ] L test (W (t) ) = Ω(1) . Theorem 4.4 demonstrates that FOMAML with train-train method can successfully train the CNN to minimize the training loss. However, it also shows that the train-train method fails on new tasks and can only achieve a constant level test loss. In comparison, we have the following theorem for the train-validation method. Theorem 4.5. Under Conditions 4.1, 4.2 and 4.3, suppose that one uses the train-validation method to train the neural network. Then with probability at least 1 -(Kn) -10 , 1. the training loss is small: min t∈[T ] L tr-val (W (t) , {S k } K k=1 ) ≤ O 1 poly(d) . 2. the test loss is also small: there exists a constant c > 0 such that L test (W (T ) ) = O(exp(-K c )) . Although both methods can achieve a 1/poly(d) training loss, the train-validation method is able to achieve a test loss of O(exp(-K c )) compared to a constant test loss Ω(1) by the train-train method. Thus, Theorems 4.4 and 4.5 show that for our data model, it is necessary to perform a train-validation split to achieve good performance on test data. We also compare FOMAML with the vanilla supervised learning framework where one simply combines all data from the K tasks together and uses gradient descent to minimize the overall crossentropy loss. In fact, the results of Theorem 4.4 still hold even for J = 0, where algorithm Eq. (3.1) reduces to gradient descent for vanilla supervised learning. Therefore, vanilla supervised learning can also only achieve Θ(1) test loss. Clearly, our results imply that FOMAML with train-validation split can significantly outperform vanilla supervised learning when learning our data model. Comparison with Bai et al. (2021) . Recently, Bai et al. (2021) studied the importance of the trainvalidation split in meta-learning under the linear centroid framework. Specifically, they considered linear ridge regression in the inner-loop and linear ridgeless regression in the outer-loop, and showed that the train-train method strictly outperforms the train-validation method in the realizable setting. As we have discussed in Section 3, our data model is linearly realizable, and thus by Bai et al. (2021) , the train-train method should outperform the train-validation method when learning our data model with linear ridge/ridgeless regression. On the contrary, Theorems 4.4 and 4.5 demonstrate that the train-validation method still significantly outperforms the train-train method for FOMAML in learning CNNs. We believe the key reason behind this difference is that the CNN model has a much higher expressive power, and therefore can more easily overfit noises in the training data points when using FOMAML. The train-validation split thus greatly helps feature learning of the neural network by performing a sample splitting between the inner and outer-loops. Therefore, our results indicate that train-validation split may have a more significant advantage when learning complicated prediction models.

5.1. EXPERIMENTAL SETUPS

Synthetic data. We generate synthetic data to test our theory. For our data generation we choose: d = 1000, K = 343, n = 10, σ ξ = 10.42, σ s = 0.00066, ν 2 = 1. For our neural network we choose: m = 18, σ 0 = 0.032. And finally we choose the following parameters for inner and outer level optimization: γ = 0.001, J = 5, η = 0.0001. Real-world Data. In our experiments, we further justify our theoretical findings in two real-world datasets: RainbowMNIST, miniImagenet, which are discussed as follows. • RainbowMNIST. Following (Yao et al., 2021) , RainbowMNIST is a 10-way meta-learning dataset built upon original MNIST dataset, where each task is constructed by applying one combination of image transformations (e.g., coloring, rotation) on the original data. Here, 40 and 16 combinations are used for meta-training and meta-testing, respectively. • miniImagenet. Following the traditional meta-learning setting (Finn & Levine, 2017; Snell et al., 2017) , miniImagenet dataset is split into meta-training, meta-validation and meta-testing classes, where 64/16/20 classes are used for meta-training/validation/testing. We adopt the traditional Nway, K-shot setting to split the training and validation set in our experiment, where N=5 and K=1 in this paper (i.e., 5-way, 1-shot learning). Backbones and Hyperparameters. For all real-world datasets, follow (Finn & Levine, 2017; Snell et al., 2017) , we adopt the standard four-block convolutional layers as the base learner, where Huberized-ReLU is used as the activation function. The number of inner-loop steps is set as 5. The inner-loop and outer-loop learning rates are set as: 0.01 and 0.001 (miniImagenet), 0.1 and 0.01 (RainbowMNIST), respectively. We report the average accuracy with 95% confidence interval over all meta-testing tasks. For synthetic data, we follow our theoretical analysis and report the crossentropy loss with 95% confidence interval over all meta-testing tasks. The inner-loop and out-loop learning rates are 0.01, 0.001, respectively.

5.2. RESULTS

Comparison between Different Activation Functions. We first conduct experiments to compare the performance between Huberized-ReLU activation function used in our theoretical analysis and the traditional ReLU function. The results are reported in Table 1 , indicating that Huberized-ReLU performs similarly to ReLU. Thus it is a reasonable replacement to simplify our analysis. 2 . According to the results, we observe that the train-validation split performs much better than the train-train method with more inner-loop steps. This is what we expected since less number of inner-loop steps corresponds to less overfitting to the training set during the inner-loop optimization. In the case of using less inner-loop steps, using the training set in the outer-loop optimization can still contribute to the optimization. 3 show that train-validation split performs better when the inner-loop learning rate is smaller than the outer-loop learning rate, which corroborates our theoretical results. On synthetic data, the train-train method performs better when the outer-loop learning rate larger than the inner-loop learning rate. This is not surprising since smaller inner-loop learning rate corresponds to less overfitting, similar to the phenomenon when using less inner-loop optimization steps (Table 2 ). 

6. OVERVIEW OF PROOF TECHNIQUES

From here onward, we will use upper script (t, τ ) on w j,r to denote the t-th outer-loop iteration and the τ -th inner-loop iteration. However, according to Eq. (3.1) and Eq. (3.2), the inner-loop updates Published as a conference paper at ICLR 2023 for a CNN filter w j,r is task-specific. Therefore we cannot rigorously use the notation w (t,τ ) j,r without referring to a specific task k. Luckily, all our analyses are based on the study of inner products of the form w j,r , x (p) k,i . We can use the notation w (t,τ ) j,r , x (p) k,i as the inner product after t outer-loop updates followed by τ inner-loop updates, where the inner-loop updates only use samples from task k. With these notations, we give the following two key lemmas that highlight the main differences between the train-train and train-validation methods. Lemma 6.1. Suppose one uses the train-train method to train the neural network. Let T (1) k,i be the first iteration such that max r∈[m] w (t,0) j,r , ξ k,i ≥ Ω 1/polylog(d) for k ∈ [K], i ∈ [n] , and j = y k,i . Then, under the same condition as in Theorem 4.4, for any t ≤ T (1) k,i , 1. for any k ∈ [K], r ∈ [m], and j ∈ {-1, 1} we have j w (t,J) j ,r , ν k ≤ | w (t,0) j ,r , ν k |O(1)γ J . 2. for r = arg max r ∈[m] w (t,0) j,r , ξ k,i , w (t,J) j,r , ξ k,i ≥ w (t,0) j,r , ξ k,i Ω(1) γΘ(polylog(d)) J . Lemma 6.2. Suppose one uses the train-validation method to train the neural network. Let T (2) j be the first iteration such that max r∈[m] j w (t,0) j,r , ν ≥ Ω(1/polylog(d)) for j ∈ {-1, 1}. Then, under the same condition as in Theorem 4.5, for any t ≤ T (2) j , 1. for any k ∈ [K] and r = arg max r ∈[m] j w (t,0) j,r , ν j w (t,J) j,r , ν k = Ω(1)j w (t,0) j,r , ν k Θ(1)γ J . 2. for any r ∈ [m], k ∈ [K], i ∈ I val k , and j ∈ {-1, 1} w (t,J) j ,r , ξ k,i ≤ O max w (t,0) j ,r , ξ k,i , J O(d -1/2 ) . The above two lemmas give some intuitions for the better generalization power of the train-validation method. According to Lemma 6.1, the J inner-loop gradient steps amplify the noise inner products more than the feature inner product for the train-train method. On the contrary, by Lemma 6.2, the J inner-loop gradient steps amplify the feature inner product and do not have a big impact on the noise inner products for the train-validation method. In the rest of this section, we mainly sketch the proof of Theorem 4.5 for the train-validation method. Based on Lemma 6.2, we can see that the inner-loop updates in the train-validation method prioritizes feature learning over noise memorization. Our further study of the outer-loop training procedure gives the following lemma. Lemma 6.3. Under the same condition as in Theorem 4.5, let T (2) = max j T (2) j , where T (2) j is defined in Lemma 6.2. Then max r∈[m] w (T (2) ,0) j,r , ξ k,i = O(d -1/2 ) for any k ∈ [K], i ∈ I val k and j ∈ {-1, 1}. We remind our readers that T (2) (defined in Lemma 6.2) represents the time taken for the feature inner product to grow to Ω(1/polylog(d)). Therefore Lemma 6.3 is essentially still a comparison between the growth rate of feature and noise inner products: it shows that when the feature inner product grows to Θ(1), the noise inner products still remain at their initialization order which is O(d -1/2 ). The next lemma provides a key result for the convergence of the training loss. Lemma 6.4. Let T = poly(d) be the total number of iterations. Under the same condition as in Theorem 4.5, we have min t∈[T (2) ,T -1] (k,i)∈Ψ - (t,J) k,i ≤ O 1 T η , where (t,J) k,i = (y k,i • f ( W(W (t) , S tr k ), x k,i )). By definition, the feature inner product has grown to Ω(1/polylog(d)) at time T (2) . After T (2) , we implement a more careful study of the training process, as some data points may have been wellfitted (with a small loss) and no longer significantly contribute to the training. Our next theorem shows that using the train-validation method, after T (2) , the feature inner product will grow even larger and will be at least Ω(K c ) for some c > 0 at the end of the training. Lemma 6.5. Under the same condition as in Theorem 4.5, suppose one runs the train-validation method for a total of T = poly(d) iterations. Then max r∈[m] j w (T,0) j,r , ν ≥ Ω(K c ) for all j ∈ {-1, 1} and for some c > 0. Lemma 6.5 shows that at the end of training, the neural network has sufficiently learned the shared feature ν. When given a fresh task, we should expect a small test loss because the shared feature ν will also be present in this newly sampled task. We are now ready to present the proof of Theorem 4.5. Proof of Theorem 4.5. By definition, L tr-val (W (t) , {S k } K k=1 ) = k∈[K] i∈I val k (t,J) k,i , where (t,J) k,i = (y k,i • f ( W(W (t) , S tr k ), x k,i )) . Then by Lemma 6.4 and the property of the cross-entropy loss that -(x) ≥ exp(-x)/2 ≥ (x)/2 for x > 0, we have min t∈[T ] L tr-val (W (t) , {S k } K k=1 ) ≤ min t∈[T (2) ,T -1] -2 k∈[K] i∈I val k (t,J) k,i = O 1 T η = O 1 poly(d) . This proves the first part of Theorem 4.5. For the second part, consider a new data point (x, y). By Definition 3.1 and Definition 3.2, the new data input x consists of two patches, one of which is a noise vector ξ that is generated from N (0, σ 2 ξ •(I-P)). Denote by E the event that | ξ k,i , ξ | ≤ d -1/4 and | w (0,0) j,r , ξ | ≤ d 3/2 for all k ∈ [K], i ∈ [n], r ∈ [m] and j ∈ {-1, 1} , where ξ is the noise vector from the new data x. Then using Gaussian concentration, we have P(E) ≥ 1 -O(exp(-d 1/4 )). We divide L test (W (T ) ) into two parts: L test (W (T ) ) = E yf (W (T ) , x) = E[1(E) yf (W (T ) , x) ] I1 + E[1(E c ) yf (W (T ) , x) ] I2 . From this decomposition, it is clear that bounding I 1 is more important, since P(E c ) is exponentially small, which makes I 2 small. Indeed, one can bound I 2 = O(poly(d)) exp(-0.5d 1/4 ). Moreover, under the event E, it holds that F -y (W (T ) -y , x) < log(2), and F y (W (T ) y , x) ≥ σ max r∈[m] y • w (T,0) y,r , ν . Applying Lemma 6.5, we have yf (W (T ) , x) = F y (W (T ) y , x) -F -y (W (T ) -y , x) ≤ (K c -log(2)) ≤ 2 exp(-K c ) where the last inequality is by the definition of (•) and the inequality log(1+x) ≤ x, ∀x ≥ 0. Therefore we have that I 1 ≤ 2 exp(-K c ). Combining the bounds on I 1 and I 2 yields L test (W (T ) ) ≤ 2 exp(-K c ) + O(poly(d)) exp(-0.5d 1/4 ) = O exp(-K c ) , which proves the theorem.

7. CONCLUSION AND FUTURE WORK

In this work, we study the FOMAML algorithm applied to a classification problem with a two-layer CNN trained by gradient descent. We proved that although both train-train and train-validation methods can achieve a small training loss, to get good generalization results, it is necessary to perform a train-validation split in the data when the noise is large and the number of samples is limited. It is of interest to extend our result to more other types of data model (for example, natural language data), and deeper neural network structures. We would also like to extend our analysis to other popular meta-learning algorithms such as MAML (with hessian term), iMAML (Rajeswaran et al., 2019) , Meta-MiniBatchProx (Zhou et al., 2019) , Reptile (Nichol et al., 2018) and closed-form solvers (Bertinetto et al., 2018) . A COMPARISON WITH BAI ET AL. (2021) In this section, we point out some differences between our experiment setup and that of Bai et al. (2021) , which also studies the importance of a train-validation split in meta-learning. The main conclusion of Bai et al. (2021) is that the train-train method could outperform the train-validation method asymptotically under a linear centroid model. They have also run experiments using CNN as backbone to support their theory (see Table 1 of Bai et al. (2021) ). This may seem to contradict our experiments at first sight, since we showed in Table 1 that the train-validation method should outperform the train-train method using CNN as backbone. We would like to point out that this is due to the different choice of loss function and optimization algorithms used by us and Bai et al. (2021) as explained as follows.

A.1 REGULARIZER IN INNER-LOOP

Let us focus on the train-train method. Our loss function is given by L tr-tr (W, {S k } K k=1 ) = 1 K K k=1 L( W(W, S k ), S k ) , where W(W, S k ) represents the weights of the network after J gradient descent steps (w.r.t. loss L(•, S k )) starting from W. While Bai et al. (2021) has their loss function given by L tr-tr (W, {S k } K k=1 ) = 1 K K k=1 L( W(W, S k ), S k ) , (A.1) where W(W, S k ) = arg min W L(W , S k ) + λ W -W 2 F , (A.2) where λ is a regularization parameter. This regularizer is popularized by Rajeswaran et al. (2019) and Zhou et al. (2019) . In practice, however, many meta-learning methods do not use this regularizer in the inner-loop (add citations).

A.2 EXPERIMENT ALGORITHMS

To compare the performance of train-validation and train-train methods, Bai et al. (2021) used iMAML (Rajeswaran et al., 2019) for the train-validation method and Meta-MiniBatchProx (Zhou et al., 2019) for the train-train method. Both algorithms were developed to target the loss function given by Eq. (A.1). In particular, iMAML used implicit function theorem to calculate the gradient ∇ W W(W, S k ), and then uses this information to do gradient descent on W with respect to the loss function given by Eq. (A.1). iMAML is similar to MAML in the sense that both algorithms perform gradient descent on W with respect to their respective loss functions L tr-tr . However, the structure of Meta-MiniBatchProx is different from that of iMAML and MAML. It is more closely related to Reptile (Nichol et al., 2018) , which is another first order meta-learning algorithm. Instead of performing gradient descent in the outer-loop, Reptile updates W at each step as a convex combination of W and W(W, S k ). This gives W ← W + ε 1 K K k=1 W(W, S k ) -W , (A.3) for some ε ∈ (0, 1). And Meta-MiniBatchProx replaces W(W, S k ) with W(W, S k ) in the above. In our experiments, we used FOMAML for both train-validation and train-train. We believe that by using the same algorithm for train-validation with train-train, the comparison is fairer.

B ADDITIONAL EXPERIMENTS

B.1 COMPARISON WITH REPTILE (NICHOL ET AL., 2018) We compare FOMAML (using the train-validation method) with Reptile, which is another firstorder algorithm used in meta-learning. Different from FOMAML, Reptile does not require a trainvalidation split in the data set. In addition, the outer-loop of Reptile does not perform gradient descent, but use a convex combination between the unadapted weights and the task-adapted weights (see Eq. (A.3)). The results are summarized in Table 4 . We see that on both RainbowMNIST and miniImagenet, FOMAML with train-validation outperforms Reptile by a small margin. 5 . We see that FOMAML with train-validation outperforms FOMAML with train-train by a large margin under all 3 neural network structures. 

B.4 EFFECT OF THE NUMBER OF SAMPLES PER TASK

We then discuss the performance with respect to the number of examples. In Table 6 , we report the performance on miniImagenent when the number of examples per class is 1, 3, 5, which corresponds to 5-way 1-shot, 5-way 3-shot and 5-way 5-shot settings. We observe that as the number of samples per class increases, both the train-train and train-validation methods receive a performance increase.

C PRELIMINARY CALCULATIONS

In this section, we present some calculations that are useful for our derivations later. Recall that (z) = log(1 + exp(-z)), which gives (z) = -(1 + exp(z)) -1 . Then, we can compute the partial derivative of the loss function using samples from task k as  ∂ ∂w j,r L(W, S k ) = 1 n n i=1 ∂ ∂w j,r (y k,i • f (W, x k,i )) = 1 n n i=1 (y k,i • f (W, x k,i )) ∂ ∂w j,r F j (W j , x k,i )j = 1 n n i=1 (y k,i • f (W, x k,i ))y k,i j 2 p=1 σ ( w j,r , x (p) k,i )x (p) k,i . Similarly, if we only evaluate the loss on the training set of S k we have ∂ ∂w j,r L(W, S tr k ) = 1 n 1 n1 i=1 (y k,i • f (W, x k,i ))y k,i j 2 p=1 σ ( w j,r , x (p) k,i )x (p) k,i . Let t ≥ 0 and τ ∈ [0, J] be the indices for the outer-loop and inner-loop respectively. We remind the readers our notation. Denote (t,J) k,i = (y k,i • f ( W(W (t) , S k ), x k,i )) when referring to the traintrain method and (t,J) k,i = (y k,i • f ( W(W (t) , S tr k ), x k,i )) when referring to the train-validation method. At outer step t, using the samples from task k, we have the inner-loop updates as j w (t,τ +1) j,r , ν k = j w (t,τ ) j,r , ν k - γ n n i=1 (t,τ ) k,i σ ( w (t,τ ) j,r , ν k y k,i ) ν k 2 2 . (C.1) For noise ξ k,i , we have the inner-loop updates as (suppose we are using all the samples from task k for the inner-loop optimization) w (t,τ +1) j,r , ξ k,i = w (t,τ ) j,r , ξ k,i - γ n i =i (t,τ ) k,i y k,i jσ ( w (t,τ ) j,r , ξ k,i ) ξ k,i , ξ k,i - γ n (t,τ ) k,i y k,i jσ ( w (t,τ ) j,r , ξ k,i ) ξ k,i 2 2 , (C.2) where the first sum goes from i = 1 to i = n, omitting the term when i = i. If we consider a noise ξ k,i from the validation set of some task, i.e. k ∈ [K] and i ∈ I val k , and we only use the training set of task k for inner-loop updates, then w (t,τ +1) j,r , ξ k,i = w (t,τ ) j,r , ξ k,i - γ n 1 n1 i =1 (t,τ ) k,i y k,i jσ ( w (t,τ ) j,r , ξ k,i ) ξ k,i , ξ k,i . (C.3) If we consider a noise ξ k,i from the training set of a task, i.e. k ∈ [K] and i ∈ I tr k , and we only use the training set of task k for inner-loop updates, then w (t,τ +1) j,r , ξ k,i = w (t,τ ) j,r , ξ k,i - γ n 1 n1 i =i (t,τ ) k,i y k,i jσ ( w (t,τ ) j,r , ξ k,i ) ξ k,i , ξ k,i - γ n 1 (t,τ ) k,i y k,i jσ ( w (t,τ ) j,r , ξ k,i ) ξ k,i 2 2 , (C.4) where the sum goes from i = 1 to i = n 1 except when i = i. Using only the training set of task k, the feature inner product update is j w (t,τ +1) j,r , ν k = j w (t,τ ) j,r , ν k - γ n 1 n1 i=1 (t,τ ) k,i σ ( w (t,τ ) j,r , ν k y k,i ) ν k 2 2 . (C.5) Next we look at the outer-loop updates. For the train-train method, the noise update is w (t+1,0) j,r , ξ k,i = w (t,0) j,r , ξ k,i - η Kn (k ,i ) =(k,i) (t,J) k ,i y k ,i jσ ( w (t,J) j,r , ξ k ,i ) ξ k ,i , ξ k,i - η Kn (t,J) k,i y k,i jσ ( w (t,J) j,r , ξ k,i ) ξ k,i 2 2 , (C.6) where the sum is over all k ∈ [K] and i ∈ [n] except when (k , i ) = (k, i). Hence, the number of terms in the sum is Kn -1. And (t,τ ) k,i = y k,i j=±1 j m r=1 σ( w (t,τ ) j,r ξ k,i ) + σ( w (t,τ ) j,r ν k y k,i ) . (C.7) The arguments in σ (•) in Eq. (C.6) are given by Eq. (C.2). The outer-loop feature update is j w (t+1,0) j,r , ν = j w (t,0) j,r , ν - η Kn k,i (t,J) k,i σ ( w (t,J) j,r , ν k y k,i ) ν 2 2 , (C.8) where the arguments in σ (•) are given by Eq. (C.1). The arguments in σ(•) in Eq. (C.7)) are given by Eq. (C.2) and Eq. (C.1). For the train-validation method, consider a noise ξ k,i from the validation set of the k-th task, i.e. i ∈ I val k . We have w (t+1,0) j,r , ξ k,i = w (t,0) j,r , ξ k,i - η Kn 2 i >n1,(k ,i ) =(k,i) (t,J) k ,i y k ,i jσ ( w (t,J) j,r , ξ k ,i ) ξ k ,i , ξ k,i - η Kn 2 (t,J) k,i y k,i jσ ( w (t,J) j,r , ξ k,i ) ξ k,i 2 2 , (C.9) where the sum is over all k ∈ [K] and i ∈ I val k except when (k , i ) = (k, i). The arguments in σ (•) are given by Eq. (C.3). If we consider a noise ξ k,i from the training set, i.e. k ∈ [K] and i ∈ I tr k , then the outer-loop update is given by w (t+1,0) j,r , ξ k,i = w (t,0) j,r , ξ k,i - η Kn 2 k,i >n1 (t,J) k ,i y k ,i jσ ( w (t,J) j,r , ξ k ,i ) ξ k ,i , ξ k,i , (C.10) where the sum is over all k ∈ [K] and i ∈ I val k . And the arguments in σ (•) are given by Eq. (C.4). The outer-loop update of the feature using the train-validation method is similar to using the traintrain method: j w (t+1,0) j,r , ν = j w (t,0) j,r , ν - η Kn 2 k,i>n1 (t,J) k,i σ ( w (t,J) j,r , ν k y k,i ) ν 2 2 , (C.11) where the arguments in σ (•) in Eq. (C.11) are given by Eq. (C.5). Note that the term in Eq. (C.9), Eq. (C.10) and Eq. (C.11) are given by Eq. (C.7), but the arguments in σ(•) in Eq. (C.7)) are now given by Eq. (C.3) and Eq. (C.5).

D PRELIMINARY LEMMAS

We will work under the following parameters for the rest of the proof. Condition D.1. ν 2 = 1, σ 0 = 1 √ d , σ s = 1 √ d log(d) 0.4 , σ ξ = log(d) 0.5 √ d , n = Θ(1), K = log(d) 0.5 , m = log(d) 0.2 , η = 1 log(d) -5 , γ = log(d) 0.4 , J = 5, h = 1 2 , M = d/2. We will also assume that our data split is symmetric:  for each k ∈ [K], and j ∈ {-1, 1} {i : i ∈ I tr k , y k,i = j} > 0 , {i : i ∈ I val k , y k,i = j} > 0 . Lemma D. . ξ k,i 2 2 = Θ(σ 2 ξ d) for any k ∈ [K] and i ∈ [n]. 2. max | ξ k,i , ξ k ,i | : k, k ∈ [K], i, i ∈ [n], (k, i) = (k , i ) ≤ O(σ 2 ξ √ d log(Kn)). 3. max | w (0,0) j,r , ν | : r ∈ [m] = Θ(σ 0 log(mKn)) for any j ∈ {-1, 1}. 4. max | w (0,0) j,r , ξ k,i | : r ∈ [m] = Θ(σ 0 σ ξ √ d log(mKn)) for any j ∈ {-1, 1}, k ∈ [K], i ∈ [n]. 5. Ω(σ 0 ) ≤ max r∈[m] j w (0,0) j,r , ν ≤ O(σ 0 log(mKn)) for any j ∈ {-1, 1}. 6. Ω(σ 0 σ ξ √ d) ≤ max r∈[m] w (0,0) j,r , ξ k,i ≤ O(σ 0 σ ξ √ d log(mKn)) for any j ∈ {-1, 1}, k ∈ [K], i ∈ [n]. 7. max | w (0,0) j,r , z k | : r ∈ [m] = Θ(σ 0 σ s √ d log(mKn)) for any j ∈ {-1, 1}, k ∈ [K]. 8. Ω(σ 0 σ s √ d) ≤ max r∈[m] j w (0,0) j,r , z k ≤ O(σ 0 σ s √ d log(mKn)) for any j ∈ {-1, 1}, k ∈ [M ]. Proof. By Lemma G.1, we have with probability at least 1 -δ/3, 1 2 σ 2 ξ d ≤ ξ k,i 2 2 ≤ 3 2 σ 2 ξ d , | ξ k,i , ξ k ,i | ≤ 2σ 2 ξ √ d log 12(Kn) 2 /δ , for all k, k ∈ [K] and i, i ∈ [n] with (k, i) = (k , i ). By Lemma G.2, we have with probability at least 1 -δ/3, | w (0,0) j,r , ν | ≤ 2 log(24m/δ)σ 0 , | w (0,0) j,r , ξ k,i | ≤ 2 log(24mKn/δ)σ 0 σ ξ √ d , for all r ∈ [m], j ∈ {-1, 1}, k ∈ [K] and i ∈ [n]. Moreover, σ 0 2 ≤ max r∈[m] j w (0,0) j,r , ν ≤ 2 log(24m/δ)σ 0 , σ 0 σ ξ √ d 4 ≤ max r∈[m] w (0,0) j,r , ξ k,i ≤ 2 log(24mKn/δ)σ 0 σ ξ √ d , for all j ∈ {-1, 1}, k ∈ [K] and i ∈ [n]. Using similar ideas, we get that with probability at least 1 -δ/3, | w (0,0) j,r , z k | = | w (0,0) j,r , z k / z k 2 | z k 2 ≤ 3 2 σ s √ d| w (0,0) j,r , x | ≤ 3 2 2 log(12m/δ)σ 0 σ s √ d , for any k ∈ [K], r ∈ [m] , and j ∈ {-1, 1}. The second inequality is by Condition 4.1. The last inequality is again by Lemma G.2. Moreover, σ 0 σ s √ d 4 ≤ max r∈[m] j w (0,0) j,r , z k ≤ 3 2 2 log(12m/δ)σ 0 σ s √ d . Combining the above results, we get that all of the above events hold simultaneously with probability at least 1 -δ. Taking δ = (Kn) -10 gives the desired result. Remark D.3. Since we will use the estimates in Lemma D.2 in the rest of our proofs repeatedly, we will not mention the high probability bound in our theorems. It should be understood that our theorems hold with the high probability bound given in Lemma D.2. Remark D.4. Recall that under Condition D.1, the term log(mKn) = polyloglog(d). Its presence/absence will not affect any of our proofs. Lemma D.5. Let {x t , y t } be two positive sequences updated as x t+1 ≥ x t (1 + A) , y t+1 ≤ y t (1 + B) , where x 0 = o(1), y 0 = o(1), A = o(1), and B = o(1). For any D = O(1), let T be the first iteration that x t ≥ D. We have that y T ≤ O(Gy 0 ) if x 0 ≥ O G -A B . Proof. Let us consider the above two sequences x t and y t with the inequalities replaced with equality. Note that this will not affect our conclusion. Then we have x t = x 0 (1 + A) t , y t = y 0 (1 + B) t . And for simplicity, let us also assume that x T = D. Then we have D x0 = (1 + A) T , which entails T = log(D) + log( 1 x0 ) log(1 + A) = O log( 1 x0 ) A . (D.1) Let T be the first iteration such that y T ≥ Gy 0 . Again, for simplicity, let us suppose that y T = Gy 0 . Using the same procedure, we obtain that T = Ω log(G) B . (D.2) Applying our assumption that x 0 ≥ O G -A B to Eq. (D.1) and Eq. (D.2), we conclude that T < T .

E TRAIN-TRAIN METHOD E.1 PHASE I

In this section, we will show that under Condition D.1, our neural network will memorize the noise and will not learn the feature. Let Ξ (t) k,i = max w (t,0) j,r , ξ k,i : j = y k,i , r ∈ [m] and Λ (t) j = max j w (t,0) j,r , ν : r ∈ [m] . Let T (1) k,i be the first iteration such that Ξ (t) k,i ≥ m -1/2 1 + γσ 2 ξ d n -J = Θ(1) and let T (1) = max k,i T (1) k,i and ( k, î) = arg max k,i T k,i . So T (1) is the time for the slowest learned noise to have an inner product of size Θ(1). And this is witnessed by the noise vector ξ k,î . Lemma E.1 (Restatement of Lemma 6.1). Under Condition D.1 and Condition 4.1, if one uses the train-train method, then for any t ≤ T (1) 1. For any k ∈ [K], r ∈ [m], and j ∈ {-1, 1} we have j w (t,J) j,r , ν k ≤ | w (t,0) j,r , ν k | (1 + γO(1)) J . 2. For j = y k,î and r = arg max r ∈[m] w (t,0) j,r , ξ k,î , w (t,J) j,r , ξ k,î ≥ w (t,0) j,r , ξ k,î 1 + γ n Ω(1)Θ(σ 2 ξ d) J . Proof. The proof for both parts relies on the following hypothesis which we will verify inductively later: max r j w (t,0) j,r , z k = o(1)Λ (t) j , for all k ∈ [K], j ∈ {-1, 1}. (E.1) - (t,τ ) k,î = Θ(1), for all τ ∈ [J]. (E.2) j w (t,τ ) j,r , ν k = o(1), for all τ ∈ [J], k ∈ [K], r ∈ [m], j ∈ {-1, 1}. (E.3) σ ( w (t,τ ) j,r , ξ k,î ) = Θ(1) w (t,τ ) j,r , ξ k,î , for all τ ∈ [J] if j = y k,î and w (t,τ ) j,r , ξ k,î > 0. (E. 4) Let us first suppose that the above hypothesis hold. By Eq. (C.1), we know j w (t,τ ) j,r , ν k is an increasing sequence in τ . Without loss of generality, suppose that j w (t,0) j,r , ν k > 0. Then Hypothesis E.3 implies σ ( w (t,τ ) j,r , ν k j) = 2j w (t,τ ) j,r , ν k . Then using the fact that -(•) ≤ 1, we can apply Eq. (C.1) repeatedly to get j w (t,J) j,r , ν k ≤ j w (t,0) j,r , ν k (1 + γO(1)) J , (E.5) where we have also used ν k 2 2 = ν 2 2 + z k 2 2 = 1 + o(1) . This proves part (1) of the lemma. Let j = y k,î . Define r(t) = arg max r ∈[m] w (t,0) j,r , ξ k,î . By Eq. (C.2), we have w (t,τ +1) j,r(t) , ξ k,î ≥ w (t,τ ) j,r(t) , ξ k,î - γ n i =î O(1)| ξ k,i , ξ k,î | - γ n (t,τ ) k,î σ ( w (t,τ ) j,r(t) , ξ k,î ) ξ k,î 2 2 ≥ w (t,τ ) j,r(t) , ξ k,î -γO(1)Θ(σ 2 ξ √ d log(Kn)) + γ n Θ(1)σ ( w (t,τ ) j,r(t) , ξ k,î )Θ(σ 2 ξ d) , (E.6) where the first inequality is because -(•) ≤ 1 and σ (•) ≤ 1. And we have used Lemma D.2 and Hypothesis E.2 to get the second inequality. Let us compare the size of the second and third term on the right hand side of Eq. (E.6). At (t, τ ) = (0, 0), if we leave out γ in both terms and consider, the third term is of size Θ(σ 0 σ 3 ξ d 3/2 /n) = Θ(log(d) 1.5 / √ d), whereas the second term always has size O(σ 2 ξ √ d log(Kn)) = O(log(d)/ √ d)O(polyloglog(d)) . Hence the third term will dominate the second term at (t, τ ) = (0, 0). Note that the size of the second term does not change as t and τ increase, meaning that if at time (t ≥ 0, τ ≥ 0) we have w (t,τ ) j,r(t) , ξ k,î = Ω(1)Ξ (0) k,î , then it always holds that the third term dominates the second term on the right hand side of Eq. (E.6), which implies w (t,τ +1) j,r(t) , ξ k,î ≥ w (t,τ ) j,r(t) , ξ k,î + γ n Ω(1)σ ( w (t,τ ) j,r(t) , ξ k,î )Θ(σ 2 ξ d) . (E.7) Applying Eq. (E.7) J times at t = 0 and using Hypothesis E.4, we obtain w (0,J) j,r(0) , ξ k,î ≥ w (0,0) j,r(0) , ξ k,î 1 + γ n Ω(1)Θ(σ 2 ξ d) J . (E.8) By Eq. (C.6), we have w (t+1,0) j,r(t) , ξ k,î ≥ w (t,0) j,r(t) , ξ k,î - η Kn (k,i) =( k,î) O(1)| ξ k,i , ξ k,î | + η Kn Θ(1)σ ( w (t,J) j,r(t) , ξ k,î ) ξ k,î 2 2 ≥ w (t,0) j,r(t) , ξ k,î -ηO(1)Θ(σ 2 ξ √ d log(Kn)) + η Kn Θ(1)σ ( w (t,J) j,r(t) , ξ k,î )Θ(σ 2 ξ d) , (E.9) where we have used Hypothesis E.2 and the fact that -(•) ≤ 1 and σ (•) ≤ 1 to get first inequality. Let us compare the second and third term on the right hand side of Eq. (E.9). At t = 0, using Eq. (E.8) we have that . Thus the third term dominates the second term in Eq. (E.9) at t = 0. And we get σ ( w (0,J) j,r(0) , ξ k,î ) ≥ Θ(1) w (0,0) j,r(0) , ξ k,î 1 + γ n Ω(1)Θ(σ 2 ξ d) J . (E. w (1,0) j,r(1) , ξ k,î ≥ w (1,0) j,r(0) , ξ k,î ≥ w (0,0) j,r(0) , ξ k,î 1 + η Kn 1 + γ n Ω(1)Θ(σ 2 ξ d) J Θ(σ 2 ξ d) = w (0,0) j,r(0) , ξ k,î 1 + Ω(1) ησ 2 ξ d Kn γσ 2 ξ d n J where the first inequality in the first line is by definition of r(t). We note that the size of the second term on the right hand side of Eq. (E.9) does not change as t and τ increase. Thus, if w (t,0) j,r , ξ k,î = Ω(1)Ξ (0) k,î for some t > 0, Eq. (E.10) implies that σ ( w (t,J) j,r , ξ k,î ) = Ω(1)σ ( w (0,J) j,r , ξ k,î ). Then we have that the third term will dominate the second term on the right hand side of Eq. (E.9) at t > 0, which yields w (t+1,0) j,r(t) , ξ k,î > w (t,0) j,r(t) , ξ k,î since the third term on the right hand side of Eq. (E.9) is positive. We have shown that w (t+1,0) j,r(t+1) , ξ k,î ≥ w (t+1,0) j,r(t) , ξ k,î > w (t,0) j,r(t) , ξ k,î . Then using the same derivation of Eq. (E.8), we obtain w (t,J) j,r(t) , ξ k,î ≥ w (t,0) j,r(t) , ξ k,î 1 + γ n Ω(1)Θ(σ 2 ξ d) J . for any t ≤ T (1) . This proves the second part of the lemma. We have the following theorem characterizing the size of Λ T (1) j . Theorem E.2. Under Condition D.1 and Condition 4.1, if one uses the train-train method, then for any t ≤ T (1) 1. Λ t j = Θ(1)Λ 0 j = O(d -1/2 ), for j ∈ {-1, 1}. 2. For any k ∈ [K], i ∈ [n], j = -y k,i and r ∈ [m], we have w (t,0) j,r , ξ k,i ≤ O(d -1/2 ) . Proof. Recall that the outer-loop for the feature is given by j w (t+1,0) j,r , ν k = j w (t,0) j,r , ν k - η Kn k ,i (t,J) k ,i σ ( w (t,J) j,r , ν k y k ,i ) ν k , ν k , (E.11) where the summation is over all k ∈ [K] and i ∈ 1) > 0, we get that j w (t,0) j,r , ν k is an increasing sequence in t for any k ∈ [K] and r ∈ [m]. And Eq. (C.8) shows that j w (t,0) j,r , ν is an increasing sequence in t. We have [n]. Since ν k , ν k = ν 2 + z k , z k ≥ 1 -o( σ ( w (t,J) j,r ν k y k,i ) ≤ 2j w (t,J) j,r ν k ≤ 2j w (t,0) j,r , ν k (1 + γO(1)) J = O(1)j w (t,0) j,r , ν (1 + γO(1)) J , where is first inequality is by Hypothesis E.3 and our assumption that j w (t,0) j,r , ν k > 0. The second inequality is by Lemma E.1. And the last equality is due to our Hypothesis E.1. Without loss of generality, consider j = 1. Let r(t) = arg max r ∈[m] w (t,0) 1,r , ν . We can upper bound the growth of the feature in the outer-loop by Λ (t+1) 1 = w (t+1,0) 1,r(t+1) , ν ≤ w (t,0) 1,r(t+1) , ν 1 + ηO(1) 1 + γΘ(1) J = w (t,0) 1,r(t+1) , ν 1 + O(1)ηγ J ≤ w (t,0) 1,r(t) , ν 1 + O(1)ηγ J = Λ (t) 1 1 + O(1)ηγ J , (E.12) where we have also used the fact that -(•) < 1 by definition. Note that Eq. (E.12) also holds for j = -1, i.e., Λ (t+1) -1 ≤ Λ (t) -1 1 + O(1)ηγ J using the same derivation. Next, we want to find a lower bound on the learning speed of the noise. After that, we wish to show that the lower bound for the noise actually grows much faster than the upper bound of the memorization of the feature, which implies that noise will be learnt sufficiently well before our neural network can pick up any learning on the feature. Consider j = y k,î . Denote r (t) = arg max r∈[m] w (t,0) j,r , ξ k,î . Plugging Lemma E.1 and Hypothesis E.4 into Eq. (E.9) we have the following upper bound hold for all t ≤ T (1) w (t+1,0) j,r (t+1) , ξ k,î ≥ w (t+1,0) j,r (t) , ξ k,î ≥ w (t,0) j,r (t) , ξ k,î 1 + Ω(1) ησ 2 ξ d Kn γσ 2 ξ d n J , (E.13) where the first inequality is by definition of r (t). Now we can compare the learning speed of the feature and noise: Λ (t+1) j ≤ Λ (t) j 1 + O(1)ηγ J , (E.14) Ξ (t+1) k,î ≥ Ξ (t) k,î 1 + Ω(1) ηγ J K σ 2 ξ d n J+1 . (E.15) We can apply Lemma D.5 once we check all of its conditions are satisfied. We have Λ (0) j = Θ(d -1/2 ) = o(1), Ξ (0) k,î = Θ(d -1/2 ) = o(1). Let G = 2, A = O(1)ηγ J and B = Ω(1) ηγ J K σ 2 ξ d n J+1 , we have Ξ (0) k,î = Θ(d -1/2 ) > G -A/B = 2 -log(d) 5.5 . Therefore, by Lemma D.5, for any t ≤ T (1) , we have Λ (t) j ≤ 2Λ j . This proves part (1) of our theorem. We now prove part (2) of our theorem. Let k ∈ [K], i ∈ [n], r ∈ [m] and j = -y k,i . By Eq. (C.6) w (t+1,0) j,r , ξ k,i ≤ w (t,0) j,r , ξ k,i - η Kn (k ,i ) =(k,i) (t,J) k ,i y k ,i jσ ( w (t,J) j,r , ξ k ,i ) ξ k ,i , ξ k,i ≤ w (t,0) j,r , ξ k,i + η Kn (k ,i ) =(k,i) O(1)| ξ k ,i , ξ k,i | ≤ w (t,0) j,r , ξ k,i + ηO(1)Θ(σ 2 ξ √ d log(Kn)) , (E.16) where the second inequality is because both -(•) and σ (•) are bounded by 1. Using Eq. (E.15) and the definition of T (1) , we can calculate that T (1) = O(1) under Condition D.1. Therefore, for any t ≤ T (1) -1 we have max r∈[m] w (t+1,0) j,r , ξ k,i ≤ max r∈[m] w (0,0) j,r , ξ k,i + tηO(1)Θ(σ 2 ξ √ d log(Kn)) ≤ O(σ 0 σ ξ √ d log(mKn)) + T (1) ηO(1)Θ(σ 2 ξ √ d log(Kn)) = O(d -1/2 ) , where we have used Lemma D.2 to get the second inequality. It remains to verify Hypothesis E.1-E.4. Let us suppose that Hypothesis E.1-E.4 hold for all t < T (1) . Then we have Λ (t+1) j ≤ 2Λ j , and w (t+1,0) j,r , ξ k,i ≤ O(d -1/2 ) for any k ∈ [K], i ∈ [n], r ∈ [m] and j = -y k,i by Theorem E.2. Proof of Hypothesis E.1. For each task-specific feature z k , we have at t = 0 using Lemma D.2 Ω(σ 0 σ s √ d) ≤ max r j w (0,0) j,r , z k ≤ O(σ 0 σ s √ d log(mKn)) , max r | w (0,0) j,r , ν | = Θ(σ 0 log(mKn)) . Since σ s √ d = o(1) , we see that Hypothesis E.1 holds at t = 0. We want to prove that it holds at t + 1. For any k ∈ [K], r ∈ [m] we have j w (t+1,0) j,r , z k = j w (t,0) j,r , z k - η Kn k =k n i=1 (t,J) k ,i σ ( w (t,J) j,r , ν k y k ,i ) z k , z k - η Kn n i=1 (t,J) k,i σ ( w (t,J) j,r , ν k y k,i ) z k ≤ j w (t,0) j,r , z k - η Kn k =k n i=1 (t,J) k ,i σ ( w (t,J) j,r , ν k y k ,i )Θ(σ 2 s √ d log(d)) - η Kn n i=1 (t,J) k,i σ ( w (t,J) j,r , ν k y k,i )Θ(σ 2 s d) ≤ j w (t,0) j,r , z k - η Kn K k =1 n i=1 (t,J) k ,i σ ( w (t,J) j,r , ν k y k ,i )Θ(σ 2 s d) , (E.17) where we used Condition 4.1 to get the first inequality. The last inequality is because log(d) d. Suppose that at time t + 1, there exists some r ∈ [m] and k ∈ [K] such that j w (t+1,0) j,r , z k ≥ 2 max r j w (0,0) j,r , z k ≥ Ω(σ s σ 0 √ d) . Then, by Eq. (E.17) we get - η Kn t t =0 k,i (t ,J) k,i σ ( w (t ,J) j,r , ν k y k,i )Θ(σ 2 s d) ≥ 2 max r j w (0,0) j,r , z k -j w (0,0) j,r , z k ≥ Θ(σ s σ 0 √ d) , - η Kn t t =0 k,i (t,J) k,i σ ( w (t,J) j,r , ν k y k,i ) ≥ Θ σ 0 σ s √ d = Θ log(d) 0.2 √ d . Here the second inequality is by Lemma D.2. But then, Eq. (C.8) implies j w (t+1,0) j,r , ν = j w (0,0) j,r , ν - η Kn t t =0 k,i (t,J) k,i σ ( w (t,J) j,r , ν k y k,i ) ≥ min r∈[m] j w (0,0) j,r , ν + Θ log(d) 0.2 √ d ≥ -Θ(σ 0 log(mKn)) + Θ log(d) 0.2 √ d = Θ(log(d) 0.1 )Λ j . We used Lemma D.2 to get the third line and last line. This is a contradiction, since we should have j w (t+1,0) j,r , ν ≤ Λ (t+1) j ≤ 2Λ (0) j . Therefore for all k ∈ [K] we conclude max r j w (t+1,0) j,r , z k ≤ 2 max r j w (0,0) j,r , z k = o(1) max r j w (0,0) j,r , ν = o(1) max r j w (t+1,0) j,r , ν . This proves that Hypothesis E.1 holds at time t + 1. Proof of Hypothesis E.2. Without loss of generality, let us assume that y k,î = 1. By our estimates in Lemma D.2, Hypothesis E.2 clearly holds at initialization. Suppose it holds at time t. Recall Eq. (C.7) and that j=±1 m r=1 j σ( w (t+1,τ ) j,r , ξ k,î )+σ( w (t+1,τ ) j,r , ν k ) ≤ m r=1 σ( w (t+1,τ ) 1,r , ξ k,î ) + σ( w (t+1,τ ) 1,r , ν k ) , (E.18) by the non-negativity of σ(•). For any r ∈ [m], we have j w (t+1,0) j,r , ν k = j w (t+1,0) j,r , z k + j w (t+1,0) j,r , ν ≤ arg max r ∈[m] j w (t+1,0) j,r , z k + arg max r ∈[m] j w (t+1,0) j,r , ν = (1 + o(1)) arg max r ∈[m] j w (t+1,0) j,r , ν = O(1)Λ (t+1) j ≤ 2Λ (0) j , where we used Hypothesis E.1 to get the third line. Then, by Eq. (E.5) we have that for any τ ∈ [J] w (t+1,τ ) 1,r , ν k ≤ w (t+1,0) 1,r , ν k (1 + γO(1)) τ = O(1)Λ (t+1) 1 (1 + γO(1)) τ = O(1)Λ (0) 1 . (E.19) This implies m r=1 σ( w (t+1,τ ) 1,r , ν k ) ≤ O(1)mΛ (0) 1 = o(1) . We also have that m r=1 σ( w (t+1,0) 1,r , ξ k,î ) ≤ mσ(Ξ (t+1) k,î ) = o(1) . where the inequality is due to t + 1 < T (1) which implies Ξ (t+1) k,î < m -1/2 1 + γσ 2 ξ d n -J . Hence, by Eq. (E.18) we get that - (t+1,0) k,î = -(o(1)) = Ω(1). Now, suppose that - (t+1,τ ) k,î = Ω(1) for some 0 ≤ τ ≤ J -1. We wish to show that -  (t+1,τ +1) 1,r , ξ k,î ) ≤ mσ Ξ (t+1) k,î 1 + γ n O(1)Θ(σ 2 ξ d) τ +1 ≤ mσ m -1/2 1 + γσ 2 ξ d n -J 1 + γ n O(1)Θ(σ 2 ξ d) J ≤ O(1) , where the second inequality is because τ + 1 ≤ J. This implies - (t+1,τ +1) k,î = -(O(1)) = Ω (1) and we are done. Proof of Hypothesis E.3. By our estimates in Lemma D.2, Hypothesis E.3 clearly holds at initialization. Suppose it holds at time t. Using Eq. (E.19), we immediately have that for any τ ∈ [J], r ∈ [m] and k ∈ [K], j w (t+1,τ ) j,r , ν k ≤ O(1)Λ (0) j = o(1) . Proof of Hypothesis E.4. By our estimates in Lemma D.2, Hypothesis E.4 clearly holds at initialization. Suppose it holds at time t. Suppose that w (t+1,0) j,r , ξ k,î > 0 for j = y k,î and some r ∈ [m]. Recall that by the assumption of t + 1 ≤ T (1) , we have w (t+1,0) j,r , ξ k,î ≤ Ξ (t+1) k,î ≤ m -1/2 1 + γσ 2 ξ d n -J = o(1) . Hence, by definition of σ (•) we have σ ( w (t+1,0) j,r , ξ k,î ) = 2 w (t+1,0) j,r , ξ k,i . While w (t+1,τ ) j,r , ξ k,î = o(1), we have σ ( w (t+1,τ ) j,r , ξ k,î ) = 2 w (t+1,τ ) j,r , ξ k,î . By Eq. (C.2) we have w (t+1,τ +1) j,r , ξ k,î ≤ w (t+1,τ ) j,r , ξ k,î + γ n i =î O(1)| ξ k,i , ξ k,î | - γ n (t+1,τ ) k,î σ ( w (t+1,τ ) j,r , ξ k,î ) ξ k,î 2 2 = w (t+1,τ ) j,r , ξ k,î + γO(1)O(σ 2 ξ √ d log(Kn)) + γ n O(1)σ ( w (t+1,τ ) j,r , ξ k,î )Θ(σ 2 ξ d) = w (t+1,τ ) j,r , ξ k,î 1 + γ n O(1)Θ(σ 2 ξ d) + O(d -1/2 ) ≤ w (t+1,0) j,r , ξ k,î 1 + γ n O(1)Θ(σ 2 ξ d) τ +1 + O(d -1/2 )τ 1 + γ n O(1)Θ(σ 2 ξ d) τ = w (t+1,0) j,r , ξ k,î 1 + γ n O(1)Θ(σ 2 ξ d) τ +1 + O(d -1/2 ) ≤ Ξ (t+1) k,î 1 + γ n O(1)Θ(σ 2 ξ d) τ +1 + O(d -1/2 ) . (E.20) We used -(•) ≤ 1 and σ (•) ≤ 1 to get the first line. We used Lemma D.2 for the second line. It is straightforward to check recursively that w (t+1,τ ) j,r , ξ k,î ≤ o(1) for τ ∈ [J -1] and hence Eq. (E.20) holds for τ ∈ [J -1]. In the meantime, this also implies σ ( w (t+1,τ ) j,r , ξ k,î ) = 2 w (t+1,τ ) j,r , ξ k,î for τ ∈ [J -1]. Then Eq. (E.20) implies w (t+1,J) j,r , ξ k,î ≤ Ξ (t+1) k,î 1 + γ n O(1)Θ(σ 2 ξ d) J + O(d -1/2 ) ≤ O(m -1/2 ) = o(1) . Therefore, σ ( w (t+1,J) j,r , ξ k,î ) = 2 w (t+1,J) j,r , ξ k,î . We note that by our definition, T (1) is the time taken for the slowest-learnt noise to be memorized by our neural network. Other noises are learned faster. Our next theorem guarantees that other noises has almost the same inner product as the slowest-learnt noise at time T (1) . Theorem E.3. Under Condition D.1 and Condition 4.1, suppose one uses the train-train method. For any k ∈ [K], i ∈ [n] with (k, i) = ( k, î), let t ∈ [T k,i , T (1) ], then we have Ξ (t) k,i = Θ(1) . More precisely, it holds that Ξ (t) k,i ≥ Ω m -1/2 1 + γσ 2 ξ d n -J . And Ξ (t) k,i ≤ O m -1/2 1 + γσ 2 ξ d n -J+0.5 . Proof. We first show that Ξ (t) k,i ≥ Ω m -1/2 1 + γσ 2 ξ d n -J = Ω(1). By definition of T (1) k,i , we know that Ξ (t) k,i ≥ m -1/2 1 + γσ 2 ξ d n -J when t = T (1) k,i . Suppose there exists t ∈ [T (1) k,i , T (1) ] such that it is the first iteration such that Ξ (t) k,i < m -1/2 1 + γσ 2 ξ d n -J . If such t does not exist, then we automatically have Ξ (t) k,i ≥ m -1/2 1 + γσ 2 ξ d n -J for all t ∈ [T k,i , T (1) ]. By definition of t we have Ξ (t-1) k,i ≥ m -1/2 1 + γσ 2 ξ d n -J . By Eq. (E.9), we have the following lower bound on Ξ (t) k,i Ξ (t) k,i ≥ Ξ (t-1) k,i -ηO(1)Θ(σ 2 ξ √ d log(Kn)) ≥ m -1/2 1 + γσ 2 ξ d n -J -O(d -1/2 ) = Ω m -1/2 1 + γσ 2 ξ d n -J , where in the first inequality we used -(•) ≤ 1 and σ (•) ≤ 1 and Lemma D.2. The second and the third line are direct calculations under Condition D.1. Using the same proof as Hypothesis E.2, we have that - (t,τ ) k,i = Ω(1) for any τ ∈ [J]. Using a similar derivation of Eq. (E.13), we found that Ξ (t+1) k,i ≥ Ξ (t) k,i 1 + Ω(1) ησ 2 ξ d Kn γσ 2 ξ d n J ≥ Ξ (t-1) k,i -O(d -1/2 ) 1 + Ω(1) ησ 2 ξ d Kn γσ 2 ξ d n J > Ξ (t-1) k,i , where the last inequality is because Ξ (t-1) k,i ≥ m -1/2 1 + γσ 2 ξ d n -J = Ω(1) O(d -1/2 ), 1+Ω(1) ησ 2 ξ d Kn γσ 2 ξ d n J = Ω(polylog(d)) O( 1). This shows that even if Ξ (t) k,i < Ξ (t-1) k,i at some t, we still have Ξ (t) k,i = Ω m -1/2 1+ γσ 2 ξ d n -J = Ω(1) and Ξ (t+1) k,i > Ξ (t-1) k,i . Hence, Ξ (t) k,i ≥ Ω(1) for all t ∈ [T (1) k,i , T (1) ] . Next, we need to show that Ξ (t) k,i ≤ O m -1/2 1 + γσ 2 ξ d n -J+0.5 . Let t = inf t ∈ (T k,i , T (1) ] : Ξ ( t) k,i ≥ m -1/2 1 + γσ 2 ξ d n -J+0.5 , ∀ t ∈ [t, T (1) ] , where the inf of an empty set is defined as ∞. If t = ∞ then we are done, because that implies .5 . By definition of t , we know Ξ (T (1) ) k,i < m -1/2 1 + γσ 2 ξ d n -J+0.5 . Assume t is finite. We first need to show that Ξ (t ) k,i = O m -1/2 1 + γσ 2 ξ d n -J+0 Ξ (t -1) k,i < m -1/2 1 + γσ 2 ξ d n -J+0.5 . By Eq. (C.6), we have Ξ (t ) k,i ≤ Ξ (t -1) k,i + O(d -1/2 ) + η Kn Θ(σ 2 ξ d) < m -1/2 1 + γσ 2 ξ d n -J+0.5 + O(d -1/2 ) + Θ ησ 2 ξ d Kn ≤ 1.5 m -1/2 1 + γσ 2 ξ d n -J+0.5 , where we used (•) ≤ 1, σ (•) ≤ 1 and Lemma D.2 in the first inequality. The last inequality is because O(d -1/2 ) + Θ ησ 2 ξ d Kn = o(1)m -1/2 1 + γσ 2 ξ d n -J+0 .5 by a direct calculation. We now show that - 1) ]. Without loss of generality, assume (t,J) k,i = O(exp(-log(d) 0.6 )) for all t ∈ [t , T y k,i = 1. Let r(t) = arg max r∈[m] w (t,0) 1,r , ξ k,i . Using similar proof of Hypothesis E.2, we have that - (t,τ ) k,i = Θ(1) for τ ∈ [J -1]. Since w (t,0) 1,r(t) , ξ k,i ≥ m -1/2 1 + γσ 2 ξ d n -J+0 .5 , we apply Eq. (E.7) J times to get w (t,J) 1,r(t) , ξ k,i ≥ Θ m -1/2 1 + γσ 2 ξ d n 0.5 . Then we have - (t,J) k,i = - j=±1 m r =1 j σ( w (t,J) j,r , ξ k,i ) + σ( w (t,J) j,r , ν k ) ≤ -σ( w (t,J) 1,r(t) , ξ k,i ) - m r =1 σ( w (t,J) -1,r , ξ k,i ) + σ( w (t,J) -1,r , ν k ) ≤ -Θ m -1/2 1 + γσ 2 ξ d n 0.5 -o(1) -o(1) ≤ O exp -Θ(1) log(d) 0.6 . (E.21) The second line is because -(•) is non-increasing. To justify the third line, we first note that the second part of Theorem E.2 implies w (t,0) -1,r , ξ k,i = O(d -1/2 ) for all r ∈ [m]. And w (t,τ +1) -1,r , ξ k,i = w (t,τ ) -1,r , ξ k,i + γ n i =i (t,τ ) k,i y k,i σ ( w (t,τ ) -1,r , ξ k,i ) ξ k,i , ξ k,i + γ n (t,τ ) k,i σ ( w (t,τ ) -1,r , ξ k,i ) ξ k,i 2 2 ≤ w (t,τ ) -1,r , ξ k,i + γ n i =i O(1)| ξ k,i , ξ k,i | ≤ w (t,τ ) -1,r , ξ k,i + O(d -1/2 ) , where we used -(•) ≤ 1, σ (•) ≤ 1 and Lemma D.2. This implies w (t,J) -1,r , ξ k,i ≤ w (t,0) -1,r , ξ k,i + J O(d -1/2 ) = O(d -1/2 ) , since J = O(1) and w (t,0) -1,r , ξ k,i = O(d -1/2 ). Hence, m r =1 σ( w (t,J) -1,r , ξ k,i ) ≤ m O(d -1/2 ) = o(1) . By Eq. (C.1) and Eq. (E.11) we get that w (t,J) -1,r , ν k ≤ w (t,0) -1,r , ν k ≤ w (0,0) -1,r , ν k = O(d -1/2 log(m)) , where the last equality is due to Lemma D.2. Thus, m r=1 σ( w (t,J) -1,r , ν k ) ≤ m O(d -1/2 log(m)) = o(1) . This justifies the third line of Eq. (E.21). Using Eq. (E.15), one can directly calculate that T (1) = O(polylog(d)). Then, using Eq. (C.6) we have Ξ (t+1) k,i ≤ Ξ (t) k,i + O(d -1/2 ) - η Kn (t,J) k,i Θ(σ 2 ξ d) ≤ Ξ (t) k,i + O(d -1/2 ) + η Kn Θ(σ 2 ξ d)O exp -log(d) 0.6 , Ξ (T (1) ) k,i ≤ Ξ (t ) k,i + (T (1) -t ) O(d -1/2 ) + η Kn Θ(σ 2 ξ d)O exp -log(d) 0.6 ≤ Ξ (t ) k,i + T (1) O(d -1/2 ) + η Kn Θ(σ 2 ξ d)O exp -log(d) 0.6 = Ξ (t ) k,i + O(d -1/2 ) + Θ(1)O exp -log(d) 0.6 ≤ 2Ξ (t ) k,i ≤ 3 m -1/2 1 + γσ 2 ξ d n -J+0.5 . In the second last line, we can absorb T (1) into O(d -1/2 ). And T (1) η Kn Θ(σ 2 ξ d) = Θ(1). To get the last line, recall that Ξ (t ) k,i ≥ m -1/2 1 + γσ 2 ξ d n -J+0.5 = Ω(1/polylog(d)). And Θ(1)O exp - log(d) 0.6 = o(1/polylog(d)) . This concludes our proof. In short, we have shown that in Phase I, the feature has not been learnt well by any neurons, which is characterized by the inequality Λ (T (1) ) j ≤ O(d -1/2 ) for j ∈ {-1, 1}. And for each noise vector ξ k,i , there exists weights w (T (1) ,0) j,r with j = y k,i such that their inner product is large: w (T (1) ,0) j,r , ξ k,i = Θ(1).

E.2 PHASE II

We have already proven that at the end of Phase I, all of the noises are memorized by the network with an inner product of size Θ(1) whereas the feature is barely memorized, with an inner product of size O(d -1/2 ). After Phase I, there is no guarantee that - (t,J) k,i = Ω(1). For this reason, the learning of both the feature and noises will slow down. The difficulty in analysis at this stage is that we no longer have - (t,J) k,i = Θ(1). The value of (t,J) k,i is getting smaller as the inner products w (t,J) j,r , ξ k,i gets larger. Moreover, if we examine Eq. (C.6), we observe that the second term on the right hand side could have a non-trivial effect on the growth of the noise inner product, if the third term on the right hand side becomes exponentially smaller than the second term because of (t,J) k,i . But the sign of each of the summand in the second term is undetermined, which makes the analysis even harder. We first make some claims that will help prove our theorems. And we will verify these claims at the end of this section. Let T = poly(d) be the total run time. We have the following claims under Condition D.1 and Condition 4.1. Claim E.4. For any k ∈  and t ∈ [T (1 ) , T ], let j = -y k,i , then [K], i ∈ [n], r ∈ [m], w (t,0) j,r , ξ k,i = O(d -1/2 ) . And w (t,J) j,r , ξ k,i = O(d -1/2 ) . Claim E.5. For any j ∈ {-1, 1} and t ∈ [T (1) , T ] we have Λ (t) j ≤ O(d -1/4 ) . Claim E.6. For any j ∈ {-1, 1}, k ∈ [K], r ∈ [m] and t ∈ [T (1) , T ] we have max r j w (t,0) j,r , z k = o(1)Λ (t) j . Lemma E.7. Let t ∈ [T (1) , T ], k ∈ [K], i ∈ [n], j = y k,i and r = arg max r ∈[m] w (t,0) j,r , ξ k,i . Suppose Claim E.4 E.5E.6 hold at t. Then w (t,0) j,r , ξ k,i ≥ Ω m -1/2 1 + γσ 2 ξ d n -J . And w (t,J) j,r , ξ k,i ≥ Ω(m -1/2 ) . Proof. The proof of the first inequality is very similar to the first part of the proof of Theorem E.3. Hence we will omit it and focus on proving the second inequality. We first recall from Eq. (C.2) that  w (t,τ +1) j,r , ξ k,i = w (t,τ ) j,r , ξ k,i - γ n i =i (t,τ ) k,i y k,i jσ ( w (t,τ ) j,r , ξ k,i ) ξ k,i , ξ k,i - γ n (t,τ ) k,i σ ( w (t,τ ) j,r , ξ k,i ) ξ k,i 2 2 ≥ w (t,τ ) j,r , ξ k,i -O(d -1/2 ) - γ n (t,τ ) k,i σ ( w (t,τ ) j,r , ξ k,i ) ξ k,i (t,0) j,r , ξ k,i > m -1/2 , then we have w (t,J) j,r , ξ k,i ≥ w (t,0) j,r , ξ k,i -J O(d -1/2 ) = Ω(m -1/2 ) . The last equality is because m -1/2 J O(d -1/2 ). By the first part of this theorem, we have w (t,0) j,r , ξ k,i ≥ Ω m -1/2 1 + γσ 2 ξ d n -J . Let us consider Ω m -1/2 1 + γσ 2 ξ d n -(τ +1) ≤ w (t,0) j,r , ξ k,i ≤ O m -1/2 1 + γσ 2 ξ d n -τ for some τ ∈ [J -1]. Using similar proof as the proof of Hypothesis E.2, it can be shown that - (t,τ ) k,i = Ω(1) for all τ ≤ τ if w (t,0) j,r , ξ k,i ≤ Ω m -1/2 1+ γσ 2 ξ d n -τ . Applying Eq. (E.22) τ +1 times, we get w (t,τ +1) j,r , ξ k,i ≥ w (t,0) j,r , ξ k,i Ω 1 + γσ 2 ξ d n τ +1 , which implies w (t,τ +1) j,r , ξ k,i ≥ Ω(m -1/2 ) if w (t,0) j,r , ξ k,i ≥ Ω m -1/2 1 + γσ 2 ξ d n -(τ +1) . Lastly, Eq. (E.22) also implies w (t,J) j,r , ξ k,i ≥ w (t,τ +1) j,r , ξ k,i -(J -(τ + 1)) O(d -1/2 ) = Ω(m -1/2 ) . This proves the lemma. Published as a conference paper at ICLR 2023 Lemma E.8. Suppose Claim E.4 E.5E.6 hold at t. Let k ∈ [K], i ∈ [n]. If Ξ (t) k,i = Θ m -1/2 1 + γσ 2 ξ d n -J+1 , then - (t,J) k,i = O(d -log(d) 0.3 ). Proof. Without loss of generality, assume y k,i = 1. Let r (t) = arg max r∈[m] w (t,0) 1,r , ξ k,i . We have - (t,J) k,i = - j=±1 m r=1 j σ( w (t,J) j,r , ξ k,i ) + σ( w (t,J) j,r , ν k ) ≤ -σ( w (t,J) 1,r (t) , ξ k,i ) - m r=1 σ( w (t,J) -1,r , ξ k,i ) + σ( w (t,J) -1,r , ν k ) . Using similar proof as the proof of Hypothesis E.2, it can be shown that - (t,τ ) k,i = Ω(1) for all τ ∈ [J -1] under Claim E.5E.6. We then apply Eq. (E.22) J -1 times to get w (t,J) 1,r (t) , ξ k,i ≥ w (t,0) 1,r (t) , ξ k,i -J O(d -1/2 ) 1 -Θ(1) γσ 2 ξ d n J = Ξ (t) k,i -J O(d -1/2 ) 1 -Θ(1) γσ 2 ξ d n J = Ω(1)Ξ (t) k,i 1 -Θ(1) γσ 2 ξ d n J = Ω(1)m -1/2 γσ 2 ξ d n . The third line is because by assumption we have Ξ (t) k,i = Θ m -1/2 1 + γσ 2 ξ d n -J+1 = Θ(1) J O(d -1/2 ). On the other hand, Claim E.4, tells us that w (t,J) -1,r , ξ k,i = O(d -1/2 ) , for all r ∈ [m]. Hence, m r=1 σ( w (t,J) -1,r , ξ k,i ) = m O(d -1/2 ) = o(1) . By Eq. (C.1) and Eq. (E.11) we get that w (t,J) -1,r , ν k ≤ w (t,0) -1,r , ν k ≤ w (0,0) -1,r , ν k = O(d -1/2 ) . Thus, m r=1 σ( w (t,J) -1,r , ν k ) = m O(d -1/2 ) = o(1) . We now have - (t,J) k,i ≤ -σ( w (t,J) 1,r , ξ k,i ) - m r=1 σ( w (t,J) -1,r , ξ k,i ) + σ( w (t,J) -1,r , ν k ) ≤ -Ω(1)m -1/2 γσ 2 ξ d n -o(1) -o(1) = O(1) exp -Ω(1)m -1/2 γσ 2 ξ d n = O(d -log(d) 0.3 ) . This proves the lemma. Theorem E.9. If Claim E.4E.5E.6 hold for all T (1) ≤ t ≤ T -1, then Ξ (T ) k,i ≤ O m -1/2 1 + γσ 2 ξ d n -J+1 , for all k ∈ [K], i ∈ [n]. Proof. Set A(q) = m -1/2 1 + γσ 2 ξ d n -q for q ∈ [J]. We proved in Theorem E.3 that Ω(A(J)) ≤ Ξ (T (1) ) k,i ≤ O(A(J -0.5)) for all k ∈ [K], i ∈ [n]. Define t(k, i) = inf t < T : Ξ (t ) k,i > A(J -1), ∀t > t , where the infimum over an empty set is defined as ∞. In words, t(k, i) is the last iteration with Ξ (t) k,i ≤ A(J -1). We have shown in the proof of Theorem E.3 that Ξ (t(k,i)+1) k,i < 1.5A(J -1). Similarly define t(k, i) = inf t(k, i) ≤ t ≤ T : Ξ (t ) k,i > 4A(J -1), ∀t > t . Again the inf over an empty set is defined as ∞. So t(k, i) is the last iteration with Ξ (t) k,i ≤ 4A(J - 1). Our goal is to show that t(k, i) = ∞ for any k ∈ [K] and i ∈ [n] . We will prove this by contradiction. Let us suppose that t(k, i) < ∞. Then, for any t ∈ [t(k, i) + 1, t(k, i)] we have that - (t,J) k,i = O(d -log(d) 0.3 ) by Lemma E.8. Recall Ξ (t+1) k,i ≥ Ξ (t) k,i + η Kn (k ,i ) =(k,i) (t,J) k ,i σ ( w (t,J) j,r , ξ k ,i )Θ(σ 2 ξ √ d log(Kn)) - η Kn (t,J) k,i σ ( w (t,J) j,r , ξ k,i )Θ(σ 2 ξ d) ≥ Ξ (t) k,i + η Kn (k ,i ) =(k,i) (t,J) k ,i Θ(σ 2 ξ √ d log(Kn)) - η Kn (t,J) k,i σ ( w (t,J) j,r , ξ k,i )Θ(σ 2 ξ d) , (E.23) where the second inequality is because σ (•) ≤ 1. And similarly we have Ξ (t+1) k,i ≤ Ξ (t) k,i - η Kn (k ,i ) =(k,i) (t,J) k ,i σ ( w (t,J) j,r , ξ k ,i )Θ(σ 2 ξ √ d log(Kn)) - η Kn (t,J) k,i σ ( w (t,J) j,r , ξ k,i )Θ(σ 2 ξ d) ≤ Ξ (t) k,i - η Kn (k ,i ) =(k,i) (t,J) k ,i Θ(σ 2 ξ √ d log(Kn)) - η Kn (t,J) k,i Θ(σ 2 ξ d) , (E.24) where we have again used σ (•) ≤ 1 to the second inequality. We have the following bound due to Lemma E.8 - t(k,i) t=t(k,i)+1 η Kn (t,J) k,i Θ(σ 2 ξ d) = t(k,i) t=t(k,i)+1 η Kn Θ(σ 2 ξ d)O(d -log(d) 0.3 ) ≤ ησ 2 ξ d Kn T O(d -log(d) 0.3 ) = o(d -1 ) , (E.25) where the last equality is because T = O(poly(d)). Let (k , i ) = arg max ( k,î) =(k,i) t(k,i) t=t(k,i)+1 - (t,J) k,î . Suppose it holds that t(k,i) t=t(k,i)+1 - (t,J) k ,i < d 1/4 . (E.26) Then using Eq. (E.24) we get Ξ ( t(k,i)+1) k,i ≤ Ξ (t(k,i)+1) k,i - η Kn t(k,i) t=t(k,i)+1 ( k,î) =(k,i) (t,J) k ,i Θ(σ 2 ξ √ d log(Kn)) - η Kn t(k,i) t=t(k,i)+1 (t,J) k,i Θ(σ 2 ξ d) ≤ Ξ (t(k,i)+1) k,i + ηΘ(σ 2 ξ √ d log(Kn))d 1/4 + o(d -1 ) ≤ 2Ξ (t(k,i)+1) k,i ≤ 3A(J -1) . The second inequality is due to Eq. (E.25), Eq. (E.26) and the definition of (k , i ). The third inequality is because ηΘ(σ 2 ξ √ d log(Kn))d 1/4 + o(d -1 ) = O(d -1/4 ) and Ξ (t(k,i)+1) k,i ≥ A(J - 1) = Θ(1). Thus Eq. (E.26) contradicts Ξ ( t(k,i)+1) k,i ≥ 4A(J -1) and we must have t(k, i) = ∞. Next, we will show that if Eq. (E.26) does not hold, then we also have t (k, i) = ∞. Let us suppose t(k,i) t=t(k,i)+1 - (t,J) k ,i ≥ d 1/4 . Since - (t,J) k ,i > 0 for any t, we get that t0 t=t(k,i)+1 - (t,J) k ,i is increasing in t 0 . Let t(k, i) + 1 < t 0 < t(k, i) be the first iteration such that t0 t=t(k,i)+1 - (t,J) k ,i ≥ d -1/8 t(k,i) t=t(k,i)+1 - (t,J) k ,i ≥ d -1/8+1/4 = d 1/8 . Since each - (t,J) k ,i < 1, we get that t0 t=t(k,i)+1 - (t,J) k ,i = t0-1 t=t(k,i)+1 - (t,J) k ,i - (t0,J) k ,i ≤ 1 + d -1/8 t(k,i) t=t(k,i)+1 - (t,J) k ,i ≤ 2d -1/8 t(k,i) t=t(k,i)+1 - (t,J) k ,i , (E.27) where the second and third line are by definition of t 0 . We can sum Eq. (E.23) over time to obtain t0 t=t(k,i)+1 Ξ (t+1) k ,i ≥ t0 t=t(k,i)+1 Ξ (t) k ,i + η Kn t0 t=t(k,i)+1 ( k,î) =(k ,i ) (t,J) k,î Θ(σ 2 ξ √ d log(Kn)) - η Kn t0 t=t(k,i)+1 (t,J) k ,i σ ( w (t,J) j,r , ξ k ,i )Θ(σ 2 ξ d) ≥ t0 t=t(k,i)+1 Ξ (t) k ,i + η Kn t(k,i) t=t(k,i)+1 ( k,î) =(k ,i ) (t,J) k,î Θ(σ 2 ξ √ d log(Kn)) - η Kn t0 t=t(k,i)+1 (t,J) k ,i σ ( w (t,J) j,r , ξ k ,i )Θ(σ 2 ξ d) ≥ t0 t=t(k,i)+1 Ξ (t) k ,i + η Kn t(k,i) t=t(k,i)+1 ( k,î) =(k ,i ) (t,J) k ,i Θ(σ 2 ξ √ d log(Kn)) - η Kn t0 t=t(k,i)+1 (t,J) k ,i σ ( w (t,J) j,r , ξ k ,i )Θ(σ 2 ξ d) , where the second inequality is because t 0 < t(k, i). We have used the definition of (k , i ) to get the last inequality. Then, subtracting t0 t=t(k,i)+1 Ξ (t) k ,i from both sides yields Ξ (t0+1) k ,i -Ξ (t(k,i)+1) k ,i ≥ Θ ησ 2 ξ √ d Kn log(Kn) Kn t(k,i) t=t(k,i)+1 (t,J) k ,i -Θ ησ 2 ξ d Kn √ m t0 t=t(k,i)+1 (t,J) k ,i = Θ(d -1/2 ) t(k,i) t=t(k,i)+1 (t,J) k ,i -Θ(1)d -1/8 t(k,i) t=t(k,i)+1 (t,J) k ,i = -Θ(1)d -1/8 t(k,i) t=t(k,i)+1 (t,J) k ,i ≥ Θ(d 1/8 ) , (E.28) where we used Lemma E.7 to get the first inequality. The first equality is due to Eq. (E.27). The last inequality is by our assumption t(k,i) t=t(k,i)+1 - (t,J) k ,i ≥ d 1/4 . It is then straightforward to check that Ξ (t) k ,i ≥ Θ(d 1/8 ) , for all t 0 + 1 ≤ t ≤ t(k, i). To see this, let us sum Eq. (E.23) over time to obtain t-1 t=t0+1 Ξ ( t+1) k ,i ≥ t-1 t=t0+1 Ξ ( t) k ,i + η Kn t-1 t=t0+1 ( k,ĩ) =(k ,i ) ( t,J) k,ĩ Θ(σ 2 ξ √ d log(Kn)) - t-1 t=t0+1 η Kn ( t,J) k ,i σ ( w ( t,J) j,r , ξ k ,i )Θ(σ 2 ξ d) ≥ t-1 t=t0+1 Ξ ( t) k ,i + η Kn t-1 t=t0+1 ( k,ĩ) =(k ,i ) ( t,J) k,ĩ Θ(σ 2 ξ √ d log(Kn)) ≥ t-1 t=t0+1 Ξ ( t) k ,i + η Kn t(k,i) t=t(k,i)+1 ( k,ĩ) =(k ,i ) ( t,J) k,ĩ Θ(σ 2 ξ √ d log(Kn)) ≥ t-1 t=t0+1 Ξ ( t) k ,i + η Kn t(k,i) t=t(k,i)+1 Kn ( t,J) k ,i Θ(σ 2 ξ √ d log(Kn)) Ξ (t) k ,i ≥ Ξ (t0+1) k ,i + Θ ησ 2 ξ √ d log(Kn) Kn Kn t(k,i) t=t(k,i)+1 (t,J) k ,i ≥ Θ(1)d -1/8 t(k,i) t=t(k,i)+1 - (t,J) k ,i + Θ(d -1/2 ) t(k,i) t=t(k,i)+1 (t,J) k ,i = Θ(1)d -1/8 t(k,i) t=t(k,i)+1 - (t,J) k ,i ≥ Θ(d 1/8 ) , where the second inequality is because we are removing a positive term from the sum. The third inequality is because t-1 < t(k, i) and t 0 > t(k, i). The fourth inequality is by definition of (k , i ). The second last inequality is by the second last line of Eq. (E.28). And our last inequality follows Eq. (E.28). Then, using similar proof of Lemma E.8, we have - (t,J) k ,i = O(exp(-Θ(d 1/8 ))) for all t 0 + 1 ≤ t ≤ t(k, i). This yields t(k,i) t=t(k,i)+1 - (t,J) k ,i = t0 t=t(k,i)+1 - (t,J) k ,i + t(k,i) t=t0+1 - (t,J) k ,i ≤ 2d -1/8 t(k,i) t=t(k,i)+1 - (t,J) k ,i + t(k,i) t=t0+1 O(exp(-Θ(d 1/8 ))) ≤ 2d -1/8 t(k,i) t=t(k,i)+1 - (t,J) k ,i + T • O(exp(-Θ(d 1/8 ))) ≤ 2d -1/8 t(k,i) t=t(k,i)+1 - (t,J) k ,i + o(d -1 ) ≤ (2d -1/8 + o(d -1 )) t(k,i) t=t(k,i)+1 - (t,J) k ,i , (E.29) where we used Eq. (E.27) to get the second line. And we used T = O(poly(d)) to get the fourth line. To obtain the last inequality, we recall the assumption t(k,i) t=t(k,i)+1 - (t,J) k ,i > d 1/4 , which implies o(d -1 ) < o(d -1 ) t(k,i) t=t(k,i)+1 - (t,J) k ,i . Eq. (E.29) is a contradiction in the form of x ≤ Ax for some x ≥ 0 and A < 1. This implies that under the assumption L tr-tr (W (t) ) ≤ Θ(T -1 η -1 ) . t(k,i) t=t(k,i)+1 - (t,J) k ,i > d 1/4 , we should still have t(k, i) = ∞. Combining with our previous result which says t(k, i) = ∞ when t(k,i) t=t(k,i)+1 - (t,J) k ,i ≤ d 1/4 , we conclude that t(k, i) = ∞ Proof. By Theorem E.9, we know that Ξ (T ) k,i ≤ O m -1/2 1 + γσ 2 ξ d n -J+1 = O(1) , for all k ∈ [K] and i ∈ [n]. Denote Ξ (t) = k,i Ξ (t) k,i . It then follows that Ξ (T ) ≤ Kn O(1) = O(1) . Moreover, if we sum Eq. (E.24) over all k, i, we get Ξ (t+1) ≥ Ξ (t) + η(Kn -1) (k,i) (t,J) k,i Θ(σ 2 ξ √ d log(Kn)) - η Kn k,i (t,J) k,i Ω 1 √ m Θ(σ 2 ξ d) ≥ Ξ (t) -Ω(1) η Kn k,i (t,J) k,i Ω 1 √ m Θ(σ 2 ξ d) , where the Ω 1 √ m term is by Lemma E.7. We can then sum over t to obtain T -1 t=T (1) Ξ (t+1) ≥ T -1 t=T (1) Ξ (t) -Ω(1) η Kn T -1 t=T (1) k,i (t,J) k,i 2 √ m Θ(σ 2 ξ d) , Ξ (T ) -Ξ (T (1) ) ≥ -Ω ησ 2 ξ d Kn √ m T -1 t=T (1) k,i (t,J) k,i , η T -1 t=T (1) k,i - (t,J) k,i ≤ O K √ m σ 2 ξ d O(1) = O(1) . (E.30) The last inequality is because Ξ (T ) -Ξ (T (1) ) ≤ O(1). And we can bound the minimum over t by min t∈[T (1) ,T -1] k,i - (t,J) k,i ≤ O 1 (T -T (1) )η = O 1 T η , where the second equality is because T (1) = polylog(d) = o(T ). Denote - (t,J) k,i = -(α(t, J, k, i)), i.e. the argument of - (t,J) k,i is α(t, J, k, i). Then the loss is related to - (t,J) k,i by L tr-tr (W (t) ) = k,i α(t, J, k, i) . Let t = arg min t k,i - (t,J) k,i . By definition of (•) and (•) we can bound -(x) ≥ exp(-x)/2 ≥ (x)/2 for x > 0. This can be used to bound our loss by min t∈[T ] L tr-tr (W (t) ) ≤ L tr-tr (W (t ) ) = k,i α(t , J, k, i) ≤ -2 k,i α(t, J, k, i) ≤ O 1 T η . This proves our theorem. It remains to prove Claims E.4E.5E.6. Proof of Claim E.4. Without loss of generality, assume y k,i = 1. By Eq. (E.16), we have the following upper bound w (t+1,0) -1,r , ξ k,i ≤ w (t,0) -1,r , ξ k,i - η Kn (k ,i ) =(k,i) (t,J) k ,i σ ( w (t,J) -1,r , ξ k ,i )Θ(σ 2 ξ √ d log(Kn)) ≤ w (t,0) -1,r , ξ k,i - η Kn (k ,i ) (t,J) k ,i σ ( w (t,J) -1,r , ξ k ,i )Θ(σ 2 ξ √ d log(Kn)) ≤ w (t,0) -1,r , ξ k,i - η Kn Θ(σ 2 ξ √ d log(Kn)) (k ,i ) (t,J) k ,i , where the second inequality is because we are adding a non-negative term to the sum. The last inequality is because σ (•) ≤ 1. Suppose all of our claims hold for t ≤ T , where T ≤ T -1. Using a similar derivation of Eq. (E.30) we have - η Kn Θ(σ 2 ξ √ d log(Kn)) T t=T (1) (k ,i ) (t,J) k ,i ≤ KΘ(σ 2 ξ √ d log(Kn)) (σ 2 ξ d) J γ J-1 n J-2 . Then we get w (T +1,0) -1,r , ξ k,i ≤ w (T (1) ,0) -1,r , ξ k,i - η Kn Θ(σ 2 ξ √ d log(Kn)) T t=T (1) (k ,i ) (t,J) k ,i ≤ w (T (1) ,0) -1,r , ξ k,i + KΘ(σ 2 ξ √ d log(Kn)) (σ 2 ξ d) J γ J-1 n J-2 ≤ O(d -1/2 ) + O(d -1/2 ) = O(d -1/2 ) , where the third line is due to part (2) of Theorem E.2. This proves our claim. Proof of Claim E.5. We will show that for any T ≤ T = O(poly(d)), we have that Λ (T ) j = O(d -1/4 ) for j ∈ {-1, 1}. We have the following upper bound Λ (t+1) j ≤ Λ (t) j - η Kn k,i (t,J) k,i Λ (t) j γ J O(1) ≤ Λ (t) j 1 -O(1) ηγ J Kn k,i (t,J) k,i Suppose that there exists some t ≤ T such that Λ ( t) j ≥ d -1/4 for the first time. Then d -1/4 ≤ Λ ( t) j ≤ Λ (T (1) ) j t-1 t=T (1) 1 -O(1) ηγ J Kn k,i (t,J) k,i , d -1/4 Λ (T (1) ) j ≤ t-1 t=T (1) 1 -O(1) ηγ J Kn k,i (t,J) k,i , t-1 t=T (1) 1 -O(1) ηγ J Kn k,i (t,J) k,i ≥ Ω(d 1/4 ) , t-1 t=T (1) log 1 -O(1) ηγ J Kn k,i (t,J) k,i ≥ Ω(log(d)) , t-1 t=T (1) -O(1) ηγ J Kn k,i (t,J) k,i ≥ Ω(log(d)) , t-1 t=T (1) k,i - (t,J) k,i ≥ Ω log(d)Kn ηγ J . The third inequality is because by the first part of Theorem E.2, we have ) ) j Λ (T (1) ) j = O(d -1/2 ). And hence d -1/4 /Λ (T ( = Ω(d 1/4 ). Note that in the proof of Theorem E.10, we have that T t=T (1) k,i - (t,J) k,i ≤ K 2 γ J-1 η(σ 2 ξ d) J = o log(d)Kn ηγ J , which creates a contradiction. Hence, for all t ≤ T , we have Λ ( t) j ≤ d -1/4 . This proves the claim. Proof of Claim E.6. By our result in Phase I, we have that Claim E.6 holds at T (1) . The rest of the proof is similar to the proof of Hypothesis E.1 and we will omit it here. We have shown that for T = poly(d), the feature has not been learnt well if one uses the train-train method. Thus the testing loss will be large, which is characterized by the following theorem. L test (W (t) ) = Ω(1) . Proof. Sample a new example x from the distribution, since we use the convolutional structure, we can assume that the first patch x (1) = y • (ν + z) and the second patch x (2) = ξ. Clearly ξ k,i , ξ follows the Gaussian distribution with mean zero and standard deviation σ ξ • ξ k,i 2 . By Lemma D.2, we further know that ξ k,i 2 ≤ 2σ ξ √ d. Therefore P | ξ k,i , ξ | ≥ d -1/4 ≤ 2 exp - 1 8σ 4 ξ d 3/2 ≤ 2 exp(-d 1/4 ) . (E.31) Similarly we have that P | w (0,0) j,r , ξ | ≥ d 3/2 ≤ 2 exp - 1 8σ 2 0 σ 2 ξ d 3/2 ≤ 2 exp(-d 1/4 ) . (E.32) Denote E to be the event that | ξ k,i , ξ | ≤ d -1/4 and | w (0,0) j,r , ξ | ≤ d 3/2 for all k ∈ [K], i ∈ [n], r ∈ [m] and j ∈ {-1, 1}. Applying union bound, we have that the P(E) ≥ 1 -2(nK + m) exp(-d 1/4 ) ≥ 1/2. We can divide L test (W) into two parts: L test (W) = E yf (W, x) = E[1(E) yf (W, x) ] + E[1(E c ) yf (W, x) ] ≥ E[1(E) yf (W, x) ] . (E.33) Let T ≤ T . And suppose that T ≥ T (1) . When event E holds, for any j ∈ {-1, 1}, we have w (T ,0) j,r , ξ = w (0,0) j,r , ξ - η Kn 2 T -1 t=0 (k,i)∈Ψ (t,J) k,i σ ( w (t,J) j,r , ξ k,i ) ξ, ξ k,i ≤ w (0,0) j,r , ξ - η Kn 2 T -1 t=0 (k,i)∈Ψ (t,J) k,i σ ( w (t,J) j,r , ξ k,i )d -1/4 ≤ w (0,0) j,r , ξ - η Kn 2 T -1 t=0 (k,i)∈Ψ (t,J) k,i d -1/4 ≤ w (0,0) j,r , ξ - ηd -1/4 Kn 2 T (2) -1 t=0 (k,i)∈Ψ (t,J) k,i - ηd -1/4 Kn 2 T -1 t=T (2) (k,i)∈Ψ (t,J) k,i ≤ w (0,0) j,r , ξ + ηd -1/4 T (2) + ηd -1/4 Kn 2 O(1) ≤ w (0,0) j,r , ξ + O(d -1/4 ) + O(d -1/4 ) = O(d -1/4 ) , (E.34) where the second line is by Eq. (E.31). The third line is by 0 ≤ σ (•) ≤ 1. The third last line is due to -(•) ≤ 1, T (2) = polylog(d) and T -1 t=T (2) (k,i)∈Ψ (t,J) k,i = O(1) . The second last line is by Eq. (E.32). We note that Eq. (E.34) holds for all j ∈ {-1, 1}, r ∈ [m], and y ∈ {-1, 1}. Using the same procedure, one can show that j w (T ,0) j,r , z ≤ O(d -1/4 ) for all r ∈ [m] and j ∈ {-1, 1}. Now without loss of generality assume y = 1. We have y F +1 (W (T ) +1 , x) -F -1 (W (T ) -1 , x) = F +1 (W (T ) +1 , x) -F -1 (W (T ) -1 , x) ≥ F +1 (W (T ) +1 , x) = m r=1 σ( w (T ,0) 1,r , ν + z ) + σ( w (T ,0) 1,r , ξ ) ≥ m r=1 σ O(d -1/4 ) + O(d -1/4 ) + σ O(d -1/4 ) = m O(d -1/4 ) = Ω(1) . (E.35) where second line is because (•) is nonincreasing and the fourth line is due to Claim E.5. Finally plugging Eq. (E.35) into Eq. (E.33) gives that ) , T ] is arbitrary, we get that L test (W) ≥ E[1(E) yf (W, x) ] ≥ E[1(E)Ω(1)] = Ω(P(E)) = Ω(1). Since T ∈ [T (1 min t∈[T (1) ,T ] L test (W (t) ) = Ω(1) . However, the above proof would also work for T < T (1) using our proofs from Phase I. Hence, min t∈[T ] L test (W (t) ) = Ω(1) , which concludes our theorem.

F TRAIN-VALIDATION METHOD

Contrary to the train-train method, we will show in this section that under Condition D.1, the feature will be learned by our neural network. Abusing notations, let Ξ (t) k,i = max w (t,0) 1,r , ξ k,i : r ∈ [m] , Γ (t) k,i = max w (t,0) -1,r , ξ k,i : r ∈ [m] , and Λ (t) j = max j w (t,0) j,r , ν : r ∈ [m] for j ∈ {-1, 1}. Define C = Jγσ 2 ξ √ d log(Kn), Ξ k,i = max{C, Ξ (t) k,i }, and Γ (t) k,i = max{C, Γ (t) k,i }. Let Ψ = {(k, i) : k ∈ [K], i > n 1 }, Ψ + = {(k, i) : k ∈ [K], i > n 1 , y k,i = 1}, and Ψ -= {(k, i) : k ∈ [K], i > n 1 , y k,i = -1}. So Ψ represents all of the samples from the validation set of all tasks. And Ψ j represents the samples from the validation set of all tasks with label j. We first present a series of useful lemmas that hold for all t ≥ 0. Theses lemmas basically show that the learning speed of z k are very slow compared with the learning speed of the feature ν. Lemma F.1. Under Condition D.1 and Condition 4.1, for any t ≥ 0, k ∈ [M ], j ∈ {-1, 1}, if one uses the train-validation method, then we have max r j w (t,0) j,r , z k = o(1)Λ (t) j . Proof. Without loss of generality, let j = 1, k ∈ [M ]. Denote r(t) = arg max r∈[m] w (t,0) 1,r , z k . Define D(t) := w (t,0) 1,r(t) , z k Λ (t) 1 . At t = 0, by Lemma D.2 we have D(0) = w (0,0) 1,r(0) , z k Λ (0) 1 ≤ Θ(σ s σ 0 √ d log(mKn)) Θ(σ 0 ) = Θ(σ s √ d log(mKn)) = o(1) . Define D 0 := Θ(σ s √ d log(mKn)). We show that for any t > 0, we have D(t ) ≤ 2D 0 = o(1). Observe that similar to Eq. (E.17), we have for any t ≥ 0, w (t+1,0) 1,r , z k = w (t,0) 1,r , z k - η Kn 2 (k ,i)∈Ψ (t,J) k ,i σ ( w (t,J) 1,r , ν k y k ,i ) z k , z k ≤ w (t,0) 1,r , z k - η Kn 2 (k ,i)∈Ψ (t,J) k ,i σ ( w (t,J) 1,r , ν k y k ,i )Θ(σ 2 s d) , (F.1) where we used Condition 4.1 for the inequality, since z k , z k ≤ max( z k 2 2 , z k 2 2 ) = Θ(σ 2 s d). Comparing Eq. (C.11) and Eq. (F.1), we can roughly understand why this lemma is true: for any r ∈ [m], the growth of w (t,0) 1,r , ν is at least a factor of 1/(σ 2 s d) = ω(1) faster than that of w (t,0) 1,r , z k . We can compute w (t ,0) 1,r(t ) , z k ≤ t -1 t=0 w (t+1,0) 1,r(t ) , z k -w (t,0) 1,r(t ) , z k + w (0,0) 1,r(t ) , z k ≤ t -1 t=0 w (t+1,0) 1,r(t ) , z k -w (t,0) 1,r(t ) , z k + w (0,0) 1,r(t ) , z k ≤ t -1 t=0 Θ(σ 2 s d) w (t+1,0) 1,r(t ) , ν -w (t,0) 1,r(t ) , ν + w (0,0) 1,r(t ) , z k = t -1 t=0 Θ(σ 2 s d) w (t+1,0) 1,r(t ) , ν -w (t,0) 1,r(t ) , ν + w (0,0) 1,r(t ) , z k = Θ(σ 2 s d) w (t ,0) 1,r(t ) , ν -w (0,0) 1,r(t ) , ν + w (0,0) 1,r(t ) , z k ≤ Θ(σ 2 s d) Λ (t ) 1 -min r∈[m] w (0,0) 1,r , ν + w (0,0) 1,r(0) , z k ≤ Θ(σ 2 s d) Λ (t ) 1 + Θ(σ 0 log(mKn)) + D 0 Λ (0) 1 , where the first and second lines are by triangle inequality. The third line is by comparing Eq. (C.11) and Eq. (F.1). The fourth line is by the fact that w and definition of r(0). The last line is by Lemma D.2. Then we have w (t ,0) 1,r(t ) , z k Λ (t ) 1 ≤ Θ(σ 2 s d) Λ (t ) 1 + Θ(σ 0 log(mKn)) + D 0 Λ (0) 1 Λ (t ) 1 ≤ Θ(σ 2 s d)(1 + log(mKn)) + D 0 < 2D 0 , where we used Λ (t) j is increasing in t to get the second inequality. The third inequality is by direct calculation and comparison:  D 0 = Θ(σ s √ d log(mKn)), σ 2 s d = (σ s √ d) 2 = σ s √ d/polylog(d) = o(1)D 0 , d -1/3 ≤ w (t,0) 1,r , ν = w (0,0) 1,r , ν - η Kn 2 t-1 t =0 (k,i)∈Ψ (t ,J) k,i σ ( w (t ,J) j,r , ν k y k,i ) ν d -1/3 -w (0,0) 1,r , ν ≤ - η Kn 2 t-1 t =0 (k,i)∈Ψ (t ,J) k,i σ ( w (t ,J) j,r , ν k y k,i ) ν 2 2 , d -1/3 -O(σ 0 log(mKn)) ≤ - η Kn 2 t-1 t =0 (k,i)∈Ψ (t ,J) k,i σ ( w (t ,J) j,r , ν k y k,i ) ν 2 2 , Ω(d -1/3 ) ≤ - η Kn 2 t-1 t =0 (k,i)∈Ψ (t ,J) k,i σ ( w (t ,J) j,r , ν k y k,i ) ν 2 2 . Recall that from Condition D.1 and Condition 4.1 w (0,0) 1,r , ν ≤ O(d -1/2 ) O(d -1/3 ) ≤ -η Kn 2 t-1 t =0 (k,i)∈Ψ (t ,J) k,i σ ( w (t ,J) j,r , ν k y k,i ) ν 2 2 . This implies w (0,0) 1,r , ν = -o(1) η Kn 2 t-1 t =0 (k,i)∈Ψ (t ,J) k,i σ ( w (t ,J) j,r , ν k y k,i ) ν 2 2 , Hence, w (t,0) 1,r , ν = -Θ(1) η Kn 2 t-1 t =0 (k,i)∈Ψ (t ,J) k,i σ ( w (t ,J) j,r , ν k y k,i ) ν 2 2 . (F.2) On the other hand, from Eq. (F.1) and Lemma D.2 we have w (t,0) 1,r , z k ≤ w (0,0) 1,r , z k - η Kn 2 t-1 t =0 (k,i)∈Ψ (t ,J) k,i σ ( w (t ,J) j,r , k y k,i )Θ(σ 2 s d) ≤ O(σ s σ 0 √ d log(mKn)) - η Kn 2 t-1 t =0 (k,i)∈Ψ (t ,J) k,i σ ( w (t ,J) j,r , ν k y k,i )Θ(σ 2 s d) ≤ -O(1) η Kn 2 t-1 t =0 (k,i)∈Ψ (t ,J) k,i σ ( w (t ,J) j,r , ν k y k,i )Θ(σ 2 s d) = -O(σ 2 s d) η Kn 2 t-1 t =0 (k,i)∈Ψ (t ,J) k,i σ ( w (t ,J) j,r , ν k y k,i ) = O(σ 2 s d) w (t,0) 1,r , ν = o(1) w (t,0) 1,r , ν , where the third line is a direct comparison of the sizes of the first and second term of the second line. The second last equality is due to Eq. (F.2). Lemma F.3. Under Condition D.1 and Condition 4.1, for any j ∈ {-1, 1}, r ∈ [m], if at time t we have j w (t,0) j,r , ν < d -1/3 , then for any k ∈ [K] j w (t,0) j,r , z k < d -1/3 . Proof. Without loss of generality, consider j = 1. Let k ∈ [K], r ∈ [m] and suppose w (t,0) 1,r , ν < d -1/3 at time t. By Eq. (C.11), we have d -1/3 > w (t,0) 1,r , ν = w (0,0) 1,r , ν - η Kn 2 t-1 t =0 (k,i)∈Ψ (t ,J) k,i σ ( w (t ,J) j,r , ν k y k,i ) ν 2 2 , d -1/3 -w (0,0) 1,r , ν > - η Kn 2 t-1 t =0 (k,i)∈Ψ (t ,J) k,i σ ( w (t ,J) j,r , ν k y k,i ) ν 2 2 , d -1/3 -min r ∈[m] w (0,0) 1,r , ν > - η Kn 2 t-1 t =0 (k,i)∈Ψ (t ,J) k,i σ ( w (t ,J) j,r , ν k y k,i ) ν 2 2 , d -1/3 + O(d -1/2 ) > - η Kn 2 t-1 t =0 (k,i)∈Ψ (t ,J) k,i σ ( w (t ,J) j,r , ν k y k,i ) ν 2 2 , 2d -1/3 > - η Kn 2 t-1 t =0 (k,i)∈Ψ (t ,J) k,i σ ( w (t ,J) j,r , ν k y k,i ) ν 2 2 , (F.3) where the fourth line is by Lemma D.2. Meanwhile, using Eq. (F.1) and Lemma D.2 we obtain w (t,0) 1,r , z k ≤ w (0,0) 1,r , z k - η Kn 2 t-1 t =0 (k ,i)∈Ψ (t,J) k ,i σ ( w (t,J) 1,r , ν k y k ,i )Θ(σ 2 s d) ≤ w (0,0) 1,r , z k + 2d -1/3 Θ(σ 2 s d) ≤ O(σ 0 σ s √ d log(mKn)) + 2d -1/3 Θ(σ 2 s d) < d -1/3 , where the second line is by Eq. (F.3). The last inequality is because the first term in the second last line is O(d -1/2 ) and Θ(σ 2 s d) = o(1). This proves our lemma. on both hand sides refer to the weights w (t,0) j,r after J steps inner-loop updates using samples from S tr k . Proof. Without loss of generality, let j = 1. Using samples from S tr k , similar to Eq. (C.5) we have w (t,τ +1) 1,r , ν = w (t,τ ) 1,r , ν - γ n 1 n1 i=1 (t,τ ) k,i σ ( w (t,τ ) j,r , ν k y k,i ) ν 2 2 . (F.4) And w (t,τ +1) 1,r , z k = w (t,τ ) 1,r , z k - γ n 1 n1 i=1 (t,τ ) k,i σ ( w (t,τ ) j,r , ν k y k,i ) z k 2 2 . This implies w (t,J) 1,r , ν = w (t,0) 1,r , ν - γ n 1 J-1 τ =0 n1 i=1 (t,τ ) k,i σ ( w (t,τ ) j,r , ν k y k,i ) ν 2 2 , w (t,J) 1,r , z k = w (t,0) 1,r , z k - γ n 1 J-1 τ =0 n1 i=1 (t,τ ) k,i σ ( w (t,τ ) j,r , ν k y k,i ) z k 2 2 . Recall ν 1,r , ν , we conclude that w (t,J) 1,r , z k = O(σ 2 s d) w (t,J) 1,r , ν . The following corollary is an immediate result of Lemma F.2 and Lemma F.4. Corollary F.5. Fix j ∈ {-1, 1}. Then for any r ∈ [m], if at time t we have j w (t,0) j,r , ν ≥ d -1/3 , then for any k ∈ [K], j w (t,J) j,r , z k = O(σ 2 s d)j w (t,J) j,r , ν = O(1/polylog(d))j w (t,J) j,r , ν , where w (t,J) j,r on both hand sides refer to the weights w (t,0) j,r after J steps inner-loop updates using samples from S tr k . Lemma F.6. Under Condition D.1 and Condition 4.1, for any j ∈ {-1, 1}, r ∈ [m], and k ∈ [K], if we are using the samples from S tr k , then for any t ≥ 0 we have j w (t,J) j,r , ν ≤ 2 max Λ (t) j , Jγ . Proof. Without loss of generality, consider j = 1. Let r ∈ [m] and k ∈ [K]. Using samples from S tr k , by Eq. (F.4) we have w (t,τ +1) 1,r , ν ≤ w (t,τ ) 1,r , ν + γ , for any τ ∈ [J -1], where we have used the fact that -(•) ≤ 1 and σ (•) ≤ 1. Repeating J times we get w (t,J) 1,r , ν ≤ w (t,0) 1,r , ν + Jγ ≤ Λ (t) 1 + Jγ ≤ 2 max Λ (t) 1 , Jγ .

F.1 PHASE I

Let T (2) j be the first iteration such that Λ (t) j ≥ m -1/2 (1 + γ) -J = Θ(1) and let T (2) = max j T j . Without loss of generality, let us assume that T (2) -1 ≤ T (2) 1 . Recall that we assumed T  w (t,J) 1,r , ν k = Ω(1) w (t,0) 1,r , ν (1 + Θ(1)γ) J . 2. For any r ∈ [m], (k, i) ∈ Ψ, and j ∈ {-1, 1} w (t,J) j,r , ξ k,i ≤ O max w (t,0) j,r , ξ k,i , C . Proof. The proof of the first part relies on the following hypothesis which we will verify inductively later. - (t,τ ) k,i = Θ(1), for all τ ∈ [J], k ∈ [K], i ∈ [n] such that y k,i = 1 . (F.5) σ ( w (t,τ ) 1,r , ν k ) = Θ(1) w (t,τ ) 1,r , ν k , for all τ ∈ [J], k ∈ [K] if w (t,τ ) 1,r , ν k > 0 . (F.6) By Eq. (C.5), we know w And Eq. (C.11) shows that w (t,τ ) 1,r , ν k is increasing in τ for any r ∈ [m] and k ∈ [K]. Moreover, we have w (t+1,0) 1,r , ν k = w (t,0) 1,r , ν k - η Kn (k ,i)∈Ψ (t,J) k ,i σ ( w (t,J) 1,r , ν k y k ,i ) ν k , ν k , which shows that w (t,0) 1,r , ν k is increasing in t for all r ∈ [m], k ∈ [K], since ν k , ν k = ν 2 2 + z k , z k > 0. (t,0) 1,r , ν is increasing in t for all r ∈ [m]. Let r(t) = arg max r∈[m] w (t,0) 1,r , ν . Recall that under Lemma D.2, we have w (0,0) 1,r(0) , ν = Ω(σ 0 ) > 0. Since this is increasing in t, we have w (t,0) 1,r(t) , ν ≥ w (t,0) 1,r(0) , ν ≥ w (0,0) 1,r(0) , ν > 0 , where the first inequality is by definition of r(t). By Lemma F.1, we have for any k ∈ [K], w (t,0) 1,r(t) , ν k = w (t,0) 1,r(t) , ν + w (t,0) 1,r(t) , z k ≥ (1 -o(1)) w (t,0) 1,r(t) , ν > 0 . And by monotonicity in τ , we have w (t,τ ) 1,r(t) , ν k > 0 for all τ ∈ [J]. From Eq. (C.5) we can compute w (t,τ +1) 1,r(t) , ν k = w (t,τ ) 1,r(t) , ν k - γ n 1 i≤n1,y k,i =1 (t,τ ) k,i σ ( w (t,τ ) 1,r(t) , ν k ) ν k 2 2 = w (t,τ ) 1,r(t) , ν k + γΘ(1) w (t,τ ) 1,r(t) , ν k = w (t,τ ) 1,r(t) , ν k (1 + Θ(1)γ) . where we have used Hypothesis F.5 F.6 and ν k 2 = Θ(1) to get the second line. Applying this repeatedly to get w (t,J) 1,r(t) , ν k = w (t,0) 1,r(t) , ν k (1 + Θ(1)γ) J = Ω(1) w (t,0) 1,r(t) , ν (1 + Θ(1)γ) J , (F.7) where we have used Lemma F.1 to get the second equality. This proves the first part of the lemma. Consider some (k, i) ∈ Ψ. Without loss of generality we will prove the second part of the lemma for j = 1, since the proof for j = -1 is the same. By Eq. (C.3), for any r ∈ [m] we have w (t,τ +1) 1,r , ξ k,i ≤ w (t,τ ) 1,r , ξ k,i + γΘ(1)Θ(σ 2 ξ √ d log(Kn)) , where we used -(•) ≤ 1, σ (•) ≤ 1 and Lemma D.2. This implies w (t,J) 1,r , ξ k,i ≤ w (t,0) 1,r , ξ k,i + JγΘ(1)Θ(σ 2 ξ √ d log(Kn)) ≤ O max w (t,0) 1,r , ξ k,i , C . (F.8) The second inequality is by definition of C. This proves the second part of the lemma. Remark F.8. Eq. (F.8) is not a surprising result. Since ξ k,i comes from the validation set and the inner-loop updates only uses samples from the training set, we expect that after J steps of inner-loop updates, the inner product should remain as the same order as before the J steps inner-loop updates. Theorem F.9 (Restatement of Lemma 6.3). Under Condition D.1 and Condition 4.1, if one uses the train-validation method, then for any t ≤ T (2) , (k, i) ∈ Ψ, Ξ (t) k,i ≤ O(1) Ξ (0) k,i = O(d -1/2 ) . And Γ (t) k,i ≤ O(1) Γ (0) k,i = O(d -1/2 ) . Proof. We will only prove our first statement that involves Ξ k,i , because the proof of the second statement that involves Γ (t) k,i will be exactly the same. Similar to the proof of Theorem E.2, we will give lower bound on the growth of Λ (t) 1 and upper bound on the growth of Ξ (t) k,i . We can upper bound the growth of Λ (t) 1 using Lemma F.7. We first note that since w (t,J) 1,r(t) , ν k > 0, we have -w (t,J) 1,r(t) , ν k < 0. Thus, if (k, i) ∈ Ψ -we have σ ( w (t,J) 1,r(t) , ν k y k,i ) = 0. Then by Eq. (C.11), we get Λ (t+1) 1 ≥ w (t+1,0) 1,r(t) , ν = w (t,0) 1,r(t) , ν - η Kn 2 (k,i)∈Ψ (t,J) k,i σ ( w (t,J) 1,r(t) , ν k y k,i ) = w (t,0) 1,r(t) , ν - η Kn 2 (k,i)∈Ψ+ (t,J) k,i Θ(1) w (t,J) 1,r(t) , ν k = w (t,0) 1,r(t) , ν + ηΩ(1) w (t,0) 1,r(t) , ν (1 + Θ(1)γ) J = w (t,0) 1,r(t) , ν (1 + Ω(1)ηγ J ) = Λ (t) 1 (1 + Ω(1)ηγ J ) , (F.9) where we used σ ( w (t,J) 1,r(t) , ν k y k,i ) = 0 if (k, i) ∈ Ψ -, and Hypothesis F.5 F.6 to get the second equality. We used Lemma F.7 to get the third equality. Next, let (k, i) ∈ Ψ and r ∈ [m]. We will derive an upper bound on the growth of Ξ (t) k,i . By Eq. (C.9), we have  w (t+1,0) 1,r , ξ k,i ≤ w (t,0) 1,r , ξ k,i + η Kn 2 (k ,i )∈Ψ\{(k,i)} Θ(σ 2 ξ √ d log(Kn)) + η Kn 2 σ ( w (t,J) 1,r , ξ k,i )Θ(σ 2 ξ d) ≤ w (t,0) 1,r , ξ k,i + ηΘ(σ 2 ξ √ d log(Kn)) + η Kn 2 Θ(σ 2 ξ d) w (t,J) 1,r , ξ k,i ≤ w (t,0) 1,r , ξ k,i + ηΘ(σ 2 ξ √ d log(Kn)) + η Kn 2 Θ(σ 2 ξ d)O max w (t,0) 1,r , ξ k,i , C , (t) k,i = w (t,0) 1,r(t) , ξ k,i . Let us compare the size of the second and third term on the right hand side of Eq. (F.10). The third term outweighs the second term at least by a factor of Θ(Jγσ 2 ξ d/(Kn 2 )) = log(d) 0.9 , due to the presence of the constant C in the third term. Therefore, the second term is dominated by the third term for all r ∈ [m] and for all t ≥ 0. So Eq. (F.10) becomes w (t+1,0) 1,r , ξ k,i ≤ w (t,0) 1,r , ξ k,i + η Kn 2 Θ(σ 2 ξ d)O max w (t,0) 1,r , ξ k,i , C ≤ max w (t,0) 1,r , ξ k,i , C + η Kn 2 Θ(σ 2 ξ d)O max w (t,0) 1,r , ξ k,i , C ≤ max Ξ (t) k,i , C + η Kn 2 Θ(σ 2 ξ d)O max Ξ (t) k,i , C = Ξ (t) k,i 1 + O(1) ησ 2 ξ d Kn 2 . Since the above inequality holds for all r ∈ [m] on the left hand side, it holds for r(t + 1) in particular. This gives Ξ (t+1) k,i ≤ Ξ (t) k,i 1 + O(1) ησ 2 ξ d Kn 2 . And it is straightforward to check by definition of .11) We are now ready to prove our theorem. Compare Eq. (F.9) and Eq. (F.11), we can use Lemma D.5 Ξ (t) k,i that Ξ (t+1) k,i ≤ Ξ (t) k,i 1 + O(1) ησ 2 ξ d Kn 2 . ( to get an upper bound on Ξ (T (2) ) k,i . We take A = Ω(1)ηγ J , B = O(1) ησ 2 ξ d Kn2 , G = 2, D = m -1/2 γ -J in Lemma D.5. We see that since Λ (0) 1 = Θ(d -1/2 ) > G -A/B = 2 -log(d) 1.5 under Condition D.1, Lemma D.5 tells us that Ξ (T (2) ) k,i ≤ O( Ξ k,i ) . This proves our main theorem. It remains to verify our previous hypothesis. Note that if we suppose all of our hypothesis hold at time t < T (2) , then Theorem F.9 implies that for all (k, i) ∈ Ψ we have Ξ (t+1) k,i ≤ O( Ξ (0) k,i ). By Lemma D.2, it is clear that Hypothesis F.5 and F.6 hold at initialization. Proof of Hypothesis F.5. The proof will be similar to the proof of Hypothesis E.2. Let k ∈ [K] and i ∈ [n] be such that y k,i = 1. Recall Eq. (C.7) and that j=±1 m r=1 j σ( w (t+1,τ ) j,r , ξ k,i )+σ( w (t+1,τ ) j,r , ν k ) ≤ m r=1 σ( w (t+1,τ ) 1,r , ξ k,i ) + σ( w (t+1,τ ) 1,r , ν k ) . (F.12) Assume our Hypothesis hold at time t. We have m r=1 σ( w (t+1,0) 1,r , ξ k,i ) ≤ m r=1 σ Ξ (t+1) k,i ≤ mσ O( Ξ (0) k,i ) = o(1) , (F.13) where the first inequality is by definition of Ξ (t+1) k,i and the second inequality is because Ξ (t+1) k,i ≤ O( Ξ (0) k,i ). The third equality is because m = polylog(d) and σ O( Ξ (0) k,i ) = O(d -1 ). For any r ∈ [m], we have w (t+1,0) 1,r , ν k = w (t+1,0) 1,r , ν + w (t+1,0) 1,r , z k ≤ Λ (t+1) 1 + max r ∈[m] w (t+1,0) 1,r , z k ≤ O(1)Λ (t+1) 1 ≤ O(1)Λ (T (2) ) 1 = O(m -1/2 (1 + γ) -J ) , where the second line is by definition of Λ (t+1) 1 . The third line is by Lemma F.1. The fourth line is because t + 1 ≤ T (2) and Λ (t) 1 is increasing in t. The last line is by definition of T (2) . Thus m r=1 σ( w (t+1,τ ) 1,r , ν k ≤ m • O(m -1/2 (1 + γ) -J ) = o(1) . (F.14) Combining Eq. (C.7), Eq. (F.12), Eq. (F.13), and Eq. (F.14), we have - (t+1,0) k,i ≥ -(o(1) + o(1)) = Ω(1) . Now, suppose that - (t+1,τ ) k,i = Ω(1) for some 0 ≤ τ ≤ J -1. We wish to show that - (t+1,τ +1) k,i = Ω(1). By Eq. (F.8), we know that for i > n 1 , we have  w (t+1,τ +1) 1,r , ξ k,i ≤ O max w (t+1,0) 1,r , ξ k,i , C ≤ O max Ξ (t+1) k,i , C = O Ξ (t+1) k,i (t+1,τ +1) 1,r , ξ k,i ) ≤ m r=1 σ O( Ξ (t+1) k,i ) ≤ mσ O( Ξ (0) k,i ) = o(1) , (F.15) where the second inequality is because Ξ (t+1) k,i ≤ O( Ξ (0) k,i ). If i ∈ [n 1 ], then by Eq. (C.4) we can bound w (t+1,τ +1) 1,r , ξ k,i ≤ w (t+1,τ ) 1,r , ξ k,i + γΘ(σ 2 ξ √ d log(Kn)) + γ n 1 Θ(σ 2 ξ d)σ ( w (t+1,τ ) 1,r , ξ k,i ) ≤ w (t+1,τ ) 1,r , ξ k,i 1 + γΘ(σ 2 ξ d)/n 1 + γΘ(σ 2 ξ √ d log(Kn)) ≤ max w (t+1,τ ) 1,r , ξ k,i , C 1 + γΘ(σ 2 ξ d)/n 1 , where we used -(•) ≤ 1 and σ (•) ≤ 1 and Lemma D.2 in the first inequality. And we used σ (x) ≤ 2|x| for any x ∈ R to get the second inequality. This implies max w (t+1,τ +1) 1,r , ξ k,i , C ≤ max w (t+1,τ ) 1,r , ξ k,i , C 1 + γΘ(σ 2 ξ d)/n 1 ≤ max w (t+1,0) 1,r , ξ k,i , C 1 + γΘ(σ 2 ξ d)/n 1 τ +1 = Ξ (t+1) k,i 1 + γΘ(σ 2 ξ d)/n 1 τ +1 ≤ O Ξ (0) k,i 1 + γΘ(σ 2 ξ d)/n 1 τ +1 = O(d -1/2 ) , where the third line is by definition of Ξ (t+1) k,i . The fourth line is because Ξ (t+1) k,i ≤ O( Ξ k,i ). And the last equality is by direct calculation. Thus w (t+1,τ +1) 1,r , ξ k,i ≤ max w (t+1,τ +1) 1,r , ξ k,i , C ≤ O(d -1/2 ) . This gives m r=1 σ( w (t+1,τ +1) 1,r , ξ k,i ) ≤ m r=1 σ O(d -1/2 ) = m • σ O(d -1/2 ) = o(1) . Using Eq. (C.5), we have w (t+1,τ +1) 1,r , ν k ≤ w (t+1,τ ) 1,r , ν k + γΘ(1) w (t+1,τ ) 1,r , ν k ≤ w (t+1,τ ) 1,r , ν k (1 + Θ(1)γ) ≤ w (t+1,0) 1,r , ν k (1 + Θ(1)γ) τ +1 ≤ w (t+1,0) 1,r , ν k (1 + Θ(1)γ) J ≤ O(m -1/2 (1 + γ) -J )(1 + Θ(1)γ) J = O(m -1/2 ) , (F.16) where we used (•) ≤ 1, σ (x) ≤ 2|x| for any x ∈ R and ν k 2 = Θ(1) in the first inequality. The fifth inequality is by t + 1 ≤ T (2) and the definition of T (2) . Then we have m r=1 σ( w (t+1,τ +1) 1,r , ν k ≤ m • O(m -1 ) = O(1) . (F.17) The inequality is because σ(O(m -1/2 )) = O(m -1 ). Combining Eq. (C.7), Eq. (F.12), Eq. (F.15), and Eq. (F.17) we obtain - (t+1,τ +1) k,i ≥ -(O(1) + o(1)) = Ω(1) . This proves the hypothesis. Proof of Hypothesis F.6. By Eq. (F.16), we have that w (t+1,τ ) 1,r , ν k < O(m -1/2 ). Thus if w (t+1,τ ) 1,r , ν k > 0, then σ ( w (t+1,τ ) 1,r , ν k ) = 2 w (t+1,τ ) 1,r , ν k , by definition of σ (•). Proof. Without loss of generality, assume y k1,i1 = y k2,i2 = j = 1. We can essentially use the proof of Hypothesis F.5 to get that m r=1 σ( w F.19b) for q ∈ {1, 2}. Moreover, since j w (t,τ ) j,r , ν k is increasing in t and τ for any j (t,J) 1,r , ξ kq,iq ) = o(1) , (F.19a) m r=1 σ( w (t,J) -1,r , ξ kq,iq ) = o(1) , ∈ {-1, 1}, r ∈ [m], k ∈ [K], we get that w (t,J) -1,r , ν k ≤ w (0,0) -1,r , ν k ≤ O(d -1/2 ). This implies m r=1 σ( w (t,J) -1,r , ν kq ) = o(1) , for q ∈ {1, 2}. By Eq. (C.7), it suffices to show that m r=1 σ( w (t,J) 1,r , ν k1 ) -σ( w (t,J) 1,r , ν k2 ) ≤ O(1) . We divide the weights into two sets. Let G = r ∈ [m] : w (t,0) 1,r , ν ≥ d -1/3 , and H = r ∈ [m] : r / ∈ G . Let us first consider r ∈ G. By Corollary F.5, we have that  w (t,J) 1,r , z kq = O(σ 2 s d) w (t,J) 1,r , ν , for q ∈ {1, 2}, which implies w (t,J) 1,r , ν kq = (1 + O(σ 2 s d)) w (t,J) 1, σ( w (t,J) 1,r , ν kq ) ≤ r∈G σ w (t,J) 1,r , ν + O(σ 2 s d)2 max Λ (t) 1 , Jγ ≤ r∈G σ( w (t,J) 1,r , ν ) + r∈G O(1)O(σ 2 s d)2 max Λ (t) 1 , Jγ ≤ r∈G σ( w (t,J) 1,r , ν ) + mO(σ 2 s d)2 max Λ (t) 1 , Jγ ≤ r∈G σ( w (t,J) 1,r , ν ) + O(1) , (F.20) where in the second line we used σ (x + ε) ≤ σ(x) + O(1)ε for ε ≥ 0. The third line is due to |G| ≤ m. If r ∈ H, then by Lemma F.3 we have w (t,0) 1,r , z kq < d -1/3 . Then by Eq. (C.5) we have  w (t,τ +1) 1,r , ν kq ≤ w (t,τ ) 1,r , ν kq (1 + O(1)γ) , for τ ∈ [J -1]. Hence w (t,J) 1,r , ν kq ≤ w (t,0) 1,r , ν kq (1 + O(1)γ) J ≤ O(d -1/3 Ξ (t) 1 ≤ Ξ (t) 1 -Θ ησ 2 ξ √ d Kn 2 (k ,i )∈Ψ+ (t,J) 1 Ξ (t) 1 + Θ ησ 2 ξ √ d Kn 2 (k ,i )∈Ψ- o(1) Ξ (t) 1 -Θ ησ 2 ξ d Kn 2 (t,J) 1 Ξ (t) 1 ≤ Ξ (t) 1 -Θ(d -1/2 ) + Θ ησ 2 ξ d Kn 2 (t,J) 1 Ξ (t) 1 + o(d -1/2 ) Ξ (t) 1 = Ξ (t) 1 1 -Θ ησ 2 ξ d Kn 2 (t,J) 1 + o(d -1/2 ) = Ξ (t) 1 1 -Θ ησ 2 ξ d Kn 2 (t,J) 1 , where we used Eq. (F.23) to get the second inequality. We used Eq. (F.23) and Eq. (F.24) to get the fourth inequality. The last equality is because -Θ ησ 2 ξ d Kn2 (t,J) 1 = Θ(1)ω(d -1/8 ) o(d -1/2 ). And maximizing the left hand side over r ∈ [m] and (k, i) ∈ Ψ + we get Ξ (t+1) 1 ≤ Ξ (t) 1 1 -Θ ησ 2 ξ d Kn 2 (t,J) 1 . (F.25) By Theorem F.9, we have that Ξ (T (2) ) 1 = d -1/3 . Let T be the first iteration such that Ξ (T ) 1 ≥ d -1/4 . Then, using Eq. (F.25) we get Ξ (T ) 1 ≤ Ξ (T -1) 1 1 -Θ ησ 2 ξ d Kn 2 (T -1,J) 1 ≤ d -1/4 1 -Θ ησ 2 ξ d Kn 2 (T -1,J) 1 ≤ d -1/4 1 + Θ ησ 2 ξ d Kn 2 = Θ(d -1/4 ) , where the second inequality is by definition of T . The third inequality is due to -(•) ≤ 1. We have d -1/4 ≤ Ξ (T ) 1 ≤ Ξ (T (2) ) 1 T -1 t=T (2) 1 -Θ ησ 2 ξ d Kn 2 (t,J) 1 , Λ (t ) 1 ≥ log(d) 1.1 . By Corollary F.5, we have that w L tr-val (W (t) ) ≤ L tr-val (W (t ) ) = (k,i)∈Ψ α(t , J, k, i) ≤ -2 (k,i)∈Ψ α(t , J, k, i) ≤ O 1 T η , where the last inequality is by Lemma F.20. This proves our theorem. Theorem F.22 (Restatement of Theorem 4.5 (2)). Let W = W (T ) be our trained weights using the train-validation method. Then the test loss is small L test (W) = o 1/polylog(d) . Proof. Sample a new example x from the distribution, since we use the convolutional structure, we can assume that the first patch x (1) = y • (ν + z) and the second patch x (2) = ξ. Clearly ξ k,i , ξ follows the Gaussian distribution with mean zero and standard deviation σ ξ • ξ k,i 2 . By Lemma D.2, we further know that ξ k,i 2 ≤ 2σ ξ (F.32) In the following, we bound I 1 and I 2 respectively. Bounding I 1 : When event E holds, for any j ∈ {-1, 1}, we have = log(1 + exp(F -1 (W -1 , x))) ≤ 1 + F -1 (W -1 , x) ≤ 2 + O( ξ 2 ) , (F.36) where the second line is due to the fact that σ(•) ≥ 0, the fourth line is by by the property of cross-entropy loss, i.e., log(1 + exp(x)) ≤ 1 + x for all x ≥ 0. Then we further have that  I 2 ≤ E[1(E c )] • E yf (W, x)



.1) Train-validation: for each task k, we denote by I tr k = {1, . . . , n 1 } the training data indices, and I val k = {n 1 + 1, . . . , n} the validation data indices. We then use S tr k = {(x k,i , y k,i )} i∈I tr k as the training data set, and S val k = {(x k,i , y k,i )} i∈I val k as the validation data set. The meta objective of the train-validation method is to minimize

Figure 1: Comparison of test loss over time of FOMAML trained with train-validation and with train-train. The test loss for the train-validation method decreases almost monotonically, whereas the test loss for the train-train method first decreases and then increases due to overfitting. B.3 FOMAML WITH DIFFERENT NEURAL NETWORKS In this section, we show that FOMAML with train-validaiton still outperforms FOMAML with train-train on CNN with more layers and ResNet. The experiments are run on miniImagenet. The results are recorded in Table5. We see that FOMAML with train-validation outperforms FOMAML with train-train by a large margin under all 3 neural network structures.Table5: Performance w.r.t. different neural networks.

2. Under Condition D.1 and Condition 4.1, the following estimates hold with probability at least 1 -(Kn) -10 . 1

ξ k,î ) = O(1). Because that would imply -

τ ∈ [J -1]. Here we have used -(•) ≤ 1, σ (•) ≤ 1 and Lemma D.2 to get the inequality. If w

and we are done. Now that we know the noise inner products are upper bounded by O m -1/2 1 + γσ 2 ξ d n -J+1 , we can provide a convergence guarantee based on this upper bound. Theorem E.10 (Restatement of Theorem 4.4 (1)). Under Condition D.1 and Condition 4.1, let T = poly(d) be the total number of iterations. Then min t∈[T ]

Theorem E.11 (Restatement of Theorem 4.4 (2)). Under Condition 4.1 and Condition D.1, the test loss is large throughout the whole training process min t∈[T ]

, ν is increasing in t for any r ∈ [m]. The sixth line is by definition of Λ (t ) 1

σ 2 s d log(mKn) = σ s log(mKn)/polylog(d) = o(1)D 0 . This proves the lemma. Lemma F.2. Under Condition D.1 and Condition 4.1, for any j ∈ {-1, 1}, r ∈ [m], if at time t we have j w (t,0) j,r , ν ≥ d -1/3 , then for any k ∈ [K] Without loss of generality, assume j = 1. And suppose w (t,0) 1,r , ν ≥ d -1/3 for some r ∈ [m]. By Eq. (C.11), we have

Under Condition D.1 and Condition 4.1, for any j ∈ {-1, 1}, r ∈ [m] and k ∈ [K] if at time t we have j w

= 1 and z k 2 2 = O(σ 2 s d) by Condition 4.1. Since we are given w (t,0) 1,r , z k = O(σ 2 s d) w (t,0)

Lemma F.7 (Restatement of Lemma 6.2). Under Condition D.1 and Condition 4.1, if one uses the train-validation method, then for any t ≤ T (2) , 1. For any k ∈ [K] and r = arg max r ∈[m] w (t,0) 1,r , ν

where the second and third line are by definition of Ξ

(t)  , ν for any k ∈ [K]. Moreover, for t ≥ t w log(d)1.1 . By Eq. (C.7), for (k, i) ∈ Ψ + we haver , ξ k,i ) + σ( w (t,J) -1,r , ν k ) ≤ -Θ(1) log(d) 1.1 -o(1) = -Θ(1) log(d) 1.1 ≤ O(1) exp -Θ(1) log(d) 1.1 ,where the second inequality is due to Eq. (F.19). The last inequality is by property of the function (•). This impliesΛ log(d) 1.1 + ηT O(1) exp -Θ(1) log(d) 1.1 ≤ O(1) log(d) 1.1 + o(1) = O(1) log(d) 1.1 , where the second last inequality is because ηT = O(poly(d)/polylog(d)) = O(poly(d)) and O(poly(d)) exp -Θ(1) log(d) 1.1 = o(1). This proves our theorem. Proof of Claim F.10. Suppose our claim holds for all t ∈ [T (2) , T -1] where T ≤ T . Then by Theorem F.19, we have that Λ (T ) j ≤ log(d) 1.2 for j ∈ {-1, 1}. Let (k, i) ∈ Ψ -. Denote r(t) = arg max r∈[m] w (t,0)1,r , ξ k,i . From Eq. (C.9), we get that w used Lemma D.2 in the first line. The third line is a direct calculation using Condition D.1. We have for t ≥ T (2) inequality is by definition of T (2) and the monotonicity of Λ . By definition of (•) and (•) we can bound -(x) ≥ exp(-x)/2 ≥ (x)/2 for x > 0. This can be used to bound our loss by min t∈[T ]

| ξ k,i , ξ | ≥ d -1/4 ≤ 2 expbe the event that | ξ k,i , ξ | ≤ d -1/4 and | w (0,0) j,r , ξ | ≤ d 3/2 for all k ∈ [K], i ∈ [n], r ∈ [m] and j ∈ {-1, 1}.Applying union bound, we have that the P(E) ≥ 1 -2(nK + m) exp(-d 1/4 ). We can divide L test (W) into two parts:L test (W) = E yf (W, x) = E[1(E) yf (W, x) ] I1 + E[1(E c ) yf (W, x) ] I2 .

j,r , ξ + ηd -1/4 T (2) + ηd -, ξ + O(d -1/4 ) + O(d -1/4 ) = O(d -1/4 ) ,(F.33)    where the second line is by Eq. (F.30). The third line is by 0 ≤ σ (•) ≤ 1. The third last line is due to -(•) ≤ 1, T (2) = polylog(d) and (recall Eq. (F.29)). The second last line is by Eq. (F.31). We note that Eq. (F.33) holds for all j ∈ {-1, 1}, r ∈ [m], and y ∈ {-1, 1}. Now without loss of generality we assume y = 1. Then,F -1 (W -1 , x) r , ν + z ) + mσ( O(d -1/4 )) r , ν + z ) + mσ( O(d -1/4 )) ≤ m r=1 σ( O(d -1/4 )) + mσ( O(d -1/4 )) ≤ log 2 ,where the second line is by Lemma Eq. (F.33), the third line is by monotonicity of j w(t,0) j,r , ν and j w (t,0) j,r , z in t, the fourth line is by Lemma D.2. Thus yf (W, x) = F +1 (W +1 , x) -F -1 (W -1 , x) , ν + z ) -F -1 (W -1 , x) F -1 (W -1 , x) ≤ log(d) 0.1 -log 2 ≤ 2 exp -log(d) 0.1 ,where the second and the third line are due to the fact that σ(•) ≥ 0, the fourth line is by Theorem F.18, the fifth line is by definition of T (3) , the last line is by log(1 + x) ≤ x, ∀x ≥ 0. Therefore we have thatI 1 ≤ 2 exp -log(d) 0.1 . (F.34)Bounding I 2 : Next we bound the second term I 2 . When event E holds, for any j ∈ {-1,t=0 (k,i)∈Ψ (t,J) k,i σ ( w (t,J) j,r , ξ k,i ) O( ξ 2 ) , ξ + O( ξ 2 ) + O( ξ 2 ) = O( ξ 2 ) , (F.35)where the second line is by| ξ, ξ k,i | ≤ ξ k,i 2 ξ 2 = O( ξ 2 ). The third line is by 0 ≤ σ (•) ≤ 1. The third last line is due to -(•) ≤ 1, T (2) = polylog(d) and = O( ξ 2 ).We note that Eq. (F.33) holds for all j ∈ {-1, 1}, r ∈ [m], and y ∈ {-1, 1}. Now again without loss of generality we assume y = 1 and have thatF -1 (W -1 , x) r , ν + z ) + mσ( O( ξ 2 )) r , ν + z ) + mσ( O( ξ 2 )) ≤ m r=1 σ( O(d -1/2 )) + mσ( O( ξ 2 )) ≤ 1 + O( ξ 2 ) ,where the second line is by Eq. (F.35). The third line is by monotonicity of j w (t,0) j,r , ν and j w (t,0) j,r , z in t. The fourth line is by Lemma D.2. Thus yf (W, x) = F +1 (W +1 , x) -F -1 (W -1 , x)≤ -F -1 (W -1 , x)

P(E c ) • 8 + O(1)E[ ξ 2 2 ] ≤ O(poly(d)) exp(-d 1/4 ) = O(poly(d)) exp(-0.5d 1/4 ) ,where the first inequality is by Cauchy-Schwartz inequality. The second inequality is by Eq. (F.36). The third inequality is by the fact that 8 + O(1)E[ ξ 2 2 ] ≤ O(poly(d)). Plugging the bounds of I 1 , I 2 into Eq. (F.32) gives that L test (W) ≤ 2 exp -log(d) 0.1 + O(poly(d)) exp(-0.5d 1/4 ) = O exp -log(d) 0.1 , which completes the proof.

Performance comparison between different activation functions.

Performance comparison of the number of optimization steps in the inner-loop.

Performance w.r.t. the inner-loop learning rate and the outer-loop learning rate.

Performance comparison between FOMAML and Reptile. TIME EVOLUTION OF TEST LOSS OF FOMAMLWe plot the time evolution of test loss using trained by FOMAML with train-validation and with train-train on our synthetic data set. The results are illustrated in Figure1aand 1b for train-validation and train-train, respectively.

Performance w.r.t. different neural networks.

Performance comparison w.r.t. the number of samples per class on miniImagenet.

for all (k, i) ∈ Ψ + . Then by our assumption that = ω(d -1/8 ). By Eq. (C.9) we have that for any r ∈ [m], and (k, i) ∈ Ψ +

ACKNOWLEDGEMENTS

We thank the anonymous reviewers and area chair for their helpful comments. ZC and QG are supported in part by the National Science Foundation IIS-2008981 and the Sloan Research Fellowship.

annex

Published as a conference paper at ICLR 2023

F.2 PHASE II

We have shown that when the feature inner product has grown to Θ(1), the noise inner products remain at Θ(d -1/2 ), which is exactly the opposite of Theorem E.2. Next we will show that this difference is maintained at least when Λ (t) j has grown to O(log(d) 0.1 ). For (k, i) ∈ Ψ, letw (t,0) y k,i ,r , ξ k,i , d -1/3 .And Γ (t)w (t,0)-y k,i ,r , ξ k,i , C .Let T(3) j be the first iteration such that Λ (t) j ≥ log(d) 0.1 . Let T (3) = max T(3)-1 , T(3) 1. And recall T = poly(d) is the total number of iterations. We first present two claims under Condition D.1 and Condition 4.1. Claim F.10. For any t ≥ T (2) , any (k, i) ∈ Ψ, we have Γ (t) k,i ≤ Θ(d -1/2 ). Claim F.11. For any t ≤ T(3) j , (k, i) ∈ Ψ j , we haveRemark F.12. Under Claim F.10, it holds that Γ (t)k ,i for all (k, i), (k , i ) ∈ Ψ. Remark F.13. By Theorem F.9, Claim F.10 holds at time T (2) . Also by Theorem F.9, Claim F.11 holds for all t ≤ T (2) .Note that from Eq. (C.11), the growth of j w (t,0) j,r , ν depends on the size of j w (t,J) j,r , ν k . We first need a lemma that can gives a lower bound on j w (t,J) j,r , ν k for t ≥ T (2) . The following lemma will be of a similar form to Lemma E.7. Lemma F.14. For anyProof. For simplicity, take j = 1. We shall see that our proof does not depend on the choice of j and thus taking j = 1 is not a loss of generality. By Eq. (C.5), we know that wBy Eq. (C.11) and the definition of r(t), we also have that w (t,0) 1,r(t) , ν is an increasing function in t. By definition of T (2) , for any t ≥ T (2) , it holds that w (t,0) 1,r(t) , ν ≥ m -1/2 (1 + γ) -J . Then by Lemma F.2, we have w). Let us suppose that F.18 is true for some τ . Under our assumption that Ξ (t) k,i ≤ O(d -1/4 ), we can apply the same proof as Hypothesis F.5 to get that -= Ω(1) for any τ ∈ [τ ]. Then using the same derivation of Eq. (F.7), we obtainThis concludes the lemma.The next lemma shows that as long as the noise vectors have not been learnt well by our network, the loss from different samples are essentially due to the feature and hence are of similar sizes.Lemma F.15. Let j ∈ {-1, 1}. Under Condition D.1 and Condition 4.1, whilewhere the second line is due to Eq. (F.20) and Eq. (F.21). Since the above holds for q ∈ {1, 2}, we getAnd using Eq. (C.7), we conclude thatLemma F.16. Under Claim F.11, for any t ≤ T(3) jwe have -Proof. Without loss of generality, let us consider j = 1. Let (k, i) ∈ Ψ + . By Eq. (C.7), Eq. (F.19) and Eq. (F.22) we have thatwhere in the first line1 < log(d) 0.1 . Using Lemma F.6, we have that for all r ∈ [m] w (t,J) 1,r , ν ≤ 2 max log(d) 0.1 , Jγ = 2Jγ . Then by a direct calculation using Condition D.1, we getRemark F.17. The -1/8 factor on d can be replaced by any c < 0. We choose -1/8 for convenience in the later proof.where we note the second inequality in Remark F.12. It is convenient to definewhere we get the third inequality by taking the logarithm of the second inequality. In the fourth inequality, we used log(1 + x) ≤ O(1)x for x ∈ [0, 1). By the second line of Eq. (F.9) we haveTo get the second inequality, we use Lemma F.15 and absorb all the constants into Ω(m -1/2 ). Summing over t, we obtainWe used -(•) ≤ 1 to get the last inequality. By a direct computation, we haveCombining this fact with Eq. (F.26) and Eq. (F.27) we conclude that T(3) 1-1 ≥ d -1/4 , then using the exactly same argument, we obtain T(3) -1 < T . This proves our Claim.Theorem F.18 (Restatement of Lemma 6.5). Under Claim F.10, we have T (3) < T .Proof. We will only prove T(3) 1 < T , since the proof of T(3) -1 < T will be the same. For the sake of contradiction, let us assume that T(3) 1 ≥ T . Then, for t ≤ T , we have that Λ (t) 1 ≤ log(d) 0.1 . From the proof of Lemma F.16, we havefor all (k, i) ∈ Ψ + . By the second line of Eq. (F.9)where the second inequality is by definition of T (2) .The third inequality is because ≤ log(d) 0.1 . Thus, we must have T > T(3) 1 . And the proof of T > T(3) -1 is the same. This proves our theorem. This will be useful for our convergence result later.Theorem F.18 implies that before training ends, feature will be memorized by the neurons with an inner product of size at least log(d) 0.1 . As we shall see later, this already guarantees a small test loss. Our next theorem makes sure that the feature inner product will not grow too big by the end of the training.Theorem F.19. Under Claim F.10, we haveProof. Without loss of generality, let j = 1. Let t be the first iteration such that, where the second line is by definition of r(t). We get the third line using the second line of Eq. (F.9). We have used σ (•) ≤ 1 to get the last line. By definition of r(t), this meansAnd by definition of t , we have Plugging Eq. (F.29) into Eq. (F.28) we havewhere the third inequality is by Theorem F.9. Since this relation holds for all r ∈ [m], we get that). This proves our claim.Lemma F.20 (Restatement of Lemma 6.4). Let T = poly(d) be the total number of iterations. Under Condition D.1 and Condition 4.1, we haveProof. By Theorem F.19, we have Λ (T ) j≤ log(d) 1.2 for j ∈ {-1, 1}. Using the same derivation of Eq. (F.29), we haveAnd we can bound the minimum over t bywhere the second equality is because L tr-val (W (t) ) ≤ Θ(T -1 η -1 ) .Proof. Denote - 

