MODELING THE DATA-GENERATING PROCESS IS NECESSARY FOR OUT-OF-DISTRIBUTION GENERALIZATION

Abstract

Recent empirical studies on domain generalization (DG) have shown that DG algorithms that perform well on some distribution shifts fail on others, and no state-of-the-art DG algorithm performs consistently well on all shifts. Moreover, real-world data often has multiple distribution shifts over different attributes; hence we introduce multi-attribute distribution shift datasets and find that the accuracy of existing DG algorithms falls even further. To explain these results, we provide a formal characterization of generalization under multi-attribute shifts using a canonical causal graph. Based on the relationship between spurious attributes and the classification label, we obtain realizations of the canonical causal graph that characterize common distribution shifts and show that each shift entails different independence constraints over observed variables. As a result, we prove that any algorithm based on a single, fixed constraint cannot work well across all shifts, providing theoretical evidence for mixed empirical results on DG algorithms. Based on this insight, we develop Causally Adaptive Constraint Minimization (CACM), an algorithm that uses knowledge about the data-generating process to adaptively identify and apply the correct independence constraints for regularization. Results on fully synthetic, MNIST, small NORB, and Waterbirds datasets, covering binary and multi-valued attributes and labels, show that adaptive dataset-dependent constraints lead to the highest accuracy on unseen domains whereas incorrect constraints fail to do so. Our results demonstrate the importance of modeling the causal relationships inherent in the data-generating process.

1. INTRODUCTION

To perform reliably in real world settings, machine learning models must be robust to distribution shifts -where the training distribution differs from the test distribution. Given data from multiple domains that share a common optimal predictor, the domain generalization (DG) task (Wang et al., 2021; Zhou et al., 2021) encapsulates this challenge by evaluating accuracy on an unseen domain. Recent empirical studies of DG algorithms (Wiles et al., 2022; Ye et al., 2022) have characterized different kinds of distribution shifts across domains. Using MNIST as an example, a diversity shift is when domains are created either by adding new values of a spurious attribute like rotation (e.g., Rotated-MNIST dataset (Ghifary et al., 2015; Piratla et al., 2020) ) whereas a correlation shift is when domains exhibit different values of correlation between the class label and a spurious attribute like color (e.g., Colored-MNIST (Arjovsky et al., 2019) ). Partly because advances in representation learning for DG (Ahuja et al., 2021; Krueger et al., 2021; Mahajan et al., 2021; Arjovsky et al., 2019; Li et al., 2018a; Sun & Saenko, 2016) have focused on either one of the shifts, these studies find that performance of state-of-the-art DG algorithms are not consistent across different shifts: algorithms performing well on datasets with one kind of shift fail on datasets with another kind of shift. In this paper, we pose a harder, more realistic question: What if a dataset exhibits two or more kinds of shifts simultaneously? Such shifts over multiple attributes (where an attribute refers to a spurious high-level variable like rotation) are often observed in real data. For example, satellite imagery data demonstrates distribution shifts over time as well as the region captured (Koh et al., 2021) . To study this question, we introduce multi-attribute distribution shift datasets. For instance, in our Col+Rot-MNIST dataset (see Figure 1 ), both the color and rotation angle of digits can shift

