COUNTERFACTUAL GENERATION UNDER CONFOUNDING

Abstract

A machine learning model, under the influence of observed or unobserved confounders in the training data, can learn spurious correlations and fail to generalize when deployed. For image classifiers, augmenting a training dataset using counterfactual examples has been empirically shown to break spurious correlations. However, the counterfactual generation task itself becomes more difficult as the level of confounding increases. Existing methods for counterfactual generation under confounding consider a fixed set of interventions (e.g., texture, rotation) and are not flexible enough to capture diverse data-generating processes. Given a causal generative process, we formally characterize the adverse effects of confounding on any downstream tasks and show that the correlation between generative factors (attributes) can be used to quantitatively measure confounding between generative factors. To minimize such correlation, we propose a counterfactual generation method that learns to modify the value of any attribute in an image and generate new images given a set of observed attributes, even when the dataset is highly confounded. These counterfactual images are then used to regularize the downstream classifier such that the learned representations are the same across various generative factors conditioned on the class label. Our method is computationally efficient, simple to implement, and works well for any number of generative factors and confounding variables. Our experimental results on both synthetic (MNIST variants) and real-world (CelebA) datasets show the usefulness of our approach.

1. INTRODUCTION

A confounder is a variable that causally influences two or more variables that are not necessarily directly causally dependent (Pearl, 2001) . Often, the presence of confounders in a data-generating process is the reason for spurious correlations among variables in the observational data. The bias caused by such confounders is inevitable in observational data, making it challenging to identify invariant features representative of a target variable (Rothenhäusler et al., 2021; Meinshausen & Bühlmann, 2015; Wang et al., 2022) . For example, the demographic area an individual resides in often confounds the race and perhaps the level of education that individual receives. Using such observational data, if the goal is to predict an individual's salary, a machine learning model may exploit the spurious correlation between education and race even though those two variables should ideally be treated as independent variables. Removing the effects of confounding in trained machine learning models has shown to be helpful in various applications such as zero or few-shot learning, disentanglement, domain generalization, counterfactual generation, algorithmic fairness, healthcare, etc. (Suter et al., 2019; Kilbertus et al., 2020; Atzmon et al., 2020; Zhao et al., 2020; Yue et al., 2021; Sauer & Geiger, 2021; Goel et al., 2021; Dash et al., 2022; Reddy et al., 2022; Dinga et al., 2020) . In observational data, confounding may be observed or unobserved and can pose various challenges in learning models depending on the task. For example, disentangling spuriously correlated features using generative modeling when there are confounders is challenging (Sauer & Geiger, 2021; Reddy et al., 2022; Funke et al., 2022) . As stated earlier, a classifier may rely on non-causal features to make predictions in the presence of confounders (Schölkopf et al., 2021) . Recent years have seen a few efforts to handle the spurious correlations caused by confounding effects in observational data (Träuble et al., 2021; Sauer & Geiger, 2021; Goel et al., 2021; Reddy et al., 2022) . However, these methods either make strong assumptions on the underlying causal generative process or require strong supervision. In this paper, we study the adversarial effect of confounding in observational data

