SAMPLE IMPORTANCE IN SGD TRAINING

Abstract

Deep learning requires increasingly bigger models and datasets to improve generalization on unseen data, where some training data samples may be more informative than others. We investigate this assumption in supervised image classification by biasing SGD (Stochastic Gradient Descent) to sample important samples more often during training of a classifier. In contrast to state-of-the-art, our approach does not require additional training iterations to estimate the sample importance, because it computes estimates once during training using the training prediction probabilities. In experiments, we see that our learning technique converges on par or faster in terms of training iterations and can achieve higher test accuracy compared to state-of-the-art, especially when datasets are not suitably balanced. Results suggest that sample importance has intrinsic balancing properties and that an importance weighted class distribution can converge faster than the usual balanced class distribution. Finally, in contrast to recent work, we find that sample importance is model dependent. Therefore, calculating sample importance during training, rather than in a pre-processing step, may be the only viable way to go.

1. INTRODUCTION

For many gradient-descent-based models increasing the model and training data sizes boosts the model performance and their ability to generalize. However, the increase in model and training data sizes comes with ever higher computational costs: longer training times and greater energy consumption are required to train a model on a given training data. One way to reduce these costs is to optimize the model training procedure. The most common training approach relies on random shuffling without replacement of the data samples during training for a given amount of epochs. As a consequence, all the samples are seen by the model the same number of times and are therefore implicitly treated as equally important. Recent works on hard example mining (Felzenszwalb et al., 2009; Loshchilov & Hutter, 2015; Simpson, 2015; Alain et al., 2015; Shrivastava et al., 2016; Chang et al., 2017; Katharopoulos & Fleuret, 2018; Arriaga & Valdenegro-Toro, 2020; Pruthi et al., 2020; Lin et al., 2017) and coresets selection (Mirzasoleiman et al., 2020; Killamsetty et al., 2020; 2021; Yoon et al., 2021; Balles et al., 2022) have shown that the training samples are not equally important to learn a given task (Katharopoulos & Fleuret, 2018) and that it is possible to speed up training by respectively focusing on hard samples or subsets of the training dataset that best approximate the full gradient. These methods estimate sample importance during training in an online fashion and leverage this information to speedup the learning process. However, they often require an additional computational overhead to compute the sample importance, which makes them less effective in practice. This is particularly true in the case of coresets selection methods, which are based on conservative estimates (Paul et al., 2021) . Another research line has shown that the training samples in a dataset can be ranked according to different importance scores (Feldman & Zhang, 2020; Feldman, 2020; Jiang et al., 2020; Toneva et al., 2018; Paul et al., 2021) and that less important samples can be pruned prior training the model with little to no loss in test accuracy. Since computing sample scores exactly is computationally unfeasible, these methods rely on approximations that usually require to fully train at least one model, compute and rank the sample scores, and then select the smallest subset of samples best approximating the test accuracy achieved by a model trained on the full training dataset. Our work is in between hard example mining, as we assume sample difficulty to be related to sample importance, and data pruning, as we compute sample importance only once after few training epochs, instead of adapting them during training as in hard example mining or coresets selection. Motivation There is increasing interest in understanding the importance that training samples have on model generalization to unseen data, in order to boost the model performance by pruning, downweighting or downsampling less important samples. However, we find that existing methods generally face two main challenges. On one hand, they impose some computational overhead due to the computation of the samples' importance. On the other hand, they often rely on many parameters to select which samples to use during training to boost the model performance. We address these challenges by proposing a simple, yet effective learning method that allows a model to estimate sample importance early in training and to switch its focus on more important samples. Our method is closely related to the data pruning method by Paul et al. ( 2021), which requires to train 10 model instances for 20 epochs to estimate the importance scores that are then used to train a new model instance from scratch. This method uses four parameters: two to compute the sample scores (the number of models and the training epochs) and two to slide a window over the scores to select the subset of samples that best generalizes on the test data (window size and position). In contrast, our method does not require additional training iterations to compute sample importance, as it is computed only once using the predicted probabilities of a single model after a few training epochs. Then training continues by sampling important samples more often. Furthermore, our method has only two parameters: epoch E at the end of which to compute the scores and a focusing parameter γ (Lin et al., 2017), which we fixed as constant for all the experiments.

Contribution

We propose an SGD-based method, which pretrains a model for few epochs under the assumption that all training samples are equally important for generalization, then it focuses on samples that are more important to learn the given task. In experiments we observe that • our method can identify sample importance with reduced computational cost and converges faster to a comparable or even better solution compared to state-of-the-art approaches; • sample importance evolves in a model-specific way during training; • sample importance can automatically balance the class distributions, and sampling by class importance can be more beneficial than balancing the class distributions by number of samples; • our method allows to use multiple augmentations more efficiently than the other baseline methods.

2. TRAINING OPTIMIZATION

Standard training of deep learning models makes the implicit assumption that each training sample is equally important in a uniform scan through all samples in every epoch. Each sample is used exactly the same number of times (i.e. the number of epochs). In contrast, we hypothesize that a sample can be more or less important for a given model to learn a given task, and that we can estimate sample importance during training to speed up learning by focusing on more important samples. Our core idea is to capture sample importance as sample difficulty; the harder a sample the more a model can learn from it to improve generalization. On the other hand, a sample that can be easily classified can be considered already learned and, therefore, is less informative.

Problem formulation

Let S = (x i , y i ) M i=1 be a training dataset with M samples, V = (x i , y i ) L i=1 a test dataset with L samples, where x . are images and y . one-hot encoded class labels. Let further f θ (.) be a model (e.g. a neural network) that can be trained using a gradient descent optimizer (e.g. stochastic gradient descent, SGD), θ its parameters, f θ (x i ) the output probabilities after softmax activation. Additionally, let N be the number of training epochs, B the training batch size, T e = ⌊M/B⌋ the number of batches (or iterations per epoch) and T = T e •N the overall training iterations (i.e. mini-batch gradient updates). Moreover, let P (i|S e , S) = 1/(|S| -|S e |)1 i̸ ∈Se (1) be the standard training sampler that uniformly scans through all the samples at every epoch without replacement (Chang et al., 2017) , where 1 is an indicator function and S e is the set of samples

