IMPORTANCE OF CLASS SELECTIVITY IN EARLY EPOCHS OF TRAINING Anonymous

Abstract

Deep networks trained for classification exhibit class-selective neurons in intermediate layers. Intriguingly, recent studies have shown that class-selective neurons are not strictly necessary for network function. But if class-selective neurons are not necessary, why do they exist? We attempt to answer this question in a series of experiments on ResNet-50 trained on ImageNet. We begin by showing that class-selective neurons emerge in the first few epochs of training before receding rapidly. Single-neuron ablation experiments show that class-selective neurons are important for network function during this early phase of training. The network is close to a linear regime during this early training phase, which may explain the emergence of these class-selective neurons in intermediate layers. Finally, by regularizing against class selectivity at different points in training, we show that the emergence of these class-selective neurons during the first few epochs of training is essential to the successful training of the network. Altogether, our results indicate that class-selective neurons in intermediate layers are vestigial remains of early epochs of training, during which they appear as quasi-linear shortcut solutions to the classification task which are essential to the successful training of the network.

1. INTRODUCTION

A significant body of research has attempted to understand the role of single neuron class-selectivity in the function of artificial (Zhou et al., 2015; Radford et al., 2017; Bau et al., 2017; Morcos et al., 2018; Olah et al., 2018; Rafegas et al., 2019; Dalvi et al., 2019; Meyes et al., 2019; Dhamdhere et al., 2019; Leavitt & Morcos, 2020a; Kanda et al., 2020; Leavitt & Morcos, 2020b) , and biological (Sherrington, 1906; Adrian, 1926; Granit, 1955; Hubel & Wiesel, 1959; Barlow, 1972) neural networks. Neurons responding selectively to specific classes are typically found throughout networks trained for image classification, even in early and intermediate layers. Interestingly, these classselective neurons can be ablated (i.e. their activation set to 0; Morcos et al. 2018) or class selectivity substantially reduced via regularization Leavitt & Morcos (2020a) with little consequence to overall network accuracy-sometimes even improving it. These findings demonstrate that class selectivity is not necessary for network function, but it remains unknown why class selectivity is learned if it is largely not necessary for network function. One notable limitation of many previous studies examining selectivity is that they have largely overlooked the temporal dimension of neural network training; single unit ablations are performed only at the end of training (Morcos et al., 2018; Amjad et al., 2018; Zhou et al., 2018; Meyes et al., 2019; Kanda et al., 2020) , and selectivity regularization is mostly constant throughout training (Leavitt & Morcos, 2020a; b) . However, there are numerous studies demonstrating substantial differences in training dynamics during the early vs. later phases of neural network training (Sagun et al., 2018; Gur-Ari et al., 2018; Golatkar et al., 2019; Frankle et al., 2020b; Jastrzebski et al., 2020) . Motivated by these studies, we asked a series of questions about the dynamics of class selectivity during training in an attempt to elucidate why neural networks learn class selectivity: When in training do class-selective neurons emerge? Where in networks do class-selective neurons first emerge? Is class selectivity uniformly (ir)relevant for the entirety of training, or are there "critical periods" during which class selectivity impacts later network function? We addressed these questions in experiments conducted in ResNet-50 trained on ImageNet, which led to the following results: • The emergence of class-selective neurons in early and intermediate layers follow a nontrivial and surprising path: after a prominent rise during the first few epochs of training, class selectivity subsides quickly during the next few epochs, before returning to a baseline level specific to each layer. • During this early training phase where average class selectivity is high in early and intermediate layers, class-selective neurons in these layers are much more important for network function compared to later in training, as assessed with single-unit ablation. • During this early, high-selectivity phase of training, the representations of early and latter layers are much more similar than during later in training, implying that selectivity in early layers could be leveraged to solve the classification problem by transmission to the latter layers via skip connections. Of particular importance is the work of Leavitt & Morcos 2020a, which introduced a regularizer for fine-grained control over the amount of selectivity learned by a network. This regularizer makes it possible to test whether the presence of selectivity is beneficial, and whether networks need to learn selectivity, two questions that ablation methods cannot test. By regularizing to discourage or promote the learning of selectivity, they found that selectivity is neither strictly necessary nor sufficient for network performance. In follow-up work they also showed that promoting class selectivity with their regularizer confers robustness to adversarial attacks, while reducing selectivity confers robustness to naturalistic perturbations (Leavitt & Morcos, 2020b) . However, they did not scale their experiments beyond ResNet18 (He et al., 2016) and Tiny-ImageNet (Fei-Fei et al., 2015) . Additionally, with the exception of one set of controls in which the regularizer was linearly warmed up during the first five epochs of training (Leavitt & Morcos, 2020a) , they did not examine the dynamics of class selectivity's importance over the training process. Most importantly, they did not attempt to address why selectivity is learned.

2.2. THE EARLY PHASE OF NEURAL NETWORK TRAINING

A breadth of approaches have been used to characterize the differences between the early and later phases of neural network training and highlight the impact of early-phase interventions on late-phase behavior. The application of iterative magnitude pruning in the Lottery Ticket approach (Frankle & Carbin, 2018) requires rewinding to a sufficiently early point in training-within the first few thousand iterations-for any benefit to be realized (Frankle et al., 2020a; b) . The local loss landscape also changes rapidly early in training (Sagun et al., 2018) ; the subspace in which gradient descent occurs quickly shrinks into a very restricted subspace (Gur-Ari et al., 2018) 

3.2. CLASS SELECTIVITY INDEX

We refer to the individual channels of the bottleneck layers as "units" or "neurons" for the purpose of our experiments. We compute a class selectivity index for each unit of each bottleneck layer of every module. This index is a measure of how much the unit favors a specific ImageNet class. It is defined as (Leavitt & Morcos, 2020a) : SI = µ max -µ -max µ max + µ -max + ϵ (1) where µ max is the largest class-conditional mean activation over all possible classes, averaged within samples of a class, and averaged over spatial locations for that channel, µ -max is the mean response to the remaining (i.e. non-µ max ) classes, and ϵ is a small value to prevent division by zero (here 10 -6 ) in the case of a dead unit. All activations are assumed to be positive because they are extracted after a ReLU layer. The selectivity index can range from 0 to 1. A unit with identical average activity for all classes would have a selectivity of 0, and a unit that only responded to a single class would have a selectivity of 1. We compute such class selectivity indices across all epochs of training and then we perform various experiments to see how class selectivity affects learning at different stages of training.

3.3. QUANTIFYING THE RELEVANCE OF CLASS-SELECTIVE NEURONS

We used the following methods to determine whether class-selective neurons are relevant to network function. Single neuron ablation: This method consists in setting the channel outputs of the bottleneck layers to 0. In our experiments, we perform progressive ablation of the channels in all bottleneck layer of a given module, ordered by class selectivity index and starting with the most class-selective channels, or by random ordering of channels (random control condition). On a fully-trained ResNet-50, we find that ablation of class-selective neurons is less damaging to network accuracy than ablation of random neurons (Appendix A.2), consistent with previous reports. Class-selectivity regularizer: We used Leavitt & Morcos (2020a)'s selectivity regularizer to either promote or suppress the emergence of class-selective neurons in different modules and at different epochs of training. The regularization term is as follows: µ SI = 1 M M m 1 B B b SI b (2) where M is the total number of modules, m is the current module, B is the total number of bottleneck layer in current module M and b is the current bottleneck layer. SI b represents the selectivity indices of the channels in the current bottleneck layer b. This term can be added to the loss function as follows: loss = - C c y c log( ŷc ) -αµ SI The left hand term of the loss function is the cross-entropy loss where C is the total number of classes, c is class index, y c is true class label and ŷc is the predicted class probability. The right hand side of the loss function is the regularization term (-αµ SI ). The alpha can be used to control the strength of the regularizer. A negative value of α will regularize against selectivity and a positive value of α will regularize in favor of selectivity.

4.1. CLASS-SELECTIVE NEURONS EMERGE DURING EARLY EPOCHS OF TRAINING

We began by examining the dynamics of class selectivity in each module over the course of training. It's possible that selectivity follows a similar pattern as accuracy: increasing quickly over the first few epochs, then gradually saturating over the remainder of training. We consider this to be the most likely scenario, as it is follows the dynamics of early and rapid change followed by stabilization observed in many other quantities during training. To test out this hypothesis, we calculated the Centered Kernel Alignment (CKA) (Kornblith et al., 2019; Nguyen et al., 2020) score between all modules. CKA can be used to find representational similarity between different components of a network. The CKA score between early module and latter module and between early module and the fully connected layer (FC) is indeed high during Figure 5 : Representational similarity between modules throughout training. We measure the representational similarity between the early module 4 and latter layers (module 7 layers, fully connected layer (FC)) using CKA. We find that the similarity between these layers is high at the early epochs of training, indicating that the network is close to linear at this stage. This linearity might be key to explain the emergence of class-selective neurons in all layers during the early phase of training (see text). Error shades indicate the 95% CI of the mean, calculated over 6 networks trained from different random initial conditions. the first few epochs, and then recedes to a low level for the rest of the training (Fig 5 ). The CKA score across all other pairs of modules shows a similar trend (Appendix A.4). From these results, the following picture emerges: at the early stages of training where the network is close to being linear, all layers jointly learn class-selective features, with latter layers piggy-backing on class-selective features learned by early layers. Leavitt & Morcos (2020a) found that regularizing beyond the range of α = [-1, -5]-depending on the model and dataset-had significant negative effects on network performance, while we were able to increase the magnitude of α all the way to -20 with minimal effect on network performance, so long as we did not regularize until epoch 5.

4.4. THE

The model never fully recovers the performance if the regularization is kept on, but if the regularization is turned off after first few epochs of training then the network is slowly able recover from the perturbation after 20+ epochs of training (Appendix A.5.3). 

4.6. REGULARIZING IN FAVOR OF CLASS SELECTIVITY DOES NOT IMPROVE LEARNING

As class-selective neurons are crucial to learning in the early epochs, could increasing the selectivity in those early epochs further improve learning? To answer this question, we regularized in favor of class-selectivity (i.e. with positive values of α) in those early epochs. We found that increasing the selectivity in those early epochs did not improve performance further (Fig 13, Fig 14) . Therefore, both an increase and a decrease in selectivity is harmful to learning, suggesting that the network roughly learns the "right" amount of selectivity for optimal learning. More details on the regularization for selectivity experiments can be found in Appendix A.5.2.

REPRESENTATION

We next explored how the regularization in favor and against class selectivity affects the balance of class representation in the predictions of the network. We found that, both increasing and decreasing class selectivity in early and intermediate modules resulted in a poorer diversity of class representation in the predictions of the network (Appendix A.5.4). This effect lasted for the few early first epochs of training, after which it was no longer visible. We conclude that tampering with class selectivity in early and intermediate layers perturbs the correct learning of the network by over- representing certain classes at the prediction stage during a period which is apparently critical for the correct training of the network.

5. DISCUSSION

Previous experiments have shown that class selectivity in individual units may not be necessary for high accuracy converged solutions (Morcos et al., 2018; Amjad et al., 2018; Donnelly & Roegiest, 2019; Leavitt & Morcos, 2020a; Kanda et al., 2020; Leavitt & Morcos, 2020b) . We sought to determine why class selectivity is learned if it is not necessary. In a series of experiments examining the training dynamics of ResNet50 trained on ImageNet, we found that class selectivity in early and intermediate layers is actually necessary during a critical period early in training in order for the network to learn high accuracy converged solutions. Specifically, we observed that class selectivity rapidly emerges within the first epoch of training, before receding over the next 2-3 epochs in early and intermediate layers. Ablating class-selective neurons during this critical period has a much greater impact on accuracy than ablating class-selective neurons at the end of training. Furthermore, we used Leavitt & Morcos (2020a)'s selectivity regularizer and found that strongly regularizing against class selectivity-to a degree that Leavitt & Morcos (2020a) found had catastrophic effects on network performance-had a negligible effect on network performance if regularization doesn't occur in the initial epochs. We also found that the representations of early and latter layers are much more similar during the early critical period of training compared to later in training, indicating that selectivity in early layers may be leveraged by latter layers in order to solve the classification problem early in training. Taken together, our results show that class-selective neurons in early and intermediate layers of deep networks are a vestige of their emergence during a critical period early in training, during which they are necessary for successful training of the network. One limitation of the present work is that it is limited to ResNet-50 trained on ImageNet. While Leavitt & Morcos (2020a) found that their results were consistent across a breadth of CNN architectures and image classification datasets, it's possible that our findings may not generalize to networks trained on NLP tasks, in which single neuron selectivity is also a topic of interest. Our observations add to the list of phenomena occurring during the critical phase early in training (see related work section 2), and prompts more investigations regarding this phase. In particular, it would be interesting to understand why the emergence of class-selectivity is necessary in this early phase of training, and whether this phenomenon connects to other critical-period phenomena, such as the emergence of lottery ticket subnetworks (Frankle & Carbin, 2018; Paul et al., 2022) , and critical learning flexibility (Achille et al., 2018) . Figure 8 : Abation of channels in a fully-trained ResNet-50. Surprisingly, ablating channels in a random order is more damaging to the network accuracy than ablating the most class-selective channels for modules 4, 5, 6. However, this trend is reversed for module 7. Error shades indicate the 95% CI of the mean, calculated over 10 networks trained from different random initial conditions. 

A.4 CENTERED KERNEL ALIGNMENT (CKA) ANALYSIS

To check the representational similarity between the modules at different stages of learning we used the Centered Kernel Alignment (CKA) (Kornblith et al., 2019; Nguyen et al., 2020) metric. To perform the CKA analysis we used the torch cka library (Subramanian, 2021) . The CKA score across all cases is high during the early epochs of training (Fig 10) which indicates that the network is jointly learning to be class-selective; with neurons from latter layers relying on the class-selective features from early and intermediate layers to produce crude shortcut strategies. So even if the individual neurons aren't class-selective, the model ends up over-representing certain classes in the early epochs. We think this might be because if neurons are not class-selective at all, then it might have difficulty learning in the early epochs and so the model as a whole ends up overfitting to a few classes. Similarly, when the model is regularized in favor of selectivity, the classes again end up being overrepresented. So this suggests that tampering with the selectivity in either direction perturbs the correct learning by over-representing classes. 



Figure 1: Structure of ResNet-50. (a) We define the early and intermediate layers as the layers present in the early and intermediate modules. (b, c) Each of these modules are made of multiple bottleneck layers which are also commonly known as conv block and identity block in ResNets.

Figure 2: Evolution of class selectivity throughout training in all modules of a ResNet-50. Each line represents the evolution of average class selectivity of neurons of a given module throughout training for a ResNet-50 trained on ImageNet. A sharp rise in selectivity can be seen in all modules during the first few epochs of training, which recedes rapidly for the intermediate modules (modules 4,5,6) but not for the last module (module 7).

Figure 3: Effect of class-selective ablations and random ablations on normalized network accuracy throughout training for early and intermediate modules. The right-hand side of the figure depicts the effect of class-selective neurons ablations on network accuracy throughout epochs. The effect on accuracy is more damaging during the earlier epochs of class-selective ablations compared to latter epochs. The random ablation curves do not show this effect (curves are close to each other). The left-hand side of the figure depicts the area under the ablation curves (AUC), as calculated by the sum of accuracies for each epoch-specific curve shown on the right-hand side. The positive slopes of the AUC curves for the class-selective ablations, especially visible during the early epochs of training, indicate that class-selective ablations are more important for network decision during these early epochs compared to latter epochs. Error shades indicate the 95% CI of the mean, calculated over 10 networks trained from different random initial conditions.

Figure 6: Regularization against selectivity. (a) Class selectivity indices throughout training for three models, respectively unregularized, regularized from epoch 0 onward, and regularized from epoch 5 onward. The effect of regularization on class-selectivity is clearly visible in the second and third plots. Note that only early and intermediate modules are regularized. Error shades indicate the 95% CI of the mean, calculated over all the bottleneck layers for each module. (b) Train and Validation accuracies for each model throughout training. The model regularized against selectivity from epoch 5 onward performs as well as the original unregularized model. The model regularized against selectivity from epoch 0 onward performs significantly worse than the other two models. This result shows the importance of class selectivity in the early epochs of training.

Figure 7: Effect of regularization against selectivity when starting from different epochs. Only the model regularized from epoch 0 onward performs poorly. Models which were regularized starting from epoch 1-5, all performed similarly. This suggests that the critical phase during which class-selectivity is important is restricted to the very first epoch of training.

Figure 9: Evolution of class selectivity throughout training across bottleneck layers. Error shades indicate the 95% CI of the mean, calculated over the channels of each bottleneck layer.

Figure 11: Regularizing against selectivity for different values of α throughout training.

Figure 12: Effect of regularization on selectivity indices across modules for different values of α throughout training. Error shades indicate the 95% CI of the mean, calculated over the bottleneck layers of each module.

Figure 13: Regularizing for selectivity with large values of α. Models regularized for selectivity with large α in the early epochs are harmful to performance. The regularization was performed on the early and intermediate modules (module 4, 5, 6) for first 4 epochs of training. Error shades indicate the 95% CI of the mean, calculated over the bottleneck layers of each module.

Figure 14: Regularizing for selectivity with smaller values of α. Even models regularized with smaller values of α do not help in increasing the performance. The regularization was performed on all modules for two cases: Over first four epochs and only over the first epoch. Error shades indicate the 95% CI of the mean, calculated over the bottleneck layers of each module.

Figure 16: Class representation over epochs. The figure shows the mean of top-5 class counts, where the class counts are counted as the classes predicted by the model over the validation set in each epoch. The model regularized from epoch 0 over-represents classes in the early epochs. Error shades indicate the 95% CI of the mean, calculated over the top-5 mean count of two different networks trained from different random initial conditions.

• In a causal experiment where we prevent class selectivity from rising sharply in early and intermediate layers during the first epochs of training, we show the network training accuracy suffers from the suppression of this phenomenon. This indicates that the rapid emergence of class-selective neurons in early and intermediate layers during the first phase of training plays an essential role in the successful training of the network. In conclusion, class-selective neurons in early and intermediate layers of deep networks seem to be a vestige of their emergence during the first few epochs of training, during which they play an essential role to the successful training of the network.

These dynamic, depth-dependent changes in selectivity during the first few epochs of training indicate that the role of selectivity could also vary during training, and in a layer-dependent manner.

As we already know that selectivity is important to learning in the latter module and as we are interested in understanding the emergence of class-selective neurons in early and intermediate modules, we only regularize the early and intermediate modules against selectivity. We used a regularizing strength of α = -20 and found that the regularizer prevents the peak of emergence of class-selective neurons in early epochs of training and that it impairs training and validation accuracy throughout training if regularized from epoch 0. However, in the second scenario where the regularizer was turned on from epoch 5 onward, the model performed almost as well as the original unregularized model (Fig 6). These results indicates that the emergence of class-selective neurons in intermediate modules in early epochs of training is an phenomenon that plays a crucial role in the correct training of the network. Other sets of experiments with different values of α can be found in Appendix A.5.1. It is worth noting that

A APPENDIX A.1 MODEL TRAINING

We trained 10 instances of ResNet-50 on ImageNet using standard training procedure using PyTorch (Paszke et al., 2019) . All instances were trained for 90 epochs with a batch size of 256, learning rate of 0.1, weight decay of 1e-4, and momentum of 0.9.The class selectivity indices were calculated over the validation set of 50k images for every epoch from epoch 0 to epoch 90. All plots were generated using Seaborn Library (Waskom, 2021) . Accuracies shown in all plots are top-1 ImageNet training/validation accuracy. Error shades in all plots represent the 95% confidence interval (CI) of the mean.

A.2 CLASS-SELECTIVE ABLATIONS ON A FULLY-TRAINED NETWORK ARE LESS DAMAGING THAN RANDOM ABLATIONS

We find that in the early and intermediate modules (modules 4, 5, 6) , an ablation of units ordered by their class-selectivity rank (i.e. most class-selective units ablated first) affects the network accuracy less than a control experiment where we ablate channels in a random order (Fig 8 ), consistent with the findings of Morcos et al. (2018) . This result is paradoxical, as it seems to indicate that although class-selective channels emerge during training in early modules, these channels do not support the decision of the fully-trained network. In the latter module (module 7), we find that class-selective channels do support the decision of the network more than random channels (Fig 8 ), consistent with our intuition (i.e. the layers near the end of the network do need to be class-selective for good performance) and previous findings by Morcos et al. (2018) .A • For α = -1.0, there isn't much difference between the unregularized model and the models regularized from epoch 0 and epoch 5 onward.• For α = -5.0, the model regularized from epoch 0 onward performed worse than the one regularized from epoch 5 onward.• For α = -10.0, -20.0, the models regularized from epoch 0 onward performed significantly worse than the ones regularized from epoch 5 onward.These results indicate that selectivity is important for learning in those early epochs and suppressing it by a significant amount impairs learning.

A.5.2 REGULARIZING FOR SELECTIVITY

If selectivity is important for learning in the early epochs, then does increasing selectivity further in early epochs improve learning? To test this, we conducted experiments in which we regularized for selectivity with different values of α. First, we tried regularizing the early and intermediate modules with large values of α (+10, +20) for the first four epochs, after which the regularizer is turned off. The latter module (module 7) was not regularized for selectivity. We found that in both cases, the model performed much worse as compared to the original unregularized model (Fig 13 ). The large α causes the selectivity index to go close to 1.0. Also, as module 7 was not regularized, its selectivity index ends up being lower than the other modules (Fig 13 ).Next, we re-did the experiment with smaller values of α (+1, +5) and we also decided to regularize module 7 for selectivity along with the early and intermediate modules. This time, we tried the experiment on two scenarios:• Turning the regularizer on for first four epochs for all modules with α = +1, +5.• Turning the regularizer on only for epoch 1 (as it is the most important (Section 4.5)) for all modules with α = +1, +5.In all cases, the increase in selectivity was harmful to performance (Fig 14 ). So overall, the results show that increasing selectivity does not improve performance. Therefore, both an increase and a decrease in selectivity is harmful to learning.

A.5.3 CAN THE MODEL EVENTUALLY RECOVER ITS PERFORMANCE FROM THE EFFECTS OF REGULARIZATION?

The regularization experiments conducted in Section 4.4 and Appendix A.5.1 are done till epoch 20.So, this naturally leads to the following questions:Does the model eventually recover from the effects of regularization if it is trained for more epochs? We trained the models regularized from epoch 0 onward and the one regularized from epoch 5 onward for 60 epochs while keeping the regularizer on with α = -20. The model regularized from epoch 0 onward is never able to recover the lost performance (Fig 15a ).Does the model eventually recover from the effects of regularization if the regularization is turned off after the early epochs? Next, we regularized a third model for only the first four 

