LATENT CAUSAL INVARIANT MODEL

Abstract

Current supervised learning can learn spurious correlation during the data-fitting process, imposing issues regarding interpretability, out-of-distribution (OOD) generalization, and robustness. To avoid spurious correlation, we propose a Latent Causal Invariance Model (LaCIM) which pursues causal prediction. Specifically, we introduce latent variables that are separated into (a) output-causative factors and (b) others that are spuriously correlated to the output via confounders, to model the underlying causal factors. We further assume the generating mechanisms from latent space to observed data to be causally invariant. We give the identifiable claim of such invariance, particularly the disentanglement of output-causative factors from others, as a theoretical guarantee for precise inference and avoiding spurious correlation. We propose a Variational-Bayesian-based method for estimation and to optimize over the latent space for prediction. The utility of our approach is verified by improved interpretability, prediction power on various OOD scenarios (including healthcare) and robustness on security.

1. INTRODUCTION

Current data-driven deep learning models, revolutionary in various tasks though, heavily rely on i.i.d data to exploit all types of correlations to fit data well. Among such correlations, there can be spurious ones corresponding to biases (e.g., selection or confounding bias due to coincidence of the presence of the third factor) inherited from the data provided. Such data-dependent spurious correlations can erode the (i) interpretability of decision-making, (ii) ability of out-of-distribution (OOD) generalization, i.e., extrapolation from observed to new environments, which is crucial especially in safety-critical tasks such as healthcare, and (iii) robustness to small perturbation (Goodfellow et al., 2014) . Recently, there is a Renaissance of causality in machine learning, expected to pursue causal prediction (Schölkopf, 2019) . The so-called "causality" is pioneered by Judea Pearl (Pearl, 2009) , as a mathematical formulation of this metaphysical concept grasped in the human mind. The incorporation of a priori about cause and effect endows the model with the ability to identify the causal structure which entails not only the data but also the underlying process of how they are generated. For causal prediction, the old-school methods (Peters et al., 2016; Bühlmann, 2018) causally related the output label Y to the observed input X, which however is NOT conceptually reasonable in scenarios with sensory-level observed data (e.g. modeling pixels as causal factors of Y does not make much sense). For such applications, we rather adopt the manner in Bengio et al. (2013) ; Biederman (1987) to relate the causal factors of Y to unobserved abstractions denoted by S, i.e., Y ← f y (S, ε y ) via mechanism f y . We further assume existence of additional latent components denoted as Z, that together with S generates the input X via mechanism f x as X ← f x (S, Z, ε x ). Taking image classification as an example, the S and Z respectively refer to object-related abstractions (e.g., contour, texture, color) and contextual information (e.g., light, view). Such an assumption is similarly adopted in the literature of nonlinear Independent Components Analysis (ICA) (Hyvarinen and Morioka, 2016; Hyvärinen et al., 2019; Khemakhem, Kingma and Hyvärinen, 2020; Teshima et al., 2020) and latent generative models (Suter et al., 2019) , which are however without separation of output (y)-causative factors (a.k.a, S) and other correlating factors (a.k.a, Z) that can both be learned in data-fitting process. We encapsulate these assumptions into a novel causal model, namely Latent Causal Invariance Model (LaCIM) as illustrated in Fig. 1 , in which we assume the structural equations f x (associated with S, Z → X), f y (associated with S → Y ) to be the Causal Invariant Mechanisms (CIMe) that hold under any circumstances with P(S, Z) allowed to be varied across domains. The incorporation of these 1

