LEARNING COUNTERFACTUALLY INVARIANT PREDICTORS

Abstract

We propose a method to learn predictors that are invariant under counterfactual changes of certain covariates. This method is useful when the prediction target is causally influenced by covariates that should not affect the predictor output. For instance, this could prevent an object recognition model from being influenced by position, orientation, or scale of the object itself. We propose a model-agnostic regularization term based on conditional kernel mean embeddings to enforce counterfactual invariance during training. We prove the soundness of our method, which can handle mixed categorical and continuous multivariate attributes. Empirical results on synthetic and real-world data demonstrate the efficacy of our method in a variety of settings.

1. INTRODUCTION AND RELATED WORK

Invariance, or equivariance to certain transformations of data, has proven essential in numerous applications of machine learning (ML), since it can lead to better generalization capabilities Arjovsky et al. ( 2019 Counterfactual invariance has the significant advantage that it incorporates structural knowledge of the DGP. However, enforcing this notion in practice is very challenging, since it is untestable in real-world observational settings, unless strong prior knowledge of the DGP is available. Inspired by problems in natural language processing (NLP), Veitch et al. ( 2021) provide a method to achieve counterfactual invariance based on distribution matching via the maximum mean discrepancy (MMD). This method enforces a necessary, but not sufficient condition of counterfactual invariance during training. Consequently, it is unclear whether this method achieves actual invariance in practice, or just an arguably weaker proxy. Furthermore, the work by Veitch et al. ( 2021) only considers discrete random variables when enforcing counterfactual invariance, and it only applies to specific, selected causal graphs. To overcome the aforementioned problems, we propose a general definition of counterfactual invariance and a novel method to enforce it. Our main contributions can be summarized as follows: • Based on a structural causal model (SCM), we provide a new definition of counterfactual invariance (cf. Definition 2.2) that is more general than that of Veitch et al. (2021) . • We establish a connection between counterfactual invariance and conditional independence that is provably sufficient for counterfactual invariance (cf. Theorem 3.2). • We propose a new objective function that is composed of the loss function and on the flexible Hilbert-Schmidt Conditional Independence Criterion (HSCIC) Park & Muandet (2020), to enforce counterfactual invariance in practice. Our method works well for both categorical and continuous covariates and outcomes, as well as in multivariate settings.

2. PRELIMINARIES AND BACKGROUND

Counterfactual invariance. We introduce structural causal models as in Pearl (2000).

Definition 2.1 (Structural causal model (SCM)).

A structural causal model is a tuple (U, V, F, P U ) such that U is a set of background variables that are exogenous to the model; V is a set of observable (endogenous) variables; F = {f V } V ∈V is a set of functions from (the domains of) pa(V ) ∪ U V to (the domain of) V , where U V ⊂ U and pa(V ) ⊆ V \ {V } such that V = f V (pa(V ), U V ); (iv) P U is a probability distribution over the domain of U. Further, the subsets pa(V ) ⊆ V \ {V } are chosen such that the graph G over V where the edge V ′ → V is in G if and only if V ′ ∈ pa(V ) is a directed acyclic graph (DAG). We always denote with Y ⊂ V the outcome (or prediction target), and with Ŷ a predictor for that target. The predictor Ŷ is not strictly part of the SCM, because we get to tune f Ŷ . Since it takes inputs from V, we often treat it as an observed variable in the SCM. As such, it also "derives its randomness from the exogenous variables", i.e., is defined on the same probability space. Each SCM implies a unique observational distribution over V (Pearl, 2000) , but it also entails interventional distributions. Given a variable A ∈ V, an intervention A ← a amounts to replacing f A in F with the constant function A = a. This yields a new SCM, which induces the interventional distribution under intervention A ← a. Similarly, we can intervene on multiple variables V ⊇ A ← a. We then write Y * a for the outcome in the intervened SCM, also called potential outcome. Note that the interventional distribution P Y * a (y) differs in general from the conditional distribution P Y|A (y | a). This could for instance happen due to unobserved confounding effects. 1 We can also condition on a set of variables W ⊆ V in the (observational distribution of the) original SCM before performing an intervention, which we denote by P Y * a |W (y | w). This is a counterfactual distribution: "Given that we have observed W = w, what would Y have been had we set A ← a, instead of the value A had actually taken?" Note that the sets A and W need not be disjoint. We can now define counterfactual invariance. Definition 2.2 (Counterfactual invariance). Let A, W be (not necessarily disjoint) sets of nodes in a given SCM. A predictor Ŷ is counterfactually invariant in A with respect to W, if P Ŷ * a |W (y | w) = P Ŷ * a ′ |W (y | w) almost surely, for all a, a ′ in the domain of A and all w in the domain of W. 2 A counterfactually invariant predictor can be viewed as robust to changes of A, in the sense that the (conditional) post-interventional distribution of Ŷ does not change for different values of the intervention. Our Definition 2.2 is more general than previously considered notions of counterfactual invariance. For instance, the invariance in Definition 1.1 by Veitch et al. ( 2021) requires Ŷ * a = Ŷ * a ′ almost surely for all a, a ′ in the domain of A. First, it does not allow to condition on observed evidence, i.e., it cannot consider "true counterfactuals" and is thus unable to promote-for 1 We use P for distributions as is common in the kernel literature (Muandet et al., 2021) 



); Chen et al. (2020); Bloem-Reddy & Teh (2020). For instance, in image recognition, predictions ought to remain unchanged under scaling, translation, or rotation of the input image. Data augmentation is one of the earliest heuristics developed to promote this kind of invariance, that has become indispensable for training successful models like deep neural networks (DNNs) Shorten & Khoshgoftaar (2019); Xie et al. (2020). Well-known examples of certain types of "invariance by design" include convolutional neural networks (CNNs) for translation invariance Krizhevsky et al. (2012), group equivariant CNNs for other group transformations Cohen & Welling (2016), recurrent neural networks (RNNs) and transformers for sequential data Vaswani et al. (2017), DeepSet Zaheer et al. (2017) for sets, and graph neural networks (GNNs) for different types of geometric structures Battaglia et al. (2018). Many real-world applications in modern ML, however, call for an arguably stronger notion of invariance based on causality, called counterfactual invariance. This case has been made for image classification, algorithmic fairness Hardt et al. (2016); Mitchell et al. (2021), robustness Bühlmann (2020), and out-of-distribution generalization Lu et al. (2021). These applications require predictors to exhibit invariance with respect to hypothetical manipulations of the data generating process (DGP) Peters et al. (2016); Heinze-Deml et al. (2018); Rojas-Carulla et al. (2018); Arjovsky et al. (2019); Bühlmann (2020). In image classification, for instance, we want a model that "would have made the same prediction, if the object position had been different with everything else being equal". Similarly, in algorithmic fairness Kilbertus et al. (2017); Kusner et al. (2017) introduce notions of interventional and counterfactual fairness, based on certain invariances in the DGP of the causal relationships between observed variables.

and the potential outcome notation Y * a instead of Y | do(a) for conciseness when mixing conditioning with interventions. 2 With a mild abuse of notation, if W = ∅ then the requirement of conditional counterfactual invariance becomes P Ŷ * a (y) = P Ŷ * a ′ (y) almost surely, for all a, a ′ in the domain of A.

