ACCOUNTING FOR UNOBSERVED CONFOUNDING IN DOMAIN GENERALIZATION

Abstract

The ability to extrapolate, or generalize, from observed to new related environments is central to any form of reliable machine learning, yet most methods fail when moving beyond i.i.d data. In some cases, the reason lies in a misappreciation of the causal structure that governs the data, and in particular as a consequence of the influence of unobserved confounders that drive changes in observed distributions and distort correlations. In this paper, we argue for defining generalization with respect to a broader class of distribution shifts (defined as arising from interventions in the underlying causal model), including changes in observed, unobserved and target variable distributions. We propose a new robust learning principle that may be paired with any gradient-based learning algorithm. This learning principle has explicit generalization guarantees, and relates robustness with certain invariances in the causal model, clarifying why, in some cases, test performance lags training performance. We demonstrate the empirical performance of our approach on healthcare data from different modalities, including image and speech data.

1. INTRODUCTION

Prediction algorithms use data, necessarily sampled under specific conditions, to learn correlations that extrapolate to new or related data. If successful, the performance gap between these two domains is small, and we say that algorithms generalize beyond their training data. Doing so is difficult however, some form of uncertainty about the distribution of new data is unavoidable. The set of potential distributional changes that we may encounter is mostly unknown and in many cases may be large and varied. Some examples include covariate shifts (Bickel et al., 2009) , interventions in the underlying causal system (Pearl, 2009) , varying levels of noise (Fuller, 2009) and confounding (Pearl, 1998) . All of these feature in modern applications, and while learning systems are increasingly deployed in practice, generalization of predictions and their reliability in a broad sense remains an open question. A common approach to formalize learning with uncertain data is, instead of optimizing for correlations in a fixed distribution, to do so simultaneously for a range of different distributions in an uncertainty set P (Ben-Tal et al., 2009) . minimize f sup P ∈P E (x,y)∼P [L(f (x), y)] for some measure of error L of the function f that relates input and output examples (x, y) ∼ P . Choosing different sets P leads to estimators with different properties. It includes as special cases, for instance, many approaches in domain adaptation, covariate shift, robust statistics and optimization (Kuhn et al., 2019; Bickel et al., 2009; Duchi et al., 2016; 2019; Sinha et al., 2017; Wozabal, 2012; Abadeh et al., 2015; Duchi & Namkoong, 2018) . Robust solutions to problem (1) are said to generalize if potential shifted, test distributions are contained in P, but also larger sets P result in conservative solutions (i.e. with sub-optimal performance) on data sampled from distribution away from worst-case scenarios, in general. One formulation of causality is in fact also a version of this problem, for P defined as any distribution arising from arbitrary interventions on observed covariates x leading to shifts in their distribution P x (see e.g. sections 3.2 and 3.3 in (Meinshausen, 2018) ). The invariance to changes in covariate distributions of causal solutions is powerful for generalization, but implicitly assumes that all

