A CLOSER LOOK AT MODEL ADAPTATION USING FEA-TURE DISTORTION AND SIMPLICITY BIAS

Abstract

Advances in the expressivity of pretrained models have increased interest in the design of adaptation protocols which enable safe and effective transfer learning. Going beyond conventional linear probing (LP) and fine tuning (FT) strategies, protocols that can effectively control feature distortion, i.e., the failure to update features orthogonal to the in-distribution, have been found to achieve improved outof-distribution generalization (OOD). In order to limit this distortion, the LP+FT protocol, which first learns a linear probe and then uses this initialization for subsequent FT, was proposed. However, in this paper, we find when adaptation protocols (LP, FT, LP+FT) are also evaluated on a variety of safety objectives (e.g., calibration, robustness, etc.), a complementary perspective to feature distortion is helpful to explain protocol behavior. To this end, we study the susceptibility of protocols to simplicity bias (SB), i.e. the well-known propensity of deep neural networks to rely upon simple features, as SB has recently been shown to underlie several problems in robust generalization. Using a synthetic dataset, we demonstrate the susceptibility of existing protocols to SB. Given the strong effectiveness of LP+FT, we then propose modified linear probes that help mitigate SB, and lead to better initializations for subsequent FT. We verify the effectiveness of the proposed LP+FT variants for decreasing SB in a controlled setting, and their ability to improve OOD generalization and safety on three adaptation datasets. However, standard adaptation protocols that rely upon finetuning (FT) all model parameters or training only a linear probe (LP) while freezing the network parameters do not maximize the potential of high-quality representations. For example, while high-quality, pre-trained models have sufficiently expressive features to perform well on

1. INTRODUCTION

Figure 1 : Strong and Safe Adaptation. Practical deployment in high risk applications requires that adapted models not only generalize well to in-and out-of distribution data of the downstream task, but they do so safely. Through the use of larger datasets (Yalniz et al., 2019) , better architectures (Zhai et al., 2022; Chen et al., 2022; Steiner et al., 2022; Tolstikhin et al., 2021) , and different self-supervised learning (SSL) approaches (He et al., 2020; Chen et al., 2020; Grill et al., 2020; Caron et al., 2020) , the quality of pretrained representations available for transfer learning tasks has dramatically improved. Indeed, representations from such high-quality SSL models have been found to be more robust (Hendrycks et al., 2019; Liu et al., 2021) , transferable (Ericsson et al., 2021) and semantically consistent (Caron et al., 2021) than their supervised counterparts. In this regard, there is growing need for adaptation protocols that explicitly capitalize on these improved pretrained features to induce similar beneficial properties, e.g., inducing more than just high accuracy on the target task, after models have been trained on the downstream task. both in-distribution (ID) and out-of-distribution (OOD) data, LP and FT are not able to effectively induce this property in adapted models (Andreassen et al., 2021) . Recently, however, Kumar et al. (2022) proved that by modifying features only in the ID representation subspace, FT can lead to higher OOD error as it distorts directions outside the ID subspace that are needed for OOD generalization. As both ID and OOD subspaces are represented by the pretrained model, Kumar et al. demonstrate that limiting feature distortion, or controlling updates towards the ID subspace, can lead to improved ID and OOD performance. To this end, they propose a new protocol which performs LP prior to FT (abbrev. LP + FT). By first performing LP, this two-step process ensures that subsequent FT will remain in the vicinity of the original LP solution. This reduces the overall distortion towards the ID distribution subspace and improves performance. While strong ID and OOD generalization on the target task is indeed an important aspect of transfer learning, practical, high-risk applications require that models are also safe (Hendrycks et al., 2021) . For example, adapted models should also be well-calibrated, robust to corruptions or adversaries and able to reliably detect anomalous samples (see Figure 1 ). Given that existing adaptation protocols are primarily focused on improving generalization, it is unclear how existing protocols utilize high-quality pretrained features to promote safe adaptation, and if current protocol design perspectives, such as mitigating feature distortion, will also enable safe generalization. Our Work: In this work, we seek to understand the factors relevant to the design of adaption protocols that promote effective and safe generalization. We take the first step towards this aim by (i) demonstrating limitations in existing LP, FT, and LP+FT protocols through an extensive, joint evaluation, and (ii) studying adaptation protocols through the complementary lens of avoiding simplicity bias, i.e., the problematic tendency of deep neural networks (DNNs) to prefer simple, potentially brittle features over complex features (Soudry et al., 2018; Gunasekar et al., 2018; Geirhos et al., 2019; Hermann et al., 2020; Shah et al., 2020) . Using the insights from our analysis, we propose three variants of the LP+FT protocol that jointly improve safety and generalization on three datasets. Our contributions can be summarized as follows: • Joint Analysis of Adaptation Protocol Safety and Generalization (Sec. 3). We show that when adaptation protocols are evaluated with respect to both ID/OOD generalization and safety, LP+FT trails behind LP or FT on several safety metrics. This demonstrates that solely mitigating feature distortion may not be sufficient for safe generalization. We also observe that keeping subsequent FT close to LP solution is crucial for the improved OOD generalization of LP+FT. This motivates us to focus on improving the LP initialization as a mechanism for improving both safety and OOD performance. • Role of Simplicity Bias in (Unsafe) Adaptation (Sec. 4). To understand how protocols may induce safe adaptation, we study how different protocols avoid simplicity bias. While simplicity bias (Shah et al., 2020; Geirhos et al., 2019) has been shown to underlie several problems in machine learning safety, to the best of our knowledge, we are the first to consider its role in adaptation settings. We demonstrate that protocols must not only reduce distortion, but also should mitigate simplicity bias for effective adaptation. • Improved Protocols for Mitigating Simplicity Bias and Distortion (Sec. 5). We propose three, simple modified LP+FT protocols that help mitigate both simplicity bias and distortion (Sec. 4.1). In particular, we consider modifying the LP step with uncertainty-driven perturbations (Pagliardini et al., 2022) , virtual adversarial training (Miyato et al., 2017) and model-soups (Wortsman et al., 2022) , as they are simple and effective strategies. Across synthetic and real datasets, the modified protocols help improve safety and generalization to some extent.

2. RELATED WORK AND BACKGROUND

Here, we discuss the most relevant work on adaptation protocols and simplicity bias; we discuss additional related work in Sup. A. Adaptation Protocols. For a comprehensive overview of transfer learning, please see the surveys of Zhuang et al. (2021) and Pan & Yang (2010) . Here, we discuss the works that are most relevant to our own. Kirichenko et al. (2022) recently demonstrated that models are able to learn both core features and spurious features. However, classifiers can rely upon spurious features, harming performance on minority groups. To reduce the reliance on spurious features, they propose to retrain the classifier on a small amount of "re-weighting" data, which allows the model to leverage the core features instead of the spurious features. Other modifications and heuristics have also been proposed to improve FT's performance, including side-tuning (Zhang et al., 2019) , which tunes a small secondary network that is then combined with the original model, using larger/smaller learning rates for the classifier, as well as regularization-based methods (Jiang et al., 2020) . In this work, we focus on two popular and effective protocols, LP and FT. We additionally study the LP+FT protocol as it is theoretically-grounded, does not require re-weighting data, is designed to exploit high-quality pre-trained representations and achieves SOTA OOD performance during adaptation. Simplicity Bias. It is well-known that DNNs demonstrate a bias toward simple, potentially less expressive features (Brutzkus et al., 2017; Soudry et al., 2018; Gunasekar et al., 2018; Geirhos et al., 2019; Hermann et al., 2020; Lubana et al., 2023) , such as textures and backgrounds, and that this bias can lead to shortcuts that limit the generalization of DNNs. Indeed, recently Shah et al. (2020) formalized this intuition by more precisely defining simplicity bias, based on the number of linear components to define a decision boundary, and showed that SB leads to non-robust decision boundaries that affects a model's sensitivity to distribution shifts and adversarial perturbations. In brief, by learning simple features first, models become invariant to complex features, potentially leading to narrow decision boundaries which can fail to generalize under data shifts. Notably, DNNs exhibit this bias even when complex features are more expressive and necessary for fitting the distribution. While various techniques have recently been proposed to mitigate simplicity bias when training from scratch or in the context of pretraining (Teney et al., 2021) , we are, to the best of our knowledge, the first to rigorously study the role of simplicity in the context of model adaptation.

3. JOINT ANALYSIS OF PROTOCOL SAFETY AND GENERALIZATION

In this section, we evaluate the performance of adaptation protocols across several additional safety objectives (Hendrycks et al., 2021) , as practical transfer learning applications require both strong and safe generalization. Through this expanded evaluation, we find that no single protocol is optimal across all safety objectives. Indeed, the inability of LP+FT to induce safe adaptation indicates that a complementary perspective to feature distortion, namely simplicity bias, is necessary when designing generalizable and safe protocols (see Sec. 4). We further argue that by constraining models around the LP initialization during FT, LP+FT may inadvertently harm safety performance by hampering models' abilities to learn complex, task-specific features needed for robust generalization. While we expand upon the role of LP initialization in Secs. 4 and 5, we begin, here, by introducing the expanded evaluation and experimental setup. Experimental Setup. Three downstream adaptation tasks (and their respective OOD distributions) are considered: CIFAR-10 (ID) → {STL10, CIFAR10.1} (OOD), Domainnet-Sketch → {Domainnet-ClipArt, Domainnet-Painting, Domainnet-Real} and Living17 (Source) → Living17 (Target). These datasets are selected as they correspond to two different types of distribution shifts (standard domain adaptation and subpopulation) and three levels of distortion (low, medium, high). A MoCo-V2 ResNet-50 (He et al., 2020) pretrained on ImageNet-1K is used as the base-feature extractor for CIFAR10 and Living17 experiments, and the CLIP ResNet-50 image encoder pretrained on 400 million (image,text) pairs is used for Domainnet-Sketch. These models are selected as they provide sufficiently high-quality representations capable of generalizing to both ID and OOD downstream data (Kumar et al., 2022) . We perform grid-search to find the best hyper-parameters, and average over 3 seeds. See Sup. B.2 for additional details. Expanded Evaluation. In addition to OOD accuracy on the aforementioned distribution shifts, we report performance on the following metrics in order to evaluate adapted models on key problems in machine learning safety (Hendrycks et al., 2021) . Our evaluation setup is inspired by Hendrycks et al. (2022) : • Mean Corruption Accuracy (mCA/m CA): We consider two sets of corruptions: the 15 naturalistic corruptions (Corr) (Hendrycks & Dietterich, 2019) , and 10 perceptually dissimilar corruptions (Corr) (Mintun et al., 2021) . Corruptions are applied to the ID test dataset and the average accuracy over each set is reported. • Calibration Error (RMSE): It is important that models are well-calibrated so that practitioners may trust the provided predictions in high-risk applications Guo et al. (2017) . We measure the root mean square error of calibration as follows: E C (P(Y = Ŷ | C = c) -c) 2 , where C indicates the confidence scores, while Ŷ and Y denote the model's predictions and ground-truth labels, respectively. • Anomaly Detection Performance (AUROC): Recognizing when samples are anomalous allows models to abstain from making uninformed and inapplicable predictions. We consider samples from Blobs, Gaussian, LSUN, Places69, Rademacher, Textures, and SVHN datasets as anomalies and report the AUROC (area under the ROC curve) of the binary classification problem of detecting such samples as anomalies. • Adversarial Accuracy: DNNs are well-known to be fooled by imperceptible distortions (Ilyas et al., 2019) . We use a 2/225, 10-step PGD (Madry et al., 2018) attack to measure the robustness of models to such perturbations. We make the following observations regarding the behavior of different protocols using this expanded evaluation. In brief, we find that no single protocol is effective across all datasets in jointly obtaining strong and safe adaptation, and that, on low distortion adaptation tasks, the quality of the LP initialization is critical as pre-trained feature extractor is not substantially updated during LP+FT. OBSERVATION Here, we ask how protocols perform when we consider both safety and generalization objectives to better understand the feature distortion perspective. In particular, if LP+FT is able to outperform LP and FT in this expanded evaluation, then it suggests that solely mitigating feature distortion during FT may be sufficient to induce robust adaptation. To test this claim, we rank protocol performance for each safety metric, where ranks are first computed for each dataset separately, and then averaged. Results are shown in Fig. 2 . Smaller ranks correspond to better performance. Results. LP+FT obtains the best rank for ID and OOD accuracy as expected, as well as Corr and Corr accuracy. However, we also see that FT is better ranked for Adversarial Accuracy and OOD calibration, while LP is better ranked for ID calibration and Corr calibration. However, given that LP+FT trails behind protocols that are not explicitly designed to limit distortion on some safety metrics, it is clear that a complementary perspective is needed to better understand protocol behavior. Indeed, LP+FT has the best average rank, indicating that it is a good starting point to improve upon. The above results are aggregated across different types of distribution shifts; we extend this analysis next by considering the interplay between individual datasets and protocol performance. These detailed results are presented in Table 1 . OBSERVATION 2: LINEAR PROBING SOLUTIONS MATTER. Naturally, the amount of distortion required to effectively adapt a pretrained model to a downstream task will vary in accordance to the similarity of the downstream and pretraining data. Here, we seek to understand how protocols behave under different levels of distortion. In particular, we hypothesize that the LP initialization becomes more influential for LP+FT in low distortion settings, as subsequent FT remains in the vicinity of initialization. To this end, we compute the batched centered kernel alignment (CKA) score (Nguyen et al., 2021) with respect to the adapted and pretrained models, and take a closer look at performance across metrics. We note that while CKA is better suited for measuring distortion than the L2 norm as used by Kumar et al. (2022) , other neural representation metrics can also be used (Ding et al., 2021; Davari et al., 2023) . Results. As shown in Fig. 3 , we see that minimal distortion (CKA ≥ 0.9) is required to obtain competitive LP+FT performance on DomainNet and Living17. However, on CIFAR10, which requires the most distortion as evidenced by lower CKA scores, FT is the most effective protocol for safety measures and is very comparable on generalization performance (see Table 1 ). The effectiveness of LP and LP+FT on Living17 in improving OOD generalization over FT is hardly surprising, as Living17 is a subset of ImageNet, on which the base feature-encoder was already trained. We plot the CKA similarity between adapted and pretrained models. DomainNet and Liv-ing17 require low distortion, as seen by performance of LP+FT across metrics with high CKA (> 0.9). In contrast, on DomainNet, the difficulty of FT in matching the ID test task performance, despite achieving high training accuracy, suggests FT may learn a solution that relies upon shortcuts (or simple features) that do not generalize. We emphasize that LP+FT greatly benefits from strong LP initializations on these lowdistortion datasets as corresponding CKA scores show that very limited updates are made during FT. While LP+FT does induce meaningful improvements over LP on Living17 and performs comparably to LP on DomainNet, we stress the model must be kept close to the LP initialization during FT. Indeed, to obtain acceptable LP+FT performance, small learning rates (3e-7,1e-5) and frozen batch-norm parameters during FT are necessary. Summary. Taken jointly, these results suggest that while solely mitigating feature distortion may not be sufficient to ensure that adapted models perform well on safety metrics across different levels of shift, improving the LP initialization may be a viable solution to obtaining strong and safe generalization. Indeed, the effectiveness of LP+FT on low distortion datasets and its high average ranking indicates that it is a promising protocol to build upon. To understand how to build better protocols, we next introduce simplicity bias as a complementary perspective to feature distortion. we see that different protocols may perform better when safety evaluation is also considered. For CIFAR-10, which requires the most distortion as evidenced by lower CKA scores, we see that FT is the most effective; LP+FT and LPare most effective, respectively, on Living17 and DomainNet, which require significantly less distortion. This suggests that, while mitigating feature distortion is effective for improving generalization, it is not always sufficient for also improving safety.

4. MITIGATING SIMPLICITY BIAS & FEATURE DISTORTION FOR SAFE ADAPTATION

As discussed in Sec. 2, simplicity bias (SB) underlies various safety issues in machine learning as models may learn to rely upon simple features that often do not generalize under distribution shifts, such as corruptions or adversaries (Shah et al., 2020) . Therefore, we argue that mitigating feature distortion in a way that minimizes this bias can be a valid mechanism to improve both safety and generalization performance. Correspondingly, in this section, we first measure the propensity of different protocols to simplicity bias in a controlled setting. In particular, given our previous observation that LP+FT models remain in close vicinity of the LP solution after FT, we focus on improving the performance of this initial LP initialization so that we may capitalize upon LP+FT strong OOD performance, while simultaneously improving safety. To this end, we propose three light-weight LP+FT variants that are able to both reduce distortion and SB. We begin by introducing our synthetic dataset and experimental setup. Dataset. As shown in Fig. 4 , we create "dominoes" of complex and simple features by pairing each class (Shah et al., 2020) from CIFAR10 (complex) with the corresponding "digit" class in MNIST (simple), e.g., "bird" samples are paired with digit "2" samples, where the label for each domino is determined by the complex, CIFAR10 sample. Datasets with three levels of correlation (95%, 99%, 100%) between the simple and complex features are constructed for training. While 100% correlation allows models to only learn the simple feature for perfect generalization, the more realistic lower correlation settings require models learn at least some aspect of the complex features. Experimental Setup. For evaluation, we also construct a randomized (10% correlation) variant, where simple features are randomly paired with complex features. We give two examples in panels 3 and 4 of Fig. 4 . To assess OOD generalization, we create a variant where complex features are sampled from STL10, instead of CIFAR10, e.g., panels 1 and 2 in Fig. 4 . Metrics. We assess the reliance on simple features using the following metrics: (1) Randomized Accuracy: the accuracy on the variant where samples contain random pairings between simple and complex features; (2) Correlated Accuracy: accuracy when pairings between simple and complex features remain correlated. Models that are susceptible to simplicity bias will have high Correlated Accuracy and low Randomized Accuracy. Likewise, models that are not susceptible to simplicity bias will have relatively lower correlated accuracy and higher randomized accuracy. Training Details. A MoCo-V2 ResNet-50 (He et al., 2020) pretrained on ImageNet-1K is the base-feature extractor. See Supp. B.2 for additional details. We performed grid-search to find the best parameters. Results are over 3 seeds and 3 correlation strengths. In particular, we observe that LP+FT and LP are effective protocols for reducing reliance upon simple features on both synthetic datasets, and on low-distortion real datasets (Sec. 3). However, as some level of distortion is typically required when adapting to downstream tasks to obtain sufficient ID task performance, we propose new variants of the LP+FT protocol that attempt to enable the subsequent FT step to distort features without compromising generalization or increasing simplicity bias. We note that, while it is possible to modify the FT step as well, modifications to LP are inexpensive as the feature-encoder is not updated. Moreover, as discussed in Sec. 3, fine-tuned solutions remain in close vicinity of initial LP initializations, further motivating strong starting solutions. To this end, we introduce the following modifications to the LP step of LP+FT below, where h are the hidden representations, θ model parameters, y labels, ŷ predicted classes, C the number of classes, g the classifier, and δ the perturbation. See Supp. B.1 for additional discussion on the choice of these mitigation strategies and Supp. B.2 for discussion on the importance of applying mitigations during LP. • LP(VAT): Virtual adversarial training (VAT) (Miyato et al., 2017) enforces local distribution smoothness by minimizing the KL-divergence between the predictions of perturbed pairs of examples. Since we are using expressive pretrained models, such perturbations may be meaningful in the inverted latent space as well, and resulting classifiers become robust in some ϵ-neighborhood around each latent-space input. Formally, let ϵ be some perturbation budget, and α a hyper-parameter weighting distributional label smoothness, we minimize the following loss: min θ L CE (g θ (h), y) -αKL [p (y | g θ (h)) , p (y | g θ (h + δ))] where δ := arg max ∥δ∥2≤ϵ KL [p (y | g θ (h)) , p (y | g θ (h + δ))] . • LP(UDP): Instead of maximizing the loss, uncertainty-driven perturbations (UDP) (Pagliardini et al., 2022) adversarially maximize a model's estimated uncertainty. UDPs have been shown to be effective in decreasing simplicity bias and improving generalization in non-adaptation settings. Formally, they can be defined as: δ u = arg max ∥δ∥2≤ϵ H(g θ (h) + δ), where H(g θ (h)) = -c∈C ŷc log ŷc , (e.g., entropy of predictions). • LP(Soup): Inspired by Wortsman et al. (2022) , we train multiple, sparse, linear probes jointly and then take the average of their weights (aka soup) as the learned LP for subsequent FT. While soups of large models improve generalization by combining models from the same low-error basin, we consider sparse classifiers soups as an alternative strategy which seeks to average diverse decision rules, to avoid relying upon a single set of simple features. Formally, given k classifiers, we minimize min θ 1...k 1 k k i=1 L CE (g θi (h), y) and let gθ = 1 k k i θ i for the FT step. Empirical Evaluation of Hardness Promoting Augmentations. We evaluate the effectiveness of the above LP variants, which we collectively refer to as "hardness-promoting", in mitigating the simplicity of bias of LP+FT. We make the following observations (see Fig. 5 ). Across all correlation strengths, we find that using the modified hardness-promoting LPs during LP+FT (aka hp-LP+FT) improves the Rand. OOD Accuracy over vanilla LP+FT(≥ 2%) and FT(> 20%). This clearly indicates that hp-LP+FT is indeed effective in decreasing reliance on simple features, potentially also leading to improved safety. Furthermore, with the exception of LP(Soup)+FT, hp-LP+FT also performs better than vanilla LP+FT on Corr. OOD accuracy. Vanilla FT does outperform all LP+FT protocols in this setting, but this is due to reliance upon simple features. Lastly, we observe that with respect to Corr. ID Accuracy that hp-LP+FT improves performance at low correlation strength, but slightly loses performance at higher correlation strengths. This is not entirely unexpected as FT's reliance upon simple features will be useful in the correlated setting. Given that hp-LP+FT is able to reduce reliance upon simple features in this controlled setting, we next evaluate whether these modified protocols are beneficial in improving the performance of LP+FT on real datasets.

5. EVALUATING GENERALIZATION AND SAFETY OF THE LP+FT FAMILY

Given the effectiveness of incorporating hardness promoting (hp) augmentations with the family of LP+FT protocols (hp-LP+FT) in mitigating simplicity bias in a synthetic setting, we further evaluate the modified protocols on the three real-world datasets (Living17, DomainNet, and CIFAR10) with respect to the generalization and safety metrics introduced in Sec. 3. We present our results in Tables 3, 4 , and 5); our observations are summarized below. Any method-specific hyperparameters (e.g., epsilon) are tuned using ID validation data and all results are reported over three seeds. We provide additional results in Supp. C.

Results.

As discussed in Sec. 3, these three datasets represent scenarios where different levels of distortion are necessary when adapting the pretrained model. On Living17, a setting which requires minimal distortion during adaptation, we see that vanilla LP+FT is quite effective with respect to both generalization and safety metrics and is a difficult baseline to surpass. Indeed, while hp-LP+FT variants do not lead to significant benefits, they generally perform comparably to vanilla LP+FT. On DomainNet, a setting where fairly low distortion is required for LP+FT but FT struggles to find a good solution, we see that hp-LP+FT variants induce some slight benefits with respect to ID/OOD generalization and robustness, though vanilla LP and hp-LP have better calibration performance. In contrast on CIFAR10, which requires more distortion to obtain an acceptable solution, we see that hp-LP+FT variants lead to improved generalization and a noticeable boost in corruption robustness. LP(VAT)+ FT and LP(VAT) are particularly effective in this regard. Lastly, across all datasets, we observe that hp-LP+FT protocols lead to similar distortion to vanilla LP+FT, which suggests that any additional benefits of hp-LP+FT should not be attributed to only reducing feature distortion. Discussion. We find that while vanilla LP+FT is already an effective protocol, especially in settings where low distortion is required, hp-LP+FT can provide some benefits and performs competitively. We suspect that the performance of these modified protocols can further be improved if more sophisticated simplicity bias mitigation strategies are used. Indeed, our central claim, that adaptation protocols should mitigate feature distortion and simplicity, is not dependent on a specific strategy. We additionally note that while such mitigation strategies may optionally also be used during FT, they cannot solely be used in FT. Indeed, in the case of extreme simplicity, if the LP classifier relies upon simple features to find a low-loss solution, during the subsequent FT step, gradients may not be back propagated in directions that contain complex features. This entails that the decision boundary continues to rely upon simple features and is at risk of reduced safety performance. We provide further discussion in Supp.B.2. To this end, we recommend incorporating hardness-promoting augmentations during LP as a potential safe-guard to simplicity bias.

6. CONCLUSION

In this paper, we took a closer look at the behavior of protocols designed for adapting large-scale pretrained models to downstream datasets. While it is argued that adaptation protocols should be designed to mitigate feature distortion (e.g., LP+FT) in order to improve ID and OOD generalization, we found that when additional aspects of safe generalization are evaluated (e.g., prediction calibration error, adversarial robustness etc.), mitigating feature distortion alone is not sufficient. We then considered the complementary perspective, that adaptation protocols should also mitigate simplicity bias. Using a synthetic dominoes dataset that allows for control over the correlation between simple and complex features, we found that protocols have varying levels of effectiveness in reducing reliance upon simple features. While, as expected, FT, is most susceptible to simplicity bias, we see that LP+FT is able to balance both distortion and simplicity bias in settings where the correlation between simple and complex features is not too extreme. Motivated by the benefits of LP+FT and given the known relationship between simplicity bias and sub-optimal generalization, we used "hardnesspromoting" LP initializations (virtual adversarial, uncertainty-driven perturbations, sparse soups) to further improve LP+FT's performance. These modifications helped reduce LP+FT's reliance upon simple features on the synthetic dataset. On three real-world datasets, these modified protocols led to some improvements in safety and generalization performance, further validating the need to consider both distortion and simplicity bias when designing adaptation protocols.

APPENDIX A ADDITIONAL RELATED WORK

For a comprehensive overview of transfer learning, please see the surveys of Zhuang et al. and Pan & Yang. Here, we discuss a few directly works directly relevant to our own. Recently, Kumar et al. demonstrated that learning probing prior to fine-tuning (e.g., LP+FT) can improve both in-distribution and out-of-distribution performance when transferring to a downstream task given a highly expressive, pretrained model. They demonstrated that FT only modifies features in the ID representation subspace and not in other directions, which can lead higher OOD error as direction outside the ID subspace are necessary for OOD generalization. However, by initializing FTwith a trained linear probe, feature distortion can be decreased since this initialization is closer to optimal model, and thus requires less distortion in ID subspace, preserving the expressiveness of the original model. Concurrently, Kirichenko et al. demonstrated that models are able to learn both core features and spurious features. However, classifiers can rely upon spurious features, harming performance on minority groups. To reduce the reliance on spurious features, they propose to retrain the classifier on a small amount of "re-weighting" data, that allows the model to leverage the core features instead of the spurious features. Other modifications and heuristics have also been proposed to improve fine-tuning, including sidetuning (Zhang et al., 2019) , which tunes a small secondary network that is then combined with the original model, using larger/smaller learning rates for the classifier, as well as regularization-based methods (Jiang et al., 2020) . We focus on the LP+FT protocol, as it is principled and achieves strong OOD performance. Additionally, several works have studied properties of the model that influence the effectiveness of transfer learning (Azizpour et al., 2016; Huh et al., 2016; Kornblith et al., 2019; Lee et al., 2023a; Evci et al., 2022; Lee et al., 2023b; Izmailov et al., 2022; Lubana et al., 2023; Rame et al., 2022) , including the robustness of pretrained features (Salman et al., 2020; Utrera et al., 2021) . While the connection between adversarial training and improved feature representations (Allen-Zhu & Li, 2021; Kaur et al., 2019) has been studied, we use virtual adversarial training during LP to learn a better classifier that is less reliant upon simple features, and we do not use an adversarially trained feature extractor. Finally, we note that while we are, to the best of our knowledge, the first to consider this holistic evaluation of safety and generalization in the context of transfer learning with highly expressive pretrained models, Hendrycks et al. have considered the trade-offs induced by different data augmentation strategies (Yun et al., 2019; Devries & Taylor, 2017; Hendrycks et al., 2020; Cubuk et al., 2019; 2020) on safety metrics in supervised learning. We emphasize that while our evaluation is similar, that our work focuses on a different context and contains an additional layer of complexity as we consider the interaction between adaptation protocols, generalization behavior and safety performance.

B EXPERIMENTAL DETAILS

Please see the https://github.com/pujacomputes/23-ICLR-Adaptation.git for training details. In brief, we performed grid-search to find the best parameters, which are as follows. For CIFAR-10 and CIFAR-100, we train only the classifier for 200 epochs with LR=30 during LP. For FT, the entire model is trained for 20 epochs with LR=1e-5. For LP+FT, the model's classifier is initialized with the solution found by LP, and then it is fine-tuned for 20 epochs. A grid-search was conducted to determine the LR for LP and FT. For Domain-Net Experiments, we use 200 epochs with LR=30 during LP. For FT, the entire model is trained for 20 epochs with LR=3e-4. For LP+FT, the model's classifier is initialized with the solution found by LP, and then it is fine-tuned for 20 epochs, using LR=3e-7. Furthermore, following Kumar et al., we freeze the batchnorm layers during LP+FT. A CLIP (Radford et al., 2021) pretrained ResNet-50 is used for the DomainNet experiments, while a MoCoV2 (He et al., 2020) is used for all CIFAR experiments. We use augmentation functions from timm (Wightman, 2019) and compute CKA scores using the packaged provided by torch-cka. When using augmented protocols, the same LRs are used. Note, all results were obtained by averaging over 3 seeds. We consider model soups of sizes 5,10,20, tune ϵ in 0.005, 0.01, 0.02 and 0.1 for UDP, and α in 0.001, 0.01, 0.1 for VAT. For CIFAR-MNIST results, LP is done for 100 epochs, and FT is done for 20 epochs.

B.1 MOTIVATION FOR HARDNESS-PROMOTING VARIANTS

We selected UDP (Pagliardini et al., 2022) , VAT (Miyato et al., 2017) , and model-soups (Wortsman et al., 2022) as simplicity bias mitigation strategies due to their effectiveness and ease of use. We emphasize, however, that our findings are not specific to the choice of a given mitigation strategy and we expect that advancements in such strategies will further improve the effectiveness of our proposed LP+FTvariants. At present, the selected strategies are strong, representative mitigations that we have confirmed are effective at mitigating simplicity bias in the adaptation context using the synthetic dominoes dataset in Sec. 4. We conceptually justify each strategy here: • UDP is designed to help mitigate simplicity bias by learning by a large margin classifier, opposed to a narrow margin classifier that relies upon simple features. As noted by Shah et al. (2020) , such narrow margin classifiers are sensitive to small perturbations and the simple features supporting the decision boundary may not be discriminative under distribution shifts. By maximizing uncertainty (instead of loss) to create adversarial perturbations, UDP is able to learn a maximum-margin classifier that is better able to handle such shifts. Notably, to create such a maximum-margin classifier, the model will necessarily learn more complex features; • We use virtual adversarial training (VAT) to help avoid reliance upon simple features, as VAT enforces distribution smoothness so that classifiers become robust in some epsilon neighborhood around the input. We note that we are performing this training in the hidden representation space, so perturbations correspond may be altering high-level semantics. To maintain strong performance under such high-level perturbations, the model should learn to rely upon more complex features, and learn a better margin classifier; To demonstrate that simplicity bias mitigation strategies must be applied during the LP step of FT for maximum effectiveness, we conduct the following additional experiment. Setup. We evaluate two additional protocols where VAT and UDP are applied only during the FT step, (LP+FT(VAT), and LP+FT(UDP)), on the synthetic dominoes dataset. We plot the results for Randomized OOD Accuracy in Fig. 6 . Results. Here, we see that, across three different correlation ratios, FT variants lose performance with respect to the LP mitigation variants. Notably, LP+ FT (UDP) loses up to 4% performance with respect to LP(UDP)+ FT. While performance drops are not as large for VAT, we nonetheless see that LP+ FT(VAT) loses performance with respect to LP(VAT)+ FT. Our results in Fig. 6 support our conceptual argument that mitigation strategies must be undertaken during the LP step to ensure that subsequent FT is in a direction that preserves complex features; applying mitigation strategies during FT may be too late to avoid simplicity bias. We note that applying mitigation strategies during FT, in addition to LP, may further improve performance, and we will add these variants in the final version. We did not include a FT soup variant as it would be prohibitively expensive to train and average large soups of entire models (instead of classifiers). This highlights the computational efficiency of implementing mitigation strategies in the LP step itself.

C ADDITIONAL RESULTS

Below, we include results corresponding to different hyperparameters (number of souped classifiers, α for vat, and δ for udp). We see that with a larger model, and different pretraining method, our proposed variants still have some benefits. We note that the baseline performance is also improved as a result of a more larger pretrained model.



Figure 3: Dataset Distortion.We plot the CKA similarity between adapted and pretrained models. DomainNet and Liv-ing17 require low distortion, as seen by performance of LP+FT across metrics with high CKA (> 0.9).

Figure 4: Synthetic Data with Simple and Complex Features. Using a synthetic dominoes dataset (Shah et al., 2020), we study the effect of simplicity bias on safety and OOD generalization.

1: MITIGATING FEATURE DISTORTION MAY NOT INDUCE SAFE ADAPTATION.

Safety

Simplicity Bias and Performance of Adaptation Protocols. Using the synthetic dominoes dataset, we measure the propensity of different models to simplicity bias by measuring the Corr. OOD and Rand. OOD accuracy With highest Corr. OOD accuracy and lowest Rand. OOD accuracy, we see that FT is particularly susceptible to inducing simplicity bias.Results. Given the above experimental setup, we report the performance of different adaptation protocols in Table.2. Across all correlation strengths, FT has the lowest Rand. OOD accuracy and high Corr. OOD accuracy. This clearly indicates that FT has learned to rely upon simple features, effectively disregarding the expressive features of the pretrained model, and is easily susceptible to simplicity bias. In contrast, by preserving the expressive features of the underlying feature encoder, LP best mitigates simplicity bias in high correlation (0.99,1.0) settings as evidenced by the highest Rand OOD accuracy (though Corr. ID/OOD accuracy does slightly suffer). However, in moderate .1 IMPROVED LINEAR PROBES FOR MITIGATING SIMPLICITY BIAS As discussed earlier, adaptation protocols have varying susceptibility to simplicity bias, and mitigating this susceptibility can help improve generalization and safety. Hardness Promoting Augmentations help Mitigate Simplicity Bias. We evaluate the modified LP+FT protocols on the dominoes dataset, and find they improve the Rand. OOD Accuracy over vanilla FT and LP+FT. This suggests that modified protocols can rely less upon shortcuts or simple features.

Living17: Hardness Promoting Augmentation and Adaptation. In this low-distortion adaptation setting, we see that vanilla LP+FT is an effective baseline and that hardness promoting variants of LP+FT tend to perform comparably.

DomainNet: Hardness Promoting Augmentations and Adaptation. While relatively low distortion is induced by LP+FT , FT struggles to find a viable solution. Here, hardness-promoting LP+FT variants, particularly LP(VAT)+FT do slightly improve ID and OOD generalization as well as robustness to corruptions.

CIFAR10

We use model-soups so that we may learn a set of classifiers that rely upon disjoint sets of features. By learning a set of diverse classifiers, we are able to average classifiers that have learned to rely upon different features, instead of becoming overly reliant upon a single simple feature. In future work, we intend to build a theoretical framework that helps us better justify these interventions and create new ones.B.2 APPLYING SIMPLICITY BIAS MITIGATION STRATEGIES TO FINE-TUNING STEP.

CIFAR10, Hardness-Promoting Augmentations.

Living17, Hardness-Promoting Augmentations

DomainNet, Diversity Promoting Augmentations and Generalization Trade-offs.

CIFAR10 with Resnet101/SimCLR Pretrained Model.

ACKNOWLEDGMENTS

We thank Ekdeep Singh Lubana for several helpful discussions during the course of this project. This work was performed under the auspices of the U.S. Department of Energy by the Lawrence Livermore National Laboratory under Contract No. DE-AC52-07NA27344, Lawrence Livermore National Security, LLC.and was supported by the LLNL-LDRD Program under Project No. 21-ERD-012. It was also partially supported by the National Science Foundation under CAREER Grant No. IIS 1845491. PT began this work as an intern at Lawrence Livermore National Laboratory.

