OMNIGROK: GROKKING BEYOND ALGORITHMIC DATA

Abstract

Grokking, the unusual phenomenon for algorithmic datasets where generalization happens long after overfitting the training data, has remained elusive. We aim to understand grokking by analyzing the loss landscapes of neural networks, identifying the dependence of the generalization gap on model weight norm as a cause of grokking. We refer to this as the "LU mechanism" because training and test losses (against model weight norm) typically resemble "L" and "U", respectively. This mechanism can explain many aspects of grokking: data size dependence, weight decay dependence, the emergence of representations, etc. Guided by the intuitive picture, we are able to induce grokking on tasks involving images, language and molecules, although the grokking signals are sometimes less dramatic. We attribute the dramatic nature of grokking for algorithmic datasets to representation learning. Generalization lies at the heart of machine learning. A good machine learning model should arguably be able to generalize fast, and behave in a smooth/predictable way under changes of (hyper)parameters. Grokking, the phenomenon where the model generalizes long after overfitting the training set, has raised interesting questions after it was observed on algorithmic datasets by Power et al. ( 2022): Q1 The origin of grokking: Why is generalization much delayed after overfitting? Q2 The prevalence of grokking: Can grokking occur on datasets other than algorithmic datasets? This paper aims to answer these questions by analyzing neural loss landscapes: A1 Grokking can result from a mismatch between training and test loss against model weight norm. Specifically, (reduced) training and test losses plotted against model weight norm resemble "L" and "U", respectively, as shown in Figure 1b . We refer to this phenomenon as the "LU mechanism", which we elaborate on in Section 2 and 3. A2 Yes. Indeed, we demonstrate grokking for a wide range of machine learning tasks in Section 4, including image classification, sentiment analysis and molecule property prediction. Grokking signals observed for these tasks are usually less dramatic than for algorithmic datasets, which we attribute to representation learning in Section 5. Partial answers to Q1 are provided in recent studies: Liu et al. ( 2022) attribute grokking to the slow formation of good representations, Thilak et al. ( 2022) attempts to link grokking to the slingshot mechanism of adaptive optimizers, and Barak et al. (2022) uses Fourier gap to describe hidden progress. This paper aims to understand grokking through the lens of neural loss landscapes. Our landscape analysis is able to explain many aspects of grokking: data size dependence, weight decay dependence, emergence of representations, etc. The paper is organized as follows: In Section 2, we review background on generalization, and introduce the LU mechanism. In Section 3, we show how the LU mechanism leads to grokking for a toy teacher-student setup. In Section 4, we show that the intuition gained from the toy problem can transfer to realistic datasets (MNIST, IMDb reviews and QM9), for which we also observe grokking, although in a slightly non-standard setup where it is relatively weak. In Section 5, we discuss why grokking is more dramatic for algorithmic datasets than on others (e.g., MNIST), by comparing their loss landscapes. We review related work in Section 6 and summarize our conclusions in Section 7.



The training loss (orange) and test loss (gray) have the shape of L and U, respectively. Their mismatch in the w > w c region leads to fast-slow dynamics, resulting in grokking.

2. THE LU MECHANISM FOR GROKKING

Weight norm and reduced loss Letting w denote the weights of a model, any function f (w) (e.g, train/test loss/accuracy) depends on both the weight norm w ≡ ||w|| 2 and the angular direction ŵ ≡ w/w. Similar to Fort and Scherlis (2019) , we define a reduced function f (w) by minimizing training loss l train (w) over angular directions, i.e., f (w) ≡ f (w * (w)), where w * (w) ≡ argmin ||w||2=w l train (w). (1) In this paper, we set f as train/test loss/error, but it also applies to other metrics of interest. In practice, we perform the constrained minimization by rescaling the model weights back to their original norm after each unconstrained optimization step. We will see that this reduced 1D loss landscape, which is easy to visualize, captures important features related to grokking. Throughout the paper, our model is initialized by multiplying a factor α ≡ w/w 0 to the standard initializationfoot_0 , where w 0 and w are the weight norm of the network before and after multiplying by α, respectively. LU mechanism Although the loss landscapes of neural networks are nonlinear, Fort and Scherlis (2019) reveal a simple landscape picture: There is a spherical shell in the weight space (the "Goldilocks" zone), where generalization is better than outside this zone. We illustrate the Goldilocks zone as the green area with average radius w c in Figure 1a ; the green stars are the generalizing solutions. The test loss is thus higher either both when w > w c and w < w c , forming a U-shape against w in Figure 1b (gray curve). By contrast, the training loss has an L-shape against weight norm . There are many solutions which overfit training data for w > w c , but high training losses are incurred for w < w c . This corresponds to the L-shaped curve seen in Figure 1b (orange curve, no regularization). In summary, the (reduced) training loss and test loss are L-shaped and U-shaped against weight norm, respectively, which we will refer to as the LU mechanism throughout the paper. It is well known in statistics that generalization error has a "U" shape against model capacity, which is usually attributed to the bias-variance trade-off. Although this common wisdom was challenged by the observation of double descent (Nakkiran et al., 2021) , the "U" curve can be recovered from a double descent simply by changing the x-axis from the number of model parameters N to the 2-norm of model parameters w ≡ ||w|| 2 , at least for linear regression (Ng and Ma, 2022) . Although the LU mechanism may remind readers of related phenomena (Schoenholz et al., 2016; Yang and Schoenholz, 2017; Nakkiran et al., 2021) , their setups are not exactly the same as ours. More importantly, our focus and contribution is to understand grokking, a brand new generalization puzzle.

Grokking dynamics

We identify the "LU mechanism" as the cause of grokking. If the weight norm is initialized to be large (e.g., the black square in the w > w c region), the model first quickly moves to a nearby overfitting solution by minimizing the training loss. Without any regularization, the model will stay where it is, because the gradient of the training loss is almost zero along the valley of overfitting solutions, so generalization does not happen. Fortunately, there are usually explicit and/or implicit regularizations that can drive the weight vector towards the Goldilocks zone w ≈ w c . When the regularization magnitude is non-zero but small, the radial motion can be (arbitrarily) slow. If weight decay is the only source of regularization, and training loss is negligible after overfitting, then weight decay γ causes w(t) ≈ exp(-γt)w 0 , when w 0 > w c , so it takes time t ≈ ln(w 0 /w c )/γ ∝ γ -1 to generalize. A small γ results in a huge generalization delay (i.e., grokking). The dependence on regularization magnitudes is illustrated in Figure 1b : no generalization at all happens for γ = 0, small γ leads to slow generalization (grokking), and large γ leads to faster generalizationfoot_1 . The above analysis only applies to large initializations w > w c . Small initializations w < w c can always generalize fastfoot_2 , regardless of regularization. Why isn't grokking commonly observed? The standard initialization schemes typically initialize w no larger than w c . However, if we increase initialization scales (explicitly or implicitly), grokking can appear. In Section 3 and 4, we find that explicitly increasing initialization weight norm can induce grokking. In Section 5, we argue for algorithmic datasets because (shown in Figure 6d ) w c (bad representation) > w c (good representation), i.e., a proper initialization for a bad representation is effectively too large for a good representation, leading to grokking. Take the addition (base p) for example: with the good (linear) representation or a bad (random) representation, the decoder needs to learn to classify O(p) or O(p 2 ) examples, respectively. (c) α = 2. The steps to overfitting is independent of weight decay, while the steps to generalization scale inversely with the weight decay.

3. GROKKING FOR A TEACHER-STUDENT SETUP

To illustrate how the LU mechanism results in grokking, we employ a toy teacher-student setup. The teacher and the student share the same architecture (a 5-100-100-5 MLP with tanh activation), but are initialized with different seeds. The student network is initialized with the standard initialization (the default one in PyTorch) but each weight is rescaled by the same factor α ≡ w/w 0 , where w 0 and w are the weight norm of the student network before and after rescaling. The teacher network is initialized standardly, i.e., α teacher = 1. Inputs and outputs have dimensions d in = 5 and d out = 5, respectively. We generate N train = 100 training and N test = 100 test samples by first drawing inputs from the standard Gaussian distribution N (0, I din×din ), and then feed the input data to the teacher to generate output labels. The student network is trained with the Adam optimizer (learning rate 3 × 10 -4 ) for 10 5 steps. LU landscapes Firstly, we compute the reduced losses by minimizing the training loss (excluding weight decay) while constraining the weight norm of the student network to be constant. We assume the converging point after training as the global minimum on the spherical surfacefoot_3 , which explicitly defines the reduced losses ltrain (α) and ltest (α). As shown in Figure 2a , ltest (α) first decreases and then increases as α increases, displaying a U-shape with a minimum at α ≈ 1. By contrast, ltrain (α) decreases when α < 1 and remains flat near zero when α ≥ 1, forming an L-shape. When weight decay γ is present, the training landscape becomes ltrain (α, γ) = ltrain (α) + γα 2 C 2 where C is the average parameter magnitude determined by the standard initialization. Training dynamics Our problem is a regression task, but we can imitate the behavior of a classification task by manually setting a threshold β = 0.01 and defining a sample to be correctly "classified" if the prediction error is less than β. We study the dynamics of training and test accuracy. Note that this is the normal training setup where the weight norm is not constrained, although with two non-standard initializations α = 0.5 (small) and α = 2.0 (large), and three weight decays γ = 0 (no reg), γ = 0.03 (small reg) and γ = 1 (large reg). As shown in Figure 2b (bottom), small initialization runs always generalize fast regardless of regularization. Large initialzation runs (top) dependend on weight decay: no regularization fails to generalize, small regularization generalizes slowly (grokking), while large regularization generalizes faster. For the large initialization α = 2.0, we do a finer sweep of γ in [0.03, 1]. We compute the number of steps and weight norm w when training or test accuracy reaches 95%. As shown in Figure 2c , the time (number of steps) to reach 95% training accuracy is independent of weight decay γ, while the time to reach 95% test accuracy is inversely proportional to the weight decay, as we derived above for the LU mechanism.

4. OMNIGROK: GROKKING FOR MORE INTERESTING TASKS

We now analyze loss landscapes and search for grokking for several more interesting datasets, and see that the insights obtained from our toy model can transfer to these datasets. We report the main results here, with experiment details included in Appendix A.

Image classification

We visualize loss landscapes of MNIST (Deng, 2012) to verify the LU mechanism, and study the dependence on training data size. Similar to the teacher-student case, we reduce losses and errors (one minus accuracy) to two variables (weight norm w and data size N ) by minimizing over angular directions of weights, i.e., ltrain (w, N ) ≡ l train (w * , N ), ltest (w, N ) ≡ l test (w * , N ), w * (w, N ) ≡ argmin ||w||2=w l train (w, N ), shown in Figure 3 2022), we see that there exists a critical training set size below which generalization is impossible. The effective theory analysis in Liu et al. (2022) only applies to algorithmic datasets, but not to other datasets with unknown optimal representations. The loss landscape analysis presented is this work can apply to all supervised-learning tasks. As shown in Figure 3 (b), the contours of constant test error are thumb-like, and the tip of the thumb determines the minimum amount of data required for generalization. Guided by the landscape analysis, we make two nonstandard decisions to induce grokking on MNIST: (1) we reduce the size of the training set from 60k to 1k samples (by taking a random subset) and (2) we increase the scale of the weight initialization distribution (by multiplying the initial weights, sampled with Kaiming uniform initialization, by a constant α > 1). With these modifications to the Steps to Test Accuracy > 60% Steps until generalization for MNIST (weight decay 5e-3) training set size and initialization scale, we train a depth-3 width-200 MLP with ReLU activations with the AdamW optimizer using MSE loss with one-hot targets. We find that the network quickly fits the training set, and test accuracy improves much later, as shown in Figure 3d , just as in the stereotypical grokking learning first observed in algorithmic datasets. Figure 3e shows the effect of training set size on time to generalization for MNIST. We find a result similar to what Power et al. ( 2022) observed, namely that generalization time increases rapidly once one approaches a certain critical data set size. The conclusions still hold for the cross entropy loss (see Appendix F), although with quantitatively milder effects.

Sentiment analysis of text

We look for grokking using LSTMs (Hochreiter and Schmidhuber, 1997) for IMDb dataset (Maas et al., 2011) . Similar to Eq. ( 3), we reduce training and test losses to depend on only the weight norm w and data size N . We show the reduced training and test error in Figure 4 (a)(b). For large data size, e.g., the full dataset, training and test errors have similar "U" shapesfoot_4 , so one cannot create grokking via the "LU" mechanism. For small data size, say 1k, however, the mismatch between training and test errors makes it possible to create grokking via large initializations. In Figure 4 (c), we initialize weights larger (α = 6) with weight decay 1, overfitting is complete within 10 2 steps, but generalization does not start until around 10 3 steps. Note that the generalization "jump" is not as sharp as on algorithmic datasets (Power et al., 2022) or MNIST, but at least generalization is delayed here. By contrast, if we use the standard initialization (α = 1) with no weight decay, generalization happens early on during training, and does not improve much after overfitting. Molecules We search for grokking using the graph convolutional neural network (GCNN) for QM9 dataset (Ramakrishnan et al., 2014) . Similar to Eq. ( 3), we define the reduced training/test losses, which are only dependent on weight norm w and data size N . As shown in Figure 5 (a)(b), when data size is large, training and test losses have similar "U" shapes, hence grokking is impossible via the "LU mechanism". When data size is small, training and test losses mismatch somewhere in the region α = w/w 0 > 1, making grokking possible. Indeed, shown in Figure 5 (d), there is a sharp drop in test loss around 10 4 steps if initialization is 3 times larger than standard, while standard initialization does not lead to grokking. Note that zero weight decay is applied in both cases, implying the existence of implicit regularizations.

5. REPRESENTATION IS KEY TO GROKKING

In Section 4, we showed that increasing initialization scales can make grokking happen for standard ML tasks. However, this seems a bit artificial and does not explain why standard initialization leads to grokking on algorithmic datasets, but not on standard ML datasets, say MNIST. The key difference is how much the task relies on representation learning. For the MNIST dataset, the quality of representation determines whether the test accuracy is 95% or 100%; by contrast in algorithmic datasets, the quality of representation determines whether test accuracy is random guess (bad representation) or 100% (good representation). So overfitting (under a bad representation) has a more dramatic effect on algorithmic datasets, i.e., the model weights increase quickly during overfitting but test accuracy remains low. During overfitting, model weight norm is much larger than at initialization, but then drops below the initialization norm when the model generalizes, shown in Figure 9 (see Appendix C), and also observed by Nanda et al. (2023) . In the following, we will compare algorithmic datasets (Section 5.1) to MNIST (Section 5.2). We show how their loss landscapes depend on representations differently, and how the difference leads to different outcomes (grokking or not).

5.1. ALGORITHMIC DATASETS

Setup Algorithmic datasets are the task of learning a binary operation a • b = c (a, b, c are symbols) with neural networks, which aim to predict c from input (a, b). We take the toy addition setup in (Liu et al., 2022) , where each input digit 0 ≤ i ≤ p -1 (output label 0 ≤ k ≤ 2(q -1)) is embedded as a vector E i (Y k ). A decoder MLP is employed to predict Y k = Dec(E i +E j ) (k = i+j). In the setup of grokking, both the decoder and the input representations R ≡ {E i } are trainable, with learning rates η D and η R , respectively; in the setup of landscape analysis, only decoder is trainable, as we explain below. Training and test losses depend on three factors: (i) representation R, (ii) weight norm w and (iii) weight direction ŵ. As in previous sections, we can optimize ŵ by minimizing the training loss on constant weight norm spheres. We further reduce the high-dimensional representations to 1D by interpolating in a particular direction: R = mR random + (1 -m)R linear where R linear refers to the linear representation in which number k is embedded to E k = [k, 0, • • • , 0], R random is the initialized representation drawn from Gaussian distributions, i.e, E k ∼ N (0, I), and m ∈ [0, 1] is a scalar interpolating between R linear and R random , that we term representation messiness because R = R linear when m = 0, and R = R random when m = 1. After these reductions, both training and test losses become functions of two variables, representation messiness m and weight norm w: w * (w, m) ≡ argmin ||w||2=w l train (w, m), ltrain (w, m) ≡ l train (w * , m), ltest (w, m) ≡ l test (w * , m) Note that our definition of ltrain (w, m) excludes the weight decay term ℓ reg = 1 2 γw 2 , but we should be aware of its presence when we analyze the dynamics of (w, m), which is governed by the gradient flow on ltrain (w, m) plus weight decay (η R /η D are learning rates of representation/decoder): However, the driving force is small because the dependence is weak, leading to grokking. We elaborate below how these particular loss landscapes lead to grokking. dw dt = -η D ∂ ltrain ∂w + γw , dm dt = -η R ∂ ltrain ∂m . ( Grokking dynamics In region II, the dynamics is slow (for small γ) due to nearly vanishing gradients. By contrast, the dynamics in region I is relatively fast. As we will explain, dynamics is also slow on the boundary of I and II, and grokking is the consequence of traversing region II and/or the boundary. Let us analyze a typical path A to E shown in Figure 6 (a)(b). A rolls "downhill" to B following training gradients, possibly continuing to C due to momentum. C is located in II where ltrain ≈ 0, so according to Eq. ( 6), dm/dt ≈ 0 and dw/dt ≈ -η D γw or, equivalently, d(logw)/dt ≈ -η D γ. So (logw, m) moves with a constant speed v = η D γ in the -w direction from C to D, a point near the boundary. Negative gradients around the boundary point towards larger w and smaller m, shown in Figure 6d (a zoom-in of Figure 6a ). The gradients become increasingly large as the model goes deeper inside region I, and at some point, the gradient totally cancels out v in the gradient direction, making the model start to drift along the boundary, as illustrated in Figure 6f . Then the model moves along the boundary with a new velocity v ′ = vcosθfoot_5 , until it reaches the generalizing solution E. The above picture is supported by empirical experiments in Appendix C and also Nanda et al. (2023) . Based on the picture, we also show the ability to eliminate grokking in Appendix C. The slow dynamics from C to E is the origin of grokking. During this period, the model first moves in the -w direction with a velocity v over the distance L 1 = L -hcotθ, and then moves along the boundary with a velocity v ′ over the distance L 2 = h/sinθ. So the total time is t = L 1 /v + L /v ′ = (L + htanθ)/(η D γ) . This formula agrees with the observation that large weight decays γ and/or larger decoder learning rates η D can make generalization happen faster (Power et al., 2022; Liu et al., 2022) . Besides, the path manifests intriguing multiple descent of test loss, shown in Figure 6c . Dependence of grokking on training data size Another important observation in Power et al. (2022) is that grokking happens faster for larger training size. Our landscape analysis can also explain the data size dependence. In Figure 6e , we show the contours (training loss = 0.02) for different training sizes (25, 35, 45, 55) . The contours of training size 45 and 55 both connect to the green star, meaning that generalization will eventually happen. However, the slopes of the contours are different, i.e., θ 55 < θ 45 . Since t = (L + htanθ)/(η D γ) increases as θ increases, we have t 55 < t 45 , i.e, more training data leads to faster grokking. For training size 35 and 25, the contours do not connect to the green star, so generalization will not happen, no matter how long the training will be run.

5.2. MNIST

We now study how training and test losses depend on representation messiness in the MNIST dataset. We denote the 28 × 28 images as the raw representation R raw . We construct a linearly separable representation R linear by assigning input representations proportional to their label y i , for example, an image of a 2 is represented by a matrix with all elements being 2. Similar to Eq. ( 4), we use m ∈ [0, 1] to interpolated between R raw and R linear : R = mR raw + (1 -m)R linear , Similarly to Eq. ( 5), we define and plot ltrain (w, m) and ltest (w, m) in Figure 7 , using the full dataset N = 60000. Comparing Figures 7a and 7b reveals two things: (1) The training and test losses behave similarly; (2) Both training and test losses depend very weakly on m. This implies that the raw image representation is already quite close to being optimal, so decent test accuracy can be obtained even without learning optimal representations. As a result, grokking does not occur (Figure 7c ). Comparing Figure 6 and 7, we see that the (strong) dependence of test performance on the representation is the key to grokking: the dependence on representation is strong for algorithmic datasets, so grokking happens. By contrast, the dependence is weak for MNIST, so grokking does not happen. 

6. RELATION TO RELATED WORKS

Grokking was first observed for algorithmic datasets by Power et al. (2022) . Several attempts have been made to understand grokking: (a) Liu et al. (2022) attributes grokking to the slow formation of good representations. (b) Shah (2021) suggests that generalizable solutions achieve lower loss than overfitting solutions, providing a training signal encouraging generalization. (c) Nanda et al. (2023) suggests grokking is a phase change due to limited data and regularization. (d) Barak et al. (2022) suggests that generalization is due not to random search, but to hidden progress of SGD to gradually amplify a Fourier gap. (e) Thilak et al. (2022) links grokking to the "Slingshot mechanism" specific to adaptive optimizers. (f) Millidge ( 2022) describes training as a random walk over parameters. Our conclusion supports (a)(b)(c)(d), but does not necessarily negate (e)(f). Double descent is the phenomenon that performance first gets worse and then gets better as we increase the model size, data size, training epochs or regularization (Nakkiran et al., 2021; Yilmaz and Heckel, 2022; Nakkiran, 2019) . The typical "U" shape of test loss in this paper does not conflict with double descent, because we are plotting the weight norm instead of the number of model parameters (Ng and Ma, 2022) . However, the "U"-shape should better be considered as empirically common rather than provably universal. In fact, the interaction between properties of data and inductive biases of learning algorithms can be more complicated than double descent (Chen et al., 2021; d'Ascoli et al., 2020) . Initialization From the optimization perspective, initializations are usually based on the "edge of chaos" idea such that variance of features and gradients should be preserved in the forward and backward pass (Glorot and Bengio, 2010; He et al., 2015; Bahri et al., 2020; Yang and Schoenholz, 2017; Jing et al., 2017) , or based on analyzing Jacobians and/or Hessians (Skorski et al., 2020) . From the generalization perspective, it was shown that large initializations overfit data easily but result in poor generalization (Xu et al., 2019; Zhang et al., 2020) , which agrees with our LU mechanism. Weight decay regularization is a standard trick in machine learning and has various effects on optimization and generalization (Zhang et al., 2018; Van Laarhoven, 2017) . In particular, Lewkowycz and Gur-Ari (2020) observes that it takes t ∝ 1/λ training steps to reach maximum test performance. This is strikingly similar to the grokking time t ∝ 1/λ we derived from the LU mechanism.

7. CONCLUSIONS

This study elucidates the grokking phenomenon from the perspective of loss landscapes. Our conclusions are: (i) grokking originates from the mismatch between training and test losses at high model weight norm ("LU" mechanism). (ii) grokking can happen in various models for a wide range of datasets, although the grokking signature is usually most dramatic for algorithmic datasets. (iii) The severity of grokking depends on how much the task relies on learning representations. This work not only reveals the mechanism of grokking, but also shows that reduced landscape analysis is a useful tool for characterizing data-model interaction and representation learning. over training brings train accuracy and test accuracy learning curves together, almost eliminating grokking. We would like to investigate in future works whether this training trick can be helpful for more standard machine learning tasks. 

D TIME TO GENERALIZE VERSUS WEIGHT DECAY

In our discussion of the "LU mechanism" as an explanation for grokking in Section 2, we predicted that the training time required for a model to generalize should be t ∝ γ -1 where γ is the weight decay. To test this, we perform a grid search over weight decays γ and plot the number of training steps required for models to reach a specified level of test accuracy in Figure 10a AdamW learning rate of 0.001. From Figure 10a , we find that t ∝ γ -1 holds across roughly two orders of magnitude of t and γ. There is some seed dependence on the generalization time (some seeds consistently require longer to generalize), but for each seed (corresponding to a particular model initialization) the relation t ∝ γ -1 appears to fit the data well. (b) ReLU MLP on MNIST: We train ReLU MLPs on MNIST as described in Appendix A. We use an α = 9.0 and train on a reduced training set of 1000 samples to delay generalization / induce grokking. From Figure 10b , we find that for γ roughly between 0.1 and 1.0 the relation t ∝ γ -1 holds. Very high values of weight decay seem to mess with optimization. On the other hand, with very low weight decay the model generalizes faster than naively expected, perhaps due to implicit regularization.

E SECTION 5.1 SETUP

Architecture Similar to Liu et al. (2022) , the decoder architecture is an MLP with hard coded addition. Each input symbol i is encoded to a scalar E i . Each output symbol k is represented by a 30D random vector Ŷk . We consider addition with base p, so input 0 ≤ i, j ≤ p -1 and output 0  ≤ k = i + j ≤ 2(p -1). We denote representation as R = {E 0 , E 1 • • • , E p-1 }. Y k = Dec w (E i + E j ), and the loss function being the mean squared error (MSE) between Y k and Ŷk , and w being the decoder weight. Although the common setup of grokking is to make both the representation R and the decoder w trainable, we will freeze part of them for easier analysis. This is where it could be a bit confusing, so we explicitly distinguish three setups: landscape analysis, reduced trajectory analysis and full trajectory analysis. Each setup have different subset of trainable parameters, as shown in Table 1 . Landscape analysis Both the representation R and weight norm w are fixed. Only the weight direction ŵ is trainable. The representation R is fixed according to Eq. ( 4), which is dependent on m, the representation messiness. The decoder has fixed weight norm w, but the weight direction ŵ is trainable. For each fixed (w, m), we minimize training loss over ŵ to get ŵ * (w, m) = argmin ŵ ℓ train (w, m, ŵ), and define reduced training and test loss, as in Eq. ( 5). The minimization is implemented by the Adam optimizer with learning rate 10 -3 for 10 4 steps. Although (w, m) are not trainable, we repeat the above minimization independently for different (w, m). In Figure 6 We are still able to observe delayed generalization on MNIST using cross entropy loss, though test accuracy first plateaus at higher than random-guess accuracy.



By "standard initialization" we mean the default one in PyTorch. For linear layers, each weight w ∼ U [-σ, σ] and bias b ∼ U [-σ, σ] where σ = 1/ √ fan_in, and U [a, b] denotes uniform distribution on [a, b]. γ should not be too large, otherwise it will bring the weights to a trivial solution w = 0. w should not be too small to harm optimization. This is generally not true when the loss landscape is non-convex. The aim of this assumption is to make the minimizer aligned with Eq. (1). In principle, reduced training losses should be non-increasing ("L"), but optimization issues may occur for too large initializations(Schoenholz et al., 2016). For simplicity, we assume ηR = ηD here, but the analysis can apply to any (ηR, ηD).



Figure 1: (a) w: L 2 norm of model weights. Generalizing solutions (green stars) are concentrated around a sphere in the weight space where w ≈ w c (green). Overfitting solutions (orange) populate the w ≳ w c region. (b) The training loss (orange) and test loss (gray) have the shape of L and U, respectively. Their mismatch in the w > w c region leads to fast-slow dynamics, resulting in grokking.

Figure 2: Teacher-student setup. α: student initialization scale, γ: weight decay. (a) The reduced training loss and test loss have the shape of "L" and "U", respectively. (b) Top row: large initialization (α = 2.0) can demonstrate no generalization (no reg), grokking (small reg) and fast generalization (large reg). Bottom: small initialization (α = 0.5) always generalizes fast, regardless of weight deacy. (c) α = 2. The steps to overfitting is independent of weight decay, while the steps to generalization scale inversely with the weight decay.

(a)(b). The reduced loss landscape reveals three things: (1) Larger initializations lead to grokking. Point A in Figure 3 corresponds to the standard initialization (α = 1), which has low training and test errors, hence no grokking. When increasing the weight norm from A to B, training error is seen to remain low while test error rises. To generalize, weight decay must be in place to bring the weight norm down, leading to grokking if weight decay is small. (2) Larger datasets lead to de-grokking. Comparing B and C in Figure 3, C is seen to have larger training size than B and lower test error. Larger data size N makes the Goldilocks zone broader, reducing or eliminating grokking even for large weight initializations. (3) Critical data size can be defined. As reported in Power et al. (2022); Liu et al. (

Figure 3: MNIST. (a) reduced training error, (b) reduced test error. Comparing A and B: larger weight norm makes learning grok (delay generalization). Comparing B and C: a larger training data size makes learning de-grok (speed up generalization). (c) "LU" holds truer for smaller data. (d) Accuracy curves for MNIST in the setting where we observe grokking. (e) Time to generalize as a function of training set size N , replicating Liu et al. (2022).

Figure 4: We use an LSTM to predict IMDb reviews. (a) training error; (b) test error; (c) reduced losses for data size 1k (top) and 50k (bottom); (d) With 1k data, a (weak) grokking signal is observed for large initializations (α = 6), while no grokking is observed for standard initializations (α = 1).

Figure 6: Loss landscapes on the 2D (w, m) plane. (a) Training loss splits the plane into two regions: large loss small w (fast dynamics) and small loss large w (slow dynamics). (b) Test loss; the green star is the generalizing solution. (c) Losses along an illustrative path A → E, demonstrating multiple descent; (d) zoom-in of the training loss highlighting the gradients on the boundary. (e) the boundary depends on training data size; (f) a simple illustration of grokking dynamics.

More experimental details are included in Appendix E. Landscape We show ltrain (w, m) and ltest (w, m) in Figures 6a and 6b, indicating the generalizing solution with a green star. Based on the reduced training loss (Figure 6a), we can divide the 2D plane into two regions I and II, separated by a dashed yellow line (the contour of training loss = 0.05): (I): The darker region, with high training losses/gradients and small weight norm. (II): The lighter region, with low training losses/gradients and large weight norm. Comparing Figures 6a and 6b reveals that training and test loss landscapes differ, especially in region II. Moreover, while the training loss depends weakly on m, the test loss depends strongly on m. As we will see, the (weak) dependence of training loss on representation drives the model to the generalizing solution.

Figure 7: MNIST landscapes as functions of representation messiness m and weight norm w: (a) training loss, and (b) test loss. Training and test losses do not have significant mismatch, and neither of them on representation strongly, which is in stark contrast to algorithmic datasets (Figure 6). (c) an illustrative path A → B → C does not manifest grokking.

Figure 9: Training 1L transformer on modular addition (p = 113). (a) Weight norm, train accuracy, and test accuracy over time, initialized and trained normally. Weight norm first increases, and is highest during the period of overfitting, but then drops to become lower than initial weight norm when the model generalizes. (b) Constrained optimization at constant weight norm (α = 0.8) largely eliminates grokking, with test and train accuracy improving almost concurrently.

-10b. We also show full training curves for these runs in Figure 10c-10d. We perform experiments in two setups: (a) Transformer on modular addition: We use the replication of grokking from Nanda et al. (2023) and train a 1-layer transformer on modular addition (p = 113 and a train set fraction of 0.3) where d model = 128, with 4 attention heads, d mlp = 512, ReLU activations, and an

Figure 10: Time to generalize as a function of weight decay: we investigate to what extent the relation t ∝ γ -1 holds, where t is number of training steps needed for the model to generalize and γ is the AdamW weight decay. When a lower weight decay is used, models spend longer in the period of overfitting before eventually generalizing. We show the generalization time t as a function of γ in (a)-(b) and full training curves for these runs in (c)-(d).

(a)(b)(d), the background heatmaps belong to landscape analysis. cross-entropy (log scale) | =100, D=200 train test

Figure12: Training curves using cross entropy loss on MNIST. We are still able to observe delayed generalization on MNIST using cross entropy loss, though test accuracy first plateaus at higher than random-guess accuracy.

Threes setups used in this paper, with different set of parameters trainable.

availability

//github.com/KindXiaoming

Appendix

A EXPERIMENT DETAILS Sentiment analysis of text IMDb (Maas et al., 2011) includes 50k movie reviews to be classified as being positive or negative. To pre-process the data, we extract the 1000 most frequent words and tokenize each review into an array of token indices. Less frequent words are ignored, and each review array is padded to length 500. We adopt the LSTM model (Hochreiter and Schmidhuber, 1997) to perform the classification, with two layers, embedding dimension 64, and hidden dimension 128. We use the Adam optimizer (Kingma and Ba, 2014) with learning rate 0.001 to minimize the binary cross entropy loss. We hold back 25% of the dataset for testing.Molecules QM9 is a database for small molecules and their properties. We use a graph convolutional neural network (GCNN) to predict the isotropic polarizability. The GCNN contains 2 convolutional layers with ReLU activation, followed by a linear layer. We use the Adam optimizer with learning rate 0.001 to minimize the MSE loss. We split the dataset into 50/50 train/test.MNIST We train width-200 depth-3 ReLU MLPs on the MNIST dataset with MSE loss. We use the AdamW optimizer with a learning rate of 0.001 and a batch size of 200.

B REDUCED LOSS FOR MODULAR ADDITION WITH TRANSFORMERS

In Figure 8 we show reduced loss landscape plots for transformers trained on modular addition. We use the setup of Nanda et al. (2023) and train a 1-layer transformer on modular addition (p = 113) with d model = 128, 4 attention heads, and d mlp = 512 with ReLU activations. We train with a learning rate of 0.001 while constraining model weight norm, for a variety of α and a variety of train set fractions. The LU shape holds for α ∈ [0.1, 4] (some optimization issue may be responsible for the rise in train loss for α > 4). We see the critical train set size is approximately 0.25, in line with earlier studies on grokking. 

C WEIGHT NORM EVOLUTION OVER TIME ON ALGORITHMIC TASKS

Evolution of weight norm As mentioned in Section 5, the dynamics of model weight norm over the course of training, on algorithmic tasks, support the LU mechanism picture of grokking. Figure 9a , shows how model norm changes over time and we see that there is an initial increase in weight norm, which peaks during overfitting, but then drops during the period of generalization to be lower than the initialization norm. For this experiment, we again used the setup of (Nanda et al., 2023) . We train with AdamW with a learning rate of 0.001 and weight decay γ = 1.Constraining a small weight norm eliminates grokking As shown in Figure 9b , reducing the initialization scale (α = 0.8) and constraining optimization to hold model weight norm constant Reduced trajectory analysis is a "thought experiment" based on landscape analysis. Since full trajectory analysis can be intractable due to too high dimensions, we try to reduce the trajectory anaysis to 2D, by making two assumptions about the real dynamics: (1) Scale separation: the dynamics of ŵ is much faster than the dynamics along w and along m, such that ŵ(t) = ŵ * (w(t), m(t)) is valid at every moment during training. (2) Representation evolution is linear, i.e., interpolating between initial random Gaussian and final linear representation. With these two assumptions, the training dynamics is effectively reduced to 2D, depending only on (w, m), obeying Eq. ( 6). In Figure 6 (a)(b)(c), the path A → E belongs to reduced trajectory analysis.Admittedly the reduced trajectory may deviate from the full trajectory since the assumptions may not be met, but it can shed light on the full trajectory: the weight norm first increases and then increases, and the decrease of weight norm is highly correlated with generalization (please see Appendix C and Figure 9 .

F MNIST EXPERIMENTS WITH CROSS ENTROPY LOSS

To respond to a reviewer's concern that our use of the MSE loss is the "secret" to get grokking on MNIST (Figure 3 ), we reran our experiments with the cross entropy (CE) loss. The results are qualitatively similar, with some quantitative differences. 

Landscape analysis

Comparing Figure 3 (MSE) and Figure 11 (CE), we notice the they are qualitatively similar: (1) for small datasets, the reduced training error and test error resemble an "L" and "U" against the weight norm, respectively; (2) for large datasets, the "U" becomes more like "L", i.e., the mismatch between the reduced training and test error is small. However, a quantitative difference exist: CE produces a broader "Goldilocks zone" (the weight range where generalization happens) than MSE. This implies that to induce grokking with CE, we need increase the weight norm to a larger value (say α = 100).

Training dynamics

We are able to observe delayed generalization during trianing on MNIST with cross entropy loss, but doing so requires a higher α than was necessary when using MSE loss, as predicted by the reduced loss landscapes in Figure 11 . Figure 12 shows training trajectories from a 3-layer ReLU MLP on MNIST trained with cross entropy loss with α = 100 and D = 200. We see that test accuracy rises to 30-40% early in training, then plateaus for an extended period, before increasing to ≈75% while train accuracy remains at 100%. While the dynamics are not as clean as with MSE loss, since test accuracy first plateaus at better-than-random accuracy, we think it is still fair to classify these dynamics as "grokking" due to the improvement in generalization late in training after a plateau.

