A UNIFIED CAUSAL VIEW OF DOMAIN INVARIANT SU-PERVISED REPRESENTATION LEARNING

Abstract

Machine learning methods can be unreliable when deployed in domains that differ from the domains on which they were trained. One intuitive approach for addressing this is to learn representations of data that are domain-invariant in the sense that they preserve data structure that is stable across domains, but throw out spuriouslyvarying parts. There are many approaches aimed at this kind of representationlearning, including methods based on data augmentation, distributional invariances, and risk invariance. Unfortunately, it is often unclear when a given method actually learns domain-invariant structure, and whether learning domain-invariant structure actually yields robust models. The key issue is that, in general, it's unclear how to formalize "domain-invariant structure". The purpose of this paper is to study these questions in the context of a particular natural domain shift notion that admits a natural formal notion of domain invariance. This notion is a formalization of the idea that causal relationships are invariant, but non-causal relationships (e.g., due to confounding) may vary. We find that whether a given method learns domain-invariant structure, and whether this leads to robust prediction, both depend critically on the true underlying causal structure of the data.

1. INTRODUCTION

Machine learning methods could have unreliable performance at the presence of domain shift (Shimodaira, 2000; Quinonero-Candela et al., 2008) , a structural mismatch between the training domain(s) and the deployed domain(s). A variety of techniques have been proposed to mitigate domain shift problems. One popular class of approach-which we'll focus on in this paper-is to try to learn a representation function ϕ of the data that is in some sense "invariant" across domains. Informally, the aim of such methods is to find a representation that captures the part of the data structure that is the "same" in all domains while discarding the part that varies across domains. It then seems intuitive that a predictor trained on top of such a representation would have stable performance even in new domains. Despite the intuitive appeal, there are fundamental open questions: when does a given method actually succeed at learning the part of the data that is invariant across domains? When does learning a domain invariant representation actually lead to robust out-of-domain predictions? There are many methods aimed at domain-invariant representation learning. When applied to broad ranges of real-world domain-shift benchmarks, there is no single dominant approach, and an attack that works well in one context is often worse than vanilla empirical risk minimization in another (i.e., just ignoring the domain-shift problem). We'll study three broad classes of method: Data augmentation Each example is perturbed in some way and we learn a representation that is the same for all perturbed versions. E.g., if t(X) is a small rotation of an image X, then ϕ(X) = ϕ(t(X)) (Krizhevsky et al., 2012; Hendrycks et al., 2019; Cubuk et al., 2019; Xie et al., 2020; Wei & Zou, 2019; Paschali et al., 2019; Hariharan & Girshick, 2017; Sennrich et al., 2015; Kobayashi, 2018; Nie et al., 2020) .

Distributional invariance

We learn a representation so that some distribution involving ϕ(X) is constant in all domains. There are three such distributional invariances that can be required to hold for all domains e, e ′ : marginal invariance: P e (ϕ(X)) = P e ′ (ϕ(X)) (Muandet et al., 2013; Ganin et al., 2016; Albuquerque et al., 2020; Li et al., 2018a; Sun et al., 2017; Sun & Saenko, 2016; Matsuura 

