LEARNING COUNTERFACTUALLY INVARIANT PREDICTORS

Abstract

We propose a method to learn predictors that are invariant under counterfactual changes of certain covariates. This method is useful when the prediction target is causally influenced by covariates that should not affect the predictor output. For instance, this could prevent an object recognition model from being influenced by position, orientation, or scale of the object itself. We propose a model-agnostic regularization term based on conditional kernel mean embeddings to enforce counterfactual invariance during training. We prove the soundness of our method, which can handle mixed categorical and continuous multivariate attributes. Empirical results on synthetic and real-world data demonstrate the efficacy of our method in a variety of settings.

1. INTRODUCTION AND RELATED WORK

Invariance, or equivariance to certain transformations of data, has proven essential in numerous applications of machine learning (ML), since it can lead to better generalization capabilities Arjovsky et al. ( 2019 Counterfactual invariance has the significant advantage that it incorporates structural knowledge of the DGP. However, enforcing this notion in practice is very challenging, since it is untestable in real-world observational settings, unless strong prior knowledge of the DGP is available. Inspired by problems in natural language processing (NLP), Veitch et al. (2021) provide a method to achieve counterfactual invariance based on distribution matching via the maximum mean discrepancy (MMD). This method enforces a necessary, but not sufficient condition of counterfactual invariance during training. Consequently, it is unclear whether this method achieves actual invariance in practice, or just an arguably weaker proxy. Furthermore, the work by Veitch et al. (2021) only considers discrete random variables when enforcing counterfactual invariance, and it only applies to 1



); Chen et al. (2020); Bloem-Reddy & Teh (2020). For instance, in image recognition, predictions ought to remain unchanged under scaling, translation, or rotation of the input image. Data augmentation is one of the earliest heuristics developed to promote this kind of invariance, that has become indispensable for training successful models like deep neural networks (DNNs) Shorten & Khoshgoftaar (2019); Xie et al. (2020). Well-known examples of certain types of "invariance by design" include convolutional neural networks (CNNs) for translation invariance Krizhevsky et al. (2012), group equivariant CNNs for other group transformations Cohen & Welling (2016), recurrent neural networks (RNNs) and transformers for sequential data Vaswani et al. (2017), DeepSet Zaheer et al. (2017) for sets, and graph neural networks (GNNs) for different types of geometric structures Battaglia et al. (2018). Many real-world applications in modern ML, however, call for an arguably stronger notion of invariance based on causality, called counterfactual invariance. This case has been made for image classification, algorithmic fairness Hardt et al. (2016); Mitchell et al. (2021), robustness Bühlmann (2020), and out-of-distribution generalization Lu et al. (2021). These applications require predictors to exhibit invariance with respect to hypothetical manipulations of the data generating process (DGP) Peters et al. (2016); Heinze-Deml et al. (2018); Rojas-Carulla et al. (2018); Arjovsky et al. (2019); Bühlmann (2020). In image classification, for instance, we want a model that "would have made the same prediction, if the object position had been different with everything else being equal". Similarly, in algorithmic fairness Kilbertus et al. (2017); Kusner et al. (2017) introduce notions of interventional and counterfactual fairness, based on certain invariances in the DGP of the causal relationships between observed variables.

