LATENT CAUSAL INVARIANT MODEL

Abstract

Current supervised learning can learn spurious correlation during the data-fitting process, imposing issues regarding interpretability, out-of-distribution (OOD) generalization, and robustness. To avoid spurious correlation, we propose a Latent Causal Invariance Model (LaCIM) which pursues causal prediction. Specifically, we introduce latent variables that are separated into (a) output-causative factors and (b) others that are spuriously correlated to the output via confounders, to model the underlying causal factors. We further assume the generating mechanisms from latent space to observed data to be causally invariant. We give the identifiable claim of such invariance, particularly the disentanglement of output-causative factors from others, as a theoretical guarantee for precise inference and avoiding spurious correlation. We propose a Variational-Bayesian-based method for estimation and to optimize over the latent space for prediction. The utility of our approach is verified by improved interpretability, prediction power on various OOD scenarios (including healthcare) and robustness on security.

1. INTRODUCTION

Current data-driven deep learning models, revolutionary in various tasks though, heavily rely on i.i.d data to exploit all types of correlations to fit data well. Among such correlations, there can be spurious ones corresponding to biases (e.g., selection or confounding bias due to coincidence of the presence of the third factor) inherited from the data provided. Such data-dependent spurious correlations can erode the (i) interpretability of decision-making, (ii) ability of out-of-distribution (OOD) generalization, i.e., extrapolation from observed to new environments, which is crucial especially in safety-critical tasks such as healthcare, and (iii) robustness to small perturbation (Goodfellow et al., 2014) . Recently, there is a Renaissance of causality in machine learning, expected to pursue causal prediction (Schölkopf, 2019) . The so-called "causality" is pioneered by Judea Pearl (Pearl, 2009) , as a mathematical formulation of this metaphysical concept grasped in the human mind. The incorporation of a priori about cause and effect endows the model with the ability to identify the causal structure which entails not only the data but also the underlying process of how they are generated. For causal prediction, the old-school methods (Peters et al., 2016; Bühlmann, 2018) causally related the output label Y to the observed input X, which however is NOT conceptually reasonable in scenarios with sensory-level observed data (e.g. modeling pixels as causal factors of Y does not make much sense). For such applications, we rather adopt the manner in Bengio et al. (2013); Biederman (1987) to relate the causal factors of Y to unobserved abstractions denoted by S, i.e., Y ← f y (S, ε y ) via mechanism f y . We further assume existence of additional latent components denoted as Z, that together with S generates the input X via mechanism f x as X ← f x (S, Z, ε x ). Taking image classification as an example, the S and Z respectively refer to object-related abstractions (e.g., contour, texture, color) and contextual information (e.g., light, view). Such an assumption is similarly adopted in the literature of nonlinear Independent Components Analysis (ICA) (Hyvarinen and Morioka, 2016; Hyvärinen et al., 2019; Khemakhem, Kingma and Hyvärinen, 2020; Teshima et al., 2020) and latent generative models (Suter et al., 2019) , which are however without separation of output (y)-causative factors (a.k.a, S) and other correlating factors (a.k.a, Z) that can both be learned in data-fitting process. We encapsulate these assumptions into a novel causal model, namely Latent Causal Invariance Model (LaCIM) as illustrated in Fig. 1 , in which we assume the structural equations f x (associated with S, Z → X), f y (associated with S → Y ) to be the Causal Invariant Mechanisms (CIMe) that hold under any circumstances with P(S, Z) allowed to be varied across domains. The incorporation of these priories can explain the spurious correlation embedded in the back-door path from Z to Y (contextual information to the class label in image classification). To avoid learning spurious correlations, our goal is to identify the intrinsic CIMe f x , f y . Specifically, we first prove the identifiability (i.e., the possibility to be precisely inferred up to an equivalence relation) of the CIMe. Notably, far beyond the scope in existing literature (Khemakhem, Kingma and Hyvärinen, 2020), our results can implicitly, and are the first to disentangle the output-causative factors (a.k.a, S) from others (a.k.a, Z) for prediction, to ensure the isolation of undesired spurious correlation. Guaranteed by such, we propose to estimate the CIMe by extending the Variational Auto-encoder (VAE) (Kingma and Welling, 2014) to the supervised scenario. For OOD prediction, we propose to optimize over latent space under the identified CIMe. To verify the correctness of our identifiability claim, we conduct a simulation experiment. We further demonstrate the utility of our LaCIM via high explainable learned semantic features, improved prediction power on various OOD scenarios (including tasks with confounding and selection bias, healthcare), and robustness on security. We summarize our contribution as follows: (i) Methodologically, we propose in section 4.1 a latent causal model in which only a subset of latent components are causally related to the output, to avoid spurious correlation and benefit OOD generalization; (ii) Theoretically, we prove the identifiability (in theorem 4.3) of CIMe f x , f y from latent variables to observed data, which disentangles outputcausative factors from others; (iii) Algorithmically, guided by the identifiability, we in section 4.3 reformulate Variational Bayesian method to estimate CIMe during training and optimize over latent space during the test; (iv) Experimentally, LaCIM outperforms others in terms of prediction power on OOD tasks and interpretability in section 5.2, and robustness to tiny perturbation in section 5.3.

2. RELATED WORK

The invariance/causal learning proposes to learn the assumed invariance for transferring. For the invariance learning methods in Krueger et al. (2020) and Schölkopf (2019), the "invariance" can refer to stable correlation rather than causation, which lacks the interpretability and impedes its generalization to a broader set of domains. For causal learning, Peters et al. ( 2016 Other works which are conceptually related to us, as a non-exhaustive review, include (i) transfer learning which also leverages invariance in the context of domain adaptation (Schölkopf et al., 2011; Zhang et al., 2013; Gong et al., 2016) or domain generalization (Li et al., 2018; Shankar et al., 2018) ; and (ii) causal inference (Pearl, 2009; Peters et al., 2017) which proposes a structural causal model to incorporate intervention via "do-calculus" for cause-effect reasoning and counterfactual learning; (iii) latent generative model which also assumes generation from latent space to observed data (Kingma and Welling, 2014; Suter et al., 2019 ) but aims at learning generator in the unsupervised scenario. ∼ P e with [k] := {1, ..., k} for any k ∈ Z + . The d e ∈ {0, 1} m denotes the one-hot encoded domain index for e, where 1 ≤ m := E train ≤ n := e∈E train n e . Our goal is to learn a model f : X → Y that learns output (y)-causative factors for prediction and performs well on the set of all environments E ⊃ E train , which is aligned with existing OOD generalization works (Arjovsky et al., 2019; Krueger et al., 2020) . We use respectively upper, lower case letter and Cursive letter to denote the random variable, the instance and the space, e.g., a is an instance in the space A of random variable A. The [f ] A denotes the f restricted on dimensions of A. The Sobolev space W k,p (A) contains all f such that A ∂ A f α A=a p da < ∞, ∀α ≤ k.



); Bühlmann (2018); Kuang et al. (2018); Heinze-Deml and Meinshausen (2017) assume causal factors as observed input, which is inappropriate for sensory-level observational data. In contrast, our LaCIM introduces latent components as causal factors of the input; more importantly, we explicitly separate them into the output-causative features and others, to avoid spurious correlation. Further, we provide the identifiability claim of causal invariant mechanisms. In independent and concurrent works, Teshima et al. (2020) and Ilse et al. (2020) also explore latent variables in causal relation. As comparisons, Teshima et al. (2020) did not differentiate S from Z; and Ilse et al. (2020) proposed to augment intervened data, which can be intractable in real cases.

Notation Let X, Y respectively denote the input and output variables. The training data {D e } e∈Etrain are collected from the set of multiple environments E train , where each domain e is associated with a distribution P e (X, Y ) over X × Y and D e = {x e i , y e i , d e } i∈[ne] i.i.d

