DOMAIN GENERALIZATION VIA HECKMAN-TYPE SELECTION MODELS

Abstract

The domain generalization (DG) setup considers the problem where models are trained on data sampled from multiple domains and evaluated on test domains unseen during training. In this paper, we formulate DG as a sample selection problem where each domain is sampled from a common underlying population through non-random sampling probabilities that correlate with both the features and the outcome. Under this setting, the fundamental iid assumption of the empirical risk minimization (ERM) is violated, so it often performs worse on test domains whose non-random sampling probabilities differ from the domains in the training dataset. We propose a Selection-Guided DG (SGDG) framework to learn the selection probability of each domain and the joint distribution of the outcome and domain selection variables. The proposed SGDG is domain generalizable as it intends to minimize the risk under the population distribution. We theoretically prove that, under certain regular conditions, SGDG can achieve smaller risk than ERM. Furthermore, we present a class of parametric SGDG (HeckmanDG) estimators applicable to continuous, binary, and multinomial outcomes. We also demonstrate its efficacy empirically through simulations and experiments on a set of benchmark datasets comparing with other well-known DG methods.

1. INTRODUCTION

In statistical learning theory, the standard assumption behind many supervised learning algorithms is that both training and test instances are independently and identically distributed (iid) according to the same underlying data distribution (Vapnik, 1991) . In other words, most statistical models assume that the training and test data are both random samples chosen randomly from the same population. Unfortunately, this assumption is often violated in real-world applications rendering model performance to deteriorate on out-of-distribution (OOD) test data (Koh et al., 2021) . Recently, the Domain Generalization (DG) problem (Blanchard et al., 2011) has gained particular attention, where it is assumed that learning systems have access to training data sampled from multiple domains, and the ultimate goal is to extrapolate to new instances sampled from previously unseen test domains. In this paper, we consider DG as a non-random sample selection problem. Let P XY represent the population data distribution, and S k denote a binary random variable indicating whether a subject is selected from the population into domain k. In a random sampling process, P (S k i = 1) is independent and identically distributed (iid). Under a non-random sample selection, the distribution of (X, Y ) in domain k is a conditional distribution of (X, Y ) given S k = 1, which often does not equal to P XY . Consequently, this leads to distributional shifts across domains: P j XY ̸ = P k XY for k ̸ = j. Mathematically, distribution shifts across domains P k XY may contain shifts in distributions of X (P k X , covariate shift (Bickel et al., 2009) ), and in the distributions of Y conditional on X (P k Y |X , concept shift (Moreno-Torres et al., 2012) ). We present a graphical model in Figure 1 to conceptually illustrate the sources of distribution shifts, assuming the existence of latent factors confounding the relationship between X, Y , and domain (S k ). In Figure 1 , C 1 represents unobserved latent factors that correlate with X and S k , resulting in covariate shift. C 2 correlates with X, Y , and S simultaneously, entailing both covariate and concept shifts. The goal of DG is to estimate the domain generalizable (agnostic) edge f : X → Y in the presence of the two types of latent confounders. represents latent factors that correlate with X and S k , resulting in covariate shift. C2 correlates with X, Y , and S k , entailing both covariate and concept shifts. Our goal is to estimate the domain generalizable (agnostic) edge f : X → Y in the presence of the two types of latent confounders. The vast majority of DG methods are developed to identify f that is robust to C 1 . However in practice, C 2 type of confounders often exist which make P (S k = 1) related to both X and Y . For example, when we train a model to predict tumor status (Y ) from histological images (X) using patients from different hospitals (S k ), there may be variations in X due to inconsistent acquisition processes such as staining differences (C 1 ) across hospitals, and differences in patient characteristics such as age, gender, race, and disease severity (C 2 ) that correlate with hospital, covariates and the outcome. As a result, a model trained in an oncology specialist hospital may not be generalizable to a hospital serving veterans. Similarly, when we train a model to predict wealth index (Y ) from satellite images (X) taken from different countries (S k ), there may be latent factors such as the economic status (C 2 ) correlating with X, Y and domains simultaneously. Therefore, a model trained on one country may not perform well in another country with a different rural/urban proportion or economic status (Koh et al., 2021) . In this paper, we propose a new class of Selection Guided Domain Generalization (SGDG) models to first estimate the selection probability that an instance is sampled into a training domain, and then use the joint distribution of the outcome Y and selection S to learn a domain generalizable model. In particular, SGDG is built on Heckman's bias correction framework (Heckman, 1979) which is a very powerful tool to learn an unbiased model from non-randomly selected samples in the presence of both C 1 and C 2 confounders. The unique contributions of this paper are summarized as follows: • To the best of our knowledge, we are the first paper to formulate the DG problem using a non-random sample selection framework, and to propose a Selection Guided Domain Generalization (SGDG) method under this framework. • We present a class of parametric SGDG (HeckmanDG) estimators applicable to continuous, binary, and multinomial outcomes † . • We demonstrate the efficacy of our method both theoretically and empirically on simulated data and four challenging benchmarks.

2. RELATED WORK

Domain Generalization. DG has been studied under various contexts. Many studies are devoted to learning domain-invariant features which are discriminative and independent of the domain, such as kernel-based methods (Muandet et al., 2013) , matching moments (Sun & Saenko, 2016) , adversarial learning (Ganin et al., 2016; Deng et al., 2020) , entropy regularization (Zhao et al., 2020) , and contrastive learning (Motiian et al., 2017; Kim et al., 2021) . Other works exploit invariant causal effects across domains (Arjovsky et al., 2019; Ahuja et al., 2020; Rosenfeld et al., 2021) . Another family of robust optimization methods seek to minimize the worst-case error (Sagawa et al., 2020; Xie et al., 2020; Krueger et al., 2021) . More recently, other prominent directions of methods improve DG by model averaging (Cha et al., 2021; Arpit et al., 2022) , gradient matching (Shi et al., 2022) , meta learning (Li et al., 2018) , data augmentation (Robey et al., 2021) , and generating novel domains (Zhou et al., 2020) . Sample Selection Bias Correction. Zadrozny (2004) formalized sample selection bias in machine learning terms and presented a bias correction method when selection only depends on the input features. Cortes et al. (2008) proposed a sample reweighting approach to tackle the same problem but assumed the availability of additional data drawn from the true population. Du & Wu (2021) proposed a framework for robust and fair learning under biased sample selection, but assumes conditional independence of Y and S given X.

3. SELECTION MODEL-GUIDED DOMAIN GENERALIZATION

3.1 DOMAIN GENERALIZATION Suppose there exists L distinct but relevant domains and let S = (S 1 , • • • , S L ) ∼ P S denote a binary random vector that indicates the domain membership where S k = 1 implies belonging to domain k. Let P XY represent the population data distribution, and each domain's data distribution be a conditional distribution of the population distribution given S k = 1, i.e., P k XY = P XY |S k =1 . Assumption 1 (Mutually Exclusive Domain Membership). We assume that if S k = 1, then S j = 0 for all j ̸ = k so that an instance can belong to one and only one domain. Assumption 2 (Independent Domain Sampling Processes). We assume that S k ⊥ ⊥ S j for all j ̸ = k. In the supervised domain generalization context, we are allowed to observe the joint distribution of X and Y , P k XY for K out of L domains, and refer the K domains as source or training domains observed during the training phase. The remaining L-K domains are referred to as target or testing domains whom we may observe in the testing phase. Under this setting, we aim to learn a prediction model that generally performs well on both source and target domains. We formalize the problem as follows. Definition 1 (Domain Generalization (Blanchard et al., 2011)) . Domain generalization refers to as the problem of learning f : X → Y that has the minimum expected loss across all possible domains, which can be further summarized as the following optimization problem: min f ∈F L k=1 E (X,Y )∼P k XY ℓ(f (X), Y ) P (S k = 1), where ℓ : Y × Y → R + (= {r ∈ R : r ≥ 0} ) is a loss function and F is a hypothesis set. We note that some other papers have considered the same problem setting (Muandet et al., 2013; Deshmukh et al., 2019; Blanchard et al., 2021) . We first introduce a proposition that claims the equivalence of the domain generalization and the risk minimization under the population distribution: Proposition 1 (Equivalence of Domain Generalization and Population Risk Minimization). Problem (1) is equivalent to the risk minimization under the population distribution P XY . That is, min f ∈F L k=1 E (X,Y )∼P k XY ℓ(f (X), Y ) P (S k = 1) = min f ∈F E (X,Y )∼P XY ℓ(f (X), Y ) , which straightforwardly follows from the law of total expectation. We define f PRM as the minimizer of the population risk minimization problem, as well as the best (hypothetical) model for the domain generalization problem.  min f ∈F K k=1 E (X,Y )∼P k XY ℓ(f (X), Y ) P (S k = 1). (3) We denote its minimizer as f SDRM . The empirical form of f SDRM is denoted as empirical risk minimization (ERM). The generalization performance of f SDRM depends on how well the source domains represent the population (or the target domains). If the source domains well approximate the population distribution, then the generalization performance of f SDRM will be sufficiently close to that of f PRM . On the other hand, if P k XY ̸ = P XY , one cannot guarantee that models trained on the source domains to effectively generalize to the population. Therefore, it is necessary to model the selection probability to bridge the gap between P k XY and P XY .

3.2. SELECTION MODEL-GUIDED DOMAIN GENERALIZATION

Now we derive our selection model-guided domain generalization problem, starting from decomposing the objective function of the domain generalization problem by introducing a domain selection model g = {g k : X → [0, 1], k = 1, • • • , L} which predicts the selection probabilities of an instance being observed in domain k. Assumption 3 (Decomposable Loss Function). We assume that the loss function ℓ can be decomposed into two components -one exclusively about the selection model and the other involving both: ℓ(f (X), Y ) = L k=1 I(S k = 1)Λ(f (X), g(X); Y, S k = 1) + L k=1 I(S k = 1)ℓ s (g(X), S k = 1), where ℓ s : [0, 1] L × {0, 1} L → R + is the loss function for learning the domain selection models and For example, the negative log-likelihood under the probabilistic framework satisfies Assumption 3: Λ : (Y × [0, 1] L ) × (Y × {0, 1} L ) → R + is if we let ℓ(f (X), Y ) = -log p(Y |X), then Λ = -log p(Y |X, S k = 1) and ℓ s = -log p(S k = 1|X) , respectively. We provide the full derivation in Appendix A.1 and present a specific parametric form in Section 4. Under Assumption 3, the population risk can be expanded as follows: E (X,Y )∼P XY [ℓ(f (X), Y )] = L k=1 E (X,Y )∼P k XY Λ(f (X), g(X); Y, S k = 1) + ℓs(g(X), S k = 1) P (S k = 1) = K k=1 E (X,Y )∼P k XY Λ(f (X), g(X); Y, S k = 1) P (S k = 1) + L k=1 E X∼P k X ℓs(g(X), S k = 1) P (S k = 1) + L k=K+1 E (X,Y )∼P k XY Λ(f (X), g(X); Y, S k = 1) P (S k = 1). Based on the expansion, we introduce our selection model-guided domain generalization problem. Definition 3 (Selection Model-Guided Domain Generalization (SGDG)). Given access to the data distribution of K source domains, P k XY for k = 1, • • • , K and to the unlabeled data distribution from all L domains, P X|S k =1 for k = 1, • • • , L. We define the selection model-guided domain generalization problem as a joint learning problem of f : X → Y and g = {g k } L k=1 , min f ∈F ,g∈G K k=1 E (X,Y )∼P k XY Λ(f (X), g(X); Y, S k = 1) P (S k = 1)+ L k=1 E X∼P k X ℓs(g(X), S k = 1) P (S k = 1), where F and G are hypothesis sets. The SGDG problem is a minimization of the expected loss of the prediction model f under the joint distribution of X and Y of the source domains, and the expected loss of the domain selection model g under the joint distribution of X and S. Under this formulation, the domain selection model g guides f to be corrected through Λ, considering the probability of being drawn from certain domains. In Section 4, we will introduce specific forms of this problem under parametric assumptions. Theorem 1 (Performance Improvement of SGDG over SDRM). Let f SDRM and f SGDG be defined as in Definition 2 and 3. f SGDG has lower risk than that of f SDRM . That is, E (X,Y )∼P XY [ℓ(f SGDG (X), Y )] ≤ E (X,Y )∼P XY [ℓ(f SDRM (X), Y )], which implies SGDG is expected to show better generalization performance compared to SDRM. By Proposition 1, SGDG is expected to perform better than SDRM for the domain generalization problem equivalently. The sketch of proof of this theorem is that SGDG performs as well as SDRM on the source domains and offers performance improvement on the target domains. Proof of this theorem is given in Appendix A.2. In reality, we may have no access to the distribution P k X for all domains (k = 1, • • • , L) but only have access to the source domains' distributions (k = 1, • • • , K). In such cases, we use P k X for k = 1, • • • , K to learn the selection model g as in Problem (7). Assumption 4. Let g * be the optimal selection model that has the minimum expected ℓ s , g * = (g * 1 , • • • , g * L ) = argmin g L k=1 E X∼P k X [ℓ s (g(X), S k = 1)]P (S k = 1). We assume that the first K coordinates (g * 1 , • • • , g * K ) minimize K k=1 E X∼P k X [ℓ s (g(X), S k = 1)]P (S k = 1). Conceptually, Assumption 4 means that the same selection models g k can be learned by contrasting domain k to the remaining domains in the training data (K \ {k}) as by contrasting domain k to the remaining domains in the population (L \ {k}). Under Assumption 4, Problem (5) reduces to min f ∈F ,g∈G K k=1 E (X,Y )∼P k XY Λ(f (X), g(X); Y, S k = 1) P (S k = 1)+ K k=1 E X∼P k X ℓs(g(X), S k = 1) P (S k = 1). (7)

4. HECKMAN-TYPE SELECTION-GUIDED DOMAIN GENERALIZATION

In this chapter, we present parametric models for Λ embodying f SGDG presented in the previous section. The essence of Λ is to model the conditional distribution of P (Y |X, S k = 1). Consider the setting where we are given training data from multiple source domains in the form of D = {x i , s i , y i } N i=1 , where s i = [s i1 , s i2 , . . . , s ik , . . . , s iK ] ∈ {0, 1} K is a binary vector indicating domain membership. By formulating the loss function as the negative log-likelihood, the empirical form of Equation 7becomes 8) Heckman (1979) proposed to model the joint distribution P (Y, S|X) of the selection latent variable S, where S = I[ S > 0], and the continuous outcome Y via a bivariate normal distribution with a correlation coefficient ρ and the mean as a linear functions of X. Building upon his work, we make the following assumption on the joint distribution of the outcome and the selection variables. min- N i=1 K k=1 s ik log p(y i |x i , s ik = 1) + s ik log p(s ik = 1|x i )+(1-s ik )log p(s ik = 0|x i ) . ( Assumption 5 (Joint Distribution of Y (or latent variables Ỹ ) and S). Let Y be the outcome variable and Sk be the latent continuous variable where S k = I[ Sk > 0] for each domain k. We assume that Y and S = (S 1 , • • • , S K ) are jointly distributed as a multivariate normal distribution with mean (f (X), g k (X)), and correlation coefficients {ρ k } K k=1 , given X. For binary and multinomial outcomes, this assumption is on Ỹ as the latent continuous variables underlying the observed outcomes. Under this assumption, g = {g k } K k=1 can be modeled as a set of independent probit models. In all three cases, the joint log-likelihoods have closed forms. Henceforth, we refer to the specific parametric form of f SGDG as the Heckman-type DG (HeckmanDG) estimator. Definition 4 (Heckman-Type Domain Generalization Estimator). We formulate HeckmanDG as a joint learning problem of f and g = {g k } K k=1 with the following learning objective: min f,g,Σ N i=1 K k=1 s ik Λ (f (xi), g k (xi); yi, s ik ; Σ) -s ik log Φ(g k (xi)) + (1 -s ik ) log Φ (-g k (xi)) where Φ(•) is the cumulative distribution function of the standard normal distribution ϕ, such that Φ(g k (x i )) = P (S k = 1|X = x i ) is the selection probability w.r.t domain k, and Φ(-g k (x i )) = P (S k = 0|X = x i ). Meanwhile, Λ(f (x i ), g k (x i ); y i , s ik ; Σ) is the conditional negative log probability of y i given s ik = 1, i.e., -log p(y i |s ik = 1, x i ). The specific form of Λ and miscellaneous model parameters Σ depends on the prediction task (either continuous-valued, binary, or multinomial outcome prediction). For example, the objective function (9) for the continuous outcome is listed below in Equation 10, where we assume Y = f (X) + ε, Outcome prediction model (feature extractor) (feature extractor) Step 1) Learn Step 2) Learn Step 2) The estimated domain selection model ĝ guides f = ω f • φ f to be corrected by considering the selection probability of instances being drawn from the training domains.

Domain selection model

S k = I[g k (X) + η k > 0], and [η k , ε] ∼ N (0, [1, σ; ρ k ]). The details of its derivation and for other types of outcomes can be found in Appendix A.3. min f,g,Σ - N i=1 K k=1 s ik log Φ ρ k yi -f (xi) σ + g k (xi) + log ϕ( y i -f (x i ) σ ) σ + (1 -s ik ) log Φ(-g k (xi)). Ultimately, f is the outcome prediction model of primary interest. For any input x from unseen domains, the HeckmanDG prediction is f (x). The proposed HeckmanDG differentiates from Heckman's bias correction method in the following perspectives. First, it allows and models multiple domain-specific sample selection mechanisms. Second, by utilizing the multiple domains in the training data, HeckmanDG does not need auxiliary data (features of instances from the target population but not sampled in the domain), which is a required input for the original Heckman model. Third, HeckmanDG alleviates the linear assumptions in Heckman's bias correction model to allow flexible forms for both prediction and selection functions including neural networks.

5. OPTIMIZATION

In this section, we consider the hypothesis sets F and G are neural networks, and present an efficient algorithm to optimize Equation 9. Specifically, we propose a two-step approach that primarily trains g to optimum (Step 1), then updates the remaining parameters of f and Σ (Step 2). In Step 1, we learn the selection model parameters based on the following objective function: ĝ = argmin g - N i=1 K k=1 s ik log Φ(g k (x i )) + (1 -s ik ) log Φ (-g k (x i )) which essentially learns to predict the source domain memberships for training instances. In Step 2, we freeze the selection model ĝ, and learn the remaining parameters of f and Σ. The proposed optimization algorithm is in part motivated by Heckman's two-step estimator for bias correction, which was devised as a means of avoiding the non-linearity of estimating both selection and outcome equations simultaneously (Heckman, 1979) . We provide pseudocode describing the overall optimization procedure in Algorithm 1 in Appendix A.4, and an ablation study in Section 7 to support the necessity of two-step optimization. Neural Network Architecture. We use a common feature extractor φ g : X → Z g for the domain selection models g = {g k } K k=1 , which only differ in the last linear predictors ω g,k : Z g → [0, 1], thus g k = ω g,k • φ g . This prevents the number of parameters of g from growing with the number of training domains K, and reduces computational complexity by requiring only a single forward pass φ g (•). Similarly, we define f = ω f •φ f , where φ f : X → Z f is the feature extractor of the outcome prediction model and ω f : Z f → Y is the final linear predictor. In general, φ g and φ f are allowed to have different neural architectures with different numbers of parameters. However, we found that simply using the same architecture (but learning different parameters) works well in practice. Therefore, in all our experiments, HeckmanDG has roughly twice as much trainable parameters as other comparative methods. An overview of our neural network architecture is provided in Figure 2 . 

6. EXPERIMENTS

Simulation. We simulate a linear regression problem to assess HeckmanDG's predictive performance. We simulate two covariates (X 1 , X 2 ) and two training domains (K = 2) based on the following setting of domain selection and outcome mechanisms: S k = I[α k 0 + α k 2 X 2 + η k > 0], Y = 1 + 1.5X 1 + 3X 2 + ε X 1 X 2 ∼ N (0, I 2 ), η k ε ∼ N 0 0 , 1 ρ k σ ρ k σ σ 2 (12) where we assume a normal prior over the selection coefficient α k 2 ∼ N (µ α2 , σ 2 α2 ) in order to implicitly control the similarity between domains by differing its parameters. We consider a true population of 100000 instances, from which we sample data for each domain with n k ∼ Uniform (1000, 2000) . In each trial, we sample the selection coefficient α k 2 ∼ N (5, 3 2 ) for the two training domains and a held-out in-distribution (ID) test set (similar to the training domains), and α k ′ 2 ∼ N (-5, 3 2 ) for another held-out out-of-distribution (OOD) test set (dissimilar to the training domains). We also simulate a random test set from the population. We assumed ρ k = 0.8 and σ = 1. Simulation Results. We observed that HeckmanDG not only outperforms ERM, but also other DG methods including IRM (Arjovsky et al., 2019) , GroupDRO (Sagawa et al., 2020) , and VREx (Krueger et al., 2021) (Table 1 ). Note that 'ERM (oracle)' is trained on iid training data, which serves as a theoretical lower bound on model performance. We highlight that not only does HeckmanDG perform well on the random test set, but also on the OOD test set. In contrast, other methods tend to fit well to the train domains (ID performance is high), but generalize poorly to random and OOD test domains. Benchmark Datasets. To further demonstrate the effectiveness of HeckmanDG on highdimensional data regimes, we conducted experiments on four datasets from the WILDS benchmark (Koh et al., 2021) : 1) CAMELYON17, 2) POVERTYMAP, 3) IWILDCAM, and 4) RXRX1. We used the same neural network architecture for selection and outcome feature extractors φ g and φ f , which are followed by linear predictors ω g and ω f . For the domain selection model, we tuned hyperparameters to obtain the best domain selection model returning the highest macro F1 score on the training data, which in all cases achieve near-perfect accuracy. For the outcome prediction model, we adhered to the official guidelines and used the OOD validation set provided in the WILDS repository for hyperparameter tuning and model selection based on the recommended metrics. Detailed descriptions of dataset statistics are presented in 

Benchmark Datasets Results

. We summarized the results on the WILDS benchmark in Tables 2, 3 , and 4. All methods apply the same network architectures based on the WILDS guideline for fair comparison. Also, we point out that we exclude methods that deviate from the DG setting such as those centered on test time adaptation and using additional unlabeled data. We observed that HeckmanDG outperforms other methods on two out of four datasets, while performing on par with other methods on the remaining two. We make the following key observations. First, HeckmanDG often robustly performs on the test domain although it may not generate the best performance on the validation domains. For example, on the CAMELYON17 dataset, the performance gap between the validation and test datasets for HeckmanDG is 3.3, which is substantially smaller than 14.1 of replicates are given in parentheses. We use the original 5 folds provided in the WILDS repository. We do not report worst-group performance for 'Fish', since it has not been reported in Shi et al. (2022) . 'ERM (scratch)' and 7.1 of 'ERM (ImageNet)'. This pattern is similar to what we observed in the simulation studies. Because HeckmanDG is designed to predict for the underlying population, it may not produce the best prediction performances for the (non-randomly selected) source domains. Although we do not know the specific selection probability of the test domain, HeckmanDG is often more robust for the testing domains by optimizing for the population distribution. Second, HeckmanDG effectively works for both the domain shift and the subpopulation shift problems. This can be observed by looking at either the average or worst-group performance on the POVERTYMAP dataset (i.e., 0.51). This again supports the robustness of HeckmanDG against a range of nonrandom selection probabilities of the testing domains.

7. ANALYSIS

Necessity of two-step optimization. To demonstrate the necessity and effectiveness of the proposed two-step optimization, we performed an ablation study on the POVERTYMAP dataset. For comparison, we trained g and f in one-step to simultaneously minimize Equation 8. We observed that one-step obtained comparable performances on training domains (Figure 3a ), but yielded worse test performance on the OOD domains (Figure 3b ). We believe that a suboptimal ĝ may mislead f to a deficient solution in the one-step optimization (more details in Appendix A.6). Relationship between the performance of g and f . To further investigate how the performance of a domain selection model influences the final outcome prediction, we take intermediate snapshots for ĝ, and learned f based on each snapshot. On all five data split folds, we always observed a positive (a) RXRX1

Method

Validation Test ERM 19.4 (0.2) 29.9 (0.4) CORAL 18.5 (0.4) 28.4 (0.3) IRM 5.6 (0.4) 8.2 (1.1) GroupDRO 15.2 (0.1) 23.0 (0.3) LISA 20.1 (0.4) 31.9 (1.0) Fish 7.5 (0.6) 10.1 (1.5) SWAD 14.2 (0.5) 22.9 (0.7) L2A-OT 17.5 (0.3) 27.8 (0.9) correlation between the performances of ĝ and f (Figure 3c ), demonstrating that a well-performing selection model is necessary to correctly guide the outcome prediction model. HeckmanDG

8. CONCLUSIONS

We propose a Selection Guided Domain Generalization (SGDG) framework, in which we formulate domain generalization as a non-random sample selection problem and propose to jointly learn the prediction model f and the domain selection model g to achieve generalization on the true population. DG is a challenging problem as the particular structure of the distribution shift in the testing domains is unknown. In the presence of this uncertainty, we propose and theoretically justify the objective of minimizing the risk targeting the population distribution through SGDG. Furthermore, we have provided a set of Heckman-type SGDG estimators for various outcome types under parametric assumptions. Although it is still an open question if a single general-purpose training algorithm can produce models that do well on all of the DG datasets (Koh et al., 2021) , we observed robust performances of the proposed HeckmanDG on four benchmark datasets. Note that SGDG can naturally utilize all domains in the training data, including the (outcome) labeled and unlabeled domains. The unlabeled domains will contribute to the estimation of ĝ in Equation 9, which indirectly improves generalization performance as we showed in Section 7. An intriguing direction for future research is to explore whether we can improve SGDG performances by adapting ĝ(X) for domains in the test data. In that way, we may further refine the prediction for the unseen domains guided by their similarities to the source domains. A.2 PROOF OF THEOREM 1 The following chain of inequalities completes the proof: E (X,Y )∼P XY [ℓ(f SDRM (X), Y )] = L k=1 E (X,Y )∼P k XY [ℓ(f SDRM (X), Y )]P (S k = 1) = L k=1 E (X,Y )∼P k XY [Λ(f SDRM (X), g SGDG (X); Y, S k )] + L k=1 E X∼P k X [ℓ s (g SGDG (X), S k = 1)]P (S k = 1) ≥ K k=1 E (X,Y )∼P k XY [Λ(f SGDG (X), g SGDG (X); Y, S k )] + L k=1 E X∼P k X [ℓ s (g SGDG (X), S k = 1)]P (S k = 1) + L k=K+1 E (X,Y )∼P k XY [Λ(f SDRM (X), g SGDG (X); Y, S k )] (13) ≥ K k=1 E (X,Y )∼P k XY [Λ(f SGDG (X), g SGDG (X); Y, S k )] + L k=1 E X∼P k X [ℓ s (g SGDG (X), S k = 1)]P (S k = 1) + L k=K+1 E (X,Y )∼P k XY [Λ(f SGDG (X), g SGDG (X); Y, S k )] (14) = E (X,Y )∼P XY [ℓ(f SGDG (X), Y )]. Note that 13 follows from the definition of SGDG. 14 holds if Y and S are dependent and if we let ℓ(f (X), Y ) = -log p(Y |X), Λ = -log p(Y |X, S k = 1), and ℓ s = -log p(S k = 1|X).

A.3 DETAILS OF HECKMAN CORRECTION-GUIDED DOMAIN GENERALIZATION

In this section, we present the exact form of our loss functions for continuous, binary, and multinomial outcomes, based on their specific parametric assumptions. Continuous Outcomes For regression tasks where y ∈ R, we assume that the selection and outcome is generated by the following data generation process: S k = I[ Sk > 0] = I[g k (x) + η k > 0] (15) Y = f (x) + ε (16) η k ε ∼ N 0, 1 σρ k σρ k σ 2 where I(•) ∈ {0, 1} is the indicator function, ρ k is the correlation between η k and ε, and σ ∈ R + is the standard deviation of ε. Being consistent with the notation in the main text, we hereby define Σ = {ρ 1 , . . . , ρ K , σ}. Denoting model parameters as θ = {f, g 1 , . . . , g K , Σ}, we formulate the data likelihood as follows: L c (θ; D) = N i=1 K k=1 p(y i |x i , s ik = 1) • p(s ik = 1|x i ) s ik • p(s ik = 0|x i ) 1-s ik (18) and we take the negative log likelihood to formulate the loss function: ℓ c (D; θ) = - N i=1 K k=1 s ik log Φ 1 g k (x i ) + ρ k yi-f (xi) σ 1 -ρ 2 k - 1 2 log 2πσ 2 - 1 2 y i -f (x i ) σ 2 - N i=1 K k=1 (1 -s ik ) log Φ 1 (-g k (x i )) Binary Outcomes For binary outcomes where y ∈ {0, 1}, we assume a probit model for both selection and outcome: S k = I[ Sk > 0] = I[g k (x) + η k > 0] Y = I[f (x) + ε > 0] η k ε ∼ N 0, 1 ρ k ρ k 1 where ρ k is the correlation between η k and ε. Herein, Σ = {ρ 1 , . . . , ρ K }. Denoting model parameters as θ = {f, g 1 , . . . , g K , Σ}, we formulate the data likelihood for binary outcomes as follows: L b (θ; D) = N i=1 K k=1 p(y i = 1, s ik = 1|x i ) yi • p(y i = 0, s ik = 1|x i ) 1-yi s ik p(s ik = 0|x i ) 1-s ik (21) and take the negative log likelihood to define the loss function for binary outcomes: ℓ b (D; θ) = - N i=1 K k=1 s ik • y i • log Φ 2 (g k (x i ), f (x i ), ρ k ) - N i=1 K k=1 s ik • (1 -y i ) • log Φ 1 (g k (x i )) -Φ 2 (g k (x i ), f (x i ); ρ k ) - N i=1 K k=1 (1 -s ik ) • log Φ 1 (-g k (x i )) where Φ 2 (•, •; a) is the cumulative density function of the bivariate standard normal given correlation a ∈ [-1, 1]. Multinomial Outcomes For multinomial outcome tasks where y ∈ {1, . . . , J}, we assume a multinomial probit outcome model (McFadden, 1989) : S k = I[ Sk > 0] = I[g k (x) + η k > 0] Ỹj = f j (x) + ε j , ∀j ∈ {1, . . . , J} Y = argmax j∈{1,...,J} Ỹj η k ε j ∼ N 0, 1 ρ k ρ k 1 where [η k , ε j ] ∼ N (0, [1, 1; ρ kj ]) and ρ kj is the correlation between ρ k and ε j . We further assume that the outcome error terms are independently distributed: ε = [ε 1 , . . . , ε J ] ⊤ ∼ N (0, I). Consequently, Σ = {ρ 11 , . . . , ρ KJ }. Denoting model parameters as θ = {f 1 , . . . , f J , g 1 , . . . , g K , ρ 11 , . . . , ρ KJ }, we formulate the multinomial outcome data likelihood as follows: L m (θ; D) = N i=1 K k=1 J j=1 p(s ik = 1, y i = j|x i ) I[yi=j] s ik p(s ik = 0|x i ) 1-s ik where I[•] ∈ {0, 1} is the indicator function. We use the negative logarithm of the data likelihood for the loss function: ℓm(D; θ) = - N i=1 K k=1 J j=1 s ik • I[yi = j] • log T ϕ(u| ⃗ 0, Σ)du + (1 -s ik ) • log Φ (-g k (xi)) where T = [-g k (x i ), ∞ × [ξ j,1 (x i ), ∞ × . . . × [ξ j,j-1 (x i ), ∞ × [ξ j,j+1 (x i ), ∞ × . . . × [ξ j,J (x i ), ∞ ⊆ R J is the half-open J-dimensional hyperrectangular domain and ξ j,j ′ (x i ) = -f j (x i ) + f j ′ (x i ). We use the GHK algorithm (Hajivassiliou & Ruud, 1994) to compute the multivariate normal integrals. We provide further details on the analysis in Section 7, conducted to demonstrate the necessity and effectiveness of the proposed two-step optimization approach. For both one-step (jointly optimizing for Equation 8) and two-step, we kept the number of training epochs to 30, with the same batch size, in order to make a valid comparison. In Figure 4 , we plot the training loss trajectories using different metrics. We observed that g tends to converge to a suboptimal point if we optimize f , g, and Σ simultaneously by a one-step optimization approach (Figure 4a ). We suspect that this happens because the gradient of g (involving s ik = 1) depends on Σ, which changes during the joint onestep optimization process. This will often result in a suboptimal ĝ, as shown in Figure 4d . Note that the original Heckman correction (Heckman, 1979) also proposed a two-step approach to avoid the computational burden of having to estimate both f and g jointly. One way to avoid this problem is to use an alternating optimization algorithm which alternately minimizes Equation 8 with respect to g, Σ, and f until convergence. Our two-step approach can be regarded as a one-iteration alternating minimization procedure as it stops after a single iteration. In Figure 4c , the one-step method yields comparable (MSE) loss values for f as those obtained from the two-step method on the training dataset. However Equation 8 converged to a larger value by the one-step method, implying that it converged to a suboptimal point for g. On the test datasets, the suboptimal solution of the onestep method resulted in worse domain generalization performance compared to the two-step method (Figure 3b ). 



† code available: https://github.com/hgkahng/domain-generalization-lightning



Figure 1: A graphical model illustrating the source of distributional shifts. X: covariates, Y : outcome, S k : domain. C1

the joint loss function, which assigns to the prediction and domain selection models a pair of true outcome Y and domain membership indicator S.

Figure 2: Schematic overview of HeckmanDG, with K training domains. Step 1) We learn g = {ω g,k • φg} K k=1 to predict the selection probabilities for each training domain.Step 2) The estimated domain selection model ĝ guides f = ω f • φ f to be corrected by considering the selection probability of instances being drawn from the training domains.

Figure 3: Analysis based on the POVERTYMAP dataset. 3a) Pearson correlation measured on the training data for HeckmanDG learned by one-step (gray) and two-step (blue, proposed) optimization. 3b) Comparison between one-step (gray) and two-step (blue, proposed) optimization, where performance is averaged across all five folds provided in the WILDS repository. 3c) Relationship between the performance of domain selection and outcome prediction (colors indicate different folds).

OF EXAMPLE OF DECOMPOSABLE LOSS ℓ(f (X), Y ) = -log p(Y |X) log p(Y |X, S k = 1)p(S k = 1|X) log p(Y |X, S k = 1) -log p(S k = 1|X) = L k=1 -log p(Y |X, S k = 1) log p(S k = 1|X) ℓs .

OPTIMIZATION, DATASETS, AND MODEL TRAINING Algorithm 1 Two-Step Optimization for HeckmanDG 1:Input: Data D = {(x i , y i , s i ) ∈ X × Y × S : i = 1, • • • , N }, Batch size B, Learning rate γ. 2: Output: f , ĝ = {ĝ k } K k=1 ,and Σ. 3: Initialize: f, g, Σ 4: Step 1: Learn the Domain Selection Models (g) 5:for all k = 1, . . . , K do 6:D k ← {(x i , s ik ) ∈ X × [0, 1] : i = 1, . . . , N } k ← g k + γ B ∇ (xi,s ik )∈B k [s ik log Φ(g k (x i )) + (1 -s ik ) log Φ(-g k (x i ))] k , for k = 1, . . . , K. 15: Step 2: Learn the Outcome Model (f, Σ) -γ B ∇ (xi,yi,si)∈B K k=1 s ik Λ(f (x i ), ĝk (x i ); y i , s ik ; Σ) 20: Σ ← Σ -γ B ∇ (xi,yi,si)∈B K k=1s ik Λ(f (x i ), ĝk (x i ); y i , s ik ; Σ)

Figure 4: Training loss of one-step (gray) vs. two-step (blue) optimization on POVERTYMAP.

Predictive performance on simulated data, measured in terms of the mean squared error (lower is better) averaged over 30 trials. Standard deviations are given in parentheses.

Table 5 of Appendix A.4. Details on hyperparameters and model training are presented in Table 6 of Appendix A.4.

CAMELYON17: We report predictive performance measured in terms of average accuracy on both the OOD validation and test set. Standard deviation across 10 replicates are given in parentheses. 'ERM (scratch)' is trained from random initial parameters, whereas we also report 'ERM (ImageNet)' trained from ImageNet-pretrained weights(Russakovsky et al., 2015).

POVERTYMAP: We report predictive performance measured in terms of both the average and worst-

A summary of results on RXRX1 and IWILDCAM. Note that the OOD validation macro F1 score for DANN was not reported since it has not been provided inSagawa et al. (2022).

ACKNOWLEDGEMENT

We thank the ICLR reviewers for their suggestions on improving the original version of this paper. This work was supported in part by the National Institutes of Health under awards NIH R01-LM013344, R01AG054467, R01AG065330, and R01-AG065330-02S1, by NYU Center for the Study of Asian American Health under the NIH/NIMHD grant award #U54MD000538 (HD).

Dataset

10 -5 , 10 -3 10 -5 , 10 -3 10 -5 , 10 -3 10 -5 , 10 -3 Weight Decay 0, 10 -4 0, 10 -5 0, 10 -5 0, 10 For optimizers we searched between {Adam, SGD}. Learning rates were searched among {10 -5 , 10 -4 , 10 -3 }, and weight decay among {0, 10 -5 , 10 -3 , 10 -1 }. Note that since HeckmanDG training is two-phase, we searched for hyperparameters sequentially. On all datasets, Σ is optimized with the Adam optimizer with a learning rate of 10 -2 and no weight decay.

A.5 IMPLEMENTATION DETAILS FOR SWAD AND L2A-OT

In Tables 2, 3 , and 4, the numbers for SWAD (Cha et al., 2021) and L2A-OT (Zhou et al., 2020) are reproduced based on our implementation. For SWAD, we set the optimum patient parameter N s = 3, the overfitting patient parameter N e = 6, and the tolerance rate r = 1.2, which are the default values used in the original paper. The evaluation frequency was set to 100. On all four datasets, models were trained with SGD, using the same learning rates and weight decay factors as ERM (Koh et al., 2021) . For L2A-OT, we set λ Domain = 0.5, λ Cycle = 10, and λ CE = 1. On all three datasets (excluding POVERTYMAP), the generator G is trained with Adam using a constant learning rate of 3 • 10 -4 , while the prediction model F is trained with SGD, using the same learning rates and weight decay factors as ERM (Koh et al., 2021) . We faithfully refer the readers to the original papers for the details regarding the hyperparameters.

