LEARNING CAUSAL SEMANTIC REPRESENTATION FOR OUT-OF-DISTRIBUTION PREDICTION

Abstract

Conventional supervised learning methods, especially deep ones, are found to be sensitive to out-of-distribution (OOD) examples, largely because the learned representation mixes the semantic factor with the variation factor due to their domainspecific correlation, while only the semantic factor causes the output. To address the problem, we propose a Causal Semantic Generative model (CSG) based on causality to model the two factors separately, and learn it on a single training domain for prediction without (OOD generalization) or with unsupervised data (domain adaptation) in a test domain. We prove that CSG identifies the semantic factor on the training domain, and the invariance principle of causality subsequently guarantees the boundedness of OOD generalization error and the success of adaptation. We also design novel and delicate learning methods for both effective learning and easy prediction, following the first principle of variational Bayes and the graphical structure of CSG. Empirical study demonstrates the effect of our methods to improve test accuracy for OOD generalization and domain adaptation.

1. INTRODUCTION

Deep learning has initiated a new era of artificial intelligence where the potential of machine learning models is greatly unleashed. Despite the great success, these methods heavily rely on the independently-and-identically-distributed (IID) assumption. This does not always perfectly hold in practice, and the prediction of output (label, response, outcome) y may be saliently affected in out-of-distribution (OOD) cases, even from an essentially irrelevant change to the input (covariate) x, like a position shift or rotation of the object in an image, or a change of background, illumination or style (Shen et al., 2018; He et al., 2019; Arjovsky et al., 2019) . These phenomena pose serious concerns on the robustness and trustworthiness of machine learning methods and severely impede them from risk-sensitive scenarios. Looking into the problem, although deep learning models allow extracting abstract representation for prediction with their powerful approximation capacity, the representation may be overconfident in the correlation between semantic factors s (e.g., shape of an object) and variation factors v (e.g., background, illumination, object position). The correlation may be domain-specific and spurious, and may change drastically in a new environment. So it has become a desire to learn representation that separates semantics s from variations v (Cai et al., 2019; Ilse et al., 2019) . Formally, the importance of this goal is that s represents the cause of y. Causal relations better reflect the fundamental mechanisms of nature, bringing the merit to machine learning that they tend to be universal and invariant across domains (Schölkopf et al., 2012; Peters et al., 2017; Schölkopf, 2019) , thus providing the most transferable and confident information to unseen domains. Causality has also been shown to lead to proper domain adaptation (Schölkopf et al., 2012; Zhang et al., 2013) , lower adaptation cost and lighter catastrophic forgetting (Peters et al., 2016; Bengio et al., 2019; Ke et al., 2019) . In this work, we propose a Causal Semantic Generative model (CSG) for proper and robust OOD prediction, including OOD generalization and domain adaptation. Both tasks have supervised data from a single training domain, but domain adaptation has unsupervised test-domain data during learning, while OOD generalization has no test-domain data, including cases where queries come sequentially or adaptation is unaffordable. (1) We build the model by cautiously following the principle of causality, where we explicitly separate the latent variables into a (group of) semantic factor s and a (group of) variation factor v. We prove that under appropriate conditions CSG identifies the

