PROVABLY LEARNING DIVERSE FEATURES IN MULTI-VIEW DATA WITH MIDPOINT MIXUP

Abstract

Mixup is a data augmentation technique that relies on training using random convex combinations of data points and their labels. In recent years, Mixup has become a standard primitive used in the training of state-of-the-art image classification models due to its demonstrated benefits over empirical risk minimization with regards to generalization and robustness. In this work, we try to explain some of this success from a feature learning perspective. We focus our attention on classification problems in which each class may have multiple associated features (or views) that can be used to predict the class correctly. Our main theoretical results demonstrate that, for a non-trivial class of data distributions with two features per class, training a 2-layer convolutional network using empirical risk minimization can lead to learning only one feature for almost all classes while training with a specific instantiation of Mixup succeeds in learning both features for every class. We also show empirically that these theoretical insights extend to the practical settings of image benchmarks modified to have additional synthetic features. E λ∼D λ λ log ϕ yi (g(z i,j )) + (1 -λ) log ϕ yj (g(z i,j )) (2.2)

1. INTRODUCTION

Data augmentation techniques have been a mainstay in the training of state-of-the-art models for a wide array of tasks -particularly in the field of computer vision -due to their ability to artificially inflate dataset size and encourage model robustness to various transformations of the data. One such technique that has achieved widespread use is Mixup (Zhang et al., 2018) , which constructs new data points as convex combinations of pairs of data points and their labels from the original dataset. Mixup has been shown to empirically improve generalization and robustness when compared to standard training over different model architectures, tasks, and domains (Liang et al., 2018; He et al., 2019; Thulasidasan et al., 2019; Lamb et al., 2019; Arazo et al., 2019; Guo, 2020; Verma et al., 2021b; Wang et al., 2021) . It has also found applications to distributed private learning (Huang et al., 2021) , learning fair models (Chuang and Mroueh, 2021) , semi-supervised learning (Berthelot et al., 2019b; Sohn et al., 2020; Berthelot et al., 2019a) , self-supervised (specifically contrastive) learning (Verma et al., 2021a; Lee et al., 2020; Kalantidis et al., 2020) , and multi-modal learning (So et al., 2022) . The success of Mixup has instigated several works attempting to theoretically characterize its potential benefits and drawbacks (Guo et al., 2019; Carratino et al., 2020; Zhang et al., 2020; 2021; Chidambaram et al., 2021) . These works have focused mainly on analyzing, at a high-level, the beneficial (or detrimental) behaviors encouraged by the Mixup-version of the original empirical loss for a given task. As such, none of these previous works (to the best of our knowledge) have provided an algorithmic analysis of Mixup training in the context of non-linear models (i.e. neural networks), which is the main use case of Mixup. In this paper, we begin this line of work by theoretically separating the full training dynamics of Mixup (with a specific set of hyperparameters) from empirical risk minimization (ERM) for a 2-layer convolutional network (CNN) architecture on a class of data distributions exhibiting a multi-view nature. This multi-view property essentially requires (assuming classification data) that each class in the data is well-correlated with multiple features present in the data. Our analysis is heavily motivated by the recent work of Allen-Zhu and Li (2021) , which showed that this kind of multi-view data can provide a fruitful setting for theoretically understanding the benefits of ensembles and knowledge distillation in the training of deep learning models. We show that Mixup can, perhaps surprisingly, capture some of the key benefits of ensembles explained by Allen-Zhu and Li (2021) despite only being used to train a single model. Main Contributions and Outline. Our main contributions are three-fold. In Sections 2 and 3, we introduce the main ideas behind Mixup and analyze a simple, linearly separable multi-view data distribution which we use to lay the groundwork for our main results. In analyzing this distribution, we motivate the use of a particular setting of Mixup -which we refer to as Midpoint Mixup -in which training is done on the midpoints of data points and their labels. Section 4 contains our main results; we prove that, for a highly noisy class of data distributions with two features per class, minimizing the empirical cross-entropy using gradient descent can lead to learning only one of the features in the data while minimizing the Midpoint Mixup cross-entropy succeeds in learning both features. While our theory focuses on the case of two features/views per class to be consistent with Allen-Zhu and Li (2021) , our techniques can readily be extended to more general multi-view data distributions. Lastly, in Section 5, we conduct experiments illustrating that our theoretical insights in Sections 3 and 4 can apply to the training of realistic models on image classification benchmarks. We show for each benchmark that, after modifying the training data to include additional spurious features correlated with the true labels, both Mixup (with standard settings) and Midpoint Mixup outperform ERM on the original test data, with Midpoint Mixup closely approximating the performance of regular Mixup. Related Work. The idea of training on midpoints (or approximate midpoints) is not new; both Guo (2021) and Chidambaram et al. (2021) empirically study settings resembling what we consider in this paper, but they do not develop theory for this kind of training (beyond an information theoretic result in the latter case). As mentioned earlier, there are also several theoretical works analyzing the Mixup formulation and it variants (Carratino et al., 2020; Zhang et al., 2020; 2021; Chidambaram et al., 2021; Park et al., 2022) , but none of these works contain optimization results (which are the focus of this work). Additionally, we note that there are many Mixup-like data augmentation techniques and training formulations that are not (immediately) within the scope of the theory developed in this paper. For example, Cut Mix (Yun et al., 2019) , Manifold Mixup (Verma et al., 2019) , Puzzle Mix (Kim et al., 2020) , Co-Mixup (Kim et al., 2021) , and Noisy Feature Mixup (Lim et al., 2021 ) are all such variations. Our work is also influenced by the existing large body of work theoretically analyzing the benefits of data augmentation (Bishop, 1995; Dao et al., 2019; Wu et al., 2020; Hanin and Sun, 2021; Rajput et al., 2019; Yang et al., 2022; Wang et al., 2022; Chen et al., 2020; Mei et al., 2021) . The most relevant such work to ours is the recent work of Shen et al. (2022) , which also studies the impact of data augmentation on the learning dynamics of a 2-layer network in a setting motivated by that of Allen-Zhu and Li (2021) . However, Midpoint Mixup differs significantly from the data augmentation scheme considered in Shen et al. (2022) , and consequently our results and setting are also of a different nature (we stick much more closely to the setting of Allen-Zhu and Li (2021) ). As such, our work can be viewed as a parallel thread to that of Shen et al. (2022) .

2. PRELIMINARIES AND MOTIVATION FOR MIDPOINT MIXUP

We will introduce Mixup in the context of k-class classification, although the definitions below easily extend to regression. As a notational convenience, we will use [k] to indicate {1, 2, ..., k}. Recall that, given a finite dataset X ⊂ R d × [k] with |X| = N , we can define the empirical cross-entropy loss J(g, X ) of a model g : R d → R k as: J(g, X ) = - 1 N i∈[N ] log ϕ yi g(x i ) where ϕ y (g(x)) = exp(g y (x)) s∈[k] exp(g s (x)) (2.1) With ϕ being the standard softmax function and the notation g y , ϕ y indicating the y-th coordinate functions of g and ϕ respectively. Now let us fix a distribution D λ whose support is contained in [0, 1] and introduce the notation z i,j (λ) = λx i + (1 -λ)x j (using z i,j when λ is clear from context) where (x i , y i ), (x j , y j ) ∈ X . Then we may define the Mixup cross-entropy J M (g, X , D λ ) as: We mention a minor differences between Equation 2.2 and the original formulation of Zhang et al. (2018) . Zhang et al. (2018) consider the expectation term in Equation 2.2 over N randomly sampled pairs of points from the original dataset X , whereas we explicitly consider mixing all N 2 possible pairs of points. This is, however, just to make various parts of our analysis easier to follow -one could also sample N mixed points uniformly, and the analysis would still carry through with an additional high probability qualifier (the important aspect is the proportions with which different mixed points show up; i.e. mixing across classes versus mixing within a class).

3. MOTIVATING MIDPOINT MIXUP: THE LINEAR REGIME

As can be seen from Equation 2.2, the Mixup cross-entropy J M (g, X , D λ ) depends heavily on the choice of mixing distribution D λ . Zhang et al. (2018) took D λ to be Beta(α, α) with α being a hyperparameter. In this work, we will specifically be interested in the case of α → ∞, for which the distribution D λ takes the value 1/2 with probability 1. We refer to this special case as Midpoint Mixup, and note that it can also be viewed as a case of the Pairwise Label Smoothing strategy introduced by Guo (2021). We will write the Midpoint Mixup loss as J M M (g, X ) (here z i,j = (x i + x j )/2 and there is no D λ dependence as it is deterministic): J M M (g, X ) = - 1 2N 2 i∈[N ] j∈[N ] log ϕ yi (g(z i,j )) + log ϕ yj (g(z i,j )) (3.1) We focus on this version of Mixup for a few key reasons. Firstly, we will show that J M M (g, X ) exhibits the nice property that its global minimizer corresponds to a model in which all of the features in the data are learned equally (in a sense to be made precise below). We will also show that this is not the case for J M (g, X , D λ ) when D λ is any other non-trivial distribution. Additionally, from a technical perspective, the Midpoint Mixup loss lends itself to a much cleaner optimization analysis due to the fact that the structure of its gradients is not changing with each optimization iteration (i.e. we do not need to sample new mixing proportions at each step). This allows us to more easily show how the gradient descent dynamics encourage learning all of the features in the data. That being said, we are not trying to claim that Midpoint Mixup is a superior practical alternative to standard Mixup -our goal is simply to show that it better accentuates the theoretical benefits of Mixup, and is empirically comparable to standard Mixup settings. Full proofs for all of the results presented in the next subsection can be found in Section C of the Appendix.

3.1. MIDPOINT MIXUP WITH LINEAR MODELS ON LINEARLY SEPARABLE DATA

To make clear what we mean by feature learning, we first turn our attention to the simple setting of learning linear models g y (x) = ⟨w y , x⟩ (i.e. one weight vector associated per class) on linearly separable data, as this setting will serve as a foundation for our main results. Namely, we consider k-class classification with a dataset X of N labeled data points generated according to the following data distribution (with N sufficiently large): Definition 3.1. [Simple Multi-View Setting] For each class y ∈ [k], let v y,1 , v y,2 ∈ R d be orthonor- mal unit vectors also satisfying v y,ℓ ⊥ v s,ℓ ′ when y ̸ = s for any ℓ, ℓ ′ ∈ [2]. Each point (x, y) ∼ D is then generated by sampling y ∈ [k] uniformly and constructing x as: x = β y v y,1 + (1 -β y )v y,2 β y ∼ Uni([0.1, 0.9]) (3.2) Definition 3.1 is multi-view in the following sense: for any class y, it suffices (from an accuracy perspective) to learn a model g that has a significant correlation with either the feature vector v y,1 or v y,2 . In this context, one can think of feature learning as corresponding to how positively correlated the weight w y is with each of the same class feature vectors v y,1 and v y,1 (we provide a more rigorous definition in our main results). If one now considers the empirical cross-entropy loss J(g, X ), it is straightforward to see that it is possible to achieve the global minimum of J(g, X ) by just considering models g in which we take ⟨w y , v y,1 ⟩ → ∞ for every class y. This means we can minimize the usual cross-entropy loss without learning both features in the dataset X . However, this is not the case for Midpoint Mixup. Indeed, we show below that a necessary (with extremely high probability) and sufficient condition for a linear model g to minimize J M M (when taking its scaling to ∞) is that it has equal correlation with both features for every class (sufficiency relies also on having weaker correlations with other class features  lim γ→∞ J M M (γg, X ) = inf J M M (h, X ) (3.3) If g has the property that for every class y we have ⟨w y , v y,ℓ1 ⟩ = ⟨w s , v s,ℓ2 ⟩ > 0 and ⟨w y , v s,ℓ2 ⟩ ≤ 0 for every s ̸ = y and ℓ 1 , ℓ 2 ∈ [2]. Furthermore, with probability 1 -exp(-Θ(N )) (over the randomness of X ), the condition ⟨w y , v y,ℓ1 ⟩ = ⟨w s , v s,ℓ2 ⟩ is necessary for g to satisfy Equation 3.3. Proof Sketch. The idea is that if g has equal correlation with both features for every class, its predictions will be constant on the original data points due to the fact that the coefficients for each feature in each data point are mirrored as per Equation 3.2. With the condition ⟨w y , v s,ℓ ⟩ ≤ 0 (this can be weakened significantly), this implies the softmax output of g on the Midpoint Mixup points will be exactly 1/2 for each of the classes being mixed (and 0 for all other classes), which is optimal. As mentioned earlier, we can also show that if we consider J M (g, X , D λ ) for any other non-pointmass distribution, we can prove that the analogue of Lemma 3.2 does not hold true. Proposition 3.3. For any distribution D λ that is not a point mass on 0, 1, or 1/2, and any linear model g satisfying the conditions of Lemma 3.2, we have that with probability 1 -exp(-Θ(N )) (over the randomness of X ) there exists an ϵ 0 > 0 depending only on D λ such that: J M (g, X , D λ ) ≥ inf J M (h, X , D λ ) + ϵ 0 (3.4) Proof Sketch. In the case of general mixing distributions, we cannot achieve the Mixup optimal behavior of ϕ yi (g(z i,j (λ))) = λ for every λ if the outputs g y are constant on the original data points. Lemma 3.2 outlines the key theoretical benefit of Midpoint Mixup -namely that its global optimizers exist within the class of models that we consider, and such optimizers learn all features in the data equally. And although Lemma 3.2 is stated in the context of linear models, the result naturally carries through to when we consider two-layer neural networks of the type we define in the next section. That being said, the interpretation of Proposition 3.3 is not intended to disqualify the possibility that the minimizer of J M (g, X , D λ ) when restricted to a specific model class is a model in which all features are learned near-equally (we expect this to be the case in fact for any reasonable D λ ). Proposition 3.3 is moreso intended to motivate the study of Midpoint Mixup as a particularly interesting choice of the mixing distribution D λ . We now proceed one step further from the above results and show that the feature learning benefit of Midpoint Mixup manifests itself even in the optimization process (when using gradient-based methods). We show that, if significant separation between feature correlations exists, the Midpoint Mixup gradients correct the separation. For simplicity, we suppose WLOG that ⟨w y , v y,1 ⟩ > ⟨w y , v y,2 ⟩. Now letting ∆ y = ⟨w y , v y,1 -v y,2 ⟩ and using the notation ∇ wy for ∂ ∂wy , we can prove: Proposition 3.4. [Mixup Gradient Lower Bound] Let y be any class such that ∆ y ≥ log k, and suppose that both ⟨w y , v y,1 ⟩ ≥ 0 and the cross-class orthogonality condition ⟨w s , v u,ℓ ⟩ = 0 hold for all s ̸ = u and ℓ ∈ [2]. Then we have with high probability that: -∇ wy J M M (g, X ), v y,2 ≥ Θ 1 k 2 (3.5) Proof Sketch. The key idea is to analyze the gradient correlation with the direction v y,1 -v y,2 via a concentration of measure argument. Proposition 3.4 shows that, assuming nonnegativity of within-class correlations and an orthogonality condition across classes (which we will show to be approximately true in our main results), the feature correlation that is lagging behind for any class y will receive a significant gradient when optimizing the Midpoint Mixup loss. On the other hand, we can also prove that this need not be the case for empirical risk minimization: Proposition 3.5. [ERM Gradient Upper Bound] For every y ∈ [k], assuming the same conditions as in Proposition 3.4, if ∆ y ≥ C log k for any C > 0 then with high probability we have that: -∇ wy J(g, X ), v y,2 ≤ O 1 k 0.1C-1 (3.6) Proof Sketch. This follows directly from the form of the gradient for J(g, X ). While Proposition 3.5 demonstrates that training using ERM can possibly fail to learn both features associated with a class due to increasingly small gradients, one can verify that this does not naturally occur in the optimization dynamics of linear models on linearly separable data of the type in Definition 3.1 (see for example, the related result in Chidambaram et al. (2021) ). On the other hand, if we move away from linearly separable data and linear models to more realistic settings, the situation described above does indeed show up, which motivates our main results.

4. ANALYZING MIDPOINT MIXUP TRAINING DYNAMICS ON GENERAL MULTI-VIEW DATA

For our main results, we now consider a data distribution and class of models that are meant to more closely mimic practical situations.

4.1. GENERAL MULTI-VIEW DATA SETUP

We adopt a slightly simplified version of the setting of Allen-Zhu and Li (2021) . We still consider the problem of k-class classification on a dataset X of N labeled data points, but our data points are now represented as ordered tuples x = (x (foot_0) , ..., x (P ) ) of P input patches x (i) with each x (i) ∈ R d (so X ⊂ R P d × [k]). As was the case in Definition 3.1 and in Allen-Zhu and Li (2021) , we assume that the data is multiview in that each class y is associated with 2 orthonormal feature vectors v y,1 and v y,2 , and we once again consider N and k to be sufficiently large. As mentioned in Allen-Zhu and Li (2021) , we could alternatively consider the number of classes k to be fixed (i.e. binary classification) and the number of associated features to be large, and our theory would still translate. We now precisely define the data generating distribution D that we will focus on for the remainder of the paper.  (p) = j∈[Q] ℓ∈[2] γ j,ℓ v sj ,ℓ , where the γ j,ℓ ∈ [δ 3 , δ 4 ] can be arbitrary. Note that there are parts of the data-generating process that we leave underspecified, as our results will work for any choice. Henceforth, we use X to refer to a dataset consisting of N i.i.d. draws from the distribution D. Our data distribution represents a very low signal-to-noise (SNR) setting in which the true signal for a class exists only in a constant (2C P ) number of patches while the rest of the patches contain low magnitude noise in the form of other class features. We focus on the case of learning the data distribution D with the same two-layer CNN-like architecture used in Allen-Zhu and Li (2021) . We recall that this architecture relies on the following polynomiallysmoothed ReLU activation, which we refer to as ReLU: ReLU(x) =        0 if x ≤ 0 x α αρ α-1 if x ∈ [0, ρ] x -1 -1 α ρ if x ≥ ρ The polynomial part of this activation function will be very useful for us in suppressing the feature noise in D. Our full network architecture, which consists of m hidden neurons, can then be specified as follows. Definition 4.2. [2-Layer Network] We denote our network by g : R P d → R k . For each y ∈ [k], we define g y as follows. g y (x) = r∈[m] p∈[P ] ReLU w y,r , x (p) (4.1) We will use w (0) y,r to refer to the weights of the network g at initialization (and w (t) y,r after t steps of gradient descent), and similarly g t to refer to the model after t iterations of gradient descent. We consider the standard choice of Xavier initialization, which, in our setting, corresponds to w (0) y,r ∼ N (0, 1 d I d ). For model training, we focus on full batch gradient descent with a fixed learning rate of η applied to J(g, X ) and J M M (g, X ). Once again using the notation ∇ w (t) y,r for ∂ ∂w (t) y,r , the updates to the weights of the network g are thus of the form: w (t+1) y,r = w (t) y,r -η∇ w (t) y,r J M M (g, X ) (4.2) In defining our data distribution and model above, we have introduced several hyperparameters. Throughout our results, we make the following assumptions about these hyperparameters. Assumption 4.3. [Choice of Hyperparameters] We assume that: d = Ω(k 20 ) P = Θ(k 2 ) C P = Θ(1) m = Θ(k) δ 1 , δ 2 = Θ(1) δ 3 , δ 4 = Θ(k -1.5 ) ρ = Θ(1/k) α = 8 Discussion of Hyperparameter Choices. We make concrete choices of hyperparameters above for the sake of calculations (and we stress that these are not close to the tightest possible choices), but only the relationships between them are important. Namely, we need d to be a significantly larger polynomial of k than P , we need δ 3 , δ 4 = o(1) but large enough so that P δ 3 ≫ δ 2 (to avoid learnability by linear models, as shown below), we need α sufficiently large so that the network can suppress the low-magnitude feature noise, and we need δ 1 , δ 2 = Θ(1) so that the signal feature coefficients significantly outweigh the noise feature coefficients. To convince the reader that our choice of model is not needlessly complicated given the setting, we prove the following result showing that there exist realizations of the distribution D on which linear classifiers cannot achieve perfect accuracy. Proposition 4.4. There exists a D satisfying all of the conditions of Definition 4.1 and Assumption 4.3 such that with probability at least 1 -k 2 exp Θ(-N/k 2 ) , for any classifier h : R P d → R k of the form h y (x) = p∈[P ] w y , x (p) and any X consisting of N i.i.d. draws from D, there exists a point (x, y) ∈ X and a class s ̸ = y such that h s (x) ≥ h y (x). Proof Sketch. The idea, as was originally pointed out by Allen-Zhu and Li (2021) , is that there are Θ(k 2 ) feature noise patches with coefficients of order Θ(k -1.5 ). Thus, because the features are orthogonal, these noise patches can influence the classification by an order Θ( √ k) term away from the direction of the true signal. The full proof can be found in Section C of the Appendix.

4.2. MAIN RESULTS

Having established the setting for our main results, we now concretely define the notion of feature learning in our context. Definition 4.5. [Feature Learning] Let (x, y) ∼ D. We say that feature v y,ℓ is learned by g if argmax s g s (x ′ ) = y where x ′ is x with all instances of feature v y,3-ℓ replaced by the all-zero vector. Our definition of feature learning corresponds to whether the model g is able to correctly classify data points in the presence of only a single signal feature instead of both (generalizing the notion of weight-feature correlation to nonlinear models). By analyzing the gradient descent dynamics of g for the empirical cross-entropy J, we can then show the following. Theorem 4.6. For k and N sufficiently large and the settings stated in Assumption 4.3, we have that the following hold with probability at least 1 -O(1/k) after running gradient descent with a step size η = O(1/poly(k)) for O(poly(k)/η) iterations on J(g, X ) (for sufficiently large polynomials in k): 1. (Training accuracy is perfect): For all (x i , y i ) ∈ X , we have argmax s g s t (x i ) = y i . 2. (Only one feature is learned): For (1 -o(1))k classes, there exists exactly one feature that is learned in the sense of Definition 4.5 by the model g t . Furthermore, the above remains true for all t = O(poly(k)) for any polynomial in k. Proof Sketch. The proof is in spirit very similar to Theorem 1 in Allen-Zhu and Li (2021) , and relies on many of the tools therein. The main idea is that, with high probability, there exists a separation between the class y weight correlations with the features v y,1 and v y,2 at initialization. This separation is then amplified throughout training due to the polynomial part of ReLU. Once one feature correlation becomes large enough, the gradient updates to the class y weights rapidly decrease, leading to the remaining feature not being learned. Theorem 4.6 shows that only one feature is learned (in our sense) for the vast majority of classes. As mentioned, our proof is quite similar to Allen-Zhu and Li (2021) , but due to simplifications in our setting (no added Gaussian noise for example) and some different ideas the proof is much shorter -we hope this makes some of the machinery from Allen-Zhu and Li (2021) accessible to a wider audience. The reason we prove Theorem 4.6 is in fact to highlight the contrast provided by the analogous result for Midpoint Mixup. Theorem 4.7. For k and N sufficiently large and the settings stated in Assumptions 4.3, we have that the following hold with probability 1 -O(1/k) after running gradient descent with a step size η = O(1/poly(k)) for O(poly(k)/η) iterations on J M M (g, X ) (for sufficiently large polynomials in k): 1. (Training accuracy is perfect): For all (x i , y i ) ∈ X , we have argmax s g s (x i ) = y i .

2.. (Both features are learned):

For each class y ∈ [k], both v y,1 and v y,2 are learned in the sense of Definition 4.5 by the model g. Furthermore, the above remains true for all t = O(poly(k)) for any polynomial in k. Proof Sketch. The core idea of the proof relies on similar techniques to that of Proposition 3.4, but the nonlinear part of the ReLU activation introduces a few additional difficulties due to the fact that the gradients in the nonlinear par are much smaller than those in the linear part of ReLU. Nevertheless, we show that even these smaller gradients are sufficient for the feature correlation that is lagging behind to catch up in polynomial time. The full proofs of Theorems 4.6 and 4.7 can be found in Section B of the Appendix. Remark 4.8. Theorems 4.6 and 4.7 show a separation between ERM and Midpoint Mixup with respect to feature learning, as we have defined. They are not results regarding the test accuracy of the trained models on the distribution D; even learning only a single feature per class is sufficient for perfect test accuracy on D. The significance (and our desired interpretation) of these results is that, when the training distribution D has some additional spurious features when compared to the testing distribution, ERM can potentially fail to learn the true signal features whereas Midpoint Mixup will likely learn all features (including the true signal). One may also interpret the results as generalization that is robust to distributional shift; the test distribution in this case has dropped some features present in the training distribution.

5. EXPERIMENTS

The goal of the results of Sections 3 and 4 was to provide theory (from a feature learning and optimization perspective) for why Mixup has enjoyed success over ERM in many practical settings. The intuition is that, for image classification tasks, one could reasonably expect images from the same class to be generated from a shared set of latent features (much like our data distribution in Definition 4.1), in which case it may be possible to achieve perfect training accuracy by learning a strict subset of these features when doing empirical risk minimization. On the other hand, based on our ideas, we would expect Mixup to learn all such latent features associated with each class (assuming some dependency between them), and thus potentially generalize better. A direct empirical verification of this phenomenon on image datasets is tricky (and a possible avenue for future work) due to the fact that one would need to clearly define a notion of latent features with respect to the images being considered, which is outside the scope of this work. Instead, we take for granted that such features exist, and attempt to verify whether Mixup is able to learn the "true" features associated with each class better than ERM when spurious features are added. For our experimental setup, we consider training ResNet-18 (He et al., 2015) on versions of Fashion MNIST (FMNIST) (Xiao et al., 2017) , CIFAR-10, and CIFAR-100 (Krizhevsky, 2009) If we work under the intuition that images from each class are generated by relatively different latent features, then this modification process corresponds to adding patches of (fixed) spurious features to each class that have a dependency (from the scaling factor γ) on the original features of the data. We leave the test data for each dataset unmodified, except for the concatenation of an all-zeros vector of the same shape to each point so that the shape of the test data matches that of the training data (in effect, this penalizes models that learned only the spurious features we concatenated in the training data). This zeroing out of the additional channels is also intended to replicate Definition 4.5 in our experimental setup. While we consider the above setup to be intuitive and resemble our theoretical setting, it is fair to ask why we chose this setup compared to the many possible alternatives. Firstly, we found that using synthetic spurious features (i.e. random orthogonal vectors scaled to have the same norm as the images) as opposed to images from different classes was far too noisy (training error went to 0 immediately); the test errors on each dataset degraded to near-random levels, so it was difficult to make comparisons. Additionally, we found the same to be true if we considered adding spurious features as opposed to concatenating them. For each of our image classification tasks, we train models using Mixup with D λ = Beta(1, 1) (the choice used in Zhang et al. (2018) for CIFAR, which we refer to as Uniform Mixup), Midpoint Mixup, and ERM. Our implementation is in PyTorch (Paszke et al., 2019) and uses the ResNet implementation of Kuang Liu, released under an MIT license. All models were trained for 100 epochs with a batch size of 750, which was the largest feasible size on our compute setup of a single P100 GPU (we use a large batch size to approximate the full batch gradient descent aspect of our theory). For optimization, we use Adam (Kingma and Ba, 2015) with the default hyperparameters of β 1 = 0.9, β 2 = 0.999 and a learning rate of 0.001. We did a modest amount of hyperparameter tuning in preliminary experiments, where we compared Adam and SGD with different log-spaced learning rates in the range [0.001, 0.1], and found that Adam with the default hyperparameters almost always worked the best. We report our results for each dataset in 1 : Final test errors on unmodified test data (mean over 5 runs) along with 1 standard deviation range for Uniform Mixup, Midpoint Mixup, and ERM. From Table 1 we see that Uniform Mixup performs the best in all cases, and that Midpoint Mixup tracks the performance of Uniform Mixup reasonably closely. We stress that the ordering of model performance is unsurprising; a truly fair comparison with Midpoint Mixup would require training on all N 2 possible mixed points, which is infeasible in our compute setup (we opt to randomly mix points per batch, as is standard). Our experiments are intended to show that Midpoint Mixup still non-trivially captures the benefits of Mixup in an empirical setting that is far from the asymptotic regime of our theory, while Mixup using standard hyperparameter settings significantly outperforms ERM in the presence of spurious features. A final observation worth making is that we find Midpoint Mixup performs significantly better than ERM when moving from the 10-class settings of FMNIST and CIFAR-10 to the 100-class setting of CIFAR-100, and this is in line with what our theory predicts (a larger number of classes more closely approximates our setting).

6. CONCLUSION

To summarize, the main contributions of this work have been theoretical motivation for an extreme case of Mixup training (Midpoint Mixup), as well as an optimization analysis separating the learning dynamics of a 2-layer convolutional network trained using Midpoint Mixup and empirical risk minimization. Our results show that, for a class of data distributions satisfying the property that there are multiple, dependent features correlated with each class in the data, Midpoint Mixup can outperform ERM (both theoretically and empirically) in learning these features. We hope that the ideas introduced in the theory can be a useful building block for future theoretical investigations into Mixup and related methods in the context of training neural networks. As this work is almost entirely theoretical in nature, we do not anticipate any (direct) negative broader impacts or potential for misuse.

8. REPRODUCIBILITY STATEMENT

All of the results discussed in this paper have accompanying complete proofs in Sections B and C of the Appendix. While the proofs are quite technical in nature, we have tried our best to provide intuitive explanations at each step, and have also included derivations of various calculations and well-known concentration of measure results in Section A of the Appendix. We have also included in the supplementary material the code necessary to run the experiments in Section 5, along with detailed instructions explaining how to recreate each of our experiments.

A SUPPORTING LEMMAS AND CALCULATIONS

In this section we collect several technical lemmas and computations that will be necessary for the proofs of our main results. A.1 GAUSSIAN CONCENTRATION AND ANTI-CONCENTRATION RESULTS The following are well-known concentration results for Gaussian random variables; we include proofs for the convenience of the reader. Proposition A.1. Let X i ∼ N (0, σ 2 i ) with i ∈ [m] and let σ = max i σ i . Then, E[max i X i ] ≤ σ 2 log m Proof. Let Z = max i X i . Then by Jensen's inequality and the MGF of N (0, σ 2 i ), we have: exp tE[Z] ≤ E exp(tZ) = E[exp t max i X i ] ≤ E i exp(tX i ) = i exp t 2 σ 2 i /2 ≤ m exp t 2 σ 2 /2 =⇒ E[Z] ≤ log m t + tσ 2 2 Minimizing the RHS yields t = √ 2 log m/σ, from which the result follows. Proposition A.2. Let X i be as in Proposition A.1. Then, P (max i X i ≥ t + σ 2 log m) ≤ exp -t 2 /(2σ 2 ) Proof. We simply union bound and use the fact that P (X i ≥ t) ≤ exp -t 2 /(2σ 2 ) (Chernoff bound for zero mean Gaussians) to get: P (max i X i ≥ t + σ 2 log m) ≤ i P (X i ≥ t + σ 2 log m) ≤ m exp -(t + σ 2 log m) 2 /(2σ 2 ) ≤ exp -t 2 /(2σ 2 ) Proposition A.3. Let X 1 , X 2 , . .., X m be i.i.d. Gaussian variables with mean 0 and variance σ 2 . Then we have that: P max i X i > Θ σ log(m/ log(1/δ)) = 1 -Θ(δ) Proof. We recall that: P (X i > x) = Θ σ x e -x 2 /(2σ 2 ) A proof of this fact can be found in Vershynin (2018) . We additionally have that: P (max X i > x) = 1 -(1 -P (X i > x)) m So from the previous asymptotic characterization of P (X i > x) we have that choosing x = Θ σ log(m/ log(1/δ)) gives P (X i > x) = Θ(log(1/δ)/m), from which the result follows. We will also have need for a recent anti-concentration result due to Chernozhukov et al. (2014) , which we restate below. Proposition A.4 (Theorem 3 (i) Chernozhukov et al. (2014)). Let X i ∼ N (0, σ 2 ) for i ∈ [m] with σ 2 > 0. Defining a m = E[max i X i /σ], we then have for every ϵ > 0: sup x∈R P max i X i -x ≤ ϵ ≤ 4ϵ(1 + a m )/σ Corollary A.5. Applying Proposition A.1, we have sup x∈R P |max i X i -x| ≤ ϵ ≤ 4ϵ(1 + √ 2 log m)/σ.

A.2 GRADIENT CALCULATIONS

Here we collect the gradient calculations used in the proofs of the main results. We recall that we use ∇ w (t) y,r to indicate ∂ ∂w (t) y,r and z i,j = (x i + x j )/2. Additionally, we will omit parentheses after ReLU when function application is clear. Calculation A.6. For any (x i , y i ) ∈ X : ∇ w (t) y i ,r g yi t (x i ) = p∈[P ] ReLU ′ w yi,r , x (p) i x (p) i Proof. ∇ w (t) y i ,r g y t (x i ) = ∂ ∂w (t) yi,r u∈[m] p∈[P ] ReLU w (t) yi,u , x (p) i = p∈[P ] ReLU ′ w (t) yi,r , x (p) i x (p) i Calculation A.7. For any (x i , y i ) ∈ X , if max u∈[k] w (t) yi,r , v u,ℓ < ρ/(δ 2 -δ 1 ) and s ̸ = y i , then: ∇ w (t) y i ,r g yi t (x i ), v yi,ℓ = p∈P y i ,ℓ (xi) β α i,p w (t) yi,r , v yi,ℓ α-1 ρ α-1 ∇ w (t) y i ,r g yi t (x i ), v s,ℓ ≤ Θ    P δ α 4 max u̸ =y w (t) yi,r , v u,ℓ α-1 ρ α-1    Proof. When max u∈[k] w (t) yi,r , v u,ℓ < ρ/(δ 2 -δ 1 ), we are in the polynomial part of ReLU for every patch in x i , since max p∈P y i ,ℓ (xi) w (t) yi,r , x (p) i < ρ since β i,p ≤ δ 2 -δ 1 . The first line then follows from Calculation A.6 and the fact that all of the feature vectors are orthonormal (so only those patches that have the features v yi,ℓ are relevant). The second line follows from the fact that there are at most P -2C P feature noise patches containing the vector v s,ℓ , and in each of these patches there are only a constant number of feature vectors (which we do not constrain). Calculation A.8. For any (x i , y i ) ∈ X , if w (t) yi,r , v yi,ℓ ≥ ρ/δ 1 , then: ∇ w (t) y i ,r g yi t (x i ), v yi,ℓ = p∈P y i ,ℓ (xi) β i,p Proof. When w (t) yi,r , v yi,ℓ ≥ ρ/δ 1 we necessarily have min p∈P y i ,ℓ (xi) w (t) yi,r , x (p) i ≥ ρ since β i,p ≥ δ 1 , and then the result again follows from Calculation A.6 and the fact that ReLU ′ = 1 in the linear regime. Calculation A.9 (ERM Gradient). ∇ w (t) y,r J(g t , X ) = - 1 N i∈[N ] 1 yi=y -ϕ y g(x i ) ∇ w (t) y,r g y t (x i ) Proof. First let us observe that: log ϕ yi (g t (x i )) = g yi t (x i ) -log s exp(g s t (x i )) =⇒ ∂ log ϕ yi (g t (x i )) ∂w y,r = 1 yi=y ∇ w (t) y,r g y t (x i ) -ϕ y (g(x i ))∇ w (t) y,r g y t (x i ) Summing (and negating) the above over all points x i gives the result. Calculation A.10 (Midpoint Mixup Gradient). ∇ w (t) y,r J M M (g t , X ) = - 1 2N 2 i∈[N ] j∈[N ] 1 yi=y + 1 yj =y -2ϕ y g t (z i,j ) ∇ w (t) y,r g y t (z i,j ) Proof. Follows from applying Calculation A.9 to each part of the summation in J M M (g, X ).

B PROOFS OF MAIN RESULTS

This section contains the proofs of the main results in this paper. We have opted to present the proofs in a linear fashion -inlining several claims and their proofs along the way -as we find this to be more readable than the alternative. The proofs of inlined claims are ended with the ■ symbol, while the proofs of the overarching results are ended with the □ symbol. For convenience, we recall the assumptions (as they were stated in the main body) that are used in these results: Assumption 4.3. [Choice of Hyperparameters] We assume that: d = Ω(k 20 ) P = Θ(k 2 ) C P = Θ(1) m = Θ(k) δ 1 , δ 2 = Θ(1) δ 3 , δ 4 = Θ(k -1.5 ) ρ = Θ(1/k) α = 8 B.1 PROOF OF THEOREM 4.6 Theorem 4.6. For k and N sufficiently large and the settings stated in Assumption 4.3, we have that the following hold with probability at least 1 -O(1/k) after running gradient descent with a step size η = O(1/poly(k)) for O(poly(k)/η) iterations on J(g, X ) (for sufficiently large polynomials in k): 1. (Training accuracy is perfect): For all (x i , y i ) ∈ X , we have argmax s g s t (x i ) = y i . 2. (Only one feature is learned): For (1 -o(1))k classes, there exists exactly one feature that is learned in the sense of Definition 4.5 by the model g t . Furthermore, the above remains true for all t = O(poly(k)) for any polynomial in k. Proof. We break the proof into two parts. In part one, we prove that (with high probability) each class output g y t becomes large (but not too large) on data points belonging to class y and stays small on other data points, which consequently allows us to obtain perfect training accuracy at the end (thereby proving the first half of the theorem). In part two, we show that (again with high probability) the max correlations with features v y,1 and v y,2 for a class y have a separation at initialization that gets amplified over the course of training, and due to this separation one of the feature correlations becomes essentially irrelevant, which will be used to prove the second half of the theorem.

Part I.

In this part, we show that the network output g yi t (x i ) reaches and remains Θ(log k) while g s t (x i ) = o(1) for all t = O(poly(k)) and s ̸ = y i . These two facts together allow us to control the 1ϕ yi (g t (x i )) terms that show up throughout our analysis (see Calculation A.9), while also being sufficient for showing that we get perfect training accuracy. The intuition behind these results is that, when g yi t (x i ) > c log k, we have that exp(g yi t (x i )) > k c so the 1 -ϕ yi (g t (x i )) terms in the gradient updates quickly become small and g yi t stops growing. Throughout this part of the proof and the next, we will use the following notation (some of which has been introduced previously) to simplify the presentation. N y = {i : i ∈ [N ] and y i = y} P y,ℓ (x i ) = {p : p ∈ [P ] and x (p) i , v y,ℓ > 0} B (t) y,ℓ = {r : r ∈ [m] and w (t) y,r , v y,ℓ ≥ ρ/δ 1 } (B.1) Here, N y represents the indices corresponding to class y points, P y,ℓ (x i ) (as used in Definition 4.1) represents the patch support of the feature v y,ℓ in x i (recall the features are orthonormal), and B (t) y,ℓ represents the set of class y weights that have achieved a big enough correlation with the feature v y,ℓ to necessarily be in the linear regime of ReLU on all class y points at iteration t. Prior to beginning our analysis of the network outputs g y t , we first prove a claim that will serve as the setting for the rest of the proof. Claim B.1. With probability 1 -O(1/k), all of the following are (simultaneously) true for every class y ∈ [k]: ■ • |N y | = Θ(N/k) • max s∈[k],r∈[m],ℓ∈[2] w (0) s,r , v y,ℓ = O(log k/ √ d) • ∀ℓ ∈ [2], max r∈[m] w In everything that follows, we will always assume the conditions of Claim B.1 unless otherwise stated. We begin by proving a result concerning the size of softmax outputs ϕ y (g t (x)) that we will repeatedly use throughout the rest of the proof. Claim B.2. Consider i ∈ N y and suppose that both max s∈[k], r∈[m], ℓ∈[2] w (t) s,r , v s,ℓ = O(log k) and max s̸ =y, r∈[m], ℓ∈[2] w (t) s,r , v y,ℓ = O(log(k)/ √ d) hold true. If we have g y t (x i ) ≥ a log k for some a ∈ [0, ∞), then: 1 -ϕ y (g t (x i )) = O 1/k a-1 if a > 1 Θ (1) otherwise Proof of Claim B.2. By assumption, all of the weight-feature correlations are O(log k) at t. Furthermore, for s ̸ = y, all of the off-diagonal correlations w (t) s,r , v y,ℓ are O(log(k)/ √ d). This implies that (using δ 4 = Θ(k -1.5 ), ρ = Θ(1/k), P = Θ(k 2 ), and α = 8): g s t (x i ) ≤ O   mP δ α 4 max u̸ =y w (t) s,r , v u,ℓ α ρ α-1   ≤ O k 2+α log(k) α k 1.5α = O log(k) α k 3 =⇒ exp(g s t (x i )) ≤ 1 + O log(k) α k 3 (B.2) Where above we disregarded the constant number (2C P ) of very low order correlations w (t) s,r , v y,ℓ and used the inequality that exp(x) ≤ 1 + x + x 2 for x ≤ 1. Now by the assumption that g y t (x i ) ≥ a log k, we have exp(g y t (x i )) ≥ k a , so: 1 -ϕ y (g t (x i )) ≤ 1 - k a k a + (k -1) + o(1) = k -1 + o(1) k a + (k -1) + o(1) (B.3) From which the result follows. ■ Corollary B.3. Under the same conditions as Claim B.2, for s ̸ = y, we have: y,r , v s,ℓ > 0, we have for every t at which w (t) y,r , v s,ℓ < ρ/(δ 2 -δ 1 ) that (using Calculations A.7 and A.9): ϕ s (g t (x i )) = O 1 k max(a,1) -η∇ w (t) y,r J(g, X ), v s,ℓ ≤ - η N i∈Ns ϕ y g t (x i ) p∈P s,ℓ (xi) β α i,p w (t) y,r , v s,ℓ α-1 ρ α-1 + η N i∈Ny 1 -ϕ y g t (x i ) Θ    P δ α 4 max u̸ =y, ℓ ′ ∈[2] w (t) y,r , v u,ℓ ′ α-1 ρ α-1    - η N i / ∈Ny∪Ns ϕ y g t (x i ) Θ    P δ α 3 min u̸ =y, ℓ ′ ∈[2] w (t) y,r , v u,ℓ ′ α-1 ρ α-1    ≤ η N i∈Ny 1 -ϕ y g t (x i ) Θ    P δ α 4 max u̸ =y, ℓ ′ ∈[2] w (t) y,r , v u,ℓ ′ α-1 ρ α-1    (B.4) Under review as a conference paper at ICLR 2023 Similarly, for w (0) y,r , v y,ℓ > 0, we have for every t at which w (t) y,r , v y,ℓ < ρ/(δ 2 -δ 1 ) that: -η∇ w (t) y,r J(g, X ), v y,ℓ ≥ η N i∈Ny 1 -ϕ y g t (x i ) p∈P s,ℓ (xi) β α i,p w (t) y,r , v y,ℓ α-1 ρ α-1 - η N i / ∈Ny ϕ y g t (x i ) Θ    P δ α 4 max u̸ =y, ℓ ′ ∈[2] w (t) y,r , v u,ℓ ′ α-1 ρ α-1    (B.5) From Equation B.5, Claim B.2, and Corollary B.3 we get that for t ≤ T A : -η∇ w (t) y,r J(g, X ), v y,ℓ ≥ Θ    η w (t) y,r , v y,ℓ α-1 kρ α-1    (B.6) Where above we also used the fact that |N y | = Θ(N/k). On the other hand, also using Claim B.2 and Corollary B.3, we have that for all t for which w (t) y,r , v s,ℓ < ρ/(δ 2 -δ 1 ): -η∇ w (t) y,r J(g, X ), v s,ℓ ≤ Θ    ηP δ α 4 max u̸ =y, ℓ ′ ∈[2] w (t) y,r , v u,ℓ ′ α-1 ρ α-1    (B.7) Now suppose that w y,r , v s,ℓ is the maximum off-diagonal correlation at initialization. Then using Equation B.7, we can lower bound the number of iterations T it takes for w (t) y,r , v s,ℓ to grow by a fixed constant C factor from initialization: T Θ    ηP δ α 4 C α-1 w (0) y,r , v s,ℓ α-1 ρ α-1    ≥ (C -1) w (0) y,r , v s,ℓ =⇒ T ≥ Θ    ρ α-1 ηP δ α 4 w (0) y,r , v s,ℓ α-2    = Θ k 1.5α-2 ρ α-1 d α/2-1 η (B.8) As there exists at least one w Having established strong control over the off-diagonal correlations, we are now ready to prove the first half of the main result of this part of the proof -that g y t (x i ) reaches Ω(log k) for all i ∈ N y in O(poly(k)) iterations. In proving this, it will help us to have some control over the network outputs g y t across different points x i and x j at the later stages of training, which we take care of below. Claim B.5. For every y ∈ [k] and all t such that max i∈Ny g y t (x i ) ≥ log k and max r∈[m],s̸ =y,ℓ∈[2] w (t) y,r , v s,ℓ = O log(k)/ √ d , we have max i∈Ny g y t (x i ) = Ω(min i∈Ny g y t (x i )). Proof of Claim B.5. Let j = argmax i∈Ny g y t (x i ). Since g y t (x j ) ≥ log k, we necessarily have that B y,ℓ is non-empty for at least one ℓ ∈ [2] (since mρ = Θ(1)). Only those weights w (t) y,r with r ∈ B y,ℓ for some ℓ ∈ [2] are asymptotically relevant (as any weights not considered can only contribute a O(1) term), and we can write: g y t (x j ) ≤ ℓ∈[2]   p∈P y,ℓ (xi) β i,p   r∈B (t) y,ℓ w (t) y,r , v y,ℓ + o(log k) For any other j ∈ N y , we have that β j,p ≥ δ 1 β i,p /(δ 2 -δ 1 ), from which the result follows.  y,r * , v y,ℓ * ≥ ρ/δ 1 (implying r ∈ B (t) y,ℓ * ) while the off-diagonal correlations still remain within a constant factor of initialization. Now we may lower bound the update to w (t) y,r * , v y,ℓ * as (using Calculation A.8): -η∇ w (t) y,r * J(g, X ), v y,ℓ * ≥ η N i∈Ny 1 -ϕ y g t (x i ) p∈P s,ℓ (xi) β i,p - η N i / ∈Ny ϕ y g t (x i ) Θ    P δ α 4 max u̸ =y, ℓ ′ ∈[2] w (t) y,r * , v u,ℓ ′ α-1 ρ α-1    (B.9) So long as max i∈Ny g y t (x i ) < log k (which is necessarily still the case at this point, as again mρ = Θ(1)), we have by the logic of Claim B.2 that we can simplify Equation B.9 to: -η∇ w (t) y,r * J(g, X ), v y,ℓ * ≥ Θ(η/k) (B.10) Where again we used the fact that |N y | = Θ(N/k). Now we can upper bound T y by the number of iterations it takes w Proof of Claim B.7. Let us again consider any class y ∈ [k] and t ≥ T y . The idea is to show that max i∈Ny 1 -ϕ y (g t (x i )) is decreasing rapidly as min i∈Ny g y t (x i ) grows to successive levels of a log k for a > 1. Firstly, following Equation B.9, we can form the following upper bound for the gradient updates to r ∈ B (t) y,ℓ : -η∇ w (t) y,r * J(g, X ), v y,ℓ * ≤ η N i∈Ny 1 -ϕ y g t (x i ) p∈P s,ℓ (xi) β i,p ≤ 1 -min i∈Ny ϕ y g t (x i ) Θ(η/k) (B.11) From Equation B.11 it follows that it takes at least Θ(k log(k)/(mη) iterations (since the correlations must grow at least log(k)/m) from T y for g y t (x i ) to reach 2 log k. Now let T a denote the number of iterations it takes for min i∈Ny g y t (x i ) to cross a log k after crossing (a -1) log k for the first time. For a ≥ 3, we necessarily have that T a = Ω(kT a-1 ) by Claim B.2 and Equation B.11. Let us now further define T f to be the first iteration at which max i∈Ny g y T f (x i ) ≥ f (k) log k for some f (k) = ω(1). By Claim B.5, at this point min i∈Ny g y T f (x i ) = Ω(f (k) log k). However, we have from the above discussion that: T f ≥ Ω(poly(k)) + f (k)-3 a=0 Ω k a log k η ≥ Ω log k k f (k)-2 -1 η(k -1) ≥ ω(poly(k)) (B.12) So max i∈Ny g y t (x i ) = O(log k) for all t = O(poly(k) ). An identical analysis also works for the off-diagonal correlations w (t) y,r , v s,ℓ but forming an upper bound using Equation B.4, so we are done. ■ We get the following two corollaries as straightforward consequences of Claim B.7.

Corollary B.8 (Perfect Training Accuracy

). We have that there exists a universal constant C such that argmax s g s t (x i ) = y i for every (x i , y i ) ∈ X for all t ≥ k C but with t = O(poly(k)). Corollary B.9 (Softmax Control). We have that for all y ∈ [k] and any t = O(poly(k)) for any polynomial in k that max i∈Ny s̸ =y exp(g s t (x i )) = k + o(1). Corollary B.8 finishes this part of the proof.

Part II.

For the next part of the proof, we characterize the separation between max r∈[m] w (0) y,r , v y,1 and max r∈[m] w y,r , v y,2 , and show that this separation (when it is significant enough) gets amplified over the course of training. To show this, we will rely largely on the techniques found in Allen-Zhu and Li (2021) , and finish in a near-identical manner to the proof of Claim B.7. As with Part I, we first introduce some notation that we will use throughout this part of the proof. S y,ℓ = 1 N i∈Ny p∈P y,ℓ (xi) β α i,p Λ (t) y,ℓ = max r∈[m] w (t) y,r , v y,ℓ Here, S y,ℓ represents the data-dependent quantities that show up in the gradient updates to the correlations during the phase of training in which the correlations are in the polynomial part of ReLU, while Λ Claim B.10 (Feature Separation at Initialization). For each class y, we have that either: Λ (0) y,1 ≥ S y,2 S y,1 1 α-2 1 + Θ 1 log 2 k Λ (0) y,2 or Λ (0) y,2 ≥ S y,1 S y,2 1 α-2 1 + Θ 1 log 2 k Λ (0) y,1 with probability 1 -O 1 log k . Proof of Claim B.10. Suppose WLOG that S y,1 ≥ S y,2 . If neither of the inequalities in the claim hold, then we have that: Λ (0) y,1 ∈ S y,2 S y,1 1 α-2 1 ± Θ 1 log 2 k Λ (0) y,2 Which follows from the fact that, for a constant A, we have: 1 1 + A log 2 k ≥ 1 - A log 2 k Now we recall that Λ (0) y,1 and Λ (0) y,2 are both maximums over i.i.d. N (0, 1/d) variables (again, since the feature vectors are orthonormal), so we can apply Corollary A.5 (Gaussian anti-concentration) to Λ (0) y,1 while taking ϵ = (S y,2 /S y,1 ) 1 α-2 Θ 1/ log 2 k Λ (0) y,2 and x = (S y,2 /S y,1 ) 1 α-2 Λ (0) y,2 . It is crucial to note that we can only do this because Λ (0) y,2 is independent of Λ (0) y,1 , and both take values over all of R. From this we get that: P Λ (0) y,1 ∈ (S y,2 /S y,1 ) 1 α-2 Λ (0) y,2 ± ϵ ≤ 4ϵ(1 + √ 2 log m) σ = O σ √ log m log 2 k Θ √ log m σ = O 1 log k with probability 1 - 1 m Where we used the fact that m = Θ(k) and Proposition A.2 to characterize Λ (0) y,2 (also noting that S y,2 /S y,1 is Θ(1)). Thus, neither of the inequalities hold with probability O(1/ log k), so we have the desired result. ■ We can use the separation from Claim B.10 to show that, in the initial stages of training, the max correlated weight/feature pair grows out of the polynomial region of ReLU and becomes large much faster than the correlations with the other feature for the same class. For y ∈ [k], let ℓ * be such that Λ (0) y,ℓ * is the left-hand side of the satisfied inequality from Claim B.10. Additionally, let r * = argmax r w (0) y,r , v y,ℓ * , i.e. the strongest weight/feature correlation pair at initialization. We will show that when w (t) y,r * , v y,ℓ * becomes Ω(ρ), the other correlations remain small. In order to do so, we need a useful lemma from Allen-Zhu and Li (2021) that we restate below. Lemma B.11 (Lemma C.19 from Allen-Zhu and Li (2021) ). Let q ≥ 3 be a constant and x 0 , y 0 = o(1). Let {x t , y t } t≥0 be two positive sequences updated as • x t+1 ≥ x t + ηC t x q-foot_1 t for some C t = Θ(1), and • y t+1 ≤ y t + ηSC t y q-1 t for some constant S = Θ(1). Where η = O(1/poly(k)) for a sufficiently large polynomial in k. Suppose x 0 ≥ y 0 S To apply Lemma B.11 in our setting, we first prove the following claim. Claim B.12. For a class y ∈ [k], we define the following two sequences: a y,t = S y,ℓ * ρ α-1 1 α-2 w (t) y,r * , v y,ℓ * and b y,t = S y,3-ℓ * ρ α-1 1 α-2 w (t) y,r , v y,3-ℓ * Where the r in the definition of b y,t is arbitrary. Then with probability 1 -O 1 log k there exist C t , S = Θ(1) such that for all t for which w (t) y,r * , v y,ℓ * < ρ/(δ 2 -δ 1 ): a y,t+1 ≥ a y,t + ηC t a α-1 y,t b y,t+1 ≤ b y,t + ηSC t b α-1 y,t Additionally (with the same probability), we have that a y,0 ≥ S 1 α-2 1 + Θ 1 polylog(k) b y,0 . Proof of Claim B.12. The update to w (t) y,r * , v y,ℓ * in this regime can be bounded as follows (using Corollary B.9 and recalling Equation B.5): -η∇ w (t) y,r * J(g, X ), v y,ℓ * ≥ η N i∈Ny 1 -ϕ y g t (x i ) p∈P s,ℓ (xi) β α i,p w (t) y,r , v y,ℓ α-1 ρ α-1 - η N i / ∈Ny ϕ y g t (x i ) Θ    P δ α 4 max u̸ =y, ℓ ′ ∈[2] w (t) y,r , v u,ℓ ′ α-1 ρ α-1    ≥ η 1 -Θ 1 k S y,ℓ * w (t) y,r * , v y,ℓ * α-1 ρ α-1 (B.13) Similarly, we have (noting also that w (t) y,r , v y,3-ℓ * < ρ/(δ 2 -δ 1 )): -η∇ w (t) y,r J(g, X ), v y,3-ℓ * ≤ ηS y,3-ℓ * w (t) y,r , v y,3-ℓ * α-1 ρ α-1 (B.14) Multiplying the above inequalities by S y,ℓ * /ρ α-1 1 α-2 , we see that a y,t and b y,t satisfy the inequalities in the claim with C t = 1 -Θ 1 k and S = (S y,3-ℓ * /S y,ℓ * ) 1 + Θ 1 k . Now by Claim B.10 we have: a y,0 ≥ S y,3-ℓ * S y,ℓ * 1 α-2 1 + Θ 1 log 2 k b y,0 ≥ S 1 α-2 1 + Θ 1 polylog(k) b y,0 So we are done. ■ Now by the fact that |N y | = Θ(N/k), we have S y,1 , S y,2 = Θ (1/k) = O(ρ), which implies that S y,ℓ * /ρ α-1 1/(α-2) = O (1/ρ). From this we get that while a y,t < C/(δ 2 -δ 1 ) for some appropriately chosen constant C, we have w (t) y,r * , v y,ℓ * < ρ/(δ 2 -δ 1 ). Since Claim B.12 holds in this regime, we can apply Lemma B.11 with A = C/(δ 2 -δ 1 ), which gives us that when a y,t ≥ C/(δ 2 -δ 1 ), we have b y,t = O(b y,0 polylog(k)). From this we obtain that when w (t) y,r * , v y,ℓ * ≥ ρ/(δ 2 -δ 1 ) we have that w Proof of Claim B.13. It follows from the same logic as in the proof of Claim B.6 that at the first iteration t for which we have min i∈Ny g y t (x i ) ≥ log k, we still have w (t) y,r , v y,3-ℓ * is within some polylog(k) factor of initialization (here the correlation w (t) y,r , v y,3-ℓ * can be viewed as the same as an off-diagonal correlation from the proof of Claim B.6). The rest of the proof then follows from identical logic to that of Claim B.7; namely, we can show that for w  g y t (x ′ i ) = O mpolylog(k) √ d = O kpolylog(k) √ d g s t (x ′ i ) = Ω P log(k) α ρ α-1 k 2.5α = ω kpolylog(k) √ d for s ̸ = y, if ∃ P s,ℓ (x i ) ̸ = ∅ (B.15) Where x ′ i is any point x i with i ∈ N y modified so that all instances of feature v y,ℓ * are replaced by 0, and the second line above follows from the fact that by Claim B.7 we must have w (t) s,r , v s,ℓ = Ω(log(k)/k) for at least some r, ℓ for every s (and d = Θ(k 20 )). This proves that feature v y,3-ℓ * is not learned in the sense of Definition 4.5. Using Claim B.13 for each class, we have by a Chernoff bound that with probability at least 1-o(1/k) that for (1 -o(1))k classes only a single feature is learned, which proves the theorem.

B.2 PROOF OF THEOREM 4.7

Theorem 4.7. For k and N sufficiently large and the settings stated in Assumptions 4.3, we have that the following hold with probability 1 -O(1/k) after running gradient descent with a step size η = O(1/poly(k)) for O(poly(k)/η) iterations on J M M (g, X ) (for sufficiently large polynomials in k): 1. (Training accuracy is perfect): For all (x i , y i ) ∈ X , we have argmax s g s (x i ) = y i .

2.. (Both features are learned):

For each class y ∈ [k], both v y,1 and v y,2 are learned in the sense of Definition 4.5 by the model g. Furthermore, the above remains true for all t = O(poly(k)) for any polynomial in k. Proof. As in the proof of Theorem 4.6, we break the proof into two parts. The first part mirrors most of the structure of Part I of the proof of Theorem 4.6, in that we analyze the off-diagonal correlations and also show that the network outputs g y t can grow to (and remain) Ω(log k) as training progresses. However, we do not show that the outputs stay O(log k) in Part I (as we did in the ERM case), as there are additional subtleties in the Midpoint Mixup analysis that require different techniques which we find are more easily introduced separately. The second part of the proof differs significantly from Part II of the proof of Theorem 4.6, as our goal is to now show that any large separation between weight-feature correlations for each class are corrected over the course of training. At a high level, we show this by proving a gradient correlation lower bound that depends only on the magnitude of the separation between correlations and the variance of the feature coefficients in the data distribution, after which we can conclude that any feature lagging behind will catch up in polynomially many training iterations. We then use the techniques from the gradient lower bound analysis to prove that the network outputs g y t stay O(log k) throughout training, which wraps up the proof. Part I. We first recall that z i,j = (x i + x j )/2, and we refer to such z i,j as "mixed points". In this part of the proof, we show that g yi t (z i,j ) crosses log k on at least one mixed point z i,j in polynomially many iterations (after which the network outputs remain Ω(log k)). As before, this requires getting a handle on the off-diagonal correlations w (t) y,r , v s,ℓ (with s ̸ = y). Throughout the proof, we will continue to rely on the notation introduced in Equation B.1 in the proof of Theorem 4.6. However, we make one slight modification to the definition of B (t) y,ℓ for the Mixup case (so as to be able to handle mixed points), which is as follows: B (t) y,ℓ = {r : r ∈ [m] and w (t) y,r , v y,ℓ ≥ 2ρ/δ 1 } (B.16) We again start by proving a claim that will constitute our setting for the rest of this proof. Claim B.14. With probability 1 -O(1/k), all of the following are (simultaneously) true for every class y ∈ [k]: • |N y | = Θ(N/k) • max s∈[k],r∈[m],ℓ∈[2] w (0) s,r , v y,ℓ = O(log k/ √ d) • ∀ℓ ∈ [2], max r∈[m] w (0) y,r , v y,ℓ = Ω(1/ √ d) • For Ω(k) tuples (s, ℓ) ∈ [k] × [2] we have w (0) y,r , v s,ℓ > 0. Proof of Claim B.14. The first three items in the claim are exactly the same as in Claim B.1, and the last item is true because the correlations w y,r , v s,ℓ are mean zero Gaussians. ■ Once again, in everything that follows, we will always assume the conditions of Claim B.14 unless otherwise stated. We now translate Claim B.2 to the Midpoint Mixup setting. Claim B.15. Consider i ∈ N y , j ∈ N s for s ̸ = y and suppose that max u / ∈{y,s} g u t (z i,j ) = O(log(k)/k) holds true. If we have g y t (z i,j ) = a log k and g s t (z i,j ) = b log k for a, b = O(1), then: 1 -2ϕ y (g t (z i,j )) =    -Ω(1) if a > 1, a -b = Ω(1) ±O(1) if a > 1, a -b = ±o(1) Θ (1) Where in the second item above the sign of 1 -2ϕ y (g t (z i,j )) depends on the sign of a -b. Proof of Claim B.15. In comparison to Claim B.2, the Midpoint Mixup case is slightly more involved in that g s t (z i,j ) can be quite large due to the x j part of z i,j . As a result, we directly assume some control over the different class outputs on the mixed points (which we will prove to hold throughout training later). By assumption, we have for u ̸ = y, s: g u t (z i,j ) = O(log(k)/k) =⇒ exp(g u t (z i,j )) ≤ 1 + O(log(k)/k) (B.17) Where above we used the inequality exp(x) ≤ 1 + x + x 2 for x ∈ [0, 1]. Now by the assumptions that g y t (z i,j ) = a log k and g s t (z i,j ) = b log k, we have: 1 -2ϕ y (g t (x i )) ≤ 1 - 2k a k a + k b + (k -2) + o(k) = k b -k a + (k -2) + o(k) k a + k b + (k -2) + o(k) (B.18) From which the result follows. ■ Corollary B.16. Under the same conditions as Claim B.15, for u ̸ = y, s we have: 1 -2ϕ u (g t (z i,j )) = Θ(1) Proof of Corollary B.16. Follows from Equations B.17 and B.18. ■ We observe that Claim B.15 and Corollary B.16 are less precise than Claim B.2, largely because there is now a dependence on the gap between the class y and class s network outputs as opposed to just the class y network output. We are now again ready to compare the growth of the diagonal correlations w (t) y,r , v y,ℓ with the off-diagonal correlations w (t) y,r , v s,ℓ . However, this is not as straightforward as it was in the ERM setting. The issue is that the off-diagonal correlations can actually grow significantly, due to the fact that the features v y,ℓ can show up when mixing points in class y with class s. Claim B.17. Fix an arbitrary class y ∈ [k]. Let A ∈ [Ω(ρ), ρ/(δ 2 -δ 1 )] and let T A be the first iteration at which max r∈[m],ℓ∈[2] w (T A ) y,r , v y,ℓ ≥ A; we must have both that T A = O(poly(k)) and that, for every s ̸ = y and ℓ ∈ [2]: w (T A ) y,r , v s,ℓ = O max ℓ ′ ∈[2] w (T A ) y,r , v y,ℓ ′ /k Additionally, for all s, ℓ with w y,r , v s,ℓ > 0, we have that w (T A ) y,r , v s,ℓ = Ω ⟨w (0) y,r ,v s,ℓ ⟩ polylog(k) . Proof of Claim B.17. By our setting, we must have that there exists a diagonal correlation w (0) y,r * , v y,ℓ * = Ω(1/ √ d), which we will focus our attention on. Using Calculation A.10 and the ideas from Calculation A.7, we can lower bound the update to w (t) y,r * , v y,ℓ * from initialization up to T A as: -η∇ w (t) y,r * J M M (g, X ), v y,ℓ * ≥ η N 2 i∈Ny j / ∈Ny 1 -2ϕ y (g t (z i,j )) Θ    w (t) y,r * , v y,ℓ * α-1 ρ α-1    + η N 2 i∈Ny j∈Ny 1 -ϕ y (g t (z i,j )) Θ    w (t) y,r * , v y,ℓ * α-1 ρ α-1    - η N 2 i / ∈Ny j / ∈Ny ϕ y g t (z i,j ) Θ    P σ α 4 max u∈[k],q∈[2] w (t) y,r * , v u,q α-1 ρ α-1    (B.19) Above we made use of the fact that, for i ∈ N y and j / ∈ N y , we have w (t) y,r * , z y,r * , v y,ℓ * /2 for at least Θ(|N y |N ) mixed points since the correlation w (t) y,r * , v u,q is positive for Ω(k) tuples (u, q) ∈ [k] × [2] (under Setting B.14). We can similarly upper bound the update to w (t) y,r , v s,ℓ for an arbitrary r ∈ [m] as: -η∇ w (t) y,r J M M (g, X ), v s,ℓ ≤ - η N 2 i / ∈Ny j∈Ns ϕ y g t (z i,j ) Θ    w (t) y,r , v s,ℓ α-1 ρ α-1    - η N 2 i / ∈Ny∪Ns j / ∈Ny∪Ns ϕ y g t (z i,j ) Θ    P δ α 3 min u∈[k],q∈[2] w (t) y,r , v u,q α-1 ρ α-1    + η N 2 i∈Ny j∈Ns 1 -2ϕ y g t (z i,j ) Θ    max q∈[2] w (t) y,r , v s,ℓ + v y,q α-1 ρ α-1    + η N 2 i∈Ny j / ∈Ny∪Ns 1 -2ϕ y g t (z i,j ) Θ    δ 4 max u∈[k],q∈[2] w (t) y,r , v u,q α-1 ρ α-1    + η N 2 i∈Ny j∈Ny 1 -ϕ y g t (z i,j ) Θ    P δ α 4 max u̸ =y,q∈[2] w (t) y,r , v u,q α-1 ρ α-1    (B.20) As the above may be rather difficult to parse on first glance, let us take a moment to unpack the individual terms on the RHS. The first two terms are a precise splitting of the -2ϕ y (g t ) term from Calculation A.10; namely, the case where we mix with the class s allows for constant size coefficients on the feature v s,ℓ while the other cases only allow for v s,ℓ to show up in the feature noise patches. The next three terms consider all cases of mixing with the class y. The first of these terms considers the case of mixing class y with class s, in which case it is possible to have patches in z i,j that have both v s,ℓ and v y,ℓ * with constant coefficients. The next term considers mixing class y with a class that is neither y nor s, in which case the feature v s,ℓ can only show up when mixing with a feature noise patch, so we suffer a factor of at least δ 4 = Θ(1/k 1.5 ) (note we do not suffer a δ α 4 factor as v y,ℓ * can still be in z i,j ) from the ⟨z i,j , v s,ℓ ⟩ part of the gradient. Finally, the last term considers mixing within class y. The first of the three positive terms in the RHS of Equation B.20 presents the main problem -the fact that the diagonal correlations can show up non-trivially in the off-diagonal correlation gradient means the gradients can be much larger than in the ERM case. However, the key is that there are only Θ(N/k 2 ) mixed points between classes y and class s. Thus, once more using the fact that Θ(|N u |) = Θ(N/k) for every u ∈ [k], the other conditions in our setting, and Claim B.15, we obtain that for all t ≤ T A : It remains to show that the off-diagonal correlations also do not decrease by too much, as if they were to become negative that would potentially cause problems in Equation B.19 due to ReLU ′ becoming 0. Using Equation B.20, we can form the following lower bound to w (t) y,r , v s,ℓ : -η∇ w (t) y,r * J M M (g, X ), v y,ℓ * ≥ Θ    η w (t) y,r * , v y,ℓ * α-1 kρ α-1    (B.21) -η∇ w (t) y,r J M M (g, X ), v s,ℓ ≤ Θ    η max u∈{s,y},q∈[2] w (t) y,r , v u,q α-1 k 2 ρ α-1    (B. -η∇ w (t) y,r J M M (g, X ), v s,ℓ ≥ -Θ    η w (t) y,r , v s,ℓ α-1 k 2 ρ α-1    (B.23) Now let T denote the number of iterations starting from initialization that it takes w  w (t) y,r , v s,ℓ = O max ℓ ′ ∈[2] w (t) y,r , v y,ℓ ′ /k Additionally, for all s, ℓ with w (0) y,r , v s,ℓ > 0, we have that w (t) y,r , v s,ℓ > 0. Proof  w (t) y,r , v y,ℓ = O(log k) =⇒ r∈[m] ℓ∈[2] w (t) y,r , v s,ℓ = O(log(k)/k) Thus, we may disregard the off-diagonal correlations in considering the class y output on z i,j (i.e. we do not need to worry about the x j part of z i,j ), and the rest is identical to Claim B.5. ■ Claim B.20. For each y ∈ [k], let T y denote the first iteration at which max i∈Ny, j / ∈Ny g y Ty (z i,j ) ≥ log(k -1). Then we have that T y = O(poly(k)) (for a sufficiently large polynomial in k) and that min i∈Ny, j / ∈[N ] g y t (z i,j ) = Ω(log k) for all t ≥ T y . Proof At this point we may lower bound the update to w (t) y,r * , v y,ℓ * as: -η∇ w (t) y,r * J M M (g, X ), v y,ℓ * ≥ η N 2 i∈Ny j / ∈Ny Θ 1 -2ϕ y (g t (z i,j )) + η N 2 i∈Ny j∈Ny Θ 1 -ϕ y (g t (z i,j )) - η N 2 i / ∈Ny j / ∈Ny ϕ y g t (z i,j ) Θ    P σ α 4 max u∈[k],q∈[2] w (t) y,r * , v u,q α-1 ρ α-1    (B.24) Using Claim B.15, we have that so long as max i∈Ny, j / ∈Ny g y t (z i,j ) < log(k -1), we get (using |N u | = Θ(N/k) for all u ∈ [k]): -η∇ w (t) y,r * J M M (g, X ), v y,ℓ * ≥ Θ(η/k) (B.25) This also implies, by the logic of Claim B.17, that the off-diagonal correlations w (t) y,r , v s,ℓ have updates that can be upper bounded as: -η∇ w (t) y,r J M M (g, X ), v s,ℓ ≤ Θ(η/k 2 ) (B.26) Comparing Equations B.25 and B.26, we have that g y t (z i,j ) ≥ log(k -1) (and, consequently, g y t (x i ) ≥ log(k -1)) for at least one mixed point z i,j with i ∈ N y in O(poly(k)) iterations while the off-diagonal correlations are O(log(k)/k). This also implies that min i∈Ny, j∈[N ] g y t (z i,j ) = Ω(log k) by Claim B.19. Finally, since Equation B.25 is positive, the class y network outputs remain Ω(log k) for t ≥ T y (as again we cannot decrease below log(k -1) by more than o(1) since the gradients are o(1)). ■ Part II. Having analyzed the growth of diagonal and off-diagonal correlations in the initial stages of training, we now shift gears to focusing on the gaps between the correlations for each class. The key idea is that J M M will push the correlations for the features v y,1 and v y,2 closer together throughout training (so as long as they are sufficiently separated), for every class y. In order to prove this, we will rely on analyzing an expectation form of the gradient for J M M . As the expressions involved in this analysis can become cumbersome quite quickly, we will first introduce a slew of new notation to make the presentation of the results a bit easier. Firstly, in everything that follows, we assume v y,1 to be the better correlated feature at time t for every class y ∈ [k] in the following sense:  r∈B (t) y,1 w (t) y,r , v y,1 ≥ r∈B (t) y,2 w (t) y, = C (t) y,1 -C (t) y,2 . Now for the aforementioned expectation analysis we introduce several relevant random variables. We use β y,p (for every y ∈ [k]) to denote a random variable following the distribution of the signal coefficients for class y from Definition 4.1 and we further use β y to denote a random variable representing the sum of C P i.i.d. β y,p . Similarly, we use z y,s to denote the average of two random variables following the distributions of class y and class s points respectively. Finally, we define A 1 (β s , β y ) and A 2 (β s , β y ) as: A 1 (β s , β y ) ≜ 1 -2ϕ y (g t (z y,s )) (B.28) A 2 (β s , β y ) ≜ 1 -ϕ y (g t (z y,s )) (B.29) In context, this notation will imply that A 1 (β y , β s ) = 1 -2ϕ s (g t (z y,s )) (i.e. swapping the order of arguments changes which coordinate of the softmax is being considered). Now we will first prove an upper bound on the difference of gradient correlations in the linear regime, and then use these ideas to prove that correlations in the poly part of ReLU will still get significant gradient. After we have done that, we will revisit this next claim to show that the separation between feature correlations continues to decrease even after they reach the linear regime. y,2 , we let: Ψ(r 1 , r 2 ) ≜ -∇ w (t) y,r 1 J M M (g, X ), v y,1 --∇ w (t) y,r 2 J M M (g, X ), v y,2 (B.30) After which we have that: Ψ(r 1 , r 2 ) ≤ Θ 1 k 2 s̸ =y E βs,βy [A 1 (β s , β y )(β y -C P δ 2 /2)] + Θ 1 k 2 E βy [A 2 (β y , β y )(β y -C P δ 2 /2)] + O P δ α 4 (log k) α-1 (B.31) Proof of Claim B.21. Using the logic from Equation B.24 as well as the fact that r 1 ∈ B y,1 and r 2 ∈ B (t) y,2 (i.e. we are considering weights in the linear regime of ReLU for each feature), we get: Ψ(r 1 , r 2 ) ≤ 1 N 2 i∈Ny j / ∈Ny 1 -2ϕ y (g t (z i,j ))   p∈Py,1(xi) β i,p -C P δ 2 /2   + 1 N 2 i∈Ny j / ∈Ny 1 -ϕ y (g t (z i,j ))   p∈Py,1(xi) β i,p -C P δ 2 /2   + O P δ α 4 (log k) α-1 (B.32) Now since we took N sufficiently large in Assumption 4.3, by concentration for bounded random variables we can replace the expressions on the RHS above with their expected values, as the deviation will be within O P δ α 4 (log k) α-1 (with probability 1 -O(1/k), consistent with our setting). However, the expectations will be over all of the random variables β u for u ∈ [k], not just the classes s and y being mixed (or in the case of the second term above, just the class y). Fortunately, we observe that for the mixed point random variable z y,s , the β u for u ̸ = y, s can only show up in the feature noise patches of z y,s . Thus, by an identical calculation to the one controlling the feature noise contribution to the gradient above (once again, refer to Equation B.24), we see that we may consider the expectation over just β y and β s while marginalizing out the other random variables and staying within the error term above, thereby obtaining Equation B.31. ■ We will now show that E βs,βy [A 1 (β s , β y )(β y -C P δ 2 /2)] is significantly negative so long as the separation between feature correlations ∆ (t) y is sufficiently large. Once again, to simplify notation even further, we will use βy = β y -C P δ 2 /2 and use P( βy ) to refer to its associated probability measure. Furthermore, we will use: D (t) y,s = (C (t) y,1 + C (t) y,2 ) -(C (t) s,1 + C (t) s,2 ) In other words, D y,s represents the difference in the linear outputs of classes y and s. With this in mind, we can prove the aforementioned result.  [k] such that there is a set U ⊂ [0, C P (δ 2 /2- δ 1 )] × [-C P (δ 2 /2 -δ 1 ), C P (δ 2 /2 -δ 1 ) ] with (P( βy ) × P( βs ))(U ) ≥ 0.01 (i.e. its measure is at least 0.01) and for all (a, b) ∈ U we have: (b -C P δ 2 /2)∆ (t) s -C P δ 2 D (t) y,s /2 ≤ (a -C P δ 2 /2)∆ (t) y (B.33) Then we have: E βs,βy [A 1 (β s , β y )(β y -C P δ 2 /2)] = -Θ (1) (B.34) Proof of Claim B.22. We begin by first showing that the expectation on the LHS of Equation B.34 is negative. Indeed, this is almost immediate from the fact that βy is a symmetric, mean zero random variable -we need only show that A 1 is monotonically decreasing in β y . From the definition of A 1 , we observe that it suffices to show that g y t (z y,s ) is monotonically increasing in β y . However, this is straightforward to see from the assumption that ∆ (t) y ≥ log k -o(1), as this implies that an ϵ increase in β y leads to a O(ϵ log k) -O(ϵ) increase in g y t , since the feature noise and weights that are in the polynomial part of ReLU can contribute at most O(1) by the logic of Claim B.19. Now we need only show that the expectation is sufficiently negative. To do this, we will rely on the following facts, which will allow us to write things purely in terms of C (t) y,ℓ and C (t) s,ℓ (i.e. disregarding the weights that are not in the linear regime):  g y t (z y,s ) ∈ β y C (t) y,1 + (C P δ 2 -β y )C (t) y,2 , β y C (t) y,1 + (C P δ 2 -β y )C (t) y,2 + O(1) (B.35) g s t (z y,s ) ∈ β s C (t) s,1 + (C P δ 2 -β s )C (t) s,2 , β s C (t) s,1 + (C P δ 2 -β s )C (t) s,2 + O(1) (B.36) g u t (z y,s ) = O log k k for u ̸ = y, s u,1 = O(log k)) respectively. Now we perform the substitution g u t ← g u t -C P δ 2 (C (t) y,1 + C (t) y,2 )/2 for all u ∈ [k], as this can be done without changing the value of ϕ y (g t (z y,s )). Under this transformation we have that (using Equation B.35): g y t (z y,s ) ∈ (β y -C P δ 2 /2)∆ (t) y , (β y -C P δ 2 /2)∆ (t) y + O(1) (B.38) g s t (z y,s ) ∈ (β s -C P δ 2 /2)∆ (t) s -C P δ 2 D (t) y,s /2, (β s -C P δ 2 /2)∆ (t) s -C P δ 2 D (t) y,s /2 + O(1) (B.39) Which isolates the correlation gap term ∆ y . Prior to proceeding further we will let Λ (t) s ≜ u̸ =y exp(g u t (z y,s )), so as to prevent the equations to follow from becoming too unwieldy. Now we have:  E βs,βy [A 1 (β s , β y )(β y -C P δ 2 /2)] = C P (δ2-δ1) C P δ1 C P (δ2-δ1) C P δ1 Λ (t) s -exp(g y t (z y,s )) Λ (t) s + exp(g y t (z y,s )) (β y -C P δ 2 /2) dP(β y )dP(β s ) ≤ C P (δ2/2-δ1) C P (δ1-δ2/2) 0 C P (δ1-δ2/2) Λ )) = O(1/k C P δ2-1 ) (after making the adjustment g u t ← g u t -C P δ 2 (C (t) y,1 + C (t) y,2 )/2 that we did above). Now using our assumption in the statement of the claim that Equation B.33 holds for some set U , we obtain: E βs,βy [A 1 (β s , β y )(β y -C P δ 2 /2)] ≤ - C P (δ2/2-δ1) C P (δ1-δ2/2) C P (δ2/2-δ1) ϱ Θ   βy exp βy ∆ (t) y O(k 1-C P δ2 ) + exp βy ∆ (t) y   dP( βy )dP( βs ) = -Θ(1) (B.42) Where the last line follows after restricting the bounds of integration of the two integrals to their intersections with U (this allows us to disregard g s t in the asymptotic expression above via Equation B.39). This proves the claim. ■ We proved Claim B.22 in terms of β y (the sum of the individual β y,p ) to keep notation manageable (avoids C P iterated integrals) and to more closely mirror the proof of Proposition 3.4. However, what we will really use for our remaining analysis is the following corollary, which gives the same result as Claim B.22 but for each of the individual terms β y,p . Below we use C P i=1 β y,i to make explicit the dependence between the sum and each individual random variable β y,p (so as to not mislead one to think of them as independent random variables). Corollary B.23. Under the same conditions as Claim B.22, we have for every p ∈ [C P ]: satisfying: E βs,βy,1,...,β y,C P A 1 (β s , C P i=1 β y,i )(β y,p -δ 2 /2) = -Θ (1) (B. -∇ w (t) y,r 1 J M M (g, X ), v y,1 ≥ 0 (B.44) We have: -∇ w For the second case when Equation B.33 does not hold, we have that for every class s ̸ = y there exists a set U ′ ⊂ [0, C P (δ 2 /2 -δ 1 )] × [-C P (δ 2 /2 -δ 1 ), C P (δ 2 /2 -δ 1 )] such that (P( βy ) × P( βs ))(U ′ ) ≥ 0.49 (note the total measure of the set which U ′ is a subset of is 0.5, by symmetry) and for all (a, b) ∈ U ′ we have: Proof of Claim B.27. The idea is to consider the sum of gradient correlations across classes, and show that the cross-class mixing term in this sum becomes smaller (as this would be our only concern -we already know the same-class mixing term will become smaller by the logic of Claim B.6). As in the previous claims in this section, we will proceed with an expectation analysis. We will focus on the weights w y,r that are in the linear regime for the feature v y,1 for each class y, as these are the only relevant weights for C (t) y,1 . Additionally, instead of considering the sum of gradient correlations over all w y,r with r ∈ B (t) y,1 , it will suffice for our purposes to just consider the sum of gradient correlations over classes while using an arbitrary weight w y,r in the linear regime. Thus, we will abuse notation slightly and use w y,r to indicate such a weight for each class y for the remainder of the proof of this claim (note that we do not mean to imply by this that weight r is in the linear regime for every class simultaneously, but rather that there exists some r for every class that is in the linear regime). Now in the same vein as Equation B.31 (referring again to Equation B.24), we have that: And we recall that N is sufficiently large so that the deviations from the expectations above are negligible compared to the subtracted term. We have carefully paired the expectations in the leading term of Equation B.52 so as to make use of the following fact: A 1 (β s , β y ) = -A 1 (β y , β s ) + u∈[k]\{y,s} exp(g u t (z y,s )) u∈[k] exp(g u t (z y,s )) (B.53) The second term on the RHS of Equation B.53 is of course o(P δ α 4 /poly(k)) so long as g s t (z y,s ) and/or g y t (z y,s ) are greater than C log k for a large enough constant C, so we obtain: Now we let Ξ y = ⟨w y , v y,1 ⟩ + ⟨w y , v y,2 ⟩ and perform the transformation g s t ← g s t -Ξ y /2 for all s ∈ [k] (note this doesn't change the value of the softmax outputs). Under this transformation we have that g y t (z y,s ) = (β y -1/2)∆ y , which isolates the gap term ∆ y . For further convenience, let us use Λ s ≜ u̸ =y exp(g u t (z y,s )), and observe that Λ s depends only on β s due to orthogonality. Using the change of variables βy = β y -1/2 we can then compute the expectation in the first term of Equation C.6 as: Then we immediately have that Equation C.11 is Θ(1). On the other hand, if this is not the case, one can see that E βs,βy [A 1 (β s , β y )(1 -β y )] = Θ(1), so we are done. Proposition 3.5. [ERM Gradient Upper Bound] For every y ∈ [k], assuming the same conditions as in Proposition 3.4, if ∆ y ≥ C log k for any C > 0 then with high probability we have that: -∇ wy J(g, X ), v y,2 ≤ O 1 k 0.1C-1 (3.6) Proof. From the facts that ⟨w y , v y,2 ⟩ ≥ 0 and ∆ y ≥ C log k, we have that g y (x i ) ≥ β y,i C log k for every i ∈ N y (where β y,i represents the coefficient in front of v y,1 in x i ). Since β y,i ∈ [0.1, 0.9], we immediately have the result from the logic of Claim B.2 and Calculation A.9.



This assumption is true for any distribution with reasonable variance; for example, the uniform distribution. q-2 1 + Θ 1 polylog(k). For every A = O(1), letting T x be the first iteration such that x t ≥ A, we must have that y Tx = O(y 0 polylog(k))



in which every training data point is transformed such that a randomly sampled training point from a different (but randomly fixed) class is concatenated (along the channels dimension) to the original point. Additionally, to introduce a dependency structure akin to what we have in Definitions 3.1 and 4.1, we sample a γ ∼ Uni([0, 1]) and scale the first part of the training point (the true image) by γ while scaling the concatenated part by 1 -γ during training.

Figure 1: Visualization of data modification in CIFAR-10.

Figure 2: Test error comparison between Uniform Mixup (green), Midpoint Mixup (orange), and ERM (blue). Each curve represents the average of 5 model runs (over the randomness of the data augmentations and model initializations), while the surrounding area represents 1 standard deviation.

Proof of Claim B.1. We prove each part of the claim in order, starting with showing that |N y | = Θ(N/k) with the desired probability for each y. To see this, we note that the joint distribution of the |N y | is multinomial with uniform probability 1/k. Now by a Chernoff bound, we have that |N 1 | = Θ(N/k) with probability at least 1 -exp(Θ(-N/k)). Conditioning on |N 1 | = Θ(N/k), we have that the joint distribution of |N 2 |, ..., |N k | is multinomial with uniform probability 1/(k -1), so we obtain an identical Chernoff bound for |N 2 |. Repeating this argument and taking a union bound gives that |N y | = Θ(N/k) for all y ∈ [k] with probability at least 1 -k exp(Θ(-N/k)). The fact that for every y we have max s∈[k],r∈[m],ℓ∈[2] w (0) s,r , v y,ℓ = O(log k/ √ d) with probability 1 -O(1/k) follows from Proposition A.2. Namely, using Proposition A.2 with t = 2σ √ 2 log m (here σ = 1/ √ d by our choice of initialization) yields that max r w (0) s,r , v y,ℓ ≥ 3 √ 2 log k/ √ d with probability bounded above by 1/k 3 for any s, y. Taking a union bound over s, y then gives the result.The final fact follows by near identical logic but using Proposition A.3 (note that the correlations w(0)s,r , v y,ℓ are i.i.d. N (0, 1/d) due to the fact that the features are orthonormal and the weights themselves are i.i.d.).

Proof of Corollary B.3. Follows from Equations B.2 and B.3. ■ With these softmax bounds in hand, we now show that the "diagonal" correlations w (t) y,r , v y,ℓ grow much more quickly than the the "off-diagonal" correlations w (t) y,r , v s,ℓ (where s ̸ = y). This will allow us to satisfy the conditions of Claim B.2 throughout training. Claim B.4. Consider an arbitrary y ∈ [k]. Let A ≤ ρ/(δ 2 -δ 1 ) and let T A denote the first iteration at which max r∈[m],ℓ∈[2] w (T A ) y,r , v y,ℓ ≥ A. Then we must have both that T A = O(poly(k)) and that max r∈[m],s̸ =y,ℓ∈[2] w (T A ) y,r , v s,ℓ = O log(k)/ √ d . Proof of Claim B.4. Firstly, all weight-feature correlations are o(ρ) at initialization (see Claim B.1). Now for s ̸ = y and w (0)

it immediately follows from comparing to Equation B.6 and recalling that α = 8 in Assumption 4.3 that T >> T A , and that T A = O(poly(k)), so the claim is proved. ■

* , v y,ℓ * to grow to log(k)/δ 1 . From Equation B.10, we clearly have that T y = O(poly(k)) for some polynomial in k. Furthermore, comparing to Equation B.7, we necessarily still have max r∈[m],s̸ =y,ℓ∈[2] w (Ty) y,r , v s,ℓ = O log(k)/ √ d . Finally, as the update in Equation B.9 is positive at T y (and the absolute value of a gradient update is o(1)), it follows that min i∈Ny g y t (x i ) = Ω(log k) for all t ≥ T y by Claim B.5. ■ The final remaining task is to show that g y t (x i ) = O(log k) and g s t (x i ) = o(1) for all t = O(poly(k)) and i ∈ N y for every y ∈ [k]. Claim B.7. For all t = O(k C ) for any universal constant C, and for every y ∈ [k] and s ̸ = y, we have that g y t (x i ) = O(log k) and max r∈[m],s̸ =y,ℓ∈[2] w (t) y,r , v s,ℓ = O log(k)/ √ d for all i ∈ N y .

represents the max class y correlation with feature v y,ℓ at time t. Now we can prove essentially the same result as Proposition B.2 inAllen-Zhu and Li (2021), which quantifies the separation between Λ (0) y,1 and Λ (0) y,2 after taking into account S y,1 and S y,2 .

, v y,3-ℓ * is still within a polylog(k) factor of w (0) y,r , v y,3-ℓ * for any r. Now from the same logic as the proof of Claim B.7, we can show that this separation remains throughout training. Claim B.13. For any class y ∈ [k], with probability 1 -O(1/ log k), we have that max r∈[m] w (t) y,r , v y,3-ℓ * = O(polylog(k) max r∈[m] w (0) y,r , v y,3-ℓ * ) for all t = O(poly(k)) for any polynomial in k.

, v y,3-ℓ * to grow by more than a polylog(k) factor we need ω(poly(k)) training iterations. ■ From Claim B.13 along with Claim B.7, it follows that with probability 1 -O(1/ log k), for any class y (after polynomially many training iterations) we have:

22) Crucially we have that Equation B.22 is a Θ(1/k) factor smaller than Equation B.21. Recalling that all correlations are O(log(k)/ √ d) at initialization, we see that the difference in the updates in Equations B.21 and B.22 is at least of the same order as Equation B.21. Thus, in O(poly(k)) iterations, it follows that w (t) y,r * , v y,ℓ * > w (t) y,r , v s,ℓ (this necessarily occurs for a t < T A by definition of A and comparison to the bounds above), after which it follows from Equations B.21 and B.22 that w (T A ) y,r , v s,ℓ = O( w (T A ) y,r * , v y,ℓ * /k) (and clearly T A = O(poly(k))). This proves the first part of the claim.

, v s,ℓ /polylog(k) for some fixed polylog(k) factor. Then it follows from Equation B.21 that in T iterations w (t) y,r * , v y,ℓ * has increased by at least a k/polylog(k) factor. As a result, we have that at T A the correlation w (t) y,r , v s,ℓ has decreased by at most a polylog(k) C factor for some universal constant C, proving the claim. ■ Corollary B.18. For any class y ∈ [k], and any t ≥ T A (for any T A satisfying the definition in Claim B.17), we have for any s ̸ = y and ℓ ∈ [2]:

(t) y,r * has reached the linear regime of ReLU on effectively all mixed points) while the off-diagonal correlations continue to lag behind by a O(1/k) factor.

Suppose that max s∈[k] C (t)s,1 = O(log k). Then for any class y ∈ [k] and any r 1 ∈ B

Claim B.22. Suppose that max u∈[k] C (t) u,1 = O(log k). Let y be any class such that ∆ (t) y ≥ log ko(1), and suppose that there exists at least one class s ∈

B.37) Which follow from Claim B.19 and Corollary B.18 (alongside the assumption that max u∈[k] C

M M (g, X ), v y,2 ≥ Θ τ ρ α-1 k 2 (B.45)Proof of Claim B.24. By near-identical logic to the steps leading to Equation B.31 and using Equation B.21, we obtain (following the same notation as before):-∇ w (t) y,r 2 J M M (g, X ), v y,2 ≥ Θ τ ρ α k 2 s̸ =y E βs,βy,1,...,β y,C P A 1 (β s , C P p=1 β y,p ) C P p=1 (δ 2 -β y,p ) α (B.46)Where above we have absorbed a factor of 1/2 resulting from mixing classes and the feature noise component into the asymptotic term in front of the summation. Now we break the rest of the proof into two cases: whether Equation B.33 holds or not. In the former case, using Corollary B.23 and our assumption in the statement of the claim that Equation B.44 holds, we get (from linearity of expectation): βs,βy,1,...,β y,C P A 1 (β s , Cov(A 1 (β s , C P p=1 β y,p )(δ 2 -β y,p ), (δ 2 -β y,p ) α-1 ) > 0 for every p, we obtain:-∇ w (t) y,r 2 J M M (g, X ), v y,2 ≥ s̸ =y C P p=1 E βs,βy,1,...,β y,C P A 1 (β s , C P p=1 β y,p )(δ 2 -β y,p ) E βs,βy,1,...,β y,C P [(δ 2 -β y,p ) α-1 ] (B.48)And the result then follows for this case from Equation B.47 and the fact that E βs,βy,1,...,β y,C P [(δ 2β y,p ) α-1 ] is a data-distribution-dependent constant.

(b -C P δ 2 /2)∆(t)  s -C P δ 2 D (t) y,s /2 > (a-C P δ 2 /2)∆ (t) y (B.49)By our anti-concentration assumption on the β y,p it immediately follows that D (t) y,s = -Θ(log k), from which we obtain that the the expectation terms in Equation B.46 are all Θ(1), so we are done. ■ Having proved Claim B.24, it remains to prove that both Equation B.44 and max y∈[k] C (t) y,1 = O(log k) hold throughout training, as after doing so we can conclude that the second feature correlation will escape the polynomial part of ReLU and become sufficiently large in polynomially many training steps. Claim B.27. For all t = O(poly(k)), for any polynomial in k, we have that max y∈[k] C

y∈[k] -∇ w (t) y,r J M M (g, X ), v y,1 ≤ Θ βs,βy [A 1 (β s , β y )β y ] + E βy,βs [A 1 (β y , β s )β s ] βy [A 2 (β y , β y )β y ] -O (P δ α 4 /poly(k)) (B.52)

βs,βy [A 1 (β s , β y )(β y -β s )] βy [A 2 (β y , β y )β y ] -O (P δ α 4 /poly(k)) (B.54) Now again by the logic of Claim B.22 we have that Cov(A 1 (β s , β y ), β y -β s ) < 0, so it follows that: βy [A 2 (β y , β y )β y ] -O (P δ α 4 /poly(k)) (B.55)And if g y t (z y,s ) ≥ C log k for a sufficiently large constant C, we have that the RHS above would be negative, which contradicts Corollary B.26, proving the claim. ■We have now wrapped up all of the pieces necessary to prove Theorem 4.7. Indeed, we can now show that for every class the correlation with both features becomes large over the course of training. Claim B.28. For every class y ∈ [k], in O(poly(k)) iterations (for a sufficiently large polynomial in k) we have that both C (t) y,1 = Ω(log k) and C (t) y,2 = Ω(log k).

βs,βy [A 1 (β s , β y )(β y -1/2s -exp(g y t (z y,s )) Λ s + exp(g y t (z y,s ))(β y -1/2) dβ y dβ s s -exp βy ∆ y Λ s + exp βy ∆ y βy d βy d βs (C.9)We will focus on the inner integral in Equation C.9. Using the symmetry of βy , we have that:E βy [A 1 (β s , β y )(β y -1/2exp -βy ∆ y Λ s + exp -βy ∆ y -Λ s -exp βy ∆ y Λ s + exp βy ∆ y s exp βy ∆ y -exp -βy ∆ y Λ 2 s + Λ s exp βy ∆ y + exp -βy ∆ y + 1 βy d βy (C.10)From our orthogonality assumption and the facts that ∆ y ≥ log k and ⟨w y , v y,2 ⟩ ≥ 0, we have that u̸ =y,s exp(g u t (z y,s )) = O(1). Additionally, if we let:D y,s = (C y,1 + C y,2 ) -(C s,1 + C s,2 )Then we get from Equation C.10 that:E βs,βy [A 1 (β s , β y )(β y -1∆ s -D y,s /2 exp βy ∆ y exp 2 βs ∆ s -D y,s + exp βs ∆ s -D y,s /2 exp βy ∆ y   d βy d βs (C.11) Now we consider two cases. First, suppose there exists a set U ⊂ [-0.4, 0.4] × [0, 0.4] with probability measure at least 0.01 such that for all ( βs , βy ) ∈ U : βs ∆ s -D y,s /2 ≤ βy ∆ y

). In what follows, we use inf J M M (h, X ) to indicate the global minimum of J M M over all functions h : R d → R k (i.e. this is the smallest achievable loss).

Now we may show: Claim B.6. For each y ∈ [k], let T y denote the first iteration such that max i∈Ny g y Ty (x i ) ≥ log k. Then T y = O(poly(k)) and max r∈[m],s̸ =y,ℓ∈[2] w , v y,ℓ * in Equation B.6 with ρ/(δ 2 -δ 1 ), shows that in O(poly(k)) additional iterations we have w

of Corollary B.18. The O(1/k) factor separation between the updates to diagonal and offdiagonal correlations shown in Equations B.21 and B.22 continue to hold once we pass into the linear regime of ReLU. Furthermore, the logic used to prove the lower bound for positive correlations in Claim B.17 easily extends to showing that the correlations remain positive throughout training. ■ As noted above, the bound on the off-diagonal correlations obtained in Claim B.17 and Corollary B.18 is much weaker than what it was in Claim B.4, which is why we weakened the assumptions in Claim B.15. We now prove the Midpoint Mixup analogues to Claims B.5, B.6, and B.7.Claim B.19. Consider y ∈ [k] and t such that max i∈Ny, j∈[N ] g y t (z i,j ) = Θ(log k). Then max i∈Ny, j∈[N ] g y t (z i,j ) = Θ(min i∈Ny, j∈[N ] g y t (z i,j )).

of Claim B.20. As in the proof of Claim B.6, applying Claim B.17 to any class y yields the existence of a correlation w , v y,ℓ * > ρ/(δ 2 -δ 1 ). And again, reusing the logic of Claim B.17 but replacing w , v y,ℓ * in Equation B.21 with ρ/(δ 2 -δ 1 ) yields that in an additional O(poly(k)) iterations we have w

Where above we used Equation B.38 to get the upper bound in the last step. We next focus on bounding the inner integral in Equation B.40. Using the symmetry of βy , we have that:E βy [A 1 (β s , β y )(β y -C P δ 2 /2)]One can sanity check that Equation B.41 is bounded below by -1, as we would expect. The only tricky aspect of Equation B.41 is the O(1) term in the exponential, which can lead to a positive contribution (via a negative integrand) when βy is close to 0. However, we can safely restrict the bounds of integration in Equation B.41 to a region [ϱ, C P (δ 2 /2 -δ 1 )] for ϱ = Θ(1/ log k) (with an appropriately chosen constant), as in such a region the integrand is guaranteed to be positive since ∆ Furthermore, this restriction does not cost us anything (like an additional positive term), as concern from Equation B.41 is purely a consequence of how we bounded g y t and g s t . Indeed, by our earlier monotonicity argument it is clear that we can cut out the region corresponding to [-ϱ, ϱ] from the first line of Equation B.40 without decreasing the RHS.

Now we can show using Corollary B.23 that there is a significant gradient component towards correcting the separation between the feature correlations even when the second feature correlation is in the polynomial part of ReLU (which is where it got stuck for a significant number of classes in the ERM proof). , v y,2 ≥ τ > 0, so long as there exists an r 1 ∈ B

annex

Claim B.25. For any y ∈ [k], ℓ ∈ [2], and r ∈ [m], we have that:-∇ w (t+1) y,r J M M (g, X ), v y,ℓ ≥ 0.99 -∇ w (t) y,r J M M (g, X ), v y,ℓ So long as s̸ =y exp g s t+1 (z i,j ) ≥ s̸ =y exp(g s t (z i,j )) for all mixed points z i,j with i ∈ N y .Proof of Claim B.25. We proceed by brute force; namely, as long as η is sufficiently small, we can prove that the gradient for J M M does not decrease too much between successive iterations. As notation is going to become cumbersome quite quickly, we will use the following place-holders for the gradient correlations at time t and t + 1:J M M (g, X ), v y,ℓWe will now prove the result assuming r ∈ B (t) y,ℓ , as the the case where r / ∈ B (t) y,ℓ is strictly better (we will have the upper bound shown below with additional o(1) factors). We have that (compare to Equation B.24):Let us now focus on the ϕ y (g t+1 (z i,j )) -ϕ y (g t (z i,j )) terms present in Equation B.50 above. We will just consider the first case above (mixing between class y and a non-y class), as the other analyses follow similarly. Furthermore, we will omit the z i,j in what follows (in the interest of brevity) and simply write g y t+1 . Additionally, similar to Claim B.22, we will use the notation Λ t = s̸ =y exp(g s t ). Now by the assumption in the statement of the claim, we have that Λ t+1 ≥ Λ t , and since m = Θ(k) (number of weights per class), we have that g y t+1 ≤ g y t + Θ(kηG t ) (all of the updates for weights in the linear regime are identical and strictly larger than updates for those in the poly regime). Thus,Where in the last line we again used the inequality exp(x) and t, we have that:Proof of Corollary B.26. For every class y, we have s̸ =y exp g s t+1 (z i,j ) ≥ s̸ =y exp(g s t (z i,j )) for all mixed points z i,j with i ∈ N y for t = 0 (see the proof of Claim B.17 If g has the property that for every class y we have ⟨w y , v y,ℓ1 ⟩ = ⟨w s , v s,ℓ2 ⟩ > 0 and ⟨w y , v s,ℓ2 ⟩ ≤ 0 for every s ̸ = y and ℓ 1 , ℓ 2 ∈ [2]. Furthermore, with probability 1 -exp(-Θ(N )) (over the randomness of X ), the condition ⟨w y , v y,ℓ1 ⟩ = ⟨w s , v s,ℓ2 ⟩ is necessary for g to satisfy Equation 3.3.Proof. We first prove sufficiency. If g satisfies the conditions in the Lemma, then we have for any data pointWe also have that g s (x i ) = 0 for any s ̸ = y i (by the cross-class orthogonality condition). Letting C = ⟨w y , v y,1 + v y,2 ⟩ (note that this correlation is the same independent of y due to the conditions of the lemma), we then get:For any mixed point z i,j with y i ̸ = y j . Equation C.1 tends to 1/2 as γ → ∞, and one can easily check that this is the global optimal prediction for the classes y i and y j on the Mixup point z i,j (for any such mixed point). Similarly, if z i,j is a mixed point with y i = y j , then Equation C.1 becomes the ERM case, we obtain the optimal prediction of 1 for the correct class in the limit.On the other hand, if there exists a pair of classes (y, s) with s ̸ = y and ℓ 1 , ℓ 2 ∈ [2] such that ⟨w y , v y,ℓ1 ⟩ ̸ = ⟨w s , v s,ℓ2 ⟩, then with probability 1 -exp(-Θ(N )) there exists a mixed point z i,j in X (where y i = y, y j = s, and y ̸ = s) such that g y (z i,j ) ̸ = g s (z i,j ), and hence lim γ→∞ ϕ y (γg ( z i,j )) ̸ = 1/2, so we cannot achieve the infimum of the Midpoint Mixup loss.Proposition 3.3. For any distribution D λ that is not a point mass on 0, 1, or 1/2, and any linear model g satisfying the conditions of Lemma 3.2, we have that with probability 1 -exp(-Θ(N )) (over the randomness of X ) there exists an ϵ 0 > 0 depending only on D λ such that:Proof. Firstly, we observe (just from properties of cross-entropy):Now suppose a model g satisfies the conditions of Lemma 3.2. Then we have that g yi (x i ) = g yj (x j ) = C > 0 for some constant C and every pair (x i , y i ) and (x j , y j ).As before, with at least probability 1 -exp(-Θ(N )), we have that there exist a pair of points (x i , y i ) and (x j , y j ) in X with y i ̸ = y j . The Mixup loss restricted to this pair (for which we use the notation J M (g, z i,j , D λ )) is then:Furthermore, we have that:From Equations C.3 and C.4 we can see that, since D λ is supported on more than just 0, 1, and 1/2, J M (g, z i,j , D λ ) → ∞ as C → ∞ (Equation C.4 implies that in the limit ϕ yi (γg(z i,j (λ))) can only take the values, 0, 1, or 1/2). It is also easy to see that the same behavior occurs if one considers C → 0. Thus it suffices to constrain our attention toSince this is a continuous function of C over a compact set, it must obtain a minimum greater than 0, and we may choose ϵ 0 to be this minimum (rescaled by a factor of Ω(1/N 2 )), thereby finishing the proof.C.2 PROOFS OF PROPOSITIONS 3.4 AND 3.5Proposition 3.4. [Mixup Gradient Lower Bound] Let y be any class such that ∆ y ≥ log k, and suppose that both ⟨w y , v y,1 ⟩ ≥ 0 and the cross-class orthogonality condition ⟨w s , v u,ℓ ⟩ = 0 hold for all s ̸ = u and ℓ ∈ [2]. Then we have with high probability that:Proof. The idea of proof will be to analyze the gradient correlation with v y,1 -v y,2 , and either show that this is significantly negative or, in the case where it is not, the gradient correlation with v y,2 is still significant. Firstly, using the cross-class orthogonality assumption and Calculation A.10, we can compute:Where above we used N y to indicate the indices corresponding to class y data points (as we do in the proofs of the main results). Now using concentration of measure for bounded random variables and the fact that N is sufficiently large, we have from Equation C.5 that with high probability (and with poly(k) representing a very large polynomial in k):Where we define the functions A 1 and A 2 as:With z y,s being a random variable denoting the sum of a class y point and a class s point (distributed according to Definition 4.1). Note that Equations C.7 and C.8 are not abuses of notation -the functions A 1 and A 2 depend only on the random variables β s and β y , since we can ignore the cross-class correlations due to orthogonality.Let us immediately observe that the first two terms (the expectation terms) in Equation C.6 are bounded above by 0. This is due to the fact that β y -1/2 is a symmetric, centered random variable and the functions A 1 and A 2 are monotonically decreasing in β y . We will focus on showing that the first term is significantly negative, as that will be sufficient for our purposes.C.3 PROOF OF PROPOSITION 4.4Proposition 4.4. There exists a D satisfying all of the conditions of Definition 4.1 and Assumption 4.3 such that with probability at least 1 -k 2 exp Θ(-N/k 2 ) , for any classifier h : R P d → R k of the form h y (x) = p∈[P ] w y , x (p) and any X consisting of N i.i.d. draws from D, there exists a point (x, y) ∈ X and a class s ̸ = y such that h s (x) ≥ h y (x).Proof. For hyperparameters, we can choose On the other hand, if there does not exist such a class pair y, u, then we are also done as that implies all of the weight-feature correlations are the same.

