FEATURE RECONSTRUCTION FROM OUTPUTS CAN MITIGATE SIMPLICITY BIAS IN NEURAL NETWORKS

Abstract

Deep Neural Networks are known to be brittle to even minor distribution shifts compared to the training distribution. While one line of work has demonstrated that Simplicity Bias (SB) of DNNs -bias towards learning only the simplest features -is a key reason for this brittleness, another recent line of work has surprisingly found that diverse/ complex features are indeed learned by the backbone, and their brittleness is due to the linear classification head relying primarily on the simplest features. To bridge the gap between these two lines of work, we first hypothesize and verify that while SB may not altogether preclude learning complex features, it amplifies simpler features over complex ones. Namely, simple features are replicated several times in the learned representations while complex features might not be replicated. This phenomenon, we term Feature Replication Hypothesis, coupled with the Implicit Bias of SGD to converge to maximum margin solutions in the feature space, leads the models to rely mostly on the simple features for classification. To mitigate this bias, we propose Feature Reconstruction Regularizer (FRR) to ensure that the learned features can be reconstructed back from the logits. The use of FRR in linear layer training (FRR-L) encourages the use of more diverse features for classification. We further propose to finetune the full network by freezing the weights of the linear layer trained using FRR-L, to refine the learned features, making them more suitable for classification. Using this simple solution, we demonstrate up to 15% gains in OOD accuracy on the recently introduced semi-synthetic datasets with extreme distribution shifts. Moreover, we demonstrate noteworthy gains over existing SOTA methods on the standard OOD benchmark DomainBed as well.

1. INTRODUCTION

Despite the remarkable success of Deep Neural Networks (DNNs) in various fields, they are known to be brittle against even minor shifts in the data distribution during inference, which are not uncommon in a real world setting (Quinonero-Candela et al., 2008; Torralba & Efros, 2011) . For example, a self-driving car that works well in normal weather may perform poorly when it is snowing, leading to disastrous outcomes. The need for improving the robustness of such systems against distribution shifts has sparked interest in the area of Out-Of-Distribution or OOD generalization (Hendrycks & Dietterich, 2019; Gulrajani & Lopez-Paz, 2020) . In this work, we aim to tackle the problem of OOD generalization of Neural Networks in a covariateshift (Shimodaira, 2000) based classification setting, by addressing the fundamental cause of their brittleness, rather than by explicitly enforcing invariances in the network using domain labels or data augmentations. More specifically, we aim to mitigate the issue of Simplicity Bias, which is the tendency of Stochastic Gradient Descent (SGD) based solutions to overly rely on simple features alone, rather than on a diverse set of features (Arpit et al., 2017; Valle-Perez et al., 2018) . While this behavior was earlier used to explain the remarkable generalization of Deep Networks, recent works suggest that this is indeed a key reason behind their brittleness to domain shifts (Shah et al., 2020) . The extent of Simplicity Bias seen in models is a result of two important factors -diversity of features learned by the feature extractorfoot_0 , and the extent to which these diverse features are used for the task at hand, such as classificationfoot_1 . Recent works suggest that generalization to distribution shifts can be improved by retraining the last layer alone, indicating that the features learned may already be good enough for the same (Rosenfeld et al., 2022; Kirichenko et al., 2022b) . Does this imply that brittleness of models can be attributed to the learning of the classification head alone? If this is the case, why does SGD fail to utilize these diverse features despite its Implicit Bias to converge to a maximum margin solution in a linearly separable case (Soudry et al., 2018 )? To answer these questions, we firstly hypothesize and empirically verify that Simplicity Bias leads to the learning of simple features over and over again, as compared to other, more complex features. For example, among the 512 penultimate layer features of a ResNet, 462 of them might capture a simple feature such as color, while the remaining 50 might capture a more complex feature such as shape -we refer to this as (Simple) Feature Replication Hypothesis. Assuming feature replication hypothesis, we further show theoretically and empirically that a maximum margin classifier in the replicated feature space would give much higher importance to the replicated feature when compared to others, highlighting why the linear layer relies more on simpler features for classification. To mitigate this, we propose a novel regularizer termed Feature Reconstruction Regularizer (FRR), to enforce that the features learned by the network can be reconstructed back from the logit or presoftmax layer used for the classification task. As shown in Fig. 2 , we firstly propose to train the linear classifier alone by freezing the weights of the feature extractor. This formulation enables the learning of an Invertible Mapping in the output layer, specifically for the domain of features seen during training. This further allows the logit layer to act as an information bottleneck, encouraging all the factors of variation in the features to be utilized for the classification task, thereby improving the diversity of features used. We theoretically show that adding this constraint while finetuning the linear layer can learn a max-margin classifier in the original input space, disregarding feature replication. Consequently, the learnt linear classifier also gives more importance to non replicated complex features while making predictions. We further explore the possibility of improving the quality of features learned by the feature extractor, by using FRR for finetuning the backbone as well. In order to do this, we freeze the linear classification head, and further finetune the backbone with FRR. We find that this encourages the network to indeed learn better quality features that are more relevant for classification. We list the key contributions of this work below - • Key Observation: We provide a crisp hypothesis of "feature replication" to explain the brittleness of ERM trained neural networks to OOD data (Sec 3.1). Using this, we further provide theoretical and empirical evidence to justify the existence of Simplicity Bias in maximum margin classifiers. • Novel Algorithm based on the Observation: Based on this, we introduce a novel FRR regularizer to safeguard against the feature replication phenomenon (Sec 3.2). We also provide theoretical support for FRR in an intuitive data distribution setting. Furthermore, we introduce a simple FRR-L method to only regularize the linear head with FRR, and then introduce FRR-FLFT training regimen to train the feature extractor for improved OOD robustness (Sec 4). • Empirical validation of the hypothesis and the proposed algorithm: We demonstrate the effectiveness of FRR-FLFT and FRR-L by conducting extensive experiments on semi-real datasets (Table 2) constructed to study OOD brittleness, as well as on standard OOD generalization benchmarks, where FRR-FLFT can provide up to 3% gains over SOTA methods for OOD generalization(Table 3 ).

2. RELATED WORKS

Learning diverse classifiers to counter simplicity bias: Recent works have shown that ERM trained models learn diverse features, however, the linear layer fails at capturing and utilizing these diverse features properly. 



In this paper, we refer to the penultimate layer's activations as features. i.e., by the final classification layer.



There have been several attempts at training classifiers which can make use of such diverse features. Teney et al. (2022) train a number of linear classifiers on top of a pre-trained network with a diversity regularizer, which encourages the classifiers to rely on different features. Xu et al. (2022) and Bahng et al. (2020) propose to train debiased classifiers which are statistically independent from trained biased networks, but these need careful design and prior knowledge of the

