HOW TO FINE-TUNE VISION MODELS WITH SGD

Abstract

SGD (with momentum) and AdamW are the two most used optimizers for finetuning large neural networks in computer vision. When the two methods perform the same, SGD is preferable because it uses less memory (12 bytes/parameter) than AdamW (16 bytes/parameter). However, on a suite of downstream tasks, especially those with distribution shifts, we show that fine-tuning with AdamW performs substantially better than SGD on modern Vision Transformer and Con-vNeXt models. We find that large gaps in performance between SGD and AdamW occur when the fine-tuning gradients in the first "embedding" layer are much larger than in the rest of the model. Our analysis suggests an easy fix that works consistently across datasets and models: merely freezing the embedding layer (less than 1% of the parameters) leads to SGD performing competitively with AdamW while using less memory. Our insights result in state-of-the-art accuracies on five popular distribution shift benchmarks: WILDS-FMoW, WILDS-Camelyon, Living-17, Waterbirds, and DomainNet.

1. INTRODUCTION

Fine-tuning large pretrained models on downstream tasks has become a dominant approach in deep learning (Kornblith et al., 2019; Chen et al., 2020; Zhai et al., 2020) . The two most commonly used optimizers in current practice are SGD and AdamW (Kingma & Ba, 2015; Loshchilov & Hutter, 2019) foot_0 . While most modern vision architectures (ViTs, ConvNeXts, and variants) increasingly use AdamW for pretraining, it is still common to use SGD for fine-tuning. Part of the appeal is that SGD is more memory and compute efficient: AdamW maintains 4 states/parameter, while SGD only maintains 3 states/parameter (Ginsburg et al., 2019; Dettmers et al., 2022) . In training ultra-large models, the additional memory from even 1 extra state/parameter can be costly. At the same time, in terms of fine-tuning accuracies, prior work (Dosovitskiy et al., 2021; Steiner et al., 2021; Kumar et al., 2022) report similar performance between AdamW and SGD on ImageNet like domains that are closer to pretraining data. In contrast, we reach different conclusions when fine-tuning on datasets that are far from pretraining data or have substantial distribution shifts. We examine 7 popular models, including vision transformers (Dosovitskiy et al., 2021; Caron et al., 2021; Radford et al., 2021 ), ConvNeXts (Liu et al., 2022 ), and ResNets (Kolesnikov et al., 2020; He et al., 2016) , of different sizes and pretraining modalities. When pretrained on a large corpus and then fine-tuned, these models achieve near state-of-the-art performance on downstream benchmarks. In addition to good transfer learning, we also want our fine-tuned models to handle practical distribution shifts gracefully. So we focus on 5 distribution shift datasets that have both in-distribution (ID) and out-of-distribution (OOD) evaluations: WILDS-FMoW, WILDS-Camelyon, Waterbirds, BREEDS-Living-17, DomainNet. These were selected to capture different types of data shifts (subpopulation shifts, spurious correlations, style shifts), including two real world shifts in medical imaging and satellite remote sensing from the WILDS benchmark (Koh et al., 2021) . We find that on newer models like ViTs and ConvNeXt, AdamW can have significantly higher accuracies, especially OOD. For example, averaged across the datasets, fine-tuning a CLIP ViT-B/16 model with AdamW gets 2.1% higher accuracy ID and 8.1% higher accuracy OOD compared to SGD (Figure 1b ). These gains are consistent across models too-averaged across all models and datasets, AdamW gets 1.2% higher accuracy ID and 4.0% higher accuracy OOD (Tables 1 2 ). Interestingly, a minor tweak to SGD where we freeze the first "embedding" layer (< 1% of parameters-see Figure 1a ) is competitive with AdamW while using lower GPU memory. Further, dropping the momentum state in SGD gives additional gains in accuracy at even lower memory cost. Gradually unfreezing a model's layers using a carefully annealed learning rate can further improves accuracies, while reducing computation. A key difference between AdamW and SGD, is that AdamW normalizes the gradient update of each parameter using an estimate of their second moments. Thus, parameters with consistently high gradients will change less when using AdamW than with SGD. Towards understanding these dynamics better, we examine the gradients at each layer of our pretrained models. We find that for the models where AdamW significantly outperforms SGD, the gradients at pretrained initialization of the first "embedding" layer are much larger than the gradients of the other layers. To test if over-training of the embedding layer is in fact why SGD performs worse than AdamW, we consider a minor modification where we freeze the embedding layer and tune the rest of the model with SGD (Figure 1a )-we call this SGD (freeze-embed). In vision transformer models, the embedding layers are only a small fraction (around 0.7% for ViT-B/16) of the total parameters of the model, so a priori we might not expect a substantial difference in accuracies. However, surprisingly this simple freezing of the embedding layer consistently improves SGD performance across most models and datasets and achieved ID and OOD accuracies that are competitive with or better than AdamW (Figure 1b ). Averaged across all datasets and models, SGD (freeze-embed) gets 76.7% accuracy OOD (vs. 72.0% for SGD and 76.0% for AdamW). The analogous AdamW (freezeembed) gets 76.5%, which does not improve over SGD (freeze-embed), supporting that freezeembed may be the reason AdamW outperforms SGD (it is not an independent axis of improvement). Inspired by our results from SGD (freeze-embed), we explored further variations of our method on CLIP ViTs. First, we tried a more memory efficient variation, SGD (freeze-embed, no momentum), which drops the momentum state in SGD. At least on CLIP models, we see that this variation gets better accuracy than SGD (freeze-embed) while giving additional memory gains. Second, we revisit the layerwise unfreezing technique proposed in prior work like ULMFiT (Howard & Ruder, 2018) and related methods (Mukherjee & Awadallah, 2019; Romero et al., 2020) . We find that with a certain exponential decay in learning rate, a gradual unfreezing procedure in some cases can lead to further gains in ID and OOD accuracies (Figure 1b ). However, a large fraction of the gain over SGD appears to come from simply freezing the embedding layer, which suggests that the embedding layer plays a key role in modern vision architectures and merits further investigation. In terms of resource comparison, our profiling on a Titan-X GPU shows that on a ViT-B/16, AdamW uses 16% and 36% more memory than SGD (freeze-embed) and SGD (freeze-embed, no momentum), respectively. The memory overhead of AdamW increases with the size of the models. On a ViT-L/14 the memory overheads of AdamW are 18%, and 49%, respectively. Gradual-unfreezing provides a gain in computational time of about 30% over AdamW. These methods and insights, while simple, lead to state-of-the-art accuracies on all five datasets: WILDS-Camelyon, WILDS-FMoW, DomainNet, Waterbirds, and BREEDS Living-17, while being more compute and memory efficient than AdamW.



We use SGD to refer to its usage in deep learning as minibatch stochastic gradient descent with momentum.



(a) Simplified schematic of ViT illustrating how we do freeze-embed. Performance of different fine-tuning methods on a CLIP ViT-B/16 averaged over 5 distribution shift datasets.

Figure1: We fine-tune 7 models including ViTs, DINO, CLIP, ConvNeXt, ResNet, on 5 distribution shift datasets (Living-17, Waterbirds, DomainNet, WILDS-Camelyon, WILDS-FMoW). Fine-tuning with SGD gets lower accuracies than AdamW on modern architectures (transformers and ConvNeXt), especially OOD. Interestingly, a minor tweak to SGD where we freeze the first "embedding" layer (< 1% of parameters-see Figure1a) is competitive with AdamW while using lower GPU memory. Further, dropping the momentum state in SGD gives additional gains in accuracy at even lower memory cost. Gradually unfreezing a model's layers using a carefully annealed learning rate can further improves accuracies, while reducing computation.

