A UNIFIED CAUSAL VIEW OF DOMAIN INVARIANT SU-PERVISED REPRESENTATION LEARNING

Abstract

Machine learning methods can be unreliable when deployed in domains that differ from the domains on which they were trained. One intuitive approach for addressing this is to learn representations of data that are domain-invariant in the sense that they preserve data structure that is stable across domains, but throw out spuriouslyvarying parts. There are many approaches aimed at this kind of representationlearning, including methods based on data augmentation, distributional invariances, and risk invariance. Unfortunately, it is often unclear when a given method actually learns domain-invariant structure, and whether learning domain-invariant structure actually yields robust models. The key issue is that, in general, it's unclear how to formalize "domain-invariant structure". The purpose of this paper is to study these questions in the context of a particular natural domain shift notion that admits a natural formal notion of domain invariance. This notion is a formalization of the idea that causal relationships are invariant, but non-causal relationships (e.g., due to confounding) may vary. We find that whether a given method learns domain-invariant structure, and whether this leads to robust prediction, both depend critically on the true underlying causal structure of the data. 1. Formalization of CISA and Counterfactually Invariant Representation Learning.

1. INTRODUCTION

Machine learning methods could have unreliable performance at the presence of domain shift (Shimodaira, 2000; Quinonero-Candela et al., 2008) , a structural mismatch between the training domain(s) and the deployed domain(s). A variety of techniques have been proposed to mitigate domain shift problems. One popular class of approach-which we'll focus on in this paper-is to try to learn a representation function ϕ of the data that is in some sense "invariant" across domains. Informally, the aim of such methods is to find a representation that captures the part of the data structure that is the "same" in all domains while discarding the part that varies across domains. It then seems intuitive that a predictor trained on top of such a representation would have stable performance even in new domains. Despite the intuitive appeal, there are fundamental open questions: when does a given method actually succeed at learning the part of the data that is invariant across domains? When does learning a domain invariant representation actually lead to robust out-of-domain predictions? There are many methods aimed at domain-invariant representation learning. When applied to broad ranges of real-world domain-shift benchmarks, there is no single dominant approach, and an attack that works well in one context is often worse than vanilla empirical risk minimization in another (i.e., just ignoring the domain-shift problem). We'll study three broad classes of method: Data augmentation Each example is perturbed in some way and we learn a representation that is the same for all perturbed versions. E.g., if t(X) is a small rotation of an image X, then ϕ(X) = ϕ(t(X)) (Krizhevsky et al., 2012; Hendrycks et al., 2019; Cubuk et al., 2019; Xie et al., 2020; Wei & Zou, 2019; Paschali et al., 2019; Hariharan & Girshick, 2017; Sennrich et al., 2015; Kobayashi, 2018; Nie et al., 2020) .

Distributional invariance

We learn a representation so that some distribution involving ϕ(X) is constant in all domains. There are three such distributional invariances that can be required to hold for all domains e, e ′ : marginal invariance: P e (ϕ(X)) = P e ′ (ϕ(X)) (Muandet et al., 2013; Ganin et al., 2016; Albuquerque et al., 2020; Li et al., 2018a; Sun et al., 2017; Sun & Saenko, 2016; Matsuura & Harada, 2020) ; et al., 2018b; Long et al., 2018; Tachet des Combes et al., 2020; Goel et al., 2020) sufficiency: et al., 2016; Rojas-Carulla et al., 2018; Wald et al., 2021) . Risk minimizer invariance For supervised learning, we learn a representation ϕ(X) so that there is a fixed (domain-independent) predictor w * on top of ϕ(X) that minimizes risk in all domains (Arjovsky et al., 2019; Lu et al., 2021; Ahuja et al., 2020; Krueger et al., 2021; Bae et al., 2021) . conditional invariance: When Y is a label of interest, P e (ϕ(X) | Y ) = P e ′ (ϕ(X) | Y ) (Li P e (Y | ϕ(X)) = P e ′ (Y | ϕ(X)) (Peters In each case, the aim is to learn a representation that throws away information that varies 'spuriously' across domains while preserving information that is reliably useful for downstream tasks. However, the notion of what is thrown away is substantively different across all the approaches and it is unclear which, if any, is appropriate for any particular problem. The principle challenge to answering our motivating questions is that it's unclear in general how to formalize the idea of "part of the data that is invariant across domains". To make progress, it is necessary to specify the manner in which different domains are related. In particular, we require an assumption that is both reasonable for real-world domain shifts and that precisely specifies what structure is invariant across domains. In many problems, it is natural to assume that causal relationships-determined by the unobserved real-world dynamics underpinning the data-should be the same in all domains. We'll use an assumption of this form; the first task is to translate it into a concrete notion of domain shift.foot_0  Specializing to supervised learning with label Y and covariates X, we proceed as follows. The covariates X are caused by some (unknown) factors of variation. These factors of variation are also dependent with Y . For some factors of variation, jointly denoted as Z, the relationship between Y and Z is spurious in the sense that Y and Z are dependent due to an unobserved common cause. The distribution of this unobserved common cause may change across environments, which in turns means the relationship between Y and Z can shift. However, the structural causal relationships between variables will be the same in all environments-e.g., P(X | pa(X)) is invariant, where pa(X) denotes the (possibly unobserved) causal parents of X. We call a family of domains with this structure Causally Invariant with Spurious Associations (CISA). Concretely, consider the problem of classifying images X as either Camel or Cow Y . In training, the presence of sand in the image background Z is strongly associated with camels. But, we may deploy our new classifier in an island domain where cows are frequently on beaches-changing the Z-Y association. Nevertheless, the causal relationships between the factors of variation-Y , Z, and others such as camera type or time of day-and the image X remain invariant. In this example, a natural formalization of "domain invariant part" of the image is the part that does not changes if grass Z is added or removed from the background; an invariant representation learning method should learn a ϕ(X) that throws away such information. The aim of CISA is to be a reasonably broad notion of domain-shift that also allows us to formalize this intuitive notion of domain-invariance. Namely, under CISA, we define the domain-invariant part of X to be the part that is not causally affected by the spurious factors of variation Z. Accordingly, a representation ϕ is domain invariant if ϕ(X) is not causally affected by Z. We say a representation with this property is counterfactually invariant to spurious factors (CF-invariant for short). We now return to the motivating questions: when does a given method actually succeed at learning the part of the data that is invariant across domains? And, when does learning a domain invariant representation actually lead to more robust out-of-domain predictions? In the context of CISAcompatible domain shifts, we can answer the first question by determining the conditions under which each approach learns a CF-invariant representation, and the second question by studying the relationship between CF-invariance and domain shift performance. Informally, the technical contributions of the paper are:



This kind of causal-invariance assumption is already used in the domain-shift literature, though the formal domain shift notion we'll use here differs from previous approaches, see Section

