FEATURE RECONSTRUCTION FROM OUTPUTS CAN MITIGATE SIMPLICITY BIAS IN NEURAL NETWORKS

Abstract

Deep Neural Networks are known to be brittle to even minor distribution shifts compared to the training distribution. While one line of work has demonstrated that Simplicity Bias (SB) of DNNs -bias towards learning only the simplest features -is a key reason for this brittleness, another recent line of work has surprisingly found that diverse/ complex features are indeed learned by the backbone, and their brittleness is due to the linear classification head relying primarily on the simplest features. To bridge the gap between these two lines of work, we first hypothesize and verify that while SB may not altogether preclude learning complex features, it amplifies simpler features over complex ones. Namely, simple features are replicated several times in the learned representations while complex features might not be replicated. This phenomenon, we term Feature Replication Hypothesis, coupled with the Implicit Bias of SGD to converge to maximum margin solutions in the feature space, leads the models to rely mostly on the simple features for classification. To mitigate this bias, we propose Feature Reconstruction Regularizer (FRR) to ensure that the learned features can be reconstructed back from the logits. The use of FRR in linear layer training (FRR-L) encourages the use of more diverse features for classification. We further propose to finetune the full network by freezing the weights of the linear layer trained using FRR-L, to refine the learned features, making them more suitable for classification. Using this simple solution, we demonstrate up to 15% gains in OOD accuracy on the recently introduced semi-synthetic datasets with extreme distribution shifts. Moreover, we demonstrate noteworthy gains over existing SOTA methods on the standard OOD benchmark DomainBed as well.

1. INTRODUCTION

Despite the remarkable success of Deep Neural Networks (DNNs) in various fields, they are known to be brittle against even minor shifts in the data distribution during inference, which are not uncommon in a real world setting (Quinonero-Candela et al., 2008; Torralba & Efros, 2011) . For example, a self-driving car that works well in normal weather may perform poorly when it is snowing, leading to disastrous outcomes. The need for improving the robustness of such systems against distribution shifts has sparked interest in the area of Out-Of-Distribution or OOD generalization (Hendrycks & Dietterich, 2019; Gulrajani & Lopez-Paz, 2020) . In this work, we aim to tackle the problem of OOD generalization of Neural Networks in a covariateshift (Shimodaira, 2000) based classification setting, by addressing the fundamental cause of their brittleness, rather than by explicitly enforcing invariances in the network using domain labels or data augmentations. More specifically, we aim to mitigate the issue of Simplicity Bias, which is the tendency of Stochastic Gradient Descent (SGD) based solutions to overly rely on simple features alone, rather than on a diverse set of features (Arpit et al., 2017; Valle-Perez et al., 2018) . While this behavior was earlier used to explain the remarkable generalization of Deep Networks, recent works suggest that this is indeed a key reason behind their brittleness to domain shifts (Shah et al., 2020) . The extent of Simplicity Bias seen in models is a result of two important factors -diversity of features learned by the feature extractorfoot_0 , and the extent to which these diverse features are used for the task at hand, such as classificationfoot_1 . Recent works suggest that generalization to distribution shifts can be improved by retraining the last layer alone, indicating that the features learned may already be good enough for the same (Rosenfeld et al., 2022; Kirichenko et al., 2022b) . Does this imply that brittleness of models can be attributed to the learning of the classification head alone? If this is the case, why does SGD fail to utilize these diverse features despite its Implicit Bias to converge to a maximum margin solution in a linearly separable case (Soudry et al., 2018 )? To answer these questions, we firstly hypothesize and empirically verify that Simplicity Bias leads to the learning of simple features over and over again, as compared to other, more complex features. For example, among the 512 penultimate layer features of a ResNet, 462 of them might capture a simple feature such as color, while the remaining 50 might capture a more complex feature such as shape -we refer to this as (Simple) Feature Replication Hypothesis. Assuming feature replication hypothesis, we further show theoretically and empirically that a maximum margin classifier in the replicated feature space would give much higher importance to the replicated feature when compared to others, highlighting why the linear layer relies more on simpler features for classification. To mitigate this, we propose a novel regularizer termed Feature Reconstruction Regularizer (FRR), to enforce that the features learned by the network can be reconstructed back from the logit or presoftmax layer used for the classification task. As shown in Fig. 2 , we firstly propose to train the linear classifier alone by freezing the weights of the feature extractor. This formulation enables the learning of an Invertible Mapping in the output layer, specifically for the domain of features seen during training. This further allows the logit layer to act as an information bottleneck, encouraging all the factors of variation in the features to be utilized for the classification task, thereby improving the diversity of features used. We theoretically show that adding this constraint while finetuning the linear layer can learn a max-margin classifier in the original input space, disregarding feature replication. Consequently, the learnt linear classifier also gives more importance to non replicated complex features while making predictions. We further explore the possibility of improving the quality of features learned by the feature extractor, by using FRR for finetuning the backbone as well. In order to do this, we freeze the linear classification head, and further finetune the backbone with FRR. We find that this encourages the network to indeed learn better quality features that are more relevant for classification. We list the key contributions of this work below - • Key Observation: We provide a crisp hypothesis of "feature replication" to explain the brittleness of ERM trained neural networks to OOD data (Sec 3.1). Using this, we further provide theoretical and empirical evidence to justify the existence of Simplicity Bias in maximum margin classifiers. • Novel Algorithm based on the Observation: Based on this, we introduce a novel FRR regularizer to safeguard against the feature replication phenomenon (Sec 3.2). We also provide theoretical support for FRR in an intuitive data distribution setting. Furthermore, we introduce a simple FRR-L method to only regularize the linear head with FRR, and then introduce FRR-FLFT training regimen to train the feature extractor for improved OOD robustness (Sec 4). • Empirical validation of the hypothesis and the proposed algorithm: We demonstrate the effectiveness of FRR-FLFT and FRR-L by conducting extensive experiments on semi-real datasets (Table 2) constructed to study OOD brittleness, as well as on standard OOD generalization benchmarks, where FRR-FLFT can provide up to 3% gains over SOTA methods for OOD generalization(Table 3 ).

2. RELATED WORKS

Learning diverse classifiers to counter simplicity bias: Recent works have shown that ERM trained models learn diverse features, however, the linear layer fails at capturing and utilizing these diverse features properly. There have been several attempts at training classifiers which can make use of such diverse features. Teney et al. (2022) train a number of linear classifiers on top of a pre-trained network with a diversity regularizer, which encourages the classifiers to rely on different features. Xu et al. (2022) and Bahng et al. (2020) propose to train debiased classifiers which are statistically independent from trained biased networks, but these need careful design and prior knowledge of the biases in trained networks. Kirichenko et al. (2022a) show that reweighting train set examples and retraining the last layer of a pre-trained deep network can alleviate spurious correlations, provided one can access a balanced dataset. In contrast to these methods, our method can work simply on the training set data, and produce a single classifier which is debiased. Huang et al. (2020) propose to mute the features with highest gradients, and use only the other features to make a prediction. While this method suppresses the maximally used features, it does not encourage the learning of hard-tolearn features, which is directly realized using our loss formulation. Kumar et al. (2022) suggest that finetuning the final linear layer first before finetuning the entire network can make it more robust to OOD shifts, and we utilize this insight in the FRR-FLFT phase of our method. A complementary approach to this problem is to learn features that are more diverse (Zhang et al., 2022; Wang et al., 2019) . We note that applying our proposed method on top of such techniques would encourage the classifier to use the diverse features effectively, and this can further benefit the performance.

Domain Generalization and OOD robustness:

The performance of neural networks is known to drop when there is a mismatch in the train and test distributions (Hendrycks & Dietterich, 2019) , and methods to mitigate this have been gaining a lot of attention in recent years. The problem has been studied under various assumptions on distribution shift. The commonly studied setting of domain generalization (Gulrajani & Lopez-Paz, 2020; Li et al., 2018a) assumes that the train distribution consists of a mixture of distinct distributions (called domains), with each train sample having a domain label associated with it. The stronger setting of aggregate domain generalization (Thomas et al., 2021; Matsuura & Harada, 2020) assumes training data to be drawn from a mixture of distributions, but does not assume the availability of domain labels. Finally, OOD robustness (Hendrycks & Dietterich, 2019; Koh et al., 2021 ) drops all of these assumptions. Most works tackling the domain generalization problem attempt to train a model whose predictions are invariant to the domain label (Li et al., 2018a; Arjovsky et al., 2019) , or try to align the features of the model for examples from different domains (Shi et al., 2021; Shankar et al., 2018) . However, since we aim to tackle the stronger setting of OOD generalization, we do not use domain labels. Tackling the OOD robustness problem, Thomas et al. (2021) and Matsuura & Harada (2020) first cluster training examples into "pseudo-domains", after which standard domain generalization techniques are used. Another recent line of works propose using model averaging (Cha et al., 2021; Li et al., 2022) and/or ensembling (Arpit et al., 2021) for better OOD generalization. These techniques are complementary to our contribution, and we demonstrate how they can benefit each other in our empirical evaluation.

3. FEATURE REPLICATION HYPOTHESIS

Prior works have shown that neural networks trained with SGD exhibit simplicity bias (SB), even when initialized with pre-trained models that can capture complex features. Our Feature Replication Hypothesis -FRH-states that: SB is observed because the simpler features of the input are replicated multiple times in the feature space of neural networks. When trained using SGD, the final linear layer then learns the max margin classifier on these replicated features, which leads to over-reliance on simpler features in the input. Hence, the outputs of the network are brittle to distribution shifts that change such replicated features. In this section, we provide empirical and theoretical evidence for FRH, and propose a new regularizer -FRR-to mitigate this effect. We first introduce some useful notations. Let f θ (x) : R d → R m be the feature extractor of a neural network parameterized by weights θ, and W ∈ R m×k be the weight matrix of the linear classifier. For input x ∈ R d , the output of the network is W T f θ (x) ∈ R k .

3.1. EMPIRICAL VALIDATION OF FEATURE REPLICATION HYPOTHESIS (FRH) IN ERM

Coloured MNIST dataset. To empirically demonstrate feature replication, we use a binarized version of the coloured MNIST dataset (Gulrajani & Lopez-Paz, 2020) . We construct this dataset by first assigning two digits of the MNIST dataset, namely "1" and "5", to classes 0 and 1 respectively. While training the network, we super-impose images of "1" onto colours of range R 0 = [(115, 0, 0) -(256, 141, 0)] (i.e. red), and images of "5" onto colours of range R 1 = [(0, 115, 0) -(141, 256, 0)] (i.e. green). The dataset is constructed such that the simple feature, namely colour, is weakly correlated with the labels, while the complex shape features are strongly correlated with labels. See Appendix D.1 for more details about the dataset. Training setup: We train a model on this dataset, and test it on images which do not have any correlation between the label and the colour, i.e. images where the digits "1" and "5" are superimposed on randomly coloured backgrounds. We construct this test distribution to see how well different algorithms learn simple (i.e. colour) and complex (i.e. shape) features, since an algorithm which depends only on the spurious colour features would not have good performance on the test domain. We train a four layered CNN on this data. If a feature in the penultimate layer f θ (x) has more than 90% correlation with the color or shape of the input, then we call it as a color feature or a shape feature, respectively. We also compute the correlation of these features with the output of the network (W T f θ (x)) over inputs from the test domain. This gives us information of the learnt features, and their contributions to the final prediction of the network. Note that the feature dimension is m = 32, and the output dimension is k = 1.

Observations:

In Table 1 , we report the number of colour features, shape features, and the average correlation of each of these with the final prediction. We observe that the ERM trained model learns both shape and colour features, but the number of learnt colour features ( 26) is much higher than the number of shape features (4), despite their weaker correlation with labels, thus validating our Feature Replication Hypothesis. We also visualize the inter-feature correlation of the learnt features in Fig 5 , which shows blocks of highly correlated features, further validating our hypothesis. We note that correlation of the output with the shape features is lower, leading to OOD accuracy of 59%.

3.2. FEATURE RECONSTRUCTION REGULARIZER (FRR)

To alleviate simple feature replication issue, we propose Feature Reconstruction Regularizer (FRR) to enforce that the learned features can be reconstructed from the output logits. We propose to retrain the final linear layer using this regularizer to allow the model to utilize diverse features to compute the final output. We implement this by introducing another neural network with the objective of reconstructing the features of the network from the output logits, i.e. features f θ (x) should be recoverable from the predictions of the network through a transform T φ (.) parameterized by φ. That is, FRR is given by: L FRR (x, θ, W, φ) = ||f θ (x) -T φ (W T f θ (x))|| p where ||.|| p denotes the p norm. We set this norm to be ∞ or 1 in our experiments. In the simplest case, T φ (y) = φy, where φ ∈ R m×k . Note that in order to find the appropriate φ, we jointly optimize W and φ using gradient descent based optimizers. We also experiment with φ being a more complex neural network. We empirically validate FRR on Coloured MNIST, where using FRR with the linear layer leads to lower correlation with Colour compared to standard ERM (Table 1 ). Consequently, OOD accuracy improves by 5% over ERM.

3.3. FRH & FRR: THEORETICAL ANALYSIS

We now present a simple and intuitive data distribution with feature replication that highlights the OOD brittleness of standard ERM, and also demonstrates FRR can be significantly more robust. Data Distribution: Consider a linearly separable distribution consisting of two factors of variation as shown in Figure 1 . That is, consider the following distribution (x, y) ∼ D, where, Also consider a feature extractor f θ (.) which captures feature replication in the first feature, i.e. for every data point (x, y), the new, feature replicated data point will be (x, y), where, y = ±1 with probability 0.5, x = [y, y] + [n 1 , n 2 ] ∈ R 2 , n i ∼ Unif[-0.5, 0.5], i ∈ [2]. (2) f θ (x) = x = [x 1 , • • • , x 1 , x 2 ] ∈ R d+1 , i.e., x 1 is repeated d times. The joint distribution of features and labels is denoted by D. Finally, we define the l 2 max margin classifier over a distribution D as w MM := arg min w 1 2 w 2 2 subject to y • w, x ≥ 1 ∀ (x, y) ∈ Supp (D). Then we have the following results: Claim 3.1 (Brittleness due to Feature Replication). Consider the data distribution given in Equation 2, 3. Then, the following holds: (1.) The max-margin classifier w MM over D is given by w MM = [1, 1], and (2.) The max-margin classifier wMM over D is given by wMM = 2 d+1 , • • • , 2 d+1 ∈ R d+1 . The above claim implies that when there are replicated features to the input of the linear layer, the max-margin classifier would give much more importance to the feature that is replicated. Hence, even a minor change in this replicated feature in the input space would be amplified in the output of the classifier. This is especially concerning in light of the observations in Table 1 , which validate the Feature Replication Hypothesis in Coloured MNIST. Proposition 3.2 (Robustness of FRR). Denote the average feature reconstruction loss L FRR ( w, φ) := max 1≤i≤d+1 E (x,y)∼ D ( w, x φi -xi ) 2 and consider any ( w * , φ * ) satisfying: ( w * , φ * ) ∈ arg min ( w, φ) L FRR ( w, φ) subject to y • w, x ≥ 0 ∀ (x, y) ∈ Supp D . We have that: w * 1 + • • • + w * d = w * d+1 . Consequently, we have w * , x ∝ w MM , x for all x ∈ R 2 . Practically, we can implement the above as 2,∞ over a batch. Above result shows that the feature reconstruction regularizer will produce a linear classifier that gives equal weights to the replicated and non-replicated features. This is equivalent to a maximum margin classifier in the non-replicated feature space, thereby resulting in enhanced robustness to distribution shifts. Same is reflected in Figure 1 (c), (d ) which show impact of FRR on the trained boundary in the non-replicated feature space. We defer the proofs of the above to Appendix A. We also provide a more general result by assuming correlated feature representations in Appendix A.

4. TRAINING PROCEDURE

Pretraining : In order to learn features which are relevant to the train distribution, we first pretrain our model using standard ERM with the cross-entropy loss L std (W, θ, (x, y)). FRR-L : Since ERM training is known to learn several rich and diverse features, we freeze the backbone parameters θ, and retrain the final layer W as the following- (W FRR , φ FRR ) = min W,φ L std (W, θ, (x, y)) + λ L L FRR (x, θ, W, φ) where λ L is a hyperparameter weighing the two losses. We train W and φ jointly. We refer to this step as FRR-L, i.e. Feature Reconstruction Regularizer -Linear, since we only train the linear layer. FRR-FLFT : Following the suggestions of Kumar et al. (2022) , we follow up the linear layer training with the finetuning of the feature extractor θ with a weighted combination of the crossentropy loss and FRR, weighted by a hyper-parameter λ F LF T . In this step we freeze the weights of the linear layer to improve the stability of training. We do this since naively using this constraint during network training could amplify the Simplicity Bias in networks in degenerate cases. For example, the backbone could learn to output a single replicated simple feature which is predictive enough on the training data. Reconstructing such a feature from logits would also be easy, but such a network might not generalize well. Formally, the optimization problem for this step is - θ FLFT = min θ L std (W FRR , θ, (x, y)) + λ FLFT L FRR (x, θ, W FRR , φ FRR ) We view this step as "sharpening" the features for more accurate predictions. Freezing the linear head makes sure that the features do not collapse to a degenerate solution. Our training algorithm is summarized in Algorithm-1 and the training pipeline is illustrated in Figure 2 .

5.1. UNDERSTANDING HOW FRR MITIGATES SIMPLICITY BIAS

To empirically illustrate the extent of Simplicity Bias in Neural Networks, Shah et al. (2020) introduced several synthetic and semi-synthetic datasets, where some features are explicitly simple, requiring a simpler decision boundary for prediction; while others are complex. In this section, we demonstrate the effectiveness of the proposed Feature Reconstruction Regularizer towards mitigating Simplicity Bias, by evaluating the same on a 10-class variant of the proposed semi-synthetic MNIST-CIFAR dataset, as discussed in the following section.

5.1.1. MNIST-CIFAR-10 DATASET

We extend the simple binary MNIST-CIFAR dataset proposed by Shah et al. (2020) to a 10-class dataset, in order to evaluate the impact of the proposed Feature Reconstruction Regularizer in a more complex scenario when compared to the binary Colored-MNIST dataset presented in Section-3. We refer to this dataset as MNIST-CIFAR-10. The higher complexity of this dataset allows for a more reliable evaluation of various settings such as linear probing, full network finetuning and fixedlinear finetuning, with better granularity of results. To construct this dataset, we first define correspondences between the classes of CIFAR-10 and MNIST. Each image from class k of MNIST is mapped with an image from class k of CIFAR-10, with the label being set to k. Thus, every training data sample (x 1 , x 2 , y) consists of x 1 and x 2 , which are images from CIFAR-10 and MNIST respectively, along with their ground truth class y. It is to be noted that for both CIFAR-10 and MNIST, labels are always correlated with the respective images. In such a scenario, although a classifier can achieve very good performance by relying solely on the simple (MNIST) features, the goal of Out-Of-Distribution (OOD) robustness requires it to rely on the complex (CIFAR-10) features as well. This dataset represents the toughest setting of OOD generalization, where there is no differentiation between important features and spurious 

5.2. TRAINING AND EVALUATION SETTINGS

We consider two separate ResNet-18 (He et al., 2016) feature extractors for CIFAR-10 and MNIST respectively. The outputs of the Global Average Pooling (GAP) layers in each of the feature extractors are concatenated to form a 1024 dimensional vector, which is given as input to the linear classifier. This architecture allows the computation of accuracy based on either a combination of both CIFAR-10 and MNIST features, or based on features of only one of the datasets. For example, to evaluate the performance of the classifier based on CIFAR-10 features alone, we replace the 512 dimensional MNIST feature vector of each data sample with an average feature vector computed from all images in the MNIST dataset. We refer to this as the CIFAR-AvgMNIST dataset, while the corresponding one for MNIST is refered to as the MNIST-AvgCIFAR dataset. Similar to the work by Shah et al. (2020) , we define two additional datasets, CIFAR-RandMNIST and MNIST-RandCIFAR, where images from one of the datasets (MNIST and CIFAR-10 respectively) are randomly shuffled with respect to their corresponding labels. The base training (E1, E2, E3) is done for 500 epochs, and the linear layer training / finetuning (E4 -E18) is done for 20 epochs, without any augmentations.

5.3. EXPERIMENTAL RESULTS IN VARIOUS TRAINING REGIMES

We present the results of training on the MNIST-CIFAR-10 dataset using different algorithms in Table 2 . The mean and standard deviation across five runs have been reported for each case. ERM Training: By training a randomly initialized model on the MNIST-CIFAR-10 dataset using the cross-entropy loss (E1), we obtain an accuracy of 99.84% on its corresponding test split. While the accuracy of this model on the MNIST-avgCIFAR dataset is high (97.44%), its performance on the CIFAR-avgMNIST dataset is poor (51.92%), indicating that the model chooses to rely more on the simpler MNIST features, rather than a combination of both CIFAR and MNIST features. While the performance on the CIFAR-avgMNIST and MNIST-avgCIFAR datasets is sufficient to understand the extent of CIFAR/ MNIST features used by the classification head, it does not give a clear picture on the features learned by the two feature extractors. To understand this, we reinitialize the linear classification head randomly, and train the same using CIFAR-RandMNIST (E5) and MNIST-RandCIFAR datasets (E6) respectively. We obtain an accuracy of 65.2% on the CIFAR-avgMNIST dataset in the former case, indicating that although the CIFAR features learned can possibly achieve 13% higher accuracy (w.r.t. E1), the bias in the classification head prevents them from participating in the classification task. The MNIST-avgCIFAR accuracy of the latter case is high as expected. An upper bound on CIFAR-10 and MNIST accuracy that can be achieved with the selected architecture and training strategy (without using any augmentations) can be seen in E2 (88.53%) and E3 (99.68%) respectively. Training the Linear Classification Head: As discussed, while ERM training (E1) learns features that can be used for better OOD performance (E5), it does not effectively leverage these features for the classification task. We firstly explore the possibility of bridging the difference in the CIFAR-avgMNIST accuracy between E1 and E5 by merely retraining the linear layer. By reinitializing and naively retraining the linear layer with Cross-entropy loss, the accuracy on CIFAR-avgMNIST improves by less than 1% (E4). Using the proposed Feature Reconstruction Regularizer (FRR) for training the linear layer alone, the CIFAR-avgMNIST accuracy improves by 7.21% as shown in E8, demonstrating the effectiveness of the proposed regularizer in mitigating Simplicity Bias. We penalize the ∞ norm of difference in original features and their reconstruction in addition to the minimization of cross-entropy loss. The reconstruction based regularizer enforces the network to utilize both CIFAR and MNIST features for classification. Since this regularizer resembles an orthonormality constraint on the linear classification head, we additionally check the effectiveness of explicitly enforcing a full-rank constraint on the linear layer by minimizing the following: ||W W T -I|| F (E7). We find that this is not effective in improving the overall accuracy, possibly because it enforces a very stringent constraint on the final classification layer. Contrary to this, the proposed Feature Reconstruction Regularizer allows more flexibility by imposing this constraint only on the domain of features seen during training. This accounts for the simple feature replication as well, enabling to view the logit layer as an information bottleneck in the reconstruction. Finetuning (FT) and Fixed Linear Finetuning (FLFT): We explore the finetuning of a given base model in two settings -firstly by finetuning all layers in the network (denoted as FT or FineTuning), and secondly, by freezing the parameters of the linear classification head and finetuning only the feature extractors, which we refer to as FLFT or Fixed Linear FineTuning. By finetuning an ERM trained base model using either of the two strategies (E9 and E10), we observe gains of less than 1%. We observe similar gains even by finetuning the full network with FRR (E11). Contrary to this, by using FRR-FLFT even on the ERM trained network (E12), we obtain 7.29% improvement over the base model. This shows that, by allowing the full network to change while imposing the FRR constraint, the network can continue to rely on simple features, possibly by reducing the number of complex features learned by the feature extractor. However, by freezing the weights of the linear layer and further imposing this constraint, the network is forced to refine the CIFAR features that are already being used for prediction. Combining FRR-L and FRR-FLFT: While we obtain similar order of gains ( ∼ 7%) using both FRR-L and FRR-FLFT individually, the former improves the diversity of features being considered by the classification head, while the latter improves the quality of the features themselves. We therefore propose a training strategy that combines both FRR-L and FRR-FLFT. Using this, we obtain gains of 16.2% over the ERM baseline as shown in E16, indicating that the combination of FRR-L and FRR-FLFT has a compounding effect by firstly selecting diverse features, and further refining these features to be more useful for classification. Although FRR-L followed by FRR-FT (E15) is also effective, it has about 6% lesser gains when compared to the proposed approach of FRR-L + FRR-FLFT. We note that following up FRR-L with ERM-FT (E13) or ERM-FLFT (E14) also refines the learned features, making them more suitable for the classification task, yielding 2.6% and 4.6% gains respectively over FRR-L. We verify the quality of features learned by the feature extractors after the proposed training strategy FRR-L + FRR-FLFT by reinitializing and retraining the linear classifier on CIFAR-RandMNIST (E17) and MNIST-RandCIFAR (E18) datasets respectively. We observe considerable gains of around 15% on MNIST-CIFAR-10 accuracy using CIFAR-RandMNIST training when compared to ERM (E5), demonstrating that the proposed approach not only results in more CIFAR features being used for classification, but also leads to the learning of better CIFAR features.

5.4. OOD GENERALIZATION IN A REAL WORLD SETTING

We show the efficacy of FRR towards improving OOD generalization on the DomainBed (Gulrajani & Lopez-Paz, 2020) benchmark. We use the performance of the model on in-domain validation data (i.e. the in-domain strategy by Gulrajani & Lopez-Paz (2020) ) to select the best hyper-parameters, and report the average performance and standard deviation across 5 random seeds. Baselines : We compare our method against standard ERM training, which has proven to be a frustratingly difficult baseline (Gulrajani & Lopez-Paz, 2020) , and also against several state of the art methods on this benchmark -SWAD (Cha et al., 2021) , MIRO (Cha et al., 2022) and SMA (Arpit , 2021) . Finally, we show that our approach can be effectively integrated with stochastic weight averaging to obtain further gains. See Appendix G for further experimental details.

Main Results:

The main results of our algorithm are reported in Table 3 . We find that our pipeline of training and finetuning with FRR, when combined with ERM achieves improved performance with respect to the state of the art methods that do not use model weight-averaging, and in fact achieves comparable performance to methods that use model weight averaging as well. Further, our method obtains substantial gains of more than 3% over ERM across datasets. The gains are especially pronounced for the larger datasets including DomainNet and TerraIncognita (8% and 5% resp.), indicating the efficacy and scalability of our method. Further, it is clear from Table 3 that finetuning the feature extractor once the linear layer is fixed provides a boost of over 1% on average over FRR-L. This empirically validates our finetuning paradigm which we denote as ERM+FRR. Finally, using our method in tandem with SWAD helps us achieve a new state-of-the-art on the DomainBed benchmark, outperforming other methods on three datasets while achieving comparable performance on two, and being better than existing SOTA by close to 1% on average. We report detailed results and further ablations in Appendices H and J.

6. CONCLUSION AND DISCUSSION

In this work, we consider the problem of OOD generalization through the lens of mitigating Simplicity Bias in Neural Network training. To unravel the paradox pertaining to the existence of Simplicity Bias in learning only the simplest features, and the observation that the features learned by large practical models may already be sufficiently diverse, we put forth the Feature Replication Hypothesis that conjectures the learning of replicated simple features and sparse complex ones. Combining this with the Implicit Bias of SGD to converge to maximum margin solutions, we provide a theoretical justification to the high OOD sensitivity of Neural Networks. To specifically overcome the effect of simple feature replication in linear layer training, we propose the Feature Reconstruction Regularizer, that penalizes the p norm distance between the features and their reconstruction from the output logits, thus improving the diversity of features used for classification. We further propose to freeze the weights of the linear layer thus trained, and use the FRR regularizer for finetuning the full network, to refine the features to be more useful for the downstream task. We justify the proposed regularizer both theoretically and empirically on synthetic and semisynthetic datasets, and demonstrate its effectiveness in a real world OOD generalization setting. We believe and hope that this work can pave the way towards obtaining a better understanding on the underlying causes for OOD brittleness of neural networks, and inspire the development of better algorithms for addressing the same. We believe the proposed regularizer can potentially work effectively in several other settings that involve the use of linear layer training/finetuning, such as domain adaptation and transfer learning. While the regularizer works effectively in a scenario where the network is first trained using an algorithm such as ERM to learn features that are relevant to the task at hand, the robustness of the proposed algorithm in the presence of severely non-relevant features is yet to be explored.

A PROOFS OF THEORETICAL RESULTS

In this section, we present a generalization of Claim 3.1 and Proposition 3.2 under weaker assumptions on the featurizer, and use them to prove the claims made in Sec 3. A more general setting Consider the dataset distribution D as defined in eq equation 2. Now, let f Θ (x) = Θx, where Θ ∈ d×2 . Further, ∀i θ i,1 +θ i,2 = 1. In simple words, the feature extractor maps the input to a d-dimensional feature representation, where each feature is a convex combination of the input feature variations. Also note that θ =       1 0 1 0 . . . . . . 1 0 0 1       corresponds to feature replication. We now rephrase the results from Sec 3.3 in this setting and provide proofs for the same. Claim A.1. (Restating Claim 3.1) -The max margin classifier w minimizing w 2 and satisfying y w, Θx ≥ 1 is given by w = 2 d , • • • , 2 d . Proof. Consider the point x = (0.5, 0.5). Then, due to the constraint on the max-margin classifier, we have y w, Θx ≥ 1 i.e. 1 2 d j=1 w j (θ j,1 + θ j,2 ) ≥ 1 1 2 d j=1 w j ≥ 1 The minimizer of 2 norm under this constraint would be when all w j = 2 d for all j. Note that the "effective classifier" w in the input space in this case is w, Θ , i.  F RR(w) = min U max 1≤i≤d E[( w, Θx u i -(Θx) i ) 2 ] Let w F RR be the minimizer of F RR(w) satisfying y w F RR , Θx ≥ 0 (i.e. it is a perfect classifier). Then, w F RR satisfies w F RR , Θ •,2 w F RR , Θ •,1 = 1 -(θ b,1 -θ a,2 ) 1 + (θ b,1 -θ a,2 ) where a = arg max i θ i,2 , b = arg max i θ i,1 . Proof. Consider the FRR for this dataset - F RR(w) = min U max 1≤i≤d E[( w, Θx u i -(Θx) i ) 2 ] Published as a conference paper at ICLR 2023 For each i, by computing the minimum over u i by considering the quadratic in u i , we obtain F RR(w) i = E[(Θx) 2 i ] - E[ w, Θx (Θx) i ] 2 E[ w, Θx 2 ] We now consider each term in this expression E[(Θx) 2 i ] = E[(θ i,1 x 1 + θ i,2 x 2 ) 2 ] = 13 12 (θ 2 i,1 + θ 2 i,2 ) + 2θ i,1 θ i,2 This is because E[x 1 x 2 ] = 1 and E[x 2 1 ] = E[x 2 2 ] = 1 + 1 12 . E[ w, Θx (Θx) i ] 2 = E[(x 1 w, Θ •,1 + x 2 w, Θ •,2 )(θ 1,i x 1 + θ 2,i x 2 )] 2 = (E[(x 2 1 w, Θ •,1 θ i,1 + x 2 2 w, Θ •,2 θ i,2 + (x 1 x 2 )( w, Θ •,1 θ i,2 + w, Θ •,2 θ i,1 )]) 2 = ( 13 12 ( w, Θ •,1 θ i,1 + w, Θ •,2 θ i,2 ) + ( w, Θ •,1 θ i,2 + w, Θ •,2 θ i,1 )) 2 = 13 2 12 2 ( w, Θ •,1 2 θ 2 i,1 + w, Θ •,2 2 θ 2 i,2 + 2 w, Θ •,2 w, Θ •,1 θ i,2 θ i,1 ) + w, Θ •,1 2 θ 2 i,2 + w, Θ •,2 2 θ 2 i,1 + 2 w, Θ •,2 w, Θ •,1 θ i,1 θ i,2 + 13 6 (( w, Θ •,1 2 + w, Θ •,2 2 )θ i,1 θ i,2 + ( w, Θ •,1 w, Θ •,2 )(θ 2 i,1 + θ 2 i,2 )) Finally, E[ w, Θx 2 ] = E[(x 1 w, Θ •,1 + x 2 w, Θ •,2 ) 2 ] = 13 12 ( w, Θ •,1 2 + w, Θ •,2 2 ) + 2 w, Θ •,1 w, Θ •,2 Putting it together, F RR(w) i = ( 13 2 12 2 -1)( w, Θ •,1 θ i,2 -w, Θ •,2 θ i,1 ) 2 13 12 ( w, Θ •,1 2 + w, Θ •,2 2 ) + 2 w, Θ •,1 w, Θ •,2 = ( 13 2 12 2 -1)( w, Θ •,1 -( w, Θ •,1 + w, Θ •,2 )θ i,1 ) 2 13 12 ( w, Θ •,1 2 + w, Θ •,2 2 ) + 2 w, Θ •,1 w, Θ •,2 = ( 13 2 12 2 -1)( w, Θ •,2 -( w, Θ •,1 + w, Θ •,2 )θ i,2 ) 2 13 12 ( w, Θ •,1 2 + w, Θ •,2 2 ) + 2 w, Θ •,1 w, Θ •,2 Let w,Θ•,1 w,Θ•,1 + w,Θ•,2 = α. Further, Let a, b be such that a = arg max i (α -θ i,1 ) 2 and b = arg max i (1 -α -θ i,2 ) 2 . Then, F RR(α) ∝ max{(α -θ a,1 ) 2 , (1 -α -θ b,2 ) 2 } 1 + α 2 +(1-α) 2 12 To minimize the above expression w.r.t. α, we compute the derivative of the above expression for each component of the max function ∂ (α-θa,1) 2 ) 1+ α 2 +(1-α) 2 12 ∂α ∝ (α -θ a,1 )(α(2θ a,1 -1) -θ a,1 + 13) (2α -2θ a,1 + 13) 2 and ∂ (1-α-θ b,2 ) 2 ) 1+ α 2 +(1-α) 2 12 ∂α ∝ - (1 -α + θ b,2 )(α(2θ b,2 -1) -θ b,1 -12) (2α -2θ a,1 + 13) 2 For α ≥ 1+θa,1-θ b,2 2 , the second term is greater and the derivative is positive. Similarly, for α ≤ 1+θa,1-θ b,2 2 , the first term is greater and the derivative is positive. Hence, the minima is obtained at α = 1+θa,1-θ b,2 2 . Assuming 0.5 ≤ θ a,1 , θ b,2 ≤ 1 or 0.0 ≤ θ a,1 , θ b,2 ≤ 0.5. Now, in order to compute a, b, we can look at the maximization problem again - a = arg max i (1 -(θ i,1 + θ b,2 )) 2 b = arg max i (1 -(θ a,1 + θ i,2 )) 2 Since θ i,1 , θ i,2 are bounded between [0, 1] the function to maximize is monotonically decreasing in θ •,1 , θ . Hence, w, Θ •,2 w, Θ •,1 = 1 -(θ b,1 -θ a,2 ) 1 + (θ b,1 -θ a,2 ) Assuming that the maximum correlation of both the features is close, FRR will lead to a solution which gives roughly equal weights to both the features. Note that for the case of feature replication, θ b,1 = 1 and θ a,2 = 1. Hence, w,Θ•,2 w,Θ•,1 = 1.

B JUSTIFICATION OF FEATURE REPLICATION HYPOTHESIS (FRH)

In a practical scenario where features are not disentangled, our hypothesis translates to the following: Conjecture: Simpler features of the input are represented more in the feature space of neural networks, while complex (hard-to-learn) features are sparse.

Assumptions:

• We consider simple features such as background to be spurious, and complex features such as shape to be robust. • We consider an overparameterized network that has the capacity to learn more features than what exist, resulting in feature repetition. Justification: We justify the conjecture by showing that all other possibilities discussed below cannot be true.

1.. Assumption: DNNs learn only Simple Features

Contradiction: Prior works (Rosenfeld et al., 2022; Kirichenko et al., 2022b) show that features learned by ERM are diverse, and last layer training on target domain is good enough to obtain robustness to spurious features. This cannot be possible if the network has learned only spurious features.

2.. Assumption: DNNs learn only Complex Features

Contradiction: The dominance of Simple features in the learning of DNNs is shown by Shah et al. (2020) . Moreover, the existence of texture-bias (Geirhos et al., 2018) and background-bias (Xiao et al., 2020) have been demonstrated in prior works, which show the dominance of Simple features. 3. Assumption: DNNs learn a uniform distribution of both Simple and Complex Features. Contradiction: SGD converges to an SVM solution due to its implicit bias (Soudry et al., 2018) . From Claim-3.1 (1), in the presence of balanced features that are correlated with the label, SVM solution gives equal weight to all features to maximize margin. This contradicts the existence of Simplicity Bias (Shah et al., 2020) . 4. Assumption: DNNs learn more Complex Features and less Simple Features. Contradiction: Since Complex features are indeed more robust and are better correlated with the labels, the classifier would rely more on these features. This contradicts the existence of Simplicity Bias (Shah et al., 2020) . Therefore, the only feasible option which supports the empirical observations in the prior works discussed above is that DNNs learn more Simple Features and Complex features are sparse, which justifies our conjecture.

C EMPIRICAL EVIDENCE FOR FEATURE REPLICATION HYPOTHESIS (FRH)

C.1 SYNTHETIC DATASETS We present empirical validation to support the Feature Replication Hypothesis (FRH) on several semi-real datasets and describe them in detail below: 1. Coloured-MNIST-2 -In this dataset, we use images of digits superimposed on either of the two colours-red or green. The difference from Coloured-MNIST is that we consider only two colours for the background, rather than a range. We notice extreme simplicity bias in this case, with the network learning 32 colour features and 0 shape features. 2. Coloured-MNIST-MultiDigit -This is similar to the Coloured-MNIST dataset described in Section-3.1, with the exception that each of the classes is now composed of two digits. More specifically, the digits '1' and '7' and mapped to Class 0 and digits '5' and '8' are mapped to Class 1. We note that '1' and '5' are chosen from the original Coloured-MNIST dataset, while the second digit (e.g.'7') in each class is selected to be one that is similar to the first digit ('1') in the same class. This dataset is constructed specifically to show that the issue of Simplicity Bias and FRH exists even when there is higher variation in the shape feature, and is reported as ColouredMNIST-MultiDigit below. We see that while more shape features are learnt as compared to Coloured-MNIST, the network still relies more on colour to make its decisions. 3. Digit-Coloured-MNIST: This is similar to the Coloured-MNIST dataset described in Section-3.1, with the exception that the digit is coloured rather than the background. This dataset is constructed specifically to show that the issue of SB and FRH exist even when the region that is coloured, which is the extent to which simple features exist in the image is much lesser, and is reported as DigitColouredMNIST below. Although this dataset also demonstrates the presence of SB, we note that the average correlation of features with shape is higher when compared to the above datasets. We attempt to demonstrate feature replication in a model trained with ERM on the Real domain of OfficeHome. We train a ResNet-50 on this domain, and perform PCA on the features learnt by this network. The network learns 2048 features per example, and we compute the 2048 × 2048 sized covariance matrix of the features over samples from a test domain (Clipart). We then compute the eigenvalues of this matrix, and find that 500 principal components can explain about 97.5% of the variance, i.e. the matrix is extremely low rank, as shown in Fig. 3 . This points to the fact that a lot of the learnt features are linearly dependent and highly correlated with each other. This trend is similar to what we observed on ColouredMNIST, where a large number of features were highly correlated with the colour, and in turn with each other (Fig. 5 of the appendix). We note that in all the additional datasets considered, simpler features are represented more in the network while complex (hard-to-learn) features are sparse. This empirically justifies our hypothesis in Section-3. In order to empirically demonstrate feature replication, we use a binarized version of the coloured MNIST dataset (Gulrajani & Lopez-Paz, 2020) . To construct this dataset, we firstly assign two digits of the MNIST dataset, namely "1" and "5", to classes 0 and 1 respectively. For the in-domain training distribution, we associate colours in the range R 0 = [(115, 0, 0) -(256, 141, 0))] (i.e. red) to label 0 (i.e. the digit "1") and the range R 1 = [(0, 115, 0) -(141, 256, 0)] (i.e. green) to the label 1 (i.e. the digit "5"), where colors are represented in the RGB space. To summarize, while training the network, we super-impose images of "1" onto colours of range R 0 , and images of "5" onto colours of range R 1 . It is to be noted that the choice of colour ranges as defined above introduces an overlapping range between [(115, 115, 0) -(141, 141, 0))] where images are associated with labels 0 and 1 with equal probability. This overlap reduces the correlation of colour features with labels, while shape features have a correlation of 1 with the labels. In Figure 4 , we show examples of images from the train and test distributions of this dataset. In Figure 5 , we pictorially depict the correlations between the 32 features learnt by the network. We can see a block structure emerging, indicating that there is a high amount of feature replication.

D.2 TOY DATASET

In line with the theoretical formulation described in Section-3.3, we further empirically validate the brittleness of SVM models and the highlight the effectiveness of the proposed Feature Reconstruc- We consider feature replication along the y-axis, and hence construct this OOD dataset to verify the extent to which the other feature is considered for classification. To select the best hyperparameter for both SVM and FRR, we consider the presence of a validation set whose distribution is similar to the test distribution. As shown in Figure 1 , we observe that the SVM model starts relying more on the replicated features alone in case of feature replication, compromising its performance on the OOD data. The proposed regularizer on the other hand, gives equal importance to both features even in the presence of feature replication, resulting in improved OOD generalization.

E ALGORITHM

Our training procedure is detailed in Alg 1.

F DETAILS ON THE OOD GENERALIZATION SETTING CONSIDERED

The problem of improving robustness to distribution shifts has been studied in several settings, where, in addition to labeled source domain data, varying levels of access to the target domain data is assumed. Some of the well-researched settings include -Unsupervised Domain Adaptation, with access to only unlabeled target domain data (Pan & Yang, 2009; Ganin et al., 2016) , and Domain Generalization, where typically data from several source distributions is assumed to be available, and the target domain in unseen during training (Blanchard et al., 2011; Li et al., 2018a; Gulrajani & Lopez-Paz, 2020) ). In the latter case, it is assumed that all training data samples are annotated with domain labels as well, so that training algorithms can explicitly impose invariance to attributes  W F RR , φ F RR ← Adam (min W,φ i L std (θ std , W, (x i , y i )) + λ L L F RR (x i , θ std , W, φ)). / * FRR-L: Training W, φ with FRR defined in eqn. 4 * / 4 θ FLFT ← Adam (min θ i L std (θ, W FRR , (x i , y i )) + λ F LF T L F RR (x i , θ, W F RR , φ F RR )). / * FRR-FLFT: Finetuning θ with FRR according to eqn. 5 * / Result: Trained model (θ FLFT , W FRR ). that cause a distribution shift in input data without change in their label distribution (Muandet et al., 2013; Ganin et al., 2016; Li et al., 2018b; Arjovsky et al., 2019; Shi et al., 2021) . A more challenging case is when the training data belongs to several distributions that may not even be sufficiently discernable to have explicit domain annotations, or may contain multidimensional distribution shifts, such as weather, time of the day and geographical location, that cannot be easily annotated or clustered. We investigate this crucial setting which has been relatively less researched, and refer to it as Aggregated Domain Generalization, as introduced by Thomas et al. (2021) . We note that this setting is different from the case of training on data from a single domain such as Im-ageNet, and evaluating on distribution shifts (Hendrycks & Dietterich, 2019) , due to the availability of an aggregate of source domains during training, which can enable the effective use of in-domain validation set for hyperparameter selection. While there have been several approaches to improve the performance of models in the setting of Domain Generalization, Gulrajani & Lopez-Paz (2020) show that when evaluated fairly, that is, without assuming access to the test domain data even for selecting the best set of hyperparameters, none of the approaches perform consistently better than standard training using Empirical Risk Minimization (ERM). Furthermore, we consider the setting of Aggregated Domain Generalization, which is more challenging due to the absence of domain labels during both training and validation.

G EXPERIMENTAL DETAILS ON DOMAINBED

We test our approach on the DomainBed benchmark (Gulrajani & Lopez-Paz, 2020) comprising of five different datasets, each of which have k domains. For each dataset, we train a model on k -1 domains, and test it on the left out domain. The average out-of-domain performance across the k held-out domains is then reported. In this section we describe the hyper-parameter selection strategy and the ranges considered for our approach. In line with the DomainBed testbench, we use ImageNet pretrained ResNet-50 models for all algorithms. We use random search to select hyperparameters for our algorithm, and use the suggested hyperparameters for the other baselines. We train for 3000 (5000 for DomainNet) steps in the FRR-L phase, and 5000 (10000 for DomainNet) steps in the FRR-FLFT phase. The batch size is fixed to 32, and SWAD hyper-parameters are the same as those used by Cha et al. (2021) . We use the in-domain accuracy protocol from Gulrajani & Lopez-Paz (2020) to select hyper-parameters for each domain of each dataset, and search over 8 random combinations of hyper-parameters for each. The range of the hyperparameters is shown in Table 5 . Note that we experiment with two implementations of ∞ norm: 1,∞ , where we first compute the ∞ of feature reconstruction for each example in a batch and then average it across the batch, and ∞,1 where we compute the average 1 reconstruction norm of each feature across the batch, and then apply ∞ norm on this m dimensional vector. All our experiments were done on single V100 GPUs.  ) Norm { 1 , 1,∞ , ∞,1 } H ABLATIONS ON DOMAINBED Comparing the choices for φ : In Table 6 , we experiment with various architectures for the decoder φ when computing FRR according to equation 1. We consider using a two layer neural network as the decoder φ (FRR-LDeeper), and also consider setting φ = W T (FRR-LShared), i.e. explicitly tying the weights of the decoder and the classifier layer. Overall, both these variants are worse than the default single layer, free parameterization of φ. We believe that this happens because the latter approach enforces a much stricter constraint on W , leading to poorer in-domain accuracy, while the former approach enforces a weaker constraint, potentially enabling reconstruction of more complex features from a smaller amount of information about them in the logits. Both these have a detrimental effect on the overall performance of the model. i n i t ( i n p u t s h a p e , n u m c l a s s e s , num domains , h p a r a m s ) s e l f . f e a t u r i z e r = n e t w o r k s . F e a t u r i z e r ( i n p u t s h a p e , s e l f . h p a r a m s ) s e l f . c l a s s i f i e r = n e t w o r k s . C l a s s i f i e r ( s e l f . f e a t u r i z e r . n o u t p u t s , n u m c l a s s e s , s e l f . h p a r a m s [ ' n o n l i n e a r c l a s s i f i e r ' ] , ) f o r p a r a m s i n s e l f . f e a t u r i z e r . p a r a m e t e r s ( ) : p a r a m s . r e q u i r e s g r a d = F a l s e s e l f . c l a s s i f i e r i n v = n e t w o r k s . C l a s s i f i e r ( n u m c l a s s e s , s e l f . f e a t u r i z e r . n o u t p u t s , s e l f . h p a r a m s [ ' n o n l i n e a r c l a s s i f i e r ' ] , e l s e True , ) s e l f . n e t w o r k = nn . S e q u e n t i a l ( s e l f . f e a t u r i z e r , s e l f . c l a s s i f i e r ) s e l f . o p t i m i z e r = t o r c h . o p t i m . Adam ( ( l i s t ( s e l f . n e t w o r k . p a r a m e t e r s ( ) ) + l i s t ( s e l f . c l a s s i f i e r i n v . p a r a m e t e r s ( ) ) ) , l r = s e l f . h p a r a m s [ ' l r ' ] , w e i g h t d e c a y = s e l f . h p a r a m s [ ' w e i g h t d e c a y ' ] , ) s e l f . r e c o n s t r u c t i o n w t = s e l f . h p a r a m s [ ' r e c o n s t r u c t i o n w t ' ] s e l f . norm = f l o a t ( s e l f . h p a r a m s [ ' norm ' ] ) In this section, we show detailed results of Table 3 in the main text. The numbers for the baselines are taken from Gulrajani & Lopez-Paz (2020), Cha et al. (2021) and Arpit et al. (2021) , while the results for MIRO (Cha et al., 2022) were reproduced using their code-base. 



In this paper, we refer to the penultimate layer's activations as features. i.e., by the final classification layer.



Figure 1: We demonstrate the brittleness of SVM (a, b) and effectiveness of FRR (c, d) based classifiers on a toy dataset comprising of 2 factors of variation, sampled from a uniform distribution. We consider d (= 0 or 5) feature replications (Rep) along the y-axis. FRR converges to a maximum margin solution in the non-replicated feature space, resulting in improved OOD robustness (d)

Figure 2: Our training procedure: Dotted fill indicates that the parameters are trainable.

1 . In particular, for the case of feature replication, d j=1 θ j,2 = d -1 and d j=1 θ j,1 = 1, leading to a skewed classifier. Now we restate and show the robustness of FRR in this setting. Proposition A.2. (Restating Prop 3.2) -Denote the feature reconstruction regularizer for this setting as -

Figure 3: Distribution of eigen values of covariance of learnt features: A small fraction of principal components can explain most of the variance in features, indicating that features are highly correlated with each other.

Figure 4: Random images from the coloured MNIST dataset: The top row shows examples from the train distribution, while the bottom row has images from the test distribution. Here, colour red corresponds to the digit 1 and green corresponds to the digit 5 in the train data, while this correlation is destroyed in the test data.

Figure 5: Correlation of the features learnt on coloured MNIST

Figure 6: Variation of OOD accuracy with varying λ F RR

d e f u p d a t e ( s e l f , m i n i b a t c h e s , u n l a b e l e d =None ) : a l l x = t o r c h . c a t ( [ x f o r x , y i n m i n i b a t c h e s ] ) a l l y = t o r c h . c a t ( [ y f o r x , y i n m i n i b a t c h e s ] ) p r e d , r e c , f e a t = s e l f . g e t f e a t s a n d r e c ( a l l x ) l o s s = F . c r o s s e n t r o p y ( p r e d , a l l y ) r e c o n s t r u c t i o n l o s s = ( t o r c h . sum ( t o r c h . max ( t o r c h . abs ( f e a tr e c ) , dim = 1 ) [ 0 ] ) / a l l x . s h a p e [ 0 ] ) l o s s = l o s s + s e l f . r e c o n s t r u c t i o n w t * r e c o n s t r u c t i o n l o s s s e l f . o p t i m i z e r . z e r o g r a d ( ) l o s s . b a c k w a r d ( ) s e l f . o p t i m i z e r . s t e p ( ) r e t u r n { ' l o s s ' : l o s s . i t e m ( ) , ' r e c o n s t r u c t i o n l o s s ' : r e c o n s t r u c t i o n l o s s . i t e m ( ) , } d e f p r e d i c t ( s e l f , x ) : r e t u r n s e l f . n e t w o r k ( x ) d e f g e t f e a t s a n d r e c ( s e l f , x ) : f e a t s = s e l f . n e t w o r k [ 0 ] ( x ) p r e d = s e l f . n e t w o r k [ 1 ] ( f e a t s ) r e c = s e l f . c l a s s i f i e r i n v ( p r e d ) r e t u r n p r e d , r e c , f e a t s J DOMAIN WISE ACCURACIES

Features replication in Coloured MNIST: We observe that ERM learns more colour features than shape features, and the prediction is less correlated with the shape features. Adding FRR makes the network depend more on shape and less on colour, leading to better OOD performance.

ID and OOD accuracy (%) by training on MNIST-CIFAR-10 in various training regimes.

Results on DomainBed: The bottom partition shows results of methods that perform model weight averaging. In both cases, with (top) and without (bottom) model weight averaging, the proposed approach outperforms existing methods.

•,2 . Hence, a = arg min i θ i,1 and b = arg min i θ i,2 . Conversely, a = arg max i θ i,2 , b = arg max i θ i,1 and α =

Features learnt by an ERM trained model on synthetic datasets.

Algorithm 1: Our training algorithm Data:Training data D S = {(x i , y i ) : i ∈ [n]}, model (θ, W ), feature reconstruction model φ, λ F RR , λ F T 1 θ std , W std ← Adam (min θ,W i L std (θ, W, (x i , y i ))). / *Standard training of model parameters θ and W . * / 2 Freeze θ to be θ std / * Initializing model for training with FRR. * / 3

Ranges of hyperparameters considered for DomainBed

Effect of different design choices on OOD accuracy:the rows shows different architecture choices for φ Sensitivity Analysis : We vary λ F RR and plot out the OOD performance in Fig 6.We find that the performance is stable for a wide range of the hyper-parameter on most domains.

Out-of-domain accuracies (%) on PACS. ± 0.8 77.4 ± 0.8 97.3 ± 0.4 73.5 ± 2.3 83.7 GroupDRO 83.5 ± 0.9 79.1 ± 0.6 96.7 ± 0.3 78.3 ± 2.0 84.4 MTL 87.5 ± 0.8 77.1 ± 0.5 96.4 ± 0.8 77.3 ± 1.8 84.6 I-Mixup 86.1 ± 0.5 78.9 ± 0.8 97.6 ± 0.1 75.8 ± 1.8 84.6 MMD 86.1 ± 1.4 79.4 ± 0.9 96.6 ± 0.2 76.5 ± 0.5 84.7 VREx 86.0 ± 1.6 79.1 ± 0.6 96.9 ± 0.5 77.7 ± 1.7 84.9 MLDG 85.5 ± 1.4 80.1 ± 1.7 97.4 ± 0.3 76.6 ± 1.1 84.9 ARM 86.8 ± 0.6 76.8 ± 0.5 97.4 ± 0.3 79.3 ± 1.2 85.1 RSC 85.4 ± 0.8 79.7 ± 1.8 97.6 ± 0.3 78.2 ± 1.2 85.2 Mixstyle 86.8 ± 0.5 79.0 ± 1.4 96.6 ± 0.1 78.5 ± 2.3 85.± 0.2 83.4 ± 0.6 97.3 ± 0.3 82.5 ± 0.5 88.1 SWAD+FRR 89.9 ± 0.2 83.9 ± 0.7 98.2 ± 0.3 84.8 ± 0.4 89.2

Out-of-domain accuracies (%) on VLCS. ± 0.3 63.4 ± 0.9 69.5 ± 0.8 76.7 ± 0.7 76.7 RSC 97.9 ± 0.1 62.5 ± 0.7 72.3 ± 1.2 75.6 ± 0.8 77.1 MLDG 97.4 ± 0.2 65.2 ± 0.7 71.0 ± 1.4 75.3 ± 1.0 77.2 MTL 97.8 ± 0.4 64.3 ± 0.3 71.5 ± 0.7 75.3 ± 1.7 77.2 I-Mixup 98.3 ± 0.6 64.8 ± 1.0 72.1 ± 0.5 74.3 ± 0.8 77.4 ERM 97.7 ± 0.4 64.3 ± 0.9 73.4 ± 0.5 74.6 ± 1.3 77.5 MMD 97.7 ± 0.1 64.0 ± 1.1 72.8 ± 0.2 75.3 ± 3.3 77.5 CDANN 97.1 ± 0.3 65.1 ± 1.2 70.7 ± 0.8 77.1 ± 1.5 77.5 ARM 98.7 ± 0.2 63.6 ± 0.7 71.3 ± 1.2 76.7 ± 0.6 77.6 SagNet 97.9 ± 0.4 64.5 ± 0.5 71.4 ± 1.3 77.5 ± 0.5 77.8 Mixstyle 98.6 ± 0.3 64.5 ± 1.1 72.6 ± 0.5 75.7 ± 1.7 77.9 VREx 98.4 ± 0.3 64.4 ± 1.4 74.1 ± 0.4 76.2 ± 1.3 78.3 IRM 98.6 ± 0.1 64.9 ± 0.9 73.4 ± 0.6 77.3 ± 0.9 78.6 DANN 99.0 ± 0.3 65.1 ± 1.4 73.1 ± 0.3 77.2 ± 0.6 78.6 CORAL 98.3 ± 0.1 66.1 ± 1.2 73.4 ± 0.3 77.5 ± 1.2 78.8 SMA 99.0 ± 0.2 63.0 ± 0.2 74.5 ± 0.3 76.4 ± 1.1 78.± 0.1 63.3 ± 0.3 75.3 ± 0.5 79.2 ± 0.6 79.1 SWAD+FRR 98.9 ± 0.4 66.3 ± 0.2 75.9 ± 0.6 79.0 ± 0.2 80.0

Out-of-domain accuracies (%) on OfficeHome. ± 0.3 53.2 ± 0.4 68.2 ± 0.7 69.2 ± 0.6 60.4 IRM 58.9 ± 2.3 52.2 ± 1.6 72.1 ± 2.9 74.0 ± 2.5 64.3 ARM 58.9 ± 0.8 51.0 ± 0.5 74.1 ± 0.1 75.2 ± 0.3 64.8 RSC 60.7 ± 1.4 51.4 ± 0.3 74.8 ± 1.1 75.1 ± 1.3 65.5 CDANN 61.5 ± 1.4 50.4 ± 2.4 74.4 ± 0.9 76.6 ± 0.8 65.7 DANN 59.9 ± 1.3 53.0 ± 0.3 73.6 ± 0.7 76.9 ± 0.5 65.9 GroupDRO 60.4 ± 0.7 52.7 ± 1.0 75.0 ± 0.7 76.0 ± 0.7 66.0 MMD 60.4 ± 0.2 53.3 ± 0.3 74.3 ± 0.1 77.4 ± 0.6 66.4 MTL 61.5 ± 0.7 52.4 ± 0.6 74.9 ± 0.4 76.8 ± 0.4 66.4 VREx 60.7 ± 0.9 53.0 ± 0.9 75.3 ± 0.1 76.6 ± 0.5 66.4 ERM 61.3 ± 0.7 52.4 ± 0.3 75.8 ± 0.1 76.6 ± 0.3 66.5 MLDG 61.5 ± 0.9 53.2 ± 0.6 75.0 ± 1.2 77.5 ± 0.4 66.8 I-Mixup 62.4 ± 0.8 54.8 ± 0.6 76.9 ± 0.3 78.3 ± 0.2 68.1 SagNet 63.4 ± 0.2 54.8 ± 0.4 75.8 ± 0.4 78.3 ± 0.3 68.1 CORAL 65.3 ± 0.4 54.4 ± 0.5 76.5 ± 0.1 78.4 ± 0.5 68.7 ± 0.4 57.7 ± 0.4 78.4 ± 0.1 80.2 ± 0.2 70.6 SWAD+FRR 65.2 ± 0.2 57.7 ± 0.5 78.2 ± 0.2 80.2 ± 0.1 70.3

