MODELING THE DATA-GENERATING PROCESS IS NECESSARY FOR OUT-OF-DISTRIBUTION GENERALIZATION

Abstract

Recent empirical studies on domain generalization (DG) have shown that DG algorithms that perform well on some distribution shifts fail on others, and no state-of-the-art DG algorithm performs consistently well on all shifts. Moreover, real-world data often has multiple distribution shifts over different attributes; hence we introduce multi-attribute distribution shift datasets and find that the accuracy of existing DG algorithms falls even further. To explain these results, we provide a formal characterization of generalization under multi-attribute shifts using a canonical causal graph. Based on the relationship between spurious attributes and the classification label, we obtain realizations of the canonical causal graph that characterize common distribution shifts and show that each shift entails different independence constraints over observed variables. As a result, we prove that any algorithm based on a single, fixed constraint cannot work well across all shifts, providing theoretical evidence for mixed empirical results on DG algorithms. Based on this insight, we develop Causally Adaptive Constraint Minimization (CACM), an algorithm that uses knowledge about the data-generating process to adaptively identify and apply the correct independence constraints for regularization. Results on fully synthetic, MNIST, small NORB, and Waterbirds datasets, covering binary and multi-valued attributes and labels, show that adaptive dataset-dependent constraints lead to the highest accuracy on unseen domains whereas incorrect constraints fail to do so. Our results demonstrate the importance of modeling the causal relationships inherent in the data-generating process.

1. INTRODUCTION

To perform reliably in real world settings, machine learning models must be robust to distribution shifts -where the training distribution differs from the test distribution. Given data from multiple domains that share a common optimal predictor, the domain generalization (DG) task (Wang et al., 2021; Zhou et al., 2021) encapsulates this challenge by evaluating accuracy on an unseen domain. Recent empirical studies of DG algorithms (Wiles et al., 2022; Ye et al., 2022) have characterized different kinds of distribution shifts across domains. Using MNIST as an example, a diversity shift is when domains are created either by adding new values of a spurious attribute like rotation (e.g., Rotated-MNIST dataset (Ghifary et al., 2015; Piratla et al., 2020) ) whereas a correlation shift is when domains exhibit different values of correlation between the class label and a spurious attribute like color (e.g., Colored-MNIST (Arjovsky et al., 2019) ). Partly because advances in representation learning for DG (Ahuja et al., 2021; Krueger et al., 2021; Mahajan et al., 2021; Arjovsky et al., 2019; Li et al., 2018a; Sun & Saenko, 2016) have focused on either one of the shifts, these studies find that performance of state-of-the-art DG algorithms are not consistent across different shifts: algorithms performing well on datasets with one kind of shift fail on datasets with another kind of shift. In this paper, we pose a harder, more realistic question: What if a dataset exhibits two or more kinds of shifts simultaneously? Such shifts over multiple attributes (where an attribute refers to a spurious high-level variable like rotation) are often observed in real data. For example, satellite imagery data demonstrates distribution shifts over time as well as the region captured (Koh et al., 2021) . To study this question, we introduce multi-attribute distribution shift datasets. For instance, in our Col+Rot-MNIST dataset (see Figure 1 ), both the color and rotation angle of digits can shift across data distributions. We find that existing DG algorithms that are often targeted for a specific shift fail to generalize in such settings: best accuracy falls from 50-62% for individual shift MNIST datasets to <50% (lower than a random guess) for the multi-attribute shift dataset. To explain such failures, we propose a causal framework for generalization under multi-attribute distribution shifts. We use a canonical causal graph to model commonly observed distribution shifts. Under this graph, we characterize a distribution shift by the type of relationship between spurious attributes and the classification label, leading to different realized causal DAGs. Using d-separation on the realized DAGs, we show that each shift entails distinct constraints over observed variables and prove that no conditional independence constraint is valid across all shifts. As a special case of multi-attribute, when datasets exhibit a single-attribute shift across domains, this result provides an explanation for the inconsistent performance of DG algorithms reported by Wiles et al. (2022); Ye et al. (2022) . It implies that any algorithm based on a single, fixed independence constraint cannot work well across all shifts: there will be a dataset on which it will fail (Section 3.3). We go on to ask if we can develop an algorithm that generalizes to different kinds of individual shifts as well as simultaneous multi-attribute shifts. For the common shifts modeled by the canonical graph, we show that identification of the correct regularization constraints requires knowing only the type of relationship between attributes and the label, not the full graph. As we discuss in Section 3.1, the type of shift for an attribute is often available or can be inferred for real-world datasets. Based on this, we propose Causally Adaptive Constraint Minimization (CACM), an algorithm that leverages knowledge about the data-generating process (DGP) to identify and apply the correct independence constraints for regularization. Given a dataset with auxiliary attributes and their relationship with the target label, CACM constrains the model's representation to obey the conditional independence constraints satisfied by causal features of the label, generalizing past work on causality-based regularization (Mahajan et al., 2021; Veitch et al., 2021; Makar et al., 2022) to multi-attribute shifts. We evaluate CACM on novel multi-attribute shift datasets based on MNIST, small NORB, and Waterbirds images. Across all datasets, applying the incorrect constraint, often through an existing DG algorithm, leads to significantly lower accuracy than the correct constraint. Further, CACM achieves substantially better accuracy than existing algorithms on datasets with multi-attribute shifts as well as individual shifts. Our contributions include: • Theoretical result that an algorithm using a fixed independence constraint cannot yield an optimal classifier on all datasets. • An algorithm, Causally Adaptive Constraint Minimization (CACM), to adaptively derive the correct regularization constraint(s) based on the causal graph that outperforms existing DG algorithms. • Multi-attribute shifts-based benchmarks for domain generalization where existing algorithms fail.



Figure 1: (a) Our multi-attribute distribution shift dataset Col+Rot-MNIST. We combine Colored MNIST (Arjovsky et al., 2019) and Rotated MNIST (Ghifary et al., 2015) to introduce distinct shifts over Color and Rotation attributes. (b) The causal graph representing the data generating process for Col+Rot-MNIST. Color has a correlation with Y which changes across environments while Rotation varies independently. (c) Comparison with DG algorithms optimizing for different constraints shows the superiority of Causally Adaptive Constraint Minimization (CACM) (full table in Section 5).

