COUNTERFACTUAL GENERATION UNDER CONFOUNDING

Abstract

A machine learning model, under the influence of observed or unobserved confounders in the training data, can learn spurious correlations and fail to generalize when deployed. For image classifiers, augmenting a training dataset using counterfactual examples has been empirically shown to break spurious correlations. However, the counterfactual generation task itself becomes more difficult as the level of confounding increases. Existing methods for counterfactual generation under confounding consider a fixed set of interventions (e.g., texture, rotation) and are not flexible enough to capture diverse data-generating processes. Given a causal generative process, we formally characterize the adverse effects of confounding on any downstream tasks and show that the correlation between generative factors (attributes) can be used to quantitatively measure confounding between generative factors. To minimize such correlation, we propose a counterfactual generation method that learns to modify the value of any attribute in an image and generate new images given a set of observed attributes, even when the dataset is highly confounded. These counterfactual images are then used to regularize the downstream classifier such that the learned representations are the same across various generative factors conditioned on the class label. Our method is computationally efficient, simple to implement, and works well for any number of generative factors and confounding variables. Our experimental results on both synthetic (MNIST variants) and real-world (CelebA) datasets show the usefulness of our approach.

1. INTRODUCTION

A confounder is a variable that causally influences two or more variables that are not necessarily directly causally dependent (Pearl, 2001) . Often, the presence of confounders in a data-generating process is the reason for spurious correlations among variables in the observational data. The bias caused by such confounders is inevitable in observational data, making it challenging to identify invariant features representative of a target variable (Rothenhäusler et al., 2021; Meinshausen & Bühlmann, 2015; Wang et al., 2022) . For example, the demographic area an individual resides in often confounds the race and perhaps the level of education that individual receives. Using such observational data, if the goal is to predict an individual's salary, a machine learning model may exploit the spurious correlation between education and race even though those two variables should ideally be treated as independent variables. Removing the effects of confounding in trained machine learning models has shown to be helpful in various applications such as zero or few-shot learning, disentanglement, domain generalization, counterfactual generation, algorithmic fairness, healthcare, etc. (Suter et al., 2019; Kilbertus et al., 2020; Atzmon et al., 2020; Zhao et al., 2020; Yue et al., 2021; Sauer & Geiger, 2021; Goel et al., 2021; Dash et al., 2022; Reddy et al., 2022; Dinga et al., 2020) . In observational data, confounding may be observed or unobserved and can pose various challenges in learning models depending on the task. For example, disentangling spuriously correlated features using generative modeling when there are confounders is challenging (Sauer & Geiger, 2021; Reddy et al., 2022; Funke et al., 2022) . As stated earlier, a classifier may rely on non-causal features to make predictions in the presence of confounders (Schölkopf et al., 2021) . Recent years have seen a few efforts to handle the spurious correlations caused by confounding effects in observational data (Träuble et al., 2021; Sauer & Geiger, 2021; Goel et al., 2021; Reddy et al., 2022) . However, these methods either make strong assumptions on the underlying causal generative process or require strong supervision. In this paper, we study the adversarial effect of confounding in observational data on a classifier's performance and propose a mechanism to marginalize such effects when performing data augmentation using counterfactual data. Counterfactual data generation provides a mechanism to address such issues arising from confounding and building robust learning models without the additional task of building complex generative models. The causal generative processes considered throughout this paper are shown in Figure 1(a) . We assume that a set of generative factors (attributes) Z 1 , Z 2 , . . . , Z n (e.g., background, shape, texture) and a label Y (e.g., cow) cause a real-world observation X (e.g., an image of a cow in a particular background) through an unknown causal mechanism g (Peters et al., 2017b). To study the effects of confounding, we consider Y, Z 1 , Z 2 , . . . , Z n to be confounded by a set of confounding variables C 1 , . . . , C m (e.g., certain breeds of cows appear only in certain shapes or colors and appear only in certain countries). Such causal generative processes have been considered earlier for other kinds of tasks such as disentanglement Suter et al. ( 2019 A related recent effort by (Sauer & Geiger, 2021) proposes Counterfactual Generative Networks (CGN) to address this problem using a data augmentation approach. This work assumes each image to be composed of three Independent Causal Mechanisms (ICMs) (Peters et al., 2017a) responsible for three fixed factors of variations: shape, texture, and background (as represented by Z 1 , Z 2 , and Z 3 in Figure 1 (b). This work then trains a generative model that learns three ICMs for shape, texture, and background separately, and combines them in a deterministic fashion to generate observations. Once the ICMs are learned, sampling images by making interventions to these mechanisms give counterfactual data that can be used along with training data to improve classification results. However, fixing the architecture to specific number and types of mechanisms (shape, texture, background) is not generalizable, and may not directly be applicable to settings where the number of underlying generative factors is unknown. It is also computationally expensive to train different generative models for each aspect of an image such as texture, shape or background. In this work, we begin with quantifying confounding in observational data that is generated by an underlying causal graph (more general than considered by CGN) of the form shown in Figure 1(a) . We then provide a counterfactual data augmentation methodology called CONIC (COunterfactual geNeratIon under Confounding). We hypothesize that the counterfactual images generated using the proposed CONIC method provide a mechanism to marginalize the causal mechanisms responsible for spurious correlations (i.e., causal arrows from C i to Z j for some i, j). We take a generative modeling approach and propose a neural network architecture based on conditional CycleGAN (Zhu et al., 2017) to generate counterfactual images. The proposed architecture improves CycleGAN's ability to generate quality counterfactual images under confounded data by adding additional contrastive losses to distinguish between fixed and modified features, while learning the cross domain translations. To demonstrate the usefulness of such counterfactual images, we consider classification as a downstream task and study the performance of various models on unconfounded test set. Our key contributions include: • We formally quantify confounding in causal generative processes of the form in Fig 1(a) , and study the relationship between correlation and confounding between any pair of generative factors. • We present a counterfactual data augmentation methodology to generate counterfactual instances of observed data, that can work even under highly confounded data (∼ 95% confounding) and provides a mechanism to marginalize the causal mechanisms responsible for confounding. • We modify conditional CycleGAN to improve the quality of generated counterfactuals. Our method is computationally efficient and easy to implement. • Following previous work, we perform extensive experiments on well-known benchmarks -three MNIST variants and CelebA datasets -to showcase the usefulness of our proposed methodology in improving the accuracy of a downstream classifier.



); Von Kügelgen et al. (2021); Reddy et al. (2022).The presence of confounding variables results in spurious correlations among generative factors in the observed data, whose effect we aim to remove using counterfactual data augmentation.

Figure 1: (a) causal data generating process considered in this paper (CONIC = Ours); (b) causal data generating process considered in CGN (Sauer & Geiger, 2021).

