SIMPLICITY BIAS IN 1-HIDDEN LAYER NEURAL NET-WORKS

Abstract

Recent works Shah et al. (2020); Chen et al. ( 2021) have demonstrated that neural networks exhibit extreme simplicity bias (SB). That is, they learn only the simplest features to solve a task at hand, even in the presence of other, more robust but more complex features. Due to the lack of a general and rigorous definition of features, these works showcase SB on semi-synthetic datasets such as Color-MNIST, MNIST-CIFAR where defining features is relatively easier. In this work, we rigorously define as well as thoroughly establish SB for one hidden layer neural networks. More concretely, (i) we define SB as the network essentially being a function of a low dimensional projection of the inputs (ii) theoretically, we show that when the data is linearly separable, the network primarily depends on only the linearly separable (1-dimensional) subspace even in the presence of an arbitrarily large number of other, more complex features which could have led to a significantly more robust classifier, (iii) empirically, we show that models trained on real datasets such as Imagenette and Waterbirds-Landbirds indeed depend on a low dimensional projection of the inputs, thereby demonstrating SB on these datasets, iv) finally, we present a natural ensemble approach that encourages diversity in models by training successive models on features not used by earlier models, and demonstrate that it yields models that are significantly more robust to Gaussian noise.

1. INTRODUCTION

Figure 1 : Classification of swans vs bears. There are several features such as background, color of the animal, shape of the animal etc., each of which is sufficient for classification but using all of them will lead to a more robust model. 1It is well known that neural networks (NNs) are vulnerable to distribution shifts as well as to adversarial examples (Szegedy et al., 2014; Hendrycks et al., 2021) . A recent line of work (Geirhos et al., 2018; Shah et al., 2020; Geirhos et al., 2020) proposes that Simplicity Bias (SB) -aka shortcut learning -i.e., the tendency of neural networks (NNs) to learn only the simplest features over other useful but more complex features, is a key reason behind this non-robustness. The argument is roughly as follows: for example, in the classification of swans vs bears, as illustrated in Figure 1 , there are many features such as background, color of the animal, shape of the animal etc. that can be used for classification. However using only one or few of them can lead to models that are not robust to specific distribution shifts, while using all the features can lead to more robust models. Several recent works have demonstrated SB on a variety of semi-real constructed datasets (Geirhos et al., 2018; Shah et al., 2020; Chen et al., 2021) , and have hypothesized SB to be the key reason for NN's brittleness to distribution shifts (Shah et al., 2020) . However, such observations are still only for specific semi-real datasets, and a general method that can identify SB on a given dataset and a given model is still missing in literature. Such a method would be useful not only to estimate the robustness of a model but could also help in designing more robust models.



Image source: Wikipedia swa, bea.1

