LEARNING CAUSAL SEMANTIC REPRESENTATION FOR OUT-OF-DISTRIBUTION PREDICTION

Abstract

Conventional supervised learning methods, especially deep ones, are found to be sensitive to out-of-distribution (OOD) examples, largely because the learned representation mixes the semantic factor with the variation factor due to their domainspecific correlation, while only the semantic factor causes the output. To address the problem, we propose a Causal Semantic Generative model (CSG) based on causality to model the two factors separately, and learn it on a single training domain for prediction without (OOD generalization) or with unsupervised data (domain adaptation) in a test domain. We prove that CSG identifies the semantic factor on the training domain, and the invariance principle of causality subsequently guarantees the boundedness of OOD generalization error and the success of adaptation. We also design novel and delicate learning methods for both effective learning and easy prediction, following the first principle of variational Bayes and the graphical structure of CSG. Empirical study demonstrates the effect of our methods to improve test accuracy for OOD generalization and domain adaptation.

1. INTRODUCTION

Deep learning has initiated a new era of artificial intelligence where the potential of machine learning models is greatly unleashed. Despite the great success, these methods heavily rely on the independently-and-identically-distributed (IID) assumption. This does not always perfectly hold in practice, and the prediction of output (label, response, outcome) y may be saliently affected in out-of-distribution (OOD) cases, even from an essentially irrelevant change to the input (covariate) x, like a position shift or rotation of the object in an image, or a change of background, illumination or style (Shen et al., 2018; He et al., 2019; Arjovsky et al., 2019) . These phenomena pose serious concerns on the robustness and trustworthiness of machine learning methods and severely impede them from risk-sensitive scenarios. Looking into the problem, although deep learning models allow extracting abstract representation for prediction with their powerful approximation capacity, the representation may be overconfident in the correlation between semantic factors s (e.g., shape of an object) and variation factors v (e.g., background, illumination, object position). The correlation may be domain-specific and spurious, and may change drastically in a new environment. So it has become a desire to learn representation that separates semantics s from variations v (Cai et al., 2019; Ilse et al., 2019) . Formally, the importance of this goal is that s represents the cause of y. Causal relations better reflect the fundamental mechanisms of nature, bringing the merit to machine learning that they tend to be universal and invariant across domains (Schölkopf et al., 2012; Peters et al., 2017; Schölkopf, 2019) , thus providing the most transferable and confident information to unseen domains. Causality has also been shown to lead to proper domain adaptation (Schölkopf et al., 2012; Zhang et al., 2013) , lower adaptation cost and lighter catastrophic forgetting (Peters et al., 2016; Bengio et al., 2019; Ke et al., 2019) . In this work, we propose a Causal Semantic Generative model (CSG) for proper and robust OOD prediction, including OOD generalization and domain adaptation. Both tasks have supervised data from a single training domain, but domain adaptation has unsupervised test-domain data during learning, while OOD generalization has no test-domain data, including cases where queries come sequentially or adaptation is unaffordable. (1) We build the model by cautiously following the principle of causality, where we explicitly separate the latent variables into a (group of) semantic factor s and a (group of) variation factor v. We prove that under appropriate conditions CSG identifies the semantic factor by fitting training data, even in presence of an s-v correlation. (2) By leveraging the causal invariance, we prove that a well-learned CSG is guaranteed to have a bounded OOD generalization error. The bound shows how causal mechanisms affect the error. (3) We develop a domain adaptation method using CSG and causal invariance, which suggests to fix the causal generative mechanisms and adapt the prior to the new domain. We prove the identification of the new prior and the benefit of adaptation. (4) To learn and adapt the model from data, we design novel and delicate reformulations of the Evidence Lower BOund (ELBO) objective following the graphical structure of CSG, so that the inference models required therein can also serve for prediction, and modeling and optimizing inference models in both domains can be avoided. To our best knowledge, our work is the first to identify semantic factor and leverage latent causal invariance for OOD prediction with guarantees. Empirical improvement in OOD performance and adaptation is demonstrated by experiments on multiple tasks including shifted MNIST and ImageCLEF-DA task.

2. RELATED WORK

There have been works that aim to leverage the merit of causality for OOD prediction. For OOD generalization, some works ameliorate discriminative models towards a causal behavior. Bahadori et al. ( 2017) introduce a regularizer that reweights input dimensions based on their approximated causal effects to the output, and Shen et al. ( 2018) reweight training samples by amortizing causal effects among input dimensions. They are extended to nonlinear cases (Bahadori et al., 2017; He et al., 2019) via linear-separable representations. Heinze-Deml & Meinshausen (2019) enforce inference invariance by minimizing prediction variance within each label-identity group. These methods introduce no additional modeling effort, but may also be limited to capture invariant causal mechanisms (they are non-generative) and may only behave quantitatively causal in the training domain. For domain adaptation/generalization, methods are developed under various causal assumptions (Schölkopf et al., 2012; Zhang et al., 2013) or using learned causal relations (Rojas-Carulla et al., 2018; Magliacane et al., 2018) . Zhang et al. (2013); Gong et al. (2016; 2018) also consider certain ways of mechanism shift. The considered causality is among directly observed variables, which may not be suitable for general data like image pixels where causality rather lies between data and conceptual latent factors (Lopez-Paz et al., 2017; Besserve et al., 2018; Kilbertus et al., 2018) . To consider latent factors, there are domain adaptation (Pan et al., 2010; Baktashmotlagh et al., 2013; Ganin et al., 2016; Long et al., 2015; 2018) and generalization methods (Muandet et al., 2013; Shankar et al., 2018) that learn a representation with domain-invariant marginal distribution, and have achieved remarkable results. Nevertheless, Johansson et al. (2019); Zhao et al. (2019) point out that this invariance is neither sufficient nor necessary to identify the true semantics and lower the adaptation error (Supplement D). Moreover, these methods and invariance risk minimization (Arjovsky et al., 2019 ) also assume the invariance in the inference direction (i.e., data → representation), which may not be as general as causal invariance in the generative direction (Section 3.2). There are also generative methods for domain adaptation/generalization that model latent factors. Cai et al. (2019) ; Ilse et al. ( 2019) introduce a semantic factor and a domain-feature factor. They assume the two factors are independent in both the generative and inference models, which may not meet reality closely. They also do not adapt the prior for domain shift thus resort to inference invariance. Zhang et al. (2020) consider a partially observed manipulation variable, while assume its independence from the output in both the joint and posterior, and the adaptation is inconsistent with causal invariance. Atzmon et al. ( 2020) consider similar latent factors, but use the same (uniform) prior in all domains. These methods also do not show guarantees to identify their latent factors. Teshima et al. (2020) leverage causal invariance and adapt the prior, while also assume latent independence and do not separate the semantic factor. They require some supervised test-domain data, and their deterministic and invertible mechanism also indicates inference invariance. In addition, most domain generalization methods require multiple training domains, with exceptions (e.g., Qiao et al., 2020) that still seek to augment domains. In contrast, CSG leverages causal invariance, and has guarantee to identify the semantic factor from a single training domain, even with a correlation to the variation factor. Generative supervised learning is not new (Mcauliffe & Blei, 2008; Kingma et al., 2014) , but most works do not consider the encoded causality. Other works consider solving causality tasks, notably causal/treatment effect estimation (Louizos et al., 2017; Yao et al., 2018; Wang & Blei, 2019) . The task does not focus on OOD prediction, and requires labels for both treated and controlled groups.

