SURGICAL FINE-TUNING IMPROVES ADAPTATION TO DISTRIBUTION SHIFTS

Abstract

A common approach to transfer learning under distribution shift is to fine-tune the last few layers of a pre-trained model, preserving learned features while also adapting to the new task. This paper shows that in such settings, selectively fine-tuning a subset of layers (which we term surgical fine-tuning) matches or outperforms commonly used fine-tuning approaches. Moreover, the type of distribution shift influences which subset is more effective to tune: for example, for image corruptions, fine-tuning only the first few layers works best. We validate our findings systematically across seven real-world data tasks spanning three types of distribution shifts. Theoretically, we prove that for two-layer neural networks in an idealized setting, first-layer tuning can outperform fine-tuning all layers. Intuitively, fine-tuning more parameters on a small target dataset can cause information learned during pre-training to be forgotten, and the relevant information depends on the type of shift.

1. INTRODUCTION

While deep neural networks have achieved impressive results in many domains, they are often brittle to even small distribution shifts between the source and target domains (Recht et al., 2019; Hendrycks & Dietterich, 2019; Koh et al., 2021) . While many approaches to robustness attempt to directly generalize to the target distribution after training on source data (Peters et al., 2016; Arjovsky et al., 2019) , an alternative approach is to fine-tune on a small amount of labeled target datapoints. Collecting such small labeled datasets can improve downstream performance in a cost-effective manner while substantially outperforming domain generalization and unsupervised adaptation methods (Rosenfeld et al., 2022; Kirichenko et al., 2022) . We therefore focus on settings where we first train a model on a relatively large source dataset and then fine-tune the pre-trained model on a small target dataset, as a means of adapting to distribution shifts. The motivation behind existing fine-tuning methods is to fit the new data while also preserving the information obtained during the pre-training phase. Such information preservation is critical for successful transfer learning, especially in scenarios where the source and target distributions share a lot of information despite the distribution shift. To reduce overfitting during fine-tuning, existing works have proposed using a smaller learning rate compared to initial pretraining (Kornblith et al., 2019; Li et al., 2020) , freezing the early backbone layers and gradually unfreezing (Howard & Ruder, 2018; Mukherjee & Awadallah, 2019; Romero et al., 2020) , or using a different learning rate for each layer (Ro & Choi, 2021; Shen et al., 2021) . We present a result in which preserving information in a non-standard way results in better performance. Contrary to conventional wisdom that one should fine-tune the last few layers to re-use the learned features, we observe that fine-tuning only the early layers of the network results in better performance on image corruption datasets such as CIFAR-10-C (Hendrycks & Dietterich, 2019). More specifically, as an initial finding, when transferring a model pretrained on CIFAR-10 to CIFAR-10-C by fine-tuning on a small amount of labeled corrupted images, fine-tuning only the first block of layers and freezing the others outperforms full fine-tuning on all parameters by almost 3% on average on unseen corrupted images. To better understand this counterintuitive result, we study a general class of fine-tuning algorithms which we call surgical fine-tuning, defined as fine-tuning only a small contiguous subset of all layers in the pre-trained neural network. Equivalently, we could define surgical fine-tuning as freezing all but a few layers during fine-tuning. Parameter freezing can be beneficial because, depending on the relationship between the source and target tasks, some layer parameters trained on the source task may be close to a minima for the target distribution. Therefore, freezing these layers can facilitate generalization to the target distribution. We evaluate the performance of surgical fine-tuning with various layer choices on 7 different distribution shift scenarios, which we categorize into input-level, feature-level, and output-level shifts. As shown in Figure 1 , fine-tuning only the first block of layers, the middle block, or the last layer can perform best in different distribution shift conditions, with the best such subset consistently outperforming fine-tuning all parameters. To support our empirical results, we theoretically analyze why different types of distribution shifts require fine-tuning different layers. For two-layer neural networks, we show why fine-tuning the first layer is better for input perturbations but fine-tuning the last layer is better for label perturbations. We then present a setting where surgical fine-tuning on the first layer provably outperforms fine-tuning all parameters. If the target distribution contains only a few new "directions" (inputs outside the span of the source distribution), we show that tuning only the first layer can learn these new directions with very few target examples, while preserving all the information learned from the source distribution. However, we show that full fine-tuning forgets information learned from the source distribution-the last layer changes to accommodate the new target directions, but now performs poorly on examples outside the span of the training data. Motivated by the theoretical insight that freezing some layers can help generalization, we empirically analyze two criteria for automatically selecting layers to tune based on loss gradients. Tuning the layers selected by such criteria can also outperform full fine-tuning, though this procedure does not outperform manually choosing the best layers to tune. Our main contribution is the empirical observation that fine-tuning only a small contiguous subset of layers can outperform full fine-tuning on a range of distribution shifts. Intriguingly, the best layers to tune differ for different distribution shift types (Figure 1 ). This finding is validated empirically across seven real-world datasets and three types of distribution shifts, and theoretically in an idealized two-layer neural network setup. We additionally empirically analyze two criteria for automatically selecting which layers to tune and find that fine-tuning only the layers with higher relative gradient norm outperforms full fine-tuning.

2. SURGICAL FINE-TUNING: FREEZING PARAMETERS DURING ADAPTATION

Our problem setting assumes two datasets from different distributions: a large dataset following the source distribution P src , and a relatively smaller dataset following the target distribution P tgt . The objective is to achieve high accuracy on target data by leveraging the different but closely related source distribution, a common scenario in real-world applications that require adaptation. For example, the source dataset can be the 50, 000 training images in CIFAR-10 (Krizhevsky et al., 2009) while the target dataset is a smaller set of 1000 corrupted CIFAR datapoints with the same image corruption (Hendrycks & Dietterich, 2019); see Figure 1 for more examples of source-target



Figure 1: Surgical fine-tuning, where we tune only one block of parameters and freeze the remaining parameters, outperforms full fine-tuning on a range of distribution shifts. Moreover, we find that tuning different blocks performs best for different types of distribution shifts. Fine-tuning the first block works best for input-level shifts such as CIFAR-C (image corruption), later blocks work best for feature-level shifts such as Entity-30 (shift in entity subgroup), and tuning the last layer works best for output-level shifts such as CelebA (spurious correlation between gender and hair color).

