HOW ROBUST IS UNSUPERVISED REPRESENTATION LEARNING TO DISTRIBUTION SHIFT?

Abstract

The robustness of machine learning algorithms to distributions shift is primarily discussed in the context of supervised learning (SL). As such, there is a lack of insight on the robustness of the representations learned from unsupervised methods, such as self-supervised learning (SSL) and auto-encoder based algorithms (AE), to distribution shift. We posit that the input-driven objectives of unsupervised algorithms lead to representations that are more robust to distribution shift than the target-driven objective of SL. We verify this by extensively evaluating the performance of SSL and AE on both synthetic and realistic distribution shift datasets. Following observations that the linear layer used for classification itself can be susceptible to spurious correlations, we evaluate the representations using a linear head trained on a small amount of out-of-distribution (OOD) data, to isolate the robustness of the learned representations from that of the linear head. We also develop "controllable" versions of existing realistic domain generalisation datasets with adjustable degrees of distribution shifts. This allows us to study the robustness of different learning algorithms under versatile yet realistic distribution shift conditions. Our experiments show that representations learned from unsupervised learning algorithms generalise better than SL under a wide variety of extreme as well as realistic distribution shifts.

1. INTRODUCTION

Machine Learning (ML) algorithms are classically designed under the statistical assumption that the training and test data are drawn from the same distribution. However, this assumption does not hold in most cases of real world deployment of ML systems. For example, medical researchers might obtain their training data from hospitals in Europe, but deploy their trained models in Asia; the changes in conditions such as imaging equipment and demography result in a shift in the data distribution between train and test set (Dockès et al., 2021; Glocker et al., 2019; Henrich et al., 2010) . To perform well on such tasks requires the models to generalise to unseen distributions -an important property that is not evaluated on standard machine learning datasets like ImageNet, where the train and test set are sampled i.i.d. from the same distribution. With increasing attention on this issue, researchers have been probing the generalisation performance of ML models by creating datasets that feature distribution shift tasks (Koh et al., 2021; Gulrajani and Lopez-Paz, 2020; Shah et al., 2020) and proposing algorithms that aim to improve generalisation performance under distribution shift (Ganin et al., 2016; Arjovsky et al., 2019; Sun and Saenko, 2016; Sagawa et al., 2020; Shi et al., 2022) . In this work, we identify three specific problems with current approaches in distribution shift problems, in computer vision, and develop a suite of experiments to address them. Most of these work are carried out under the scope of supervised learning (SL), including various works that either investigate spurious correlations (Shah et al., 2020; Hermann and Lampinen, 2020; Kalimeris et al., 2019) or those that propose specialised methods to improve generalisation and/or avoid shortcut solutions (Arjovsky et al., 2019; Ganin et al., 2016; Sagawa et al., 2020; Teney et al., 2022) . However, recent research (Shah et al., 2020; Geirhos et al., 2020) highlighted the extreme vulnerability of SL methods to spurious correlations: they are susceptible to learning only features that are irrelevant to the true labelling functions yet highly predictive of the labels. This behaviour is not surprising given SL's target-driven objective: when presented with two features that are equally predictive of the target label, SL models have no incentive to learn both as learning only one of them suffices to predict the target label. This leads to poor generalisation when the learned feature is missing in the OOD test set. On the other hand, in recent times, research in computer vision has seen a surge of unsupervised representation learning algorithms. These include self-supervised learning (SSL) algorithms (e.g., Chen et al. (2020a); Grill et al. (2020) ; Chen and He (2021)), which learn representations by enforcing invariance between the representations of two distinctly augmented views of the same image, and auto-encoder based algorithms (AE) (Rumelhart et al., 1985; Kingma and Welling, 2014; Higgins et al., 2017; Burda et al., 2016) , which learn representations by reconstructing the input image. The immense popularity of these methods are mostly owed to their impressive performance on balanced in-distribution (ID) test datasets -how they perform on distribution shift tasks remains largely unknown. However, in distribution shift tasks, it is particularly meaningful to study unsupervised algorithms. This is because, in comparison to SL, their learning objectives are more input-driven i.e. they are incentivised to learn representations that most accurately represent the input data (Chen et al., 2020a; Alemi et al., 2017) . When presented with two features equally predictive of the labels, unsupervised learning algorithms encourage the model to go beyond learning what's enough to predict the label, and instead focus on maximising the mutual information between the learned representations and the input. We hypothesise that this property of unsupervised representation learning algorithms helps them avoid the exploitation of spurious correlations, and thus fare better under distribution shift, compared to SL. Contribution: Systematically evaluate SSL and AE on distribution shift tasks. We evaluate and compare the generalisation performance of unsupervised representation learning algorithms, including SSL and AE, with standard supervised learning. See section 2 for more details on our experiments. Problem 2: Disconnect between synthetic and realistic datasets Broadly speaking, there exists two types of datasets for studying distribution shift: synthetic datasets where the shift between train/test distribution is explicit and controlled (e.g. MNIST-CIFAR (Shah et al., 2020 ), CdSprites (Shi et al., 2022) ) and realistic datasets featuring implicit distribution shift in the real world (e.g. WILDS (Koh et al., 2021) ). We provide visual examples in fig. 1 .



Figure 1: Synthetic vs. realistic distribution shift: The distribution shift in synthetic datasets (left, MNIST-CIFAR and CdSprites) are usually extreme and controllable (adjusted via changing the correlation); for realistic datasets (right, WILDS-Camelyon17 and FMoW) distribution shift can be subtle, hard to identify and impossible to control. 1.1 EXISTING PROBLEMS AND CONTRIBUTIONS Problem 1: The outdated focus on supervised regime for distribution shift In ML research, distribution shift has been studied in various contexts under different terminologies such as simplicity bias (Shah et al., 2020), dataset bias (Torralba and Efros, 2011), shortcut learning (Geirhos et al., 2020), and domain adaptation and generalisation (Koh et al., 2021; Gulrajani and Lopez-Paz, 2020).Most of these work are carried out under the scope of supervised learning (SL), including various works that either investigate spurious correlations(Shah et al., 2020; Hermann and Lampinen, 2020;  Kalimeris et al., 2019)  or those that propose specialised methods to improve generalisation and/or avoid shortcut solutions(Arjovsky et al., 2019; Ganin et al., 2016; Sagawa et al., 2020; Teney et al.,  2022). However, recent research(Shah et al., 2020; Geirhos et al., 2020)  highlighted the extreme vulnerability of SL methods to spurious correlations: they are susceptible to learning only features that are irrelevant to the true labelling functions yet highly predictive of the labels. This behaviour is not surprising given SL's target-driven objective: when presented with two features that are equally predictive of the target label, SL models have no incentive to learn both as learning only one of them suffices to predict the target label. This leads to poor generalisation when the learned feature is missing in the OOD test set.

