GATED DOMAIN UNITS FOR MULTI-SOURCE DOMAIN GENERALIZATION Anonymous

Abstract

Distribution shift (DS) is a common problem that deteriorates the performance of learning machines. To tackle this problem, we postulate that real-world distributions are composed of elementary distributions that remain invariant across different environments. We call this an invariant elementary distribution (I.E.D.) assumption. The I.E.D. assumption implies an invariant structure in the solution space that enables knowledge transfer to unseen domains. To exploit this property in domain generalization (DG), we developed a modular neural network layer that consists of Gated Domain Units (GDUs). Each GDU learns an embedding of an individual elementary distribution that allows us to encode the domain similarities during the training. During inference, the GDUs compute similarities between an observation and each of the corresponding elementary distributions which are then used to form a weighted ensemble of learning machines. Because our layer is trained with backpropagation, it can naturally be integrated into existing deep learning frameworks. Our evaluation on image, text, graph, and time-series data shows a significant improvement in the performance on out-of-training target domains without domain information and any access to data from the target domains. This finding supports the practicality of the I.E.D. assumption and demonstrates that our GDUs can learn to represent these elementary distributions.

1. INTRODUCTION

A fundamental assumption in machine learning is that training and test data are independently and identically distributed (I.I.D.). This assumption ensures consistency-results from statistical learning theory, meaning that the learning machine obtained from an empirical risk minimization (ERM) attains the lowest achievable risk as sample size grows (Vapnik, 1998; Schölkopf, 2019) . Unfortunately, a considerable amount of research and real-world applications in the past decades has provided a staggering evidence against this assumption (Zhao et al., 2018; 2020; Ren et al., 2019; Taori et al., 2020) (see D 'Amour et al. (2020) for case studies). The violation of the I.I.D. assumption is usually caused by a distribution shift (DS) and can result in inconsistent learning machines (Sugiyama & Kawanabe, 2012) , implying the loss of performance guarantee of machine learning models in the real world. Therefore, to tackle DS, recent work advocates for domain generalization (DG) (Blanchard et al., 2011; Muandet et al., 2013; Li et al., 2017; 2018b; Zhou et al., 2021a) . This generalization to utterly unseen domains is crucial for robust deployment of the models in practice, especially when new, unforeseeable domains emerge after model deployment. However, the most important question that DG seeks to answer is how to identify the right invariance that allows for generalization. The contribution of this work is twofold. First, we advocate that real-world distributions are composed of smaller "units" called invariant elementary distributions that remain invariant across different domains; see Section 2.1. Second, we propose to implement this hypothesis through so-called gated domain units (GDUs) . Specifically, we developed a modular neural network layer that consists of GDUs. Each GDU learns an embedding of an individual elementary domain that allows us to express the domain similarities during training. For this purpose, we adopt the theoretical framework of reproducing kernel Hilbert space (RKHS) to retrieve a geometrical representation of each distribution in the form of a kernel mean embedding (KME) without information loss (Berlinet & Thomas-Agnan, 2004; Smola et al., 2007; Sriperumbudur et al., 2010; Muandet et al., 2017) . This representation accommodates methods based on analytical geometry to measure similarities between distributions. We show that these similarity measures can be learned and utilized to improve the generalization capability of deep learning models to previously unseen domains. The remainder of this paper is organized as follows: Our theoretical framework is laid out in Section 2 with our modular DG layer implementation shown in Section 3. In Section 4, we outline related work. Our experimental evaluations are presented in Section 5. Finally, we discuss potential limitations of our approach and future work in Section 6.

2. DOMAIN GENERALIZATION WITH INVARIANT ELEMENTARY DISTRIBUTIONS

We assume a mixture component shift for the multi-source DG setting. This shift refers to the most common DS stating that the data is made up of different sources, each with its own characteristics, and their proportions vary between the training and test scenario (Quinonero-Candela et al., 2022) . Our work thus differs in the assumption from related work in DG, in which the central assumption is the covariate shift (i.e., the conditional distribution of the source and test data stays the same) (David et al., 2010) . In the following, let X and Y be the input and output space, with a joint distribution P. We are given a set of D labeled source datasets {D s i } D i=1 with D s i ⊆ X × Y. Each of the source datasets is assumed to be I.I.D. generated by a joint distribution P s i with support on X × Y, henceforth denoted domain. The set of probability measures with support on X × Y is denoted by P. The multi-source dataset D s comprises the merged individual source datasets {D s j } D j=1 . We aim to minimize the empirical risk, see Section 3.3 for details. Important notation is summarized in Table 1 . Similar to Mansour et al. (2009; 2012) ; Hoffman et al. (2018a) , we assume that the distribution of the source dataset can be described as a convex combination P s = D j=1 α s j P s j where α s = (α s 1 , . . . , α s D ) is an element of the probability simplex, i.e., α s ∈ ∆ D := {α ∈ R D | α j ≥ 0 ∧ D j=1 α j = 1}. In other words, α j quantifies the contribution of each individual source domain to the combined source domain.

2.1. INVARIANT ELEMENTARY DISTRIBUTIONS

In contrast, we generalize their problem descriptions: We express the distribution of each domain as a convex combination of K elementary distributions {P j } K j=1 ⊂ P, meaning that P s = K j=1 α j P j where α ∈ ∆ K . Our main assumption is that these elementary distributions remain invariant across the domains. The advantage is that we can find an invariant subspace at a more elementary level, as opposed to when we consider the source domains as some sort of basis for all unseen domain. Figure 1 illustrates this idea. Theoretically speaking, the I.E.D assumption is appealing because it implies the invariant structure in the solution space, as shown in the following lemma. The proof is given in Appendix A.1. Lemma 1. Let L : Y × Y → R + be a non-negative loss function, F a hypothesis space of functions f : X → Y, and P s (X, Y ) a data distribution. Suppose that the I.E.D assumption holds, i.e., there exist K elementary distributions P 1 , . . . , P K such that any data distribution can be expressed as P s (X, Y ) = K j=1 α j P j (X, Y ) for some α ∈ ∆ K . Then, the corresponding Bayes predictor f * ∈ arg min f ∈F E (X,Y )∼P [L(Y, f (X))] is Pareto-optimal with respect to a vector of elementary risk functionals (R 1 , . . . , R K ) : F → R K + where R j (f ) := E (X,Y )∼Pj [L(Y, f (X))]. Lemma 1 implies that, under the I.E.D assumption, Bayes predictors must belong to a subspace of F called the Pareto set F Pareto ⊂ F which consists of Pareto-optimal models. The model f is said to be Pareto-optimal if there exists no g ∈ F such that R j (g) ≥ R j (f ) for all j ∈ {1, . . . , K} with R j (g) > R j (f ) for some j; see, e.g., Sener & Koltun (2018, Definition 1) . In other words, the I.E.D assumption allows us to translate the invariance property of data distributions to the solution space. Since Bayes predictors of all future test domains must lie within the Pareto set, which is a  V 1 • µ V1 • µ V2 V 2 x i • ϕ(x i ) β i1 β i2

During training

During inference Hence, the first challenge during the training phase (left panel) is to extract these elementary distributions from the observed data (orange). The unobserved elementary distributions are represented by the elementary bases V 1 and V 2 (cyan and pink). The second challenge during the inference phase (right panel) is to create a weighted ensemble of learning machines that utilize the similarities between the embedding of the unseen observation ϕ(x i ) and the embeddings of these distributions µ V1 and µ V2 in the RKHS H (green rectangle) as weights β i1 and β i2 . strict subset of the original hypothesis space, it is still possible to identify the optimal predictors of future test domains, even without additional data from the test domains, except the I.E.D. assumption itself. Hence, given data from the training domains, it is sufficient for the purpose of generalization to maintain only solutions within this Pareto set during the training time. Unfortunately, neither the elementary distributions nor the weights α are known in practice. Motivated by this theoretical insight, our DG layer presented in Section 3 is designed to uncover them from a multi-source training dataset D s . While Lemma 1 shows the theoretical appeal of the I.E.D. assumption, we discuss below a situation in which it might hold in practice. The limitations will be discussed later in Section 6. Real-world example. In this work, we postulate that the elementary domain bases are the invariant subspaces that allow us to generalize to unseen domains. In practice, the question arises if and when elementary domains evolve. Consider that we aim to learn to predict the risk of developing Diabetes from laboratory data from Europe and then infer the risk from data from the United States of America. Naturally, factors influencing the data-generating process may change, such as the level of physical activity and nutritional habits. While, to a certain degree, these common factors remain invariant across continents, each of these factors' contributions may differ. In terms of our assumptions, we model each of these factors with a corresponding elementary distribution P j . For a previously unseen individual, we can then determine the coefficients α s j and quantify each factor's contribution without any information about the individual's origin.

2.2. KERNEL MEAN EMBEDDING OF DISTRIBUTIONS

We leverage the KME of distributions (Berlinet & Thomas-Agnan, 2004; Smola et al., 2007; Muandet et al., 2017) to discover the elementary distributions and evaluate similarities between them. Let H be a reproducing kernel Hilbert space (RKHS) of real-valued functions on X with a reproducing kernel k : X × X → R (Schölkopf et al., 2001) . The KME of a probability measure P ∈ P in the RKHS H is defined by a mapping ϕ(P) = µ P := X k(x, •) dP(x). We assume that the kernel k is characteristic, i.e., the mapping µ P is injective (Fukumizu et al., 2004; Sriperumbudur et al., 2008) . Theoretically, this essential assumption ensures that there is no information loss when mapping the distribution into H. Given the samples {x 1 , . . . , x n } generated I.I.D. from P, µ P can be approximated by the empirical KME μP = (1/n) n i=1 k(x i , •) = (1/n) n i=1 ϕ(x i ). We refer non-expert readers to Muandet et al. (2017) for a thorough review on this topic. Challenges. Figure 1 depicts two challenges that come with our assumption of elementary distributions. First, since we do not have access to the samples from the hidden elementary distributions, the elementary KME cannot be estimated directly from the samples at hand. To overcome this challenge, we instead seek a proxy KME µ Vj := (1/N ) N k=1 ϕ(v j k ) = (1/N ) N k=1 k(v j k , •) for each elementary KME µ Pj from a domain basis V j , where V j = {v j 1 , . . . , v j N } ⊆ X for all j ∈ {1, . . . , M }. Hence, the KME µ Vj can be interpreted as the KME of the empirical probability measure PVj = (1/N ) N k=1 δ v j k . Here, we assume that M = K. The sets V j are referred to as elementary domain basis. Intuitively, the elementary domain basis V 1 , . . . , V M represents each elementary distribution by a set of vectors that mimic samples generated from the corresponding distribution. In Figure 1 , V 1 and V 2 as well as their mapping in H visualize this first challenge. The second challenge is the objective of learning the unknown similarity between a single sample x i and an elementary domain V j , which we denote by β ij . Considering the advantage of KMEs, that is to tackle this challenge from a geometrical viewpoint, we quantify similarities between KMEs. For example, in Figure 1 , the similarity between ϕ(x i ) and µ V1 (β i1 ) and µ V2 (β i2 ) could be quantified as their distance or angle. These similarity coefficients enable our Domain Generalization Layer to represent a convex combination of elementary domain-specific learning machines, commonly known as ensembles. We introduce our layer in the following Section 3.

3. DOMAIN GENERALIZATION LAYER

This section aims to transfer the theoretical ideas presented in Section 2 into a deep learning framework. For the purpose of implementation, let x ∈ R h×w denote the input data point and h ξ : R h×w → R e the feature extractor (FE) that maps the input into a low-dimensional representation x ∈ R e . Then the prediction layer g θ : R e → Y infers the label y. To tackle the DG problem, we introduce a layer module called the gated domain unit (GDU). A GDU consists of three main components: (1) a similarity function γ : H × H → R that is the same for all elementary domains, (2) an elementary basis V j and (3) a learning machine f (x, θ j ) for each elementary domain j ∈ {1, . . . , M }. The architecture of the layer proposed herein is depicted in Figure 2 . The DG layer consists of several GDUs that represent the elementary distributions. During training, these GDUs learn the elementary domain bases V 1 , . . . , V M that approximate these distributions. Σ x GDU1 β1 f (x; θ1) ŷ1 • • • • • • GDUM βM f (x; θM ) ŷM γ(ϕ(x), µ Vj ) Vj µV j x ϕ ϕ β j Essentially, the process is as follows: First, the j-th GDU takes xi as an input and yields β ij as an output. The KME of each domain basis V j is required in order to apply γ to compute similarity between xi and V j . These KMEs are obtained by ϕ( V j ) := µ Vj = (1/N ) N k=1 ϕ(v j k ) = (1/N ) N k=1 k(v j k , •). The GDU, therefore, has the task to allocate coefficients β ij for each elementary domain based on a similarity function γ. The function γ outputs the β ij = γ(ϕ(x i ), µ Vj ) coefficients that in turn represent similarities between the KME of both, the corresponding domain basis V j and the input xi . Theoretically speaking, µ Vj and the feature mapping ϕ(x i ) are elements of the associated RKHS H, which allow us to evaluate similarities of non-linear features in a higher dimensional feature space. Each GDU is then connected to a learning machine f (x i , θ j ) that yields an elementary domain-specific inference. The final prediction of the layer is then an ensemble of these learning machines g θ (x i ) = M j=1 β ij f (x i , θ j ) where θ = (θ 1 , . . . , θ M ). In Figure 2 , we give an overview of how data is processed and information is stored in the GDU. In summary, GDUs leverage the invariant elementary distribution (I.E.D.) assumption and represent our algorithmic contribution: The elementary domain bases are stored as weights in the layer. Storing information as a weight matrix (i.e., domain memory) allows to learn the elementary domain bases efficiently using backpropagation. Hence, we avoid the dependency on problem-adaptive methods (e.g., domain-adversarial training) and domain information (e.g., domain labels).

3.1. DOMAIN SIMILARITY MEASURES

For the similarity function γ, we consider two similarity measures H(ϕ(x), µ Vj ), namely the cosine similarity (CS) (Kim et al., 2019) and maximum mean discrepancy (MMD) (Borgwardt et al., 2006; Gretton et al., 2012) . To ensure that the resulting coefficients β i lie on the probability simplex, we apply the kernel softmax function (Gao et al., 2019) and interpret its output as the similarity between an observation x and an elementary domain basis V i . We get β ij = γ(ϕ(x i ), µ Vj ) = exp κH(ϕ(x i ), µ Vj ) M k=1 exp κH(ϕ(x i ), µ V k ) , where κ > 0 is a positive softness parameter for the kernel softmax. Geometrically speaking, these similarities correspond to the angle and distance of two KMEs in the RKHS H. The function ϕ maps the observation x and domain basis V j into H meaning that ϕ(x) = µ δx = k(x, •) is the KME of a Dirac measure δ x and ϕ(V j ) = µ Vj = (1/N ) N k=1 k(v j k , •). CS. The CS function H(ϕ(x i ), µ Vj ) = ⟨ϕ(xi),µ V j ⟩ H ∥ϕ(xi)∥ H ∥µ V j ∥ H is used as an angle-based similarity. MMD. We consider the MMD for calculating a distance-based similarity measure. The distance is then given as ∥ϕ(x i ) -µ Vj ∥ H . Subsequently, the similarity function H is the negative MMD: H(ϕ(x i ), µ Vj ) = -∥ϕ(x i )-µ Vj ∥ H . The intuition behind the negative MMD is to put higher weights on samples that are closer to the KME of an elementary domain basis.

3.2. PROJECTION-BASED GENERALIZATION

For classification tasks, we introduce an alternative approach to infer the β i coefficients that is based on the idea of kernel sparse coding (Gao et al., 2010; 2013) . Herein the goal is to find an approximated representation of each feature mapping ϕ(x i ) using the elements of a dictionary {µ Vj } M j=1 . This approach allows us to approximate the feature mapping with these elements by ϕ(x i ) ≈ M j=1 β ij µ Vj . In contrast to the aforementioned approaches, an elementary domain KME µ Vj does not necessarily represent the KME of an elementary distribution µ Pj . Therefore, we present another approach that aims to find a set {µ Vj } M j=1 that permits µ P s to be represented as a linear combination. Since P is assumed to be a convex combination of elementary distributions, we can find a linear combination to represent µ P s by the domain KMEs µ Vj , as long as µ P s ∈ H M := span{µ Vj | j = 1, . . . , M }. The RKHS H M is a subspace of the actual RKHS H, which allows us to represent elements of H at least approximately in the subspace H M . By keeping the H M large, we gain more representative power. To make H M as large as possible, we have to ensure its spanning elements are linearly independent or, even better, orthogonal. Orthogonal KMEs ensure two desirable properties. First, pairwise orthogonal elements in H M guarantee no redundancy. Second, having orthogonal elements allows us to make use of the orthogonal projection. This projection geometrically yields the best approximation of ϕ(x) in H M . In other words, we can achieve the best possible approximation of the feature mapping by using its orthogonal components (see Proposition 3.1). The orthogonal projection is given by Π H M : H → H M , ϕ(x) → M i=1 ⟨ϕ(x), µ Vj ⟩ H ∥µ Vj ∥ 2 H µ Vj . (2)  L(g) = 1 b b i=1 L(ŷ i , y i ) = 1 b b i=1 L( M j=1 γ(ϕ(x i ), µ Vj )f j (x i ), y i ) for an underlying task and the respective batch size. In addition, our objective is that the model learns to distinguish between different domains. Thus, the regularization Ω D is introduced to control the domain basis. In our case, we require the regularization Ω D to ensure that the KMEs of the elementary domain basis are able to represent the KMEs of the elementary domains. Therefore, we minimize the MMD between the feature mappings ϕ(x i ) and the associated representation M j=1 β ij µ Vj . Note that β ij = γ(ϕ(x i ), µ Vj ). Hence, the regularization Ω D = Ω OLS D is defined as Ω OLS D ∥g∥ H = 1 b b i=1 ∥ϕ(x i ) - M j=1 β ij µ Vj ∥ 2 H (see Appendix B.2 for details). The intuition is the objective to represent each feature mapping ϕ(x i ) by the domain KMEs µ Vj . Thus, we try to minimize the MMD between the feature map and a combination of µ Vj . The minimum of the stated regularization can be interpreted as the ordinary least square-solution of a regression-problem of ϕ(x i ) by the components of H M . In other words, we want to ensure that the basis V j is contained in feature mappings ϕ(x i ). In the particular case of projection, we want the KME of the elementary domain to be orthogonal to ensure high expressive power. For this purpose, the additional term Ω ⊥ D will be introduced to ensure the desired orthogonality. Considering a kernel function with k(x, x) = 1, orthogonality would require the Gram matrix K ij = ⟨µ Vi , µ Vj ⟩ H to be close to the identity matrix I. There are a variety of methods for regularizing matrices available (Xie et al., 2017; Bansal et al., 2018) . A well-known method to ensure orthogonality is the soft orthogonality (SO) regularization Bansal et al., 2018) . As pointed out by Bansal et al. (2018) , the spectral restricted isometry property (SRIP) and mutual coherence (MC) regularization can be a promising alternative for SO and thus are additionally implemented in the DG layer. Hence, in the case of projection, the regularization is given by Ω ⊥ D = λ∥K -I∥ 2 F ( Ω D ∥g∥ H = λ OLS Ω OLS D ∥g∥ H + λ ORT H Ω ⊥ D ∥g∥ H , λ OLS , λ ORT H ≥ 0. Lastly, sparse coding is an efficient technique to find the least possible basis to recover the data subject to a reconstruction error (Olshausen & Field, 1997) . Several such applications yield strong performances, for example in the field of computer vision (Lee et al., 2007; Yang et al., 2009) . Kernel sparse coding transfers the reconstruction problem of sparse coding into H by using the mapping ϕ, and, by applying a kernel function, the reconstruction error is quantified as the inner product (Gao et al., 2010; 2013) . To ensure sparsity, we apply the L 1 -norm on the coefficients β and add Ω L1 D (∥γ∥) := ∥γ(ϕ(x i ), µ Vj )∥ 1 to the regularization term Ω D with the corresponding coefficient λ L1 . Appendix B.3 gives a visual overview of the model training.

4. RELATED WORK

DG, also known as out-of-distribution (OOD) generalization, is among the hardest problems in machine learning (Blanchard et al., 2011; Muandet et al., 2013; Arjovsky et al., 2019) . In contrast, DA, which predates DG and OOD problems, deals with a slightly simpler scenario in which some data from the test distribution are available (Ganin et al., 2015) . Hence, based on the available data, the task is to develop learning machines that transfer knowledge learned in a source domain specifically to the target domain. Approaches pursued in DA can be grouped primarily into (1) discrepancy-based DA (Sun et al., 2016; Peng & Saenko, 2018; Ben-David et al., 2010; Fang et al., 2020; Tzeng et al., 2014; Long et al., 2015; Baktashmotlagh et al., 2016) (2) adversary-based DA (Tzeng et al., 2017; Liu & Tuzel, 2016; Ganin et al., 2015; Long et al., 2018) , and (3) reconstruction-based DA (Bousmalis et al., 2016; Hoffman et al., 2018b; Kim et al., 2017; Yi et al., 2017; Zhu et al., 2017; Ghifary et al., 2014) . In DA, learning the domain-invariant components requires access to unlabeled data from the target domain. Unlike problems in DA, where the observed data from the test domains can be used to find the most appropriate invariant structures (Ben-David et al., 2010) , the lack thereof in DG calls for a postulation of invariant structure that will enable the OOD generalization. To enable generalization to unseen domains without any access to data from them, researchers have made significant progress in the past decade and developed a broad spectrum of methodologies (Zhou et al., 2021a; c; Li et al., 2019; Blanchard et al., 2011) . For thorough review see, e.g., Zhou et al. (2021a) ; Wang et al. (2021) . Existing works can be categorized into methods based on domaininvariant representation learning (Muandet et al., 2013; Li et al., 2018b; d) , meta-learning (Li et al., 2018a; Balaji et al., 2018) , data augmentation (Zhou et al., 2020) , to name a few. Another recent stream of research from a causal perspective includes invariant risk minimization (Arjovsky et al., 2019) , invariant causal prediction (Peters et al., 2016) , and causal representation learning (Schölkopf et al., 2021) . The overall motivation here is to learn the representation that is robust to domain-specific spurious correlations. In other words, it is postulated that "causal" features are the right kind of invariance that will enable OOD generalization. Despite the successful applications, DG remains a challenging research gap. We differentiate our work from existing ones as follow. First, we postulate the existence of domaininvariant structure at the distributional level rather than at the data representation, which is a common assumption in DG. This is motivated by theoretical results (Mansour et al., 2009; Hoffman et al., 2018a) , stating that a distribution-weighted combination of source hypotheses represents the ideal hypothesis. Furthermore, our distributional assumption, as we argued in Section 2, generalizes previous work that proposes to use domain-specific knowledge to tackle the problem of DG from a more elementary setting. For example, approaches such as Piratla et al. ( 2020); Monteiro et al. (2021) can be compared to our GDUs as domain-specific predictors, in the special case, where each elementary domain represents a single source domain. However, GDUs do not assume the existence of a single common classifier for all the domains, providing a combination of multiple common classifiers shared between different source domains. Second, we incorporate the I.E.D. assumption directly into our model's architecture, as shown in Figure 2 . Designing effective architectures for DG has been largely neglected (Zhou et al., 2020, Sec. 4.1) . Last, we do not assume access to domain information. Although obtaining such information can be difficult in practice, see our short discussion in Appendix C.4 (Niu et al., 2017) , DG methods that can deal with their absence (e.g., Huang et al. (2020) ; Carlucci et al. (2019) ; Li et al. (2018c) ) are yet scarce (Zhou et al., 2020, Sec. 4.2) .

5. EXPERIMENTS

Since ERM is one of the strongest baselines in DG (Gulrajani & Lopez-Paz, 2020; Koh et al., 2021) , we, first, compare our approach compared to ERM and ensemble learning (Table 2 and Appendix  C.1 ). Second, we benchmark our approach to state-of-the-art DG (e.g., CORAL, LISA, IRM, FISH, Group DRO) methods focusing on image, graph, and text data (Table 3 and Appendix C.4 ). Third, we analyse the GDUs robustness gainst DS that occurs in daily clinical practice (Table 12 and Appendix  C.3 ). Finally, in Appendix C.2, we conduct an ablation study focusing on the representation learned during training (Appendix C.2.2). In our experiments, we distinguish two modes of training the DG layer: fine tuning (FT), where we extract features using a pre-trained model, and end-to-end training (E2E), where the FE and the DG layer are jointly trainedfoot_0 .

5.1. PROOF-OF-CONCEPT BASED ON DIGITS CLASSIFICATION

Following Feng et al. (2020) among others, we create a multi-source dataset by combining five publicly available digits image datasets, namely MNIST (Lecun et al., 1998) , MNIST-M (Ganin & Lempitsky, 2015) , SVHN (Netzer et al., 2011) , USPS, and Synthetic Digits (SYN) (Ganin & Lempitsky, 2015) . The task is to classify digits between zero and nine. Each of these datasets is considered an out-of-training target domain which is inaccessible during training, and the remaining four are the source domains. Details are given in Appendix C.1. Table 2 summarizes the results for the most challenging out-of-training target domain, namely MNIST-M. In Appendix C.1, we provide the results on the remaining target domains in Table 7 and a discussion heuristics for choosing hyperparameters for our GDUs. Our method noticeably improves for all datasets mean accuracy and decreases the standard deviation in comparison to the ERM and ensemble baselines, making the results more stable across the ten iterations reported. 9 , and the strength of the regularization terms on Figure 6 , Figure 7 , and Figure 8 to assess the sensitivity of the DG layer to the choice of hyperparameters, (B) visualize the output of the FE (Figure 11 ). Our ablation study in (A) reveals stable results across different sets of hyper-parameters. While the layer is not sensitive to the choice of regularization strength, we recommend not to omit the regularization completely, although the computational expenses decrease without the orthogonal regularization. As an illustration in (B), we project the output of the FE trained with a dense layer (ERM) and with the DG layer by t-SNE (t-distributed stochastic neighbor embedding). The GDU-trained FE yields more concentrated and bounded clusters in comparison to the one trained by ERM. Hence, we observe a positive effect on the representation learned by the FE.

5.2. WILDS BENCHMARK

To challenge the I.E.D. assumption and the OOD generalization capabilities of the GDUs, we use WILDS, a curated set of real-world experiments for benchmarking DG methods (Koh et al., 2021) . Further, WILDS is a semi-synthetic dataset set that operates under similiar assumptions as the source component shift (Koh et al., 2021) . We consider the following eight datasets: Camelyon17, FMoW, FMoW, Amazon, iWildCam, and RxRx1, OGB-MolPCBA, Civil-Comments, and PovertyMap, which represent the task of real-world DG. We closely follow Koh et al. (2021) for the experiments. Details on datasets and benchmark methods are given in Appendix C.4. We present our benchmarking in Table 3 . Our results are achieved out-of-the-box (i.e., default parameters) since hyperparameter optimization has a substantial impact on the generalization performance (Gulrajani & Lopez-Paz, 2020) , and we aim to highlight the improvements solely attributable to our GDUs. First, we observe the strengths and weaknesses of the benchmarks in the different data sets, all of which are lower than ERM at least once. In contrast, although GDUs show similar behavior across the datasets, performing very well for some datasets (e.g., FMoW, Poverty Map), they, however, do not fall below ERM across all GDU experiments conducted. In addition, the baselines require domain information. Our approach requires less information, yet, achieving comparable results to the benchmarks.

5.3. ECG EXPERIMENT

The PhysioNet/Computing in Cardiology Challenge 2020 (Perez Alday et al., 2021; Goldberger et al., 2000; Perez Alday et al., 2020) aims to identify clinical diagnoses from 12-lead ECG recordings from 6 different databases. This publicly available pooled dataset contains 43,101 recordings sampled with various sampling frequencies and lengths. Each recording is labeled as having one or more of 24 cardiac abnormalities; hence, the task is to perform a multi-label binary classification. For our experiment, we iterate over the databases, taking one at a time as the test domain while utilizing 

6. CONCLUSION AND DISCUSSIONS

We introduced the I.E.D. assumption, postulating that real-world distributions are composed of elementary distributions that remain invariant across different domains and showed that it implies an invariant structure in the solution space that enables knowledge transfer to unseen domains. Empirical results based on real-world data support the practicality of the I.E.D. assumption and that we can learn such a representation. Further, we presented a modular neural network layer consisting of Gated Domain Units (GDUs) that leverage the I.E.D. assumption. Our GDUs can substantially improve the downstream performance of learning machines in real-world DG tasks. Across our experiments, we observed that for some datasets FT is better than E2E and vice versa. In E2E training, the feature extractor (encoder) is jointly trained with GDUs. Hence, the latent representation is stochastic during training, meaning that we have variability in the representation fed into GDUs between epochs. In contrast, in FT, the feature extractor is pretrained and always produces the same embedding. Especially with large feature extractors such as ResNet-50, learning the elementary domains can be more effective when we avoid any stochasticity in the latent representation. Limitations. A major limitation of our I.E.D. assumption is to provide theoretical evidence that this assumption holds in practice. We aim to expand the scope of the theoretical understanding of the I.E.D. assumption and the GDUs. In addition, the particular theoretical setting of Albuquerque et al. ( 2019) (i.e., each elementary domain represents a source domain) seems promising to extend their generalization guarantee to cases where our I.E.D. assumption holds. Second, our GDU layer induces additional computational overhead due to the regularization and model size that increases as a function of the number of elementary domains. Noteworthy, our improvement is achieved with a relatively small number of elementary domains indicating that the increased complexity is not a coercive consequence of applying the DG layer. Also, the results achieved are not a consequence of increased complexity, as the ensemble baseline shows. Future work We expect the I.E.D. assumption and GDUs to be adapted, yielding novel applications that tackle DG. For example, we suggest dynamically increasing the number of elementary domains during learning until their distributional variance reaches a plateau as a measure of their heterogeneity. Hence, one would learn the number of elementary domains instead of fixing the number of elementary domains prior to training.

A PROOFS

A.1 PROOF OF LEMMA 1 Proof. The result holds trivially for K = 1. For K ≥ 2 and by the I.E.D assumption, P s (X, Y ) = K j=1 α j P j (X, Y ) for some α ∈ ∆ K . Then, we can write the risk functional for each f ∈ F as R(f ) = L(y, f (x)) dP s (x, y) = L(y, f (x)) d( K j=1 α j P j (x, y)) = K j=1 α j L(y, f (x)) dP j (x, y) = K j=1 α j R j (f ) where R j : F → R + is the elementary risk functional associated with the elementary distribution P j (X, Y ). Hence, the Bayes predictors satisfy f * ∈ arg min f ∈F R(f ) = arg min f ∈F K j=1 α j R j (f ). (A.3) Since the rhs of equation A.3 corresponds to the linear scalarization of a multi-objective function (R 1 , . . . , R K ), its solution (i.e., a stationary point) is Pareto-optimal with respect to these objective functions (Ma et al., 2020, Definition 3.1) ; see, also, (Hillermeier, 2001a; b) . That is, the Bayes predictors for the data distribution that satisfies the I.E.D assumption must belong to the Pareto set F Pareto := {f * : f * = arg min f ∈F K j=1 α j R j (f ), α ∈ ∆ K } ⊂ F. A.2 PROOF OF PROPOSITION 3.1 Proof. Suppose we have a representation, µ P = M j=1 β j µ Vj ⟨µ Vi , µ Vi ⟩ H = 0 ∀i ̸ = j, i.e. {µ V1 , . . . , µ Vm } are pairwise orthogonal. We want to minimize the MMD by minimizing ∥µ P - M j=1 β j µ Vj ∥ 2 H = ⟨µ P , µ P ⟩ H ∥µ P ∥ 2 H = -2⟨µ P , M j=1 β j µ Vj ⟩ H + ⟨ M i=1 β i µ Vi , M j=1 β j µ Vj ⟩ H (A.2) = ∥µ P ∥ 2 H -2 M j=1 β j ⟨µ P , µ Vj ⟩ H + M i=1 M j=1 β i β j ⟨µ Vi , µ Vj ⟩ H δij ⟨µ V i ,µ V j ⟩ H = (A.3) = ∥µ P ∥ 2 H -2 M j=1 β j ⟨µ P , µ Vj ⟩ H + M j=1 β 2 j ∥µ Vj ∥ 2 H . (A.4) By defining Φ(β) := ∥µ P - M j=1 β j µ Vj ∥ 2 H , (A.5) we can simply find the optimal β j by using the partial derivative ∂Φ ∂β j = -2⟨µ P , µ Vj ⟩ H + 2β j ∥µ Vj ∥ 2 H ! = 0 (A.4) ⇔ β j ∥µ Vj ∥ 2 H = ⟨µ P , µ Vj ⟩ H (A.5) ⇔ β * j = ⟨µ P , µ Vj ⟩ H ∥µ Vj ∥ 2 H . (A.6) Please note that the function Φ is convex. As written in Section 2.1, we postulate that the elementary domain bases are the invariant subspaces that allow us to generalize to unseen domains. In practice, the question arises if and when elementary domains evolve. Consider that we aim to learn to predict the risk of developing Diabetes from laboratory data from Europe and then infer the risk from data from the United States of America. Naturally, factors influencing the data-generating process may change, such as the level of physical activity and nutritional habits. While, to a certain degree, these common factors remain invariant across continents, each of these factors' contributions may differ. In terms of our assumptions, we model each of these factors with a corresponding elementary distribution. Figure 3 depicts our assumption and how it differs from existing worksfoot_1 . To exploit this assumption in out-of-distribution (OOD) generalization, we developed a modular neural network layer that consists of so-called Gated Domain Units (GDUs). In Figure 4 , we visualized the fundamental concept of the GDUs. Each GDU learns an embedding of an individual elementary domain that allows us to encode the domain similarities during the training. During inference, the GDUs compute similarities between observation and each of the corresponding elementary distributions, which are then used to form a weighted ensemble of learning machines. In other words, for a previously unseen individual, we aim to determine the coefficients and quantify each factor's contribution without any information about the individual's origin.

B.2 DETAILED VIEW OF THE REGULARIZATION TERM Ω OLS D

First, consider the following single term ∥ϕ(x i ) - M j=1 β ij µ Vj ∥ 2 H that can be expressed as ∥ϕ(x i ) - M j=1 β ij µ Vj ∥ 2 H = ∥ϕ(x i )∥ 2 H (1) -2 ⟨ϕ(x i ), M j=1 β ij µ Vj ⟩ H (2) + ∥ M j=1 β ij µ Vj ∥ 2 H (3) . (B.1) AD (1): We begin with Term (1) and write ∥ϕ( x i )∥ 2 H as ∥ϕ(x i )∥ 2 H = ⟨ϕ(x i ), ϕ(x i )⟩ H = k(x i , xi ). We could evaluate this term using the kernel function k for each data point in the batch b. However, since this ! !"#! ! Gated Domain Units x x x DE CH AT USA

During Training

we extract the set of unobserved elementary distributions that remain invariant from the observed data.

During Inference

we create a weighted ensemble using similarities between unseen observation and elementary distributions (! !$%&' , $ !$%&' ) Figure 4 : Visualization of the concept of the Gated Domain Unit and how they are leveraged to build distributionally weighted ensembles of learning machines. term does not depend on the the elementary domains {V 1 , . . . , V M }, it is unnecessary to compute this value to minimize the penalty. Thus, we obtain a similar result by minimizing the penalty without considering ∥ϕ(x i )∥ 2 H in the regularization.

AD (2):

Term (2) can be expressed as ⟨ϕ(x i ), M j=1 β ij µ Vj ⟩ H = M j=1 β ij ⟨ϕ(x i ), µ Vj ⟩ H (B.2) Implementation-wise, the evaluation of this term requires the calculation of the inner product ⟨ϕ(x i ), µ Vj ⟩ H . Since our CS and projection-based methods involve this inner product to determine the coefficients β ij , we pre-compute the inner product ⟨ϕ(x i ), µ Vj ⟩ H once for a mini-batch and store these information during training to avoid multiple calculations of the same term. Moreover, the projection-based method does not apply softmax and has a linear form. Therefore, the term (2) can be simplified even further: ⟨ϕ(x i ), M j=1 β ij µ Vj ⟩ H = M j=1 β ij ⟨ϕ(x i ), µ Vj ⟩ H (B.3) = M j=1 ⟨ϕ(x i ), µ Vj ⟩ H ∥µ Vj ∥ 2 H ⟨ϕ(x i ), µ Vj ⟩ H (B.4) = M j=1 ⟨ϕ(x i ), µ Vj ⟩ 2 H ∥µ Vj ∥ 2 H . (B.5) AD (3): Last, we express the term (3) as follows ∥ M j=1 β ij µ Vj ∥ 2 H = M j=1 M k=1 β ij β ik ⟨µ Vj , µ V k ⟩ H , (B.6) and calculate the inner product of the domains ⟨µ Vj , µ V k ⟩ H by ⟨µ Vj , µ V k ⟩ H = 1 N 2 N l=1 N m=1 ⟨ϕ(v l j ), ϕ(v m k )⟩ H (B.7) = 1 N 2 N l=1 N m=1 k v l j , v m k =: K jk , (B.8) where N represents the number of vectors per domain basis. Note that this term does not depend on the input data x i and, hence, matrix K jk can be calculated once at the beginning of the optimization step and stored to be re-used for all the data point of a batch. Combining Equation B.6 and Equation B.8 yields ∥ M j=1 β ij µ Vj ∥ 2 H = M j=1 M k=1 β ij β ik ⟨µ Vj , µ V k ⟩ H (B.9) = 1 N 2 M j=1 M k=1 β ij β ik N l=1 N m=1 k v l j , v m k (B.10) = M j=1 M k=1 β ij β ik K jk (B.11) = β T i K jk β i . (B.12) As a final step, we use the results for Term (1), (2), and (3) to obtain the desired regularization term Ω OLS D = 1 b b i=1 ∥ϕ(x i ) - M j=1 β ij µ Vj ∥ 2 H (B.13) = 1 b b i=1 ∥ϕ(x i )∥ 2 H -2⟨ϕ(x i ), M j=1 β ij µ Vj ⟩ H + ∥ M j=1 β ij µ Vj ∥ 2 H . (B.14) As mentioned above, ∥ϕ(x i )∥ 2 H is independent from the elementary domains, and, thus a constant in the regularization. Hence, we can exclude this term, which avoids additional computational effort. 

C EXPERIMENTS

In this section, we provide a detailed description of the DG experiment presented in Section 5. Our Digits and ECG experiments are implemented using TensorFlow 2.4.1 and TensorFlow Probability 0.12.1. For the WILDS benchmarking we use our PyTorch (version 1.11.0). All source code will be made available on GitHub https://github.com/ (TensorFlow) and https://github. com/ (PyTorch). Overall, our experiments aim to show the validity of the invariant elementary distribution (I.E.D.) assumption and the Gated Domain Units (GDUs). For the DG layer, we considered two modes of model training: fine tuning (FT) and end-to-end training (E2E). In FT scenario, we first pre-train the FE in the ERM single fashion. Then, we extract features using the pre-trained model and pass them to the DG layer for training the latter. For the E2E training, however, the whole model including the FE and DG layer is trained jointly from the very beginning. Zhao et al. (2018) . Each dataset, except USPS, is split into training and test sets of 25,000 and 9,000 images, respectively. For USPS, we take the whole dataset for the experiment since it contains only 9,298 imagesfoot_3 . Our experimental setup regarding datasets, data loader, and FE are based on existing work (Feng et al., 2020; Peng et al., 2019) . The structure of the FE is summarized in Table 5 and the subsequent learning machine is a dense layer. h θ x Σ H M = span{µ Vi |i = 1, . . . , M } L(ŷ, y) + Ω(∥f ∥ H ) + µ V1 β 1 (W 1 x + b 1 ) µ V2 β 2 (W 2 x + b 2 ) . . . . . . µ VM β M (W M x + b M ) ϕ(x) γ(ϕ(x), µ VM ) VM ϕ(VM ) µV M x (x, β M ) In the Empirical Risk Minimization (ERM) single experiment, we add a dense layer with 10 outputs (activation=tanh) as a classifier to the FE. In the Empirical Risk Minimization (ERM) ensemble experiment, we add M classification heads (a dense layers with 10 outputs and tanh activation each) to the FE and average their output for the final prediction. This sets a baseline for our DG layer to show performance gain against the ERM model with the same number of learning machines. For training, we resorted to the Adam optimizer with a learning rate of 0.001. We used early stopping and selected the best model weights according to the validation accuracy. For the validation data, we used the combined test splits only of the respective source datasets. The batch size was set to 512. Although the DG layer requires more computation resources than the ERM models, all digits experiments were conducted on a single GPU (NVIDIA GeForce RTX 3090). Heuristics for main parameter of DG layer From a practical perspective, our layer requires choosing two main hyper-parameters: the number of elementary domains M and since we use the characteristics Gaussian kernel the corresponding parameter σ. The parameter M determines the size of the ensemble of learning machines and, thus, for deep learning models, their overall network size. As a heuristic to choose M , we suggest to cluster the output of a pre-trained FE. In the following, we provide an example. We pre-trained the FE for the test domain MNIST-M and pass the source data through this FE, which we cluster with the k-means algorithm. Subsequently, we analyse three different metrics (Calinski Harabasz score, Davies Bouldinn score, and Silhouette score) to select the optimal number of clusters as the basis to choose M . All scores yielded an accordance between four to five clusters. Therefore, we set M to five and observed in Table 2 in Section 5 strong results in the generalizing to the unseen test domain MNIST-M. As for the parameter σ, we resort to the median heuristic proposed in (Muandet et al., 2016) that is σ 2 = median{ ∥ xi -xj ∥ 2 : i, j = 1, . . . , n}. While both heuristics require a pre-trained FE, cross-validation can act as a reasonable alternative. The hyper-parameters relevant for the DG layer are summarized in Table 6 . In the FT setting, we applied the median heuristics presented above to estimate σ of the Gaussian kernel function, where the estimator is denoted as σ. Since median heuristic is not applicable for the E2E scenario, σ was fixed to 7.5 for E2E. Note that our approach to choose the relevant parameters was kept very general to show the feasibility of the I.E.D. assumption and the generalization ability of GDUs and, most importantly, to provide easy-to-reproduce results. During training, additional epoch metrics can be subscribed using our custom DG layer callback, which may help to choose the model parameters. Furthermore, we observed that the elementary domains become naturally orthogonal during the experiments, and thus, we set λ ORT H relatively small. Since the orthogonal regularization puts additional computational burden, one could omit this term completely to speed up training. Table 6 : Parameters for DG Layer in Digits and Digit-DG Experiments for the Fine Tuning (FT) and End-to-end training (E2E) Settings. In case of Projection, we chose the spectral restricted isometry property (SRIP) as the orthogonal regularization 8 . Therefore, we follow their instructions to conduct a fair comparison and ensure reproducibility. For the hyper-parameters, however, we kept the same values that we used for the Digits experiment, see Table 6 . As a first method, we consider the CCSA (Classification and Contrastive Semantic Alignment) method, which learns a domain-invariant representation by utilizing the CCSA loss (Motiian et al., 2017) . Second, MMD-AAE (Maximum Mean Discrepancy-based Adverserial Autoencoders) extends adverserial autoencoders by a maximum mean discrepancy regularization to learn a domain-invariant feature representation (Li et al., 2018b) . CrossGrad (Cross-Gradient) augments data by perturbating the input space using the cross-gradients of a label and domain predictor (Shankar et al., 2018) . Ω ⊥ D . EXPERIMENT M N λ L 1 λ OLS λ ORT H σ κ FT CS 5 10 1e -3 1e -3 - σ 2 MMD 5 10 1e -3 1e -3 - σ 2 PROJECTION 5 10 1e -3 1e -3 1e -8 σ - E2E CS 5 10 1e -3 1e -3 - 7.5 2 MMD 5 10 1e -3 1e -3 - 7.5 2 PROJECTION 5 10 1e -3 1e -3 1e -8 Another augmentation-based DG method is L2A-OT (Learning to Augment by Optimal Transport) (Zhou et al., 2021b) . Specifically, a data generator trained to maximize the optimal transport distance between source and pseudo domains, is used to augment the source data. All aforementioned methods rely on the availability of domain information such as domain labels. To benchmark our layer to a method for DG without domain information, we resort to the JiGen (Jigsaw puzzle based Generalization) method (Carlucci et al., 2019) . JiGen introduces an auxiliary loss for solving jigsaw task during training. Further, we use the adaptive and non-adaptive stochastic feature augmentation (SFA-S and SFA-A, respectively) method proposed by Li et al. (2021) . In principle, both method augment the latent feature embedding of a FE using random noise. Our results are summarized in Table 9 . As noted by Li et al. (2021) , it is challenging to outperform augmentation-based DG methods. In addition, SFA-A and SF-S are computationally light (i.e., only adding random noise to the feature embedding) and do not require domain information (Li et al., 2021) . Nevertheless, our layer achieves competitive results even against the strongest baselines in all DG tasks without requiring domain information. We chose the Digits dataset to conduct an ablation study, which is organized as follows: (1) ablation of the regularization terms presented in Section 3, (2) effect of the orthogonal regularization for projection-based generalization, and (3) affect on the FE's output. As a reminder, we introduced the regularization to be dependent on the form of generalization (i.e., domain similarity measures or projection-based generalization in Section 3). For the domain similarity measure case, the regularization is Ω D ∥g∥ H = λ OLS Ω OLS D ∥g∥ H + λ L1 Ω L1 D (∥γ∥), (C.1) where λ OLS , λ L1 ≥ 0. In the case of projection, the regularization is given by Ω D ∥g∥ H = λ OLS Ω OLS D ∥g∥ H + λ ORT H Ω ⊥ D ∥g∥ H (C.2) with λ OLS , λ ORT H ≥ 0. Although one can additionally choose the sparse regularization in projection-based generalization, we set the focus in the ablation study on the two main regularization terms that are the OLS and orthogonal regularization. For (1) we vary in Equation C.1 and Equation C.2 the corresponding weights λ 1 and λ 2 in the interval of [0; 0.1] and display the mean classification accuracy for the most challenging classification task of MNSIT-M in the form of a heatmap. In Figures 6 7 8 , we see that the classification accuracy remains on an overall similar level which indicates that the DG layer is not very sensitive to the hyper-parameter change for MNIST-M as the test domain. Nevertheless, we observe that ablating the regularization terms by setting the corresponding weights to zero decreases the classification results and the peaks in performance occur when the regularization is included during training of the DG layer. Applying the DG layer comes with additional overhead, especially the regularization that ensures the orthogonality of the elementary domain bases. This additional effort raises a question whether ensuring the theoretical assumptions outweigh the much higher computational effort. Thus, in a second step, we analyze how the orthogonal regularization affects the orthogonality of the elementary domain bases (i.e., spectral restricted isometry property (SRIP) value) and the loss function (i.e., categorical cross-entropy). In Figure 10 , we depict the mean and standard deviation of the SRIP value and loss over five runs for 40 epochs. The SRIP value can be tracked during training with the DG layer's callback functionalities. First, we observe that the elementary domains are almost orthogonal when initialized. Training the layer leads in the first epochs to a decrease in orthogonality. This initial decrease happens because cross-entropy has a stronger influence on the optimization than regularization in the first epochs. After five epochs, the cross-entropy decrease to a threshold when the regularization becomes more effective and the orthogonality of the elementary domain bases increases again. In Figure 10 , we also observe that ablating the orthogonal regularization, while leading to better orthogonality of the domains, does not significantly affect the overall cross-entropy during training. Finally, we project the output of the FE trained with a dense layer (ERM) and with the DG layer by t-SNE (t-distributed stochastic neighbor embedding) in Figure 11 . The GDU-trained FE yields more concentrated and bounded clusters in comparison to the one trained by ERM. Hence, we observe a positive effect on the representation learned by the FE.

C.2.2 INTERPRETATION OF THE ELEMENTARY DOMAINS

We analyze the learned elementary domains in the digits experiment based on two visualizations, and choose the maximum mean discrepancy (MMD) as the similarity measure and MNIST-M as the test domain. The first visualization depicts the MMD between the datasets (i.e., MNIST, MNIST-M, SVHN, USPS, and Synthetic Digits (SYN)) and the learned elementary domains (i.e., V 1 -V 5 ) as a heatmap (see Figure 12 (left)). The heatmap indicates that the source and test domains are close to one another in terms of the MMD. Hence, we expect that their closeness reflects in the learning of the elementary domains. In other words, we expect that each elementary domains contributes similarly to the source and test domains (i.e., the coefficients β are similar for each of these domains). In Section 3.1, we derive the coefficients by applying a kernel softmax function to the negative MMD distances. Since the MMD distances between the source / test domains and the elementary domains are similar, the coefficients will be similar too. We conclude that the learned elementary domains represent the same distributional characteristics that existed among the source and test domains. In the second visualisation, we show the t-SNE (t-distributed stochastic neighbor embedding) of the feature extractor output for each source and test domain alongside the elementary domains in Figure 12 (right). First, we observe that the learned elementary domain bases form distinctive clusters. We see these clusters as a validation of our hypothesis that each GDU learns to mimic samples generated from a corresponding elementary distribution as pointed out in Section 2.2. However, we can not answer whether and where these elementary distributions occur in the real world. Moreover, these elementary distributions yet lack interpretability. In summary, the MMD heatmap and t-SNE embeddings of the learned elementary and source domains on Figure 12 indicate that the GDUs learn to represent distributional structures in the dataset.

C.3 ECG EXPERIMENT

We adopted the task of multi-label binary classification of 12-lead electrocardiogram (ECG) signals combined from 6 different sources introduced in the PhysioNet/Computing in Cardiology Challenge 2020foot_5 (Perez Alday et al., 2021; Goldberger et al., 2000; Perez Alday et al., 2020) . Each ECG recordings is annotated with 24 binary labels indicating whether or not a certain cardiac abnormality is present. The data is aggregated from 6 different databases and contains 43,101 recordings sampled with various sampling frequencies, number of subjects, and lengths. Table 10 summarizes most important details about the data sources for this experiment. According to the original challenge score, we measure the performance in terms of the generalized Intersection-over-Union (IoU) score where partial credit is assigned to misdiagnoses that result in similar treatments or outcomes. The score is defined as score := y T • W • ŷ y ∪ ŷ , (C.3) where y, ŷ ∈ {0, 1} 24 represent actual labels and predicted labels and W stands for the partial credit-assignment matrix provided as a part of the challenge description. Note that in case of identity matrix W the score is exactly the Intersection-over-Union (IoU) score. The score is then adjusted for a solution y majority , which always predicts the normal/majority class, and is moreover normalized As a pre-processing step, we down-sampled all the signals to 125 Hz and applied Z-score, random amplification and random stretching according to Vicar et al. (2020) . For that we partially adopted the code provided by the authorsfoot_6 . Additionally, we cropped each signal to its first 15,000 points if the signal was too long (mostly applied to INCART database). Each dataset was randomly split into train and validation parts with 3:1 ratio. During each experiment, we used the train splits of 5 databases for training and utilized the validation splits of the training databases for early stopping. The hold-out 6-th database was used for inference and testing only. Table 11 describes the architecture of FE used for the task. Since the provided ECG recordings have different lengths, we used TensorFlow padded batching, which is padding all the recordings in a batch to the length of the longest sequence in the batch. Therefore, input from different batches can have different lengths so the spatial dimensions of the 1D-Convolutional layers are not predefined and are presented as *. We used the Adam optimizer to optimize weighted binary cross-entropy loss defined as -( w pos • y • log ŷ) + (1 -y) • log (1 -ŷ). Positive weights w pos are defined per class based on the training split data inversely proportional to the frequency of positive labels for each class. A learning rate was initially set to 0.001 and during the training reduced by the factor of 0.2 if the training loss was not improving for 10 epochs. We also applied early stopping and restored model weights to the best model according to the validation accuracy after the training end. Since each input samples for this experiment have a larger size than the previous one, we decreased the batch size to 64. Each ECG experiment was performed on a single GPU (Nvidia GTX 1080 Ti). The parameters relevant for the DG layer are summarized in Table 12 . We have to emphasize that we did not perform extensive hyper-parameter tuning since our goal was to show the feasibility of the I.E.D. assumption and GDUs while keeping the experiments reproducible. (Huang et al., 2017) . We trained the FE from scratch. Both, ERM and the DG were trained over 250 epochs with early stopping, a learning rate of 0.001, which is reduced by a factor of 0.2 if the cross-entropy loss has not improved after 10 epochs. All results were aggregated over ten runs. FMoW Analyzing satellite images with machine learning (ML) models may enable novel possibilities in tackling global sustainability and economic challenges such as population density mapping and deforestation tracking. However, satellite imagery changes over time due to human behavior (e.g., infrastructure development), and the extent of change is different in each region. The Functional Map of the World (FMoW) dataset consists of satellite images from different continents and years: training (76,863 images; between 2002-2013), validation (19,915 images; between 2013 and 2016), and test (22,108 images, between 2016-2017) . The objective is to determine one of 62 building types (e.g., shopping malls) and land-use. As instructed in Koh et al. ( 2021), we used the DenseNet-121 pre-trained on ImageNet without L2-regularization. For the optimization, we use the Adam optimizer with a learning rate of 1e-4, which is decayed by a factor of 0.96 per epoch. The models were trained for 50 epochs with early stopping and a batch size of 64. Additionally, we report the worst-region accuracy, which is a specific metric used for FMoW. This worst-region accuracy reports the worst accuracy across the following regions: Asia, Europe, Africa, America, and Oceania (see Koh et al. (2021) for the details). Again, we report the results over three runs. Amazon Recent research shows that consumer-facing machine learning application large performance disparities across different set of users. To study this performance disparities, WILDS (Koh et al., 2021) leverages a variant of the Amazon Review dataset. The Aamazon-WILDS dataset is composed of data from 3,920 domains (number of reviewers) and the task is a multi-class sentiment classification, where the model receives a review text and has to predict the rating from one to five. To split this dataset, a between training, validation, and test disjoint set of reviewers is used: training (245,502 reviews from 1,252 reviewers), validation (100,050 reviews from 1,334 reviewers), test (100,050 reviews from 1,334 reviewers). For the experiments and baseline models, we use the specifications made in Koh et al. (2021) . As for the FE, we used DistilBERT-base-uncased models. For ERM, we use a batch size of 8, learning rate 1e-5, L2 regularization of 0.01, 3 epochs with early stopping and 512 as the maximum length of tokens. For training the DG layer, we used the same specifications as made for ERM. The performance is measured in 10th percentile accuracy. iWildsCam Wildlife camera traps offer an excellent possibility to understand and monitor biodiversity loss. We conducted the RxRx1 experiments in accordance with the specifications made in (Koh et al., 2021) . As for the FE, we, thus, used the ResNet50 pre-trained on ImagNet (He et al., 2016) . We trained the models using AdamW with default parameters β 1 = 0.9 and β 2 = 0.999 using a learning rate of 1e-4 and a L2-regularization with strength 1e-5 for 90 epochs with a batch size of 75. We scheduled the learning rate to linearly increase in the first ten epochs and then decreased it following a cosine rate. For trainingthe DG layer, we chose the same parameters as for the ERM. All results were aggregated over three runs. For training ERM and our DG, we use the default parameters: five GNN layers with a dimensionality of 300 and a learning rate of 0.001. We train for 100 epochs using early stopping.

OBG-MolPCBA

As for the performance, we report the mean and standard deviation of the average precision across all scaffolds (domains) over three runs. CivilComments In the last decades, users have generated a vast amount of text on the Internet, some of which contain toxic comments. Machine learning has been leveraged for automatic text review to flag toxic comments. However, the models are prone to learn spurious correlations between toxicity and information on demographics in the comment, which causes the model performance to drop in specific subpopulations. To study this OOD task, we leverage the modified CivilComment dataset from Koh et al. (2021) . Based on text input, the task is to predict a binary label, toxic or non-toxic. The domains are defined according to eight demographic identities: male, female, LGBTQ, Christian, Muslim, other religions, Black, and White. All comments were randomly split into a disjoint training (269,038 comments), validation (45,180 comments), and test (133,782 comments) set. Again, we follow Koh et al. ( 2021) and use a DistillBERT-base-uncased model with the following parameters: batch size = 16, learning rate = 1e-5, AdamW optimizer, number of epochs = 5, L2 regularization 0.01, and the maximum number of tokens of 300. We use these default parameters for training our DG layer. The performance is measured in the worst-group accuracy and we report mean and standard deviation across five runs. PovertyMap As the FMoW example shows, satellite images in combination with machine learning models can been used to monitor sustainability and economic challenges on a global scale. Another application of these satellite images is poverty estimation across different spatial regions. However, there exists a lack of labels for developing countries since obtaining the ground truth is expensive, which makes this application attractive for machine learning models. To study the OOD generalization to unseen countries, we use a modified version of the poverty mapping dataset of WILDS (Koh et al., 2021) . The task is to predict a real-valued aset wealth index between 1 and 5 based on a multi-spectral satellite image. The domain refers to the country and whether the the the image is from a rural or urban are. In contrast to the other datasets, this dataset is split in five different folds, whereby in each fold the the training, valdiation and test set contains a disjoint set of countries, however, data from both rural and urban regions. The avergae size of each set across the 5 folds is for the training ˜10,000 images (13-14 countries), ˜4,000 images (4-5 different countries), and for the test set ˜4,000 images (13-14 countries). On the challenge of obtaining domain labels. In the example of hospitals (e.g. Camelyon17 dataset), domain labels come, in fact, for free. However, other examples, such as the CivilComments dataset, show the opposite. This dataset requires additional annotations (i.e., demographic identities), which can be tedious to obtain in practice. Some algorithms need these domain annotations to achieve superior performance on each subgroup. Furthermore, the task of subgroup detection in itself is a difficult and relevant problem. Coming back to our hospital example, even people from the same hospital might belong to different subpopulation (e.g. gender, race, age) and these demographic subgroups are often more relevant for diagnosis than which hospital a patient comes from. This information, however, is not always available (due to anonymization standards, for instance) and, therefore, the relevant domain annotation might be hard to obtain. We follow Koh et al. (2021) and use a pre-trained ResNet-18 model minimizing the sqarred error loss. For the optimization, we rely on the Adam optimizer with the following parameters: learning rate of 1e-3 with a decay of 0.96 per epoch, batch size of 64 and early stopping based on the OOD evaluation score. For evaluation, we report the Pearson correlation (r) between the predicted and actual asset wealth indices across the five different folds. General benchmark methods Following the WILDS benchmarking procedure (Koh et al., 2021) , we compare our proposed DG layer to the following baselines. First, empirical risk minimization (ERM), which minimizes the average training loss over the pooled dataset. Second, a group of DG algorithms provided by the WILDS benchmark, namely, Coral, Fish, IRM, and DRO. The Coral algorithm introduces a penalty for differences in means and covariances of the domains feature distributions. The Fish algorithm achieves DG by approximating an inter-domain gradient matching objective, i.e., maximizing the inner product between gradients from different domains (Shi et al., 2021) . Conceptually, Fish learns feature representations that are invariant across domains. Invariant risk minimization (IRM) introduces a penalty for feature distributions with different optimal classifiers for each domain (Arjovsky et al., 2019) . The idea is to enable OOD generalization by learning domain-invariant causal predictors. Lastly, group distributionally robust optimization (DRO) explicitly minimizes the training loss on the worst-case domain (Sagawa et al., 2020; Hu et al., 2018) . In addition to the baselines originally presented in Koh et al. (2021) , we consider the following more recent DG baselines. First, we describe LISA, which instead of regularizing the internal representations for generalization, seeks to learn domain-invariant predictors with selective data 



All source code is made available on GitHub. Of note, Figure is a complete fictive example, and we do not want to make medical implications in any way. We used the digits data from https://github.com/FengHZ/KD3A [last accessed on 2022-05-17, available under MIT License.] published inFeng et al. (2020). Results were reported byZhou et al. (2021b) andLi et al. (2021). Of note, both authors did not report the standard deviation on their results. https://physionetchallenges.org/2020/ [last accessed on 2021-03-10, available under Creative Commons Attribution 4.0 International Public License]. https://github.com/tomasvicar/BUTTeam [last accessed on 2022-05-17, available under BSD 2-Clause License].



βijcoefficient for sample xi and µ V j

Figure1: A visualization of an "invariant elementary distribution (I.E.D.)" assumption for domain generalization (DG): the observed data distributions (orange and violet) are composed of the same set of unobserved elementary distributions (blue and red) that remain invariant across different domains. Hence, the first challenge during the training phase (left panel) is to extract these elementary distributions from the observed data (orange). The unobserved elementary distributions are represented by the elementary bases V 1 and V 2 (cyan and pink). The second challenge during the inference phase (right panel) is to create a weighted ensemble of learning machines that utilize the similarities between the embedding of the unseen observation ϕ(x i ) and the embeddings of these distributions µ V1 and µ V2 in the RKHS H (green rectangle) as weights β i1 and β i2 .

Figure 2: Visualization of the DG layer (left panel) and its main component, the GDU (right panel).The DG layer consists of several GDUs that represent the elementary distributions. During training, these GDUs learn the elementary domain bases V 1 , . . . , V M that approximate these distributions.

Figure 3: Visualization and motivation of our invariant elementary distribution assumption and how they can be instantiated with Gated Domain Units.

Figure 5 depicts the layout of our DG layer.

Figure 5: Domain generalization (DG) layer.

Digit-DG Benchmark In previous research, the aforementioned digits data is not only used for domain adaptation (DA), but also for domain generalization (DG) methods. For the latter,Zhou et al. (2021b) andLi et al. (2021) introduced Digit-DG dataset and the evaluation protocol to benchmark seven DG methods and ERM 4 . Unlike the Digits experiment described above, Digit-DG dataset fromZhou et al. (2021b) andLi et al. (2021) consists of only four datasets (without USPS) and a different FE summarized in Table

Figure 6: Classification results for varying λ L1 and λ OLS in the interval of [0; 0.1] for FT (left) and E2E (right) CS on MNIST-M.

Figure 7: Classification results for varying λ L1 and λ OLS in the interval of [0; 0.1] for FT (left) and E2E (right) MMD on MNIST-M.

Figure 8: Classification results for varying λ ORT H and λ OLS in the interval of [0; 0.1] for FT (left) and E2E (right) Projection on MNIST-M.

Figure 9: Mean and standard deviation of classification accuracy over 10 runs for varying number of elementary domains (M , upper panel) and varying number of vector for each domain basis (N , lower panel) for MNIST-M dataset.

Figure 10: Effect of omitting the orthogonal regularization term Ω ⊥ D . Spectral restricted isometry property (SRIP) (left) and categorical cross-entropy (right) with and without orthogonal regularization and their evolution during training for MNSIT-M dataset. The mean and standard deviation presented for End-to-end (E2E) and Fine-tuning (FT) training scenarios are calculated over 10 runs.

Figure 11: Visualization of t-SNE Embedding on unseen Synthetic Digits Dataset. Colors encode true label.

Figure 12: MMD heatmap (left) and t-SNE embedding (right) for the test domain MNIST-M.

Important notation

Proposition 3.1. For a KME µ P of a given mixture distribution P the following holds µ P ∈ span{µ Vj | V j , ∀j = 1, . . . , M }, where ⟨µ Vi , µ Vj ⟩ H = 0, ∀i ̸ = j (i.e., the KME of the elementary domains basis are pairwise orthogonal). The value of the function M j=1 ∥µ P -β j µ Vj ∥ 2 H k is minimal if the coefficients are set as β * j = ⟨µ P , µ Vj ⟩ H /∥µ Vj ∥ 2 H . Vj , where β j = ⟨µ P , µ Vj ⟩ H /∥µ Vj ∥ 2 + λ D Ω D (∥g∥ H ). The goal of the training can be described in terms of the two components of this function. Consider a batch of training data {x 1 , . . . , x b }, where b is the batch size. During training, we minimize the loss function

Results Digits experiment. The mean (standard deviation) accuracy for ten runs is reported. Best results are bold.

Results on WILDS benchmarking tasks. Our results are achieved out-of-the-box (i.e., default parameters) without hyperparameter optimization. We use a grey background to highlight methods using no domain information for DG. We compute the metrics following Koh et al. (2021) and report the mean (standard deviation). Best benchmark and GDU results are bold.

Feature Extractor used for the Digits Experiment

Results Digits experiment. All experiments were repeated ten times and the mean (standard deviation) accuracy is reported. Best results according to the mean accuracy are highlighted in bold.

Feature Extractor used for the Digit-DG Benchmark Experiment

Results of the Digits-DG experiment. All experiments were repeated ten times. Methods are classified into augmentation-based (A) and non-augmentation-based (B) as well as DG with (✓) and without (✗) domain information according toLi et al. (2021). Best results according to the mean accuracy are highlighted in bold.

ECG Data Sources Details.

Feature Extraction Architecture used for the ECG Experiment is an adapted Version of LeNet Architecture for 1D input Signals. Note that ECG recordings have variable lengths, therefore, the spatial dimension is not defined and denoted as *.

Parameters for DG Layer in ECG experiments for the Fine Tuning (FT) and End-to-end training (E2E) Settings.For comparison of our approach and benchmarking, we followed the standard procedure of WILDS experiments, described inKoh et al. (2021). As a technical note, all WILDS experiments have been implemented in Pytorch (version >= 1.7.0) based on the specifications made in Koh et al. (2021) and their code published on https://github.com/p-lambda/wilds [last accessed on 2022-05-17, available under MIT License]. The results for the benchmarks were retrieved from the official leaderboard https://wilds.stanford.edu/leaderboard/ [last accessed on 2022-09-26].Camelyon17 In medical applications, the goal is to apply models trained on a comparatively small set of hospitals to a larger number of hospitals. For this application, we study images of tissue slides under a microscope to determine whether a patient has cancer or not. Shifts in patient populations, slide staining, and image acquisition can impede model accuracy in previously unseen hospitals. Camelyon17 comprises images of tissue patches from five different hospitals. While the first three hospitals are the source domains (302,436 examples), the forth and fifth are the validation (34,904 examples) and test domain (85,054 examples), respectively.

However, images from different camera traps differ in illumination, color, camera angle, background, vegetation, and relative animal frequencies. We use the iWildsCam dataset consisting of 323 different camera traps positioned in different locations worldwide. In the dataset, we refer to different locations of camera traps as different domains, in particular 243 training traps (129,809 images), 32 validation traps (14,961 images), and 48 test traps (42,791 images). The objective is to classify one of 182 animal species.Following the instructions by Koh et al. (2021), we used again the ResNet50 pre-trained on ImagNet(He et al., 2016). For ERM, we used a learning rate of 3e-5 and no L2-regularization. The models were trained for 12 epochs with a batch size of 16 with the Adam optimizer. In addition to the accuracy, we report the macro F1-score to evaluate the performance on rare species (seeKoh et al. (2021) for details). All results were aggregated over three runs.RxRx1In biomedical research areas such as genomics or drug discovery, high-throughput screening techniques generate a vast amount of data in several batches. Because experimental designs cannot fully mitigate the effects of confounding variables like temperature, humidity, and measurements across batches, this creates heterogeneity in the observed datasets (commonly known as batch effect).

In biomedical research, machine learning has the potential to accelerate drug discovery while reducing the experimental overhead due to lowering the number of experiments required. However, to leverage the potential of machine learning, the models need to generalize to molecules structurally different from those seen during training. To study this OOD generalization across molecule scaffolds, we use the OGB-MolPCBA dataset. This dataset is split into the following subsets according to the scaffold structure: training (44,930 domains), validation(31,361 domains),  and test (43,739 domains). The task is to classify the presence/absence of 128 biological activities based on a graph representation of a molecule.

Parameters for DG Layer in WILDS experiments for the Fine Tuning (FT) and End-to-end training (E2E) Settings.Yao et al. (2022). Common Gradient Descent (CGD), introduced by Piratla et al. (2021), is based on Group-DRO. However, it proposes to focus not on groups with the worst regularization but on common groups that enable generalization. Last, Adaptive Risk Minimization using batch normalization (ARM-BN) byZhang et al. (2021) is different from the methods presented since it adapts to previously unseen domains during test time using unlabeled observations from this test domain.

Detailed results on RxRx1 dataset.

Detailed results on FMoW dataset.

