FEATURE RECONSTRUCTION FROM OUTPUTS CAN MITIGATE SIMPLICITY BIAS IN NEURAL NETWORKS

Abstract

Deep Neural Networks are known to be brittle to even minor distribution shifts compared to the training distribution. While one line of work has demonstrated that Simplicity Bias (SB) of DNNs -bias towards learning only the simplest features -is a key reason for this brittleness, another recent line of work has surprisingly found that diverse/ complex features are indeed learned by the backbone, and their brittleness is due to the linear classification head relying primarily on the simplest features. To bridge the gap between these two lines of work, we first hypothesize and verify that while SB may not altogether preclude learning complex features, it amplifies simpler features over complex ones. Namely, simple features are replicated several times in the learned representations while complex features might not be replicated. This phenomenon, we term Feature Replication Hypothesis, coupled with the Implicit Bias of SGD to converge to maximum margin solutions in the feature space, leads the models to rely mostly on the simple features for classification. To mitigate this bias, we propose Feature Reconstruction Regularizer (FRR) to ensure that the learned features can be reconstructed back from the logits. The use of FRR in linear layer training (FRR-L) encourages the use of more diverse features for classification. We further propose to finetune the full network by freezing the weights of the linear layer trained using FRR-L, to refine the learned features, making them more suitable for classification. Using this simple solution, we demonstrate up to 15% gains in OOD accuracy on the recently introduced semi-synthetic datasets with extreme distribution shifts. Moreover, we demonstrate noteworthy gains over existing SOTA methods on the standard OOD benchmark DomainBed as well.

1. INTRODUCTION

Despite the remarkable success of Deep Neural Networks (DNNs) in various fields, they are known to be brittle against even minor shifts in the data distribution during inference, which are not uncommon in a real world setting (Quinonero-Candela et al., 2008; Torralba & Efros, 2011) . For example, a self-driving car that works well in normal weather may perform poorly when it is snowing, leading to disastrous outcomes. The need for improving the robustness of such systems against distribution shifts has sparked interest in the area of Out-Of-Distribution or OOD generalization (Hendrycks & Dietterich, 2019; Gulrajani & Lopez-Paz, 2020) . In this work, we aim to tackle the problem of OOD generalization of Neural Networks in a covariateshift (Shimodaira, 2000) based classification setting, by addressing the fundamental cause of their brittleness, rather than by explicitly enforcing invariances in the network using domain labels or data augmentations. More specifically, we aim to mitigate the issue of Simplicity Bias, which is the tendency of Stochastic Gradient Descent (SGD) based solutions to overly rely on simple features alone, rather than on a diverse set of features (Arpit et al., 2017; Valle-Perez et al., 2018) . While this behavior was earlier used to explain the remarkable generalization of Deep Networks, recent works suggest that this is indeed a key reason behind their brittleness to domain shifts (Shah et al., 2020) .

