A CLOSER LOOK AT MODEL ADAPTATION USING FEA-TURE DISTORTION AND SIMPLICITY BIAS

Abstract

Advances in the expressivity of pretrained models have increased interest in the design of adaptation protocols which enable safe and effective transfer learning. Going beyond conventional linear probing (LP) and fine tuning (FT) strategies, protocols that can effectively control feature distortion, i.e., the failure to update features orthogonal to the in-distribution, have been found to achieve improved outof-distribution generalization (OOD). In order to limit this distortion, the LP+FT protocol, which first learns a linear probe and then uses this initialization for subsequent FT, was proposed. However, in this paper, we find when adaptation protocols (LP, FT, LP+FT) are also evaluated on a variety of safety objectives (e.g., calibration, robustness, etc.), a complementary perspective to feature distortion is helpful to explain protocol behavior. To this end, we study the susceptibility of protocols to simplicity bias (SB), i.e. the well-known propensity of deep neural networks to rely upon simple features, as SB has recently been shown to underlie several problems in robust generalization. Using a synthetic dataset, we demonstrate the susceptibility of existing protocols to SB. Given the strong effectiveness of LP+FT, we then propose modified linear probes that help mitigate SB, and lead to better initializations for subsequent FT. We verify the effectiveness of the proposed LP+FT variants for decreasing SB in a controlled setting, and their ability to improve OOD generalization and safety on three adaptation datasets.

1. INTRODUCTION

Figure 1 : Strong and Safe Adaptation. Practical deployment in high risk applications requires that adapted models not only generalize well to in-and out-of distribution data of the downstream task, but they do so safely. Through the use of larger datasets (Yalniz et al., 2019) , better architectures (Zhai et al., 2022; Chen et al., 2022; Steiner et al., 2022; Tolstikhin et al., 2021) , and different self-supervised learning (SSL) approaches (He et al., 2020; Chen et al., 2020; Grill et al., 2020; Caron et al., 2020) , the quality of pretrained representations available for transfer learning tasks has dramatically improved. Indeed, representations from such high-quality SSL models have been found to be more robust (Hendrycks et al., 2019; Liu et al., 2021 ), transferable (Ericsson et al., 2021) and semantically consistent (Caron et al., 2021) than their supervised counterparts. In this regard, there is growing need for adaptation protocols that explicitly capitalize on these improved pretrained features to induce similar beneficial properties, e.g., inducing more than just high accuracy on the target task, after models have been trained on the downstream task. However, standard adaptation protocols that rely upon finetuning (FT) all model parameters or training only a linear probe (LP) while freezing the network parameters do not maximize the potential of high-quality representations. For example, while high-quality, pre-trained models have sufficiently expressive features to perform well on Correspondence to pujat@umich.edu. both in-distribution (ID) and out-of-distribution (OOD) data, LP and FT are not able to effectively induce this property in adapted models (Andreassen et al., 2021 ). Recently, however, Kumar et al. (2022) proved that by modifying features only in the ID representation subspace, FT can lead to higher OOD error as it distorts directions outside the ID subspace that are needed for OOD generalization. As both ID and OOD subspaces are represented by the pretrained model, Kumar et al. demonstrate that limiting feature distortion, or controlling updates towards the ID subspace, can lead to improved ID and OOD performance. To this end, they propose a new protocol which performs LP prior to FT (abbrev. LP + FT). By first performing LP, this two-step process ensures that subsequent FT will remain in the vicinity of the original LP solution. This reduces the overall distortion towards the ID distribution subspace and improves performance. While strong ID and OOD generalization on the target task is indeed an important aspect of transfer learning, practical, high-risk applications require that models are also safe (Hendrycks et al., 2021) . For example, adapted models should also be well-calibrated, robust to corruptions or adversaries and able to reliably detect anomalous samples (see Figure 1 ). Given that existing adaptation protocols are primarily focused on improving generalization, it is unclear how existing protocols utilize high-quality pretrained features to promote safe adaptation, and if current protocol design perspectives, such as mitigating feature distortion, will also enable safe generalization. Our Work: In this work, we seek to understand the factors relevant to the design of adaption protocols that promote effective and safe generalization. We take the first step towards this aim by (i) demonstrating limitations in existing LP, FT, and LP+FT protocols through an extensive, joint evaluation, and (ii) studying adaptation protocols through the complementary lens of avoiding simplicity bias, i.e., the problematic tendency of deep neural networks (DNNs) to prefer simple, potentially brittle features over complex features (Soudry et al., 2018; Gunasekar et al., 2018; Geirhos et al., 2019; Hermann et al., 2020; Shah et al., 2020) . Using the insights from our analysis, we propose three variants of the LP+FT protocol that jointly improve safety and generalization on three datasets. Our contributions can be summarized as follows: • Joint Analysis of Adaptation Protocol Safety and Generalization (Sec. 3). We show that when adaptation protocols are evaluated with respect to both ID/OOD generalization and safety, LP+FT trails behind LP or FT on several safety metrics. This demonstrates that solely mitigating feature distortion may not be sufficient for safe generalization. We also observe that keeping subsequent FT close to LP solution is crucial for the improved OOD generalization of LP+FT. This motivates us to focus on improving the LP initialization as a mechanism for improving both safety and OOD performance. • Role of Simplicity Bias in (Unsafe) Adaptation (Sec. 4). To understand how protocols may induce safe adaptation, we study how different protocols avoid simplicity bias. While simplicity bias (Shah et al., 2020; Geirhos et al., 2019) has been shown to underlie several problems in machine learning safety, to the best of our knowledge, we are the first to consider its role in adaptation settings. We demonstrate that protocols must not only reduce distortion, but also should mitigate simplicity bias for effective adaptation. • Improved Protocols for Mitigating Simplicity Bias and Distortion (Sec. 5). We propose three, simple modified LP+FT protocols that help mitigate both simplicity bias and distortion (Sec. 4.1). In particular, we consider modifying the LP step with uncertainty-driven perturbations (Pagliardini et al., 2022 ), virtual adversarial training (Miyato et al., 2017) and model-soups (Wortsman et al., 2022) , as they are simple and effective strategies. Across synthetic and real datasets, the modified protocols help improve safety and generalization to some extent.

2. RELATED WORK AND BACKGROUND

Here, we discuss the most relevant work on adaptation protocols and simplicity bias; we discuss additional related work in Sup. A. Adaptation Protocols. For a comprehensive overview of transfer learning, please see the surveys of Zhuang et al. (2021) and Pan & Yang (2010) . Here, we discuss the works that are most relevant to our own. Kirichenko et al. (2022) recently demonstrated that models are able to learn both core features and spurious features. However, classifiers can rely upon spurious features, harming

