NEURAL NETWORKS FOR LEARNING COUNTERFAC-TUAL G-INVARIANCES FROM SINGLE ENVIRONMENTS

Abstract

Despite -or maybe because of-their astonishing capacity to fit data, neural networks are believed to have difficulties extrapolating beyond training data distribution. This work shows that, for extrapolations based on finite transformation groups, a model's inability to extrapolate is unrelated to its capacity. Rather, the shortcoming is inherited from a learning hypothesis: Examples not explicitly observed with infinitely many training examples have underspecified outcomes in the learner's model. In order to endow neural networks with the ability to extrapolate over group transformations, we introduce a learning framework counterfactually-guided by the learning hypothesis that any group invariance to (known) transformation groups is mandatory even without evidence, unless the learner deems it inconsistent with the training data. Unlike existing invariance-driven methods for (counterfactual) extrapolations, this framework allows extrapolations from a single environment. Finally, we introduce sequence and image extrapolation tasks that validate our framework and showcase the shortcomings of traditional approaches. Neural networks are widely praised for their ability to interpolate the training data. However, in some applications, they have also been shown to be unable to learn patterns that can provably extrapolate out-of-distribution (beyond the training data distribution) (Arjovsky et al., 



consider a supervised learning task where the training data contains infinitely many sequences x (tr) =(A,B) associated with label y (tr) = C, but no examples of a sequence x (tr) =(B,A). If given a test example x (te) =(B,A), the hypothesis considers it to be out of distribution and the prediction P (Y (te) = C|X (te) = (B,A)) is undefined, since P (X (tr) = (B,A)) = 0. This happens regardless of a prior over P (X (tr) ). This unseen-is-underspecified learning hypothesis is not guaranteed to push neural networks to assume symmetric extrapolations without evidence. Contributions. Since symmetries are intrinsically tied to human single-environment extrapolation capabilities, this work explores a learning framework that modifies the learner's hypothesis space to allow symmetric extrapolation (over known groups) without evidence, while not losing valuable antisymmetric information if observed to predict the target variable in the training data. Formally, a symmetry is an invariance to transformations of a group, known as a G-invariance. In Theorem 1 we show that the counterfactual invariances needed for symmetry extrapolation -denoted Counterfactual G-invariances (CG-invariances)-are stronger than traditional G-invariances. Theorem 2, then, introduces a condition in the structural causal model where G-invariances of linear automorphism groups are safe to use as CG-invariances. With that, Theorem 3 defines a partial order over the appropriate invariant subspaces that we use to learn the correct G-invariances from a single environment without evidence, while retaining the ability to be sensitive to antisymmetries shown to be relevant in the training data. Finally, we introduce sequence and image counterfactual extrapolation tasks with experiments that validate the theoretical results and showcase the advantages of our approach.

2. RELATED WORK

Counterfactual inference and invariances. Recent efforts have brought counterfactual inference to machine learning models. Independent causal mechanism (ICM) and Invariant Risk Minimization (IRM) methods (Arjovsky et al., 2019; Besserve et al., 2018; Johansson et al., 2016; Parascandolo et al., 2018; Schölkopf, 2019) , Causal Discovery from Change (CDC) methods (Tian and Pearl, 2001) , and representation disentanglement methods (Bengio et al., 2020; Goudet et al., 2017) broadly look for representations, classifiers, or mechanism descriptions, that are invariant across multiple environments observed in the training data or inferred from the training data (Creager et al., 2020) . They rely on multiple environment samples in order to reason over new environments. To the best of our knowledge there is no clear effort for extrapolations from a single environment. The key similarity between the ICM framework and our framework is the assumption of independently sampled mechanisms (the transformations) and causes. Domain adaptation and domain generalization. Domain adaptation and domain generalization (e.g. (Long et al., 2017; Muandet et al., 2013; Quionero-Candela et al., 2009; Rojas-Carulla et al., 2018; Shimodaira, 2000; Zhang et al., 2015) and others) ask questions about specific -observed or known-changes in the data distribution rather than counterfactual questions. A key difference is that counterfactual inference accounts for hypothetical interventions, not known ones. Forced G-invariances. Forcing a G-invariance may contradict the training data, where the target variable is actually influenced by the transformation of the input. For instance, handwritten digits are not invariant to 180 o rotations, since digits 6 and 9 would get confused. Data augmentation is a type of forced G-invariance (Chen et al., 2020; Lyle et al., 2020) and hence, will fail to extrapolate. Other works forcing G-invariances that will also fail include (not an extensive list): Zaheer et al. (2017) and Murphy et al. (2019a; b) for permutation groups over set and graph inputs; Cohen and Welling (2016) , Cohen et al. (2019) for dihedral and spherical transformation groups over images. Learning invariances from training data. The parallel work of Benton et al. (2020) considers learning image invariances from the training data, however does not consider extrapolation tasks. Moreover, it does not provide a concrete theoretical proof of invariance, relying on experimental results over interpolation tasks for validation. Another parallel work (Zhou et al., 2021) uses metalearning to learn symmetries that are shared across several tasks (or environments). The works of van der Wilk et al. (2018) and Anselmi et al. (2019) focus on learning invariances from training data for better generalization error of the training distribution. However, none of these works consider the extrapolation task. In contrast, our framework formally considers counterfactual extrapolation, for which we provide both theoretical and experimental results.

3. EXTRAPOLATIONS FROM A SINGLE ENVIRONMENT

Figure 1 : Illustration of our structural causal model (SCM), where gray nodes indicate observed variables (in training). X and X U I ← U I are obtained from X (hid) and are coupled by sharing U D . However, U I and U I can have different support, resulting in different distributions over X and X U I ← U I . Geometrically, extrapolation can be thought as reasoning beyond a convex hull of a set of training points (Haffner, 2002; Hastie et al., 2012; Xu et al., 2021) . However, for neural networks -with their arbitrary representation mappings-this geometric interpretation can be insufficient. Rather, we believe extrapolations are better described through counterfactual reasoning (Neyman, 1923; Rubin, 1974; Pearl, 2009; Schölkopf, 2019) . Specifically in our task, we ask: After seeing training data from environment A, the learner wants to extrapolate and predict what would have been the output if the training environment were B. Extrapolations differ from traditional domain adaptation due to its counterfactual nature -a what-if question of an intervention that can only be imagined if given offline data (Bareinboim et al., 2020; Pearl and Mackenzie, 2018) , rather than a known distributional change. Specifically, our framework follows the independent causal mechanism principle (Schölkopf, 2019; Peters et al., 2017) : A mechanism describing a variable given its causes is independent of all other mechanisms describing other variables. For instance, in the causal model U X → X → Y ← U Y , this implies that the conditional distribution P (Y |X) is not influenced by any change in P (X).

3.1. TRANSFORMATION GROUPS

We focus on extrapolations tied to finite linear automorphism groups acting on the input data. We start with an example. Consider an input x ∈ X = R 3n 2 representing a vectorized n × n RGB image. We can define at least three linear automorphism groups: (1) G rot ≡ {T (k) } k∈{0 ,270 • } , which rotates the image by k degrees, (2) G color ≡ {T (α) } α∈S3 , which permutes the RGB channels of the image, and (3) G vflip ≡ {T (v) , T (0) }, which flips the image vertically. More generally, a linear automorphism group G satisfies six properties: (automorphism) ∀T ∈ G, T : X → X ;(identity) I(x) = x, I ∈ G; (is closed under composition) ∀T, T ∈ G, T • T ∈ G, where T • T (x) = T (T (x)); (associative) ∀T, T , T † ∈ G, T • (T • T † ) = (T • T ) • T † ; (has inverses) ∀T ∈ G, ∃T -1 ∈ G s.t. T -1 • T = I; and (is linear) T ∈ G is a linear function. Besides images, sequences x = (x 1 , x 2 , . . .) are another input of interest, where x ∈ X for some appropriately defined set X . Here, the symmetric group (permutation group) S n , is the set of all permutations S n = {π | π : {1, . . . , n} → {1, . . . , n} is a bijection} equipped with the composition operator. Attributed graphs (A, X) ∈ X , where A is tensor of edge properties and X is a matrix of node attributes, are also of interest for the permutation group S n . Subgroups and overgroups. Just as we can compose image transformations to make new image transformations, we can also compose automorphism groups into larger automorphism groups (overgroups). For instance, we can compose rotations and image flips to form a linear automorphism group G {rot,vflip} = G rot ∪ G vflip containing all such compositions, where • is the group join operator. Following standard notation, we say G rot ≤ G {rot,vflip} to indicate that G rot is a subgroup of G {rot,vflip} , or, equivalently, G {rot,vflip} is an overgroup of G rot . Henceforth, we use G {1,...,m} ≡ ∪ m i=1 G i to denote the group generated by the groups G 1 , . . . , G m .

3.2. THE CAUSAL MECHANISM AND AN ECONOMICAL DATA GENERATION PROCESS

We assume that a fundamentally economical process created the training data, where the focus was on sampling diverse environments in a way that mattered to the task. For instance, image datasets will contain mostly upright pictures, rather than images over all possible orientations, but we will assume the dataset curators strive for a somewhat diverse set of subjects for each label (e.g., a good representation of different types of subjects and environmental conditions). Hence, the absence of variation over image orientations in the dataset can be counted as evidence against its effect on the image labels. We describe the data generation with the help of a structural causal model (SCM) (Pearl, 2009, Definition 7.1.1) illustrated in Figure 1 . Consider a supervised task over inputs X and their corresponding outputs Y , which are random variables defined over a suitable space. The hidden random variable X (hid) := g(U u ), where g : U → X is a measurable map (deterministic function) that describes the input X in some unknown canonical form, where U u is a random variable (e.g., U u ∼ Uniform(0, 1)). Next, we define how X (hid) is modified by transformations into the observed input X. Transformation of X (hid) into X. Consider a collection of finite linear automorphism groups G 1 , . . . , G m . Let I ⊆ {1, . . . , m} be a subset and D ⊆ {1, . . . , m}\I be a subset of its complement. We will later define the target variable to be dependent only on the groups indexed by D. Consider independent and identically distributed random variables U D and U I that select transformations in the respective overgroups G D = ∪ j∈D G j and G I = ∪ i∈I G i . We note in passing that we allow G D ∩ G I = {T identity } even though G D ∩ G I = {T identity } makes the counterfactual task easier. The observed input is defined as  X := T U D ,U I • X (hid) , I • T (1) D • T (2) I • . . .. Note that T (i) I or T (i) D could be identity transformations. Appendix B.1 shows that this indexing is surjective, i.e., it can index every transformation in G D∪I . Target variable. The output Y associated with X is given by Y := h(X (hid) , U D , U Y ), where h is a deterministic function and U Y is an independent random variable. A distribution over the set of background random variables U all = {U u , U Y , U D , U I } along with Equations ( 2) and (3) induces a joint distribution P (Y, X). If the support of U I is a singleton set {c} for some constant c, then (Y, X) are said to be sampled using an economical data generation process. In other words, the training data can contain just one value for the variable U I since the outputs Y do not depend on U I . For instance, if G I is the rotation group, and the image label Y does not depend on image rotation, then the observed images can be all upright since the sampling is economical. This is not a required condition for our method to work, however. Extrapolation as counterfactual reasoning. We can now ask "what would have happened to Y if we had given specific values of U I to the data generation process in Equations ( 2) and (3) rather than sampling from P (U I )". For instance, would the class of an image change if we had flipped the image along the vertical axis? Would we re-classify outlier events if we changed the order of events in a stationary time series? These are counterfactual queries over environment background variables U I . We now describe the counterfactual variable in our task via variable coupling (Pitman, 1976; Propp and Wilson, 1996) , which we believe gives a standard-statistics-friendly description of counterfactual SCMs (Shpitser and Pearl, 2007) . The coupling of two independent variables D 1 and D 2 is a proof technique that creates a random vector (D † 1 , D † 2 ), such that D i and D † i have the same marginal distributions, i = 1, 2, but makes D † 1 and D † 2 structurally dependent. For instance, consider independent 6-sided and 12-sided dice, denoted D 1 and D 2 respectively. Let D † 1 = (U + 1 ) mod 6 + 1 and D † 2 = (U + 2 ) mod 12 + 1, where U is a 12-sided die roll and 1 , 2 ∈ {0, 1} are two independent coin flips. Then, the tuple (D † 1 , D † 2 ) has coupled the variables D 1 and D 2 via the common random variable U . Definition 1 (Counterfactual coupling (CFC)). The counterfactual coupling of the observed data (Y, X) is a vector (Y, X, X U I ← U I ), where Y = h(X (hid) , U D , U Y ), X = T U D ,U I • X (hid) , and hid) , for appropriately defined U u , U D , U Y , U I , U I . The subscript U I ← U I denotes the counterfactual variable to X when U I replaces U I in the data generation process. For a constant u, X U I ←u gives the same definition as the twin network method of Balke and Pearl (1994) . X U I ← U I = T U D , U I •X ( The support of U I in Definition 1 can be very different from that of U I , potentially inducing a different distribution over X U I ← U I than X even if the variables X U I ← U I and X are structurally dependent via U D . Armed with Definition 1, we are now ready to describe our task.

3.3. EXTRAPOLATION MODEL

We start by defining counterfactual G-invariant (CG-invariant) representations. Definition 2 (CG-invariant representations). Let the vector (X, X U I ← U I ) denote the counterfactual coupling of the random variable X given in Definition 1 for any U I . A representation function Γ : X → R d , d ≥ 1, is deemed CG-invariant if Γ(X) = Γ(X U I ← U I ) , where the equality implies that Γ(X  U I ←u ) = Γ(X U I ←u ), Y |X (tr) d = Ŷ |X (tr) , with Ŷ |X (tr) ∼ g true (Γ true (X (tr) )), where d = means the random variables have the same distribution. Then, if Γ true (X) = Γ true (X U I ← U I ), then we have that, by our definition of X (te) and X (tr) , g true • Γ true extrapolates: Y |X (te) d = Ŷ |X (te) , with Ŷ |X (te) ∼ g true (Γ true (X (te) )). Alas, learning Γ true is the real challenge: (i) We do not know I (and, hence, we do not know the group G I which is related to the CG-invariance); (ii) this would also require knowing P ( U I ), which we don't. Without an observed X U I ← U I , the statistical assumption that examples not explicitly observed with infinitely large training data have underspecified outcomes in the learner's statistical model does not push the model towards learning Γ true . We must change this assumption.

4. CG-INVARIANCES FOR EXTRAPOLATION

In this section we introduce our learning framework, which seeks to use the training data to approximate Γ true and g true of Equation (6). Our framework regularizes neural network weights towards representations that are invariant to groups that negligibly impact training data accuracy. We overcome some key challenges: (a) Theorem 1 below shows that CG-invariances (Definition 2) are stronger than G-invariances. After that, Theorem 2 defines conditions under which G-invariances suffice as CG-invariances, and (b) We derive an optimization objective where all G-invariances are mandatory, except the ones deemed inconsistent with the training data, replacing the traditional unseen-is-underspecified learning hypothesis. Our first question is whether CG-invariances are just G-invariances. Theorem 1 shows they are not. Theorem 1 (CG-invariance is stronger than G-invariance). Let the vector (X, X U I ← U I ) denote the counterfactual coupling of the observed variable X given in Definition 1. For a representation Γ : X → R d , d ≥ 1, let G-inv : ∀T I ∈ G I , Γ(X) = Γ(T I • X) , CG-inv : Γ(X) = Γ(X U I ← U I ) , denote the conditions on Γ for G I -invariance and CG-invariance respectively. Then, CG-inv =⇒ Ginv, but G-inv =⇒ CG-inv. The proof in Appendix B.2 constructs a task over images and a representation Γ that is G I -invariant but is not CG-invariant (for appropriately chosen G I and G D ). The following condition ensures that a G I -invariance is also a CG-invariance. Theorem 2. If G I is a normal subgroup of G D∪I , then CG-inv ⇐⇒ G-inv. A subgroup H of a group G is called normal (denoted H G) if for all h ∈ H and g ∈ G, ghg -1 ∈ H. Proof in the Appendix B.2 utilizes the fact that if G I G D∪I , then any T ∈ G D∪I can be written as T = T I • T D for some T I ∈ G I , T D ∈ G D . Throughout the rest of the paper, we will assume that G I is a normal subgroup of G D∪I in the SCM Equation (2).

4.1. CONSTRUCTING SUBSPACES OF VEC(X ) PARTIALLY ORDERED BY INVARIANCE STRENGTH

As discussed before, we do not know G I . In this subsection, we build neural network weights that are invariant to G M for different subsets M ⊆ {1, . . . , m}. A detailed step-by-step example of this construction for 3 × 3 images is shown in Appendix C. We start by restating the Reynolds operator, which has been extensively used in the literature of G-invariant representations without attribution: Lemma 1 (Reynolds operator (Mumford et al. (1994) , Definition 1.5)). Let G be a (finite) linear automorphism group over vec(X ). Then, T = 1 |G| T ∈G T (7) is a G-invariant linear automorphism, i.e., ∀T † ∈ G and ∀x ∈ vec(X ), it must be that T (T † x) = T x. Since T is a projection operator (i.e., T 2 = T ), all the eigenvalues of T are either 0 or 1. Using this fact, we now describe G-invariant neurons using the left eigenspace of T corresponding to the eigenvalue 1. Lemma 2. If W denotes the left eigenspace corresponding to the eigenvalue 1 of the Reynolds operator T for the group G, then ∀b ∈ R, the linear transformation γ(x; w, b) = w T x + b is invariant to all transformations T ∈ G, i.e., γ(T x; w, b) = γ(x; w, b), if and only if w ∈ W. The above property of the Reynolds operator can be leveraged to build neural networks that adhere to particular group symmetries, as done by Yarotsky (2018) and van der Pol et al. ( 2020). If we knew G I , restricting the parameters of each neuron to the left 1-eigenspace of the Reynolds operator of G I would give us a way to build a G I -invariant neural network. Alas, we do not know I, and consequently we do not know G I . Instead, we want to construct bases for the complete space vec(X ) such that they are partially ordered by their invariance strength: From most invariant bases to least. In other words, we construct bases for subspaces B M for M ⊆ {1, . . . , m} such that any weight vector w ∈ B M is (a) invariant to the groups G i for i ∈ M , and (b) not invariant to any group G j for j ∈ {1, . . . , m} \ M . Later, we will use this partial order to define a regularization term for our method. Theorem 3 shows how these bases can be constructed inductively, where we start with the most invariant subspace (when M = {1, . . . , m}) and judiciously work our way over increasingly less invariant subspaces. A reader more interested in the algorithm can first refer to the pseudocode in Appendix D or the example in Appendix C (Step 2). Theorem 3 (G-invariant subspace bases can be partially ordered by invariance strength). Let W i ⊆ vec(X ) be the left eigenspace corresponding to the eigenvalue 1 of the Reynolds operator T i for group G i , i = 1, . . . , m. We construct the invariant subspace partitions B M = i∈M W i ; B M = orth B M ( B M ) , ∀M ∈ ℘({1, . . . , m}) \ ∅, ( ) where ℘ is the power set, B M = N M B N , orth A1 (A 2 ) removes from the subspace A 2 its orthogonal projection onto the subspace A 1 , and is the direct sum operator. Then, the linear transformation γ( x; w, b) = w T x+b, b ∈ R, ∀w ∈ B M \{0}, is G M -invariant but not G j -invariant ∀j ∈ {1, . . . , m} \ M . The proof in Appendix B.3 shows that B M contains all the vectors w that are invariant to G M but could also contain vectors that are invariant to some overgroup of G M . Thus, each step of our inductive method performs a Gram-Schmidt orthogonalization in order to satisfy condition (b) above: we need to remove from B M all weight vectors that are invariant to more groups in addition to those indexed by M (i.e., supersets of M ). In addition, if needed, we obtain the basis for the rest of the space through B ∅ = orth B ∅ (vec(X )), the orthogonal complement of B ∅ . Note that if w ∈ B N , then w is never G H -invariant for H N as we remove all such w from B N . Hence, the partial order of nested subsets in ℘({1, . . . , m}) induces a partial order of invariance strengths in the bases of the input domain vec(X ) (see Figure 5 for an example). We define level of invariance (or invariance strength) of a subspace B M as the size of M (i.e., |M |). Practical aspects. Our algorithm should output d X = dim(vec(X )) basis vectors covering the entire space (i.e., our new neuron, described later in Equation ( 10), still has d X + 1 parameters as the original one). Thus we stop the algorithm in Theorem 3 once d X basis vectors are found. Moreover, the algorithm needs to run only once for groups G 1 , . . . , G m , and the results can be reused for other neural architectures. While the worst-case runtime of finding the bases could be exponential in m, it is unclear whether this exponential runtime can actually happen in practice (all of our experimental runtimes take less than one minute in commodity machines).

4.2. LEARNING CG-INVARIANT REPRESENTATIONS WITHOUT KNOWLEDGE OF G I .

We are now ready to learn a CG-invariant representation using neural networks Γ and g. Let G 1 , . . . , G m be known linear automorphism groups. Under the assumption of Theorem 2, we just need Γ to be G I -invariant, with G I ≡ ∪ i∈I G i , but I ⊆ {1, . . . , m} is unknown to us. We achieve the correct G I -invariance by redefining the neuron weights of Γ using the subspaces of Theorem 3 and proposing a regularized objective that pushes Γ towards the strongest overgroup G-invariance that does not significantly hurt the training data, where significantly is controlled by a regularization strength λ > 0.

More formally, let

Γ : vec(X ) × R d X ×H × R H → R d , H ≥ 1, d ≥ 1, be a neural network layer with H neurons, parameterized by free parameters Ω ∈ R d X ×H and b ∈ R H . The H neurons are arranged in an appropriate architecture as described in Section 5, but reader can imagine a feedforward layer for now. Let g : R d → ImP (Y |X) be a link function. The training data D (tr) = {(y (tr) i , x (tr) i )} N i=1 is assumed to be sampled according to the SCM data generation process in Equations ( 1) to (3), with the hidden G I satisfying the conditions in Theorem 2. Let B M ∈ R d X ×d M be a matrix whose columns are the orthogonal basis of subspace B M = {0} (from Theorem 3) with dimension d M . Any vector w ∈ B M can be expressed as a linear combination of these basis columns. The coefficients of the linear combination form our learnable parameters. These neuron weights Ω have a correspondence to the nonzero subspace bases B M1 , B M2 , . . . , B M B : Ω = ω M 1 ,1 ••• ω M 1 ,H ••• ••• . . . ω M B ,1 ••• ω M B ,H , where B ≤ d X , and ω Mi,h ∈ R d M i ×1 represents the learnable parameters for the subspace B Mi and the h-th neuron. The h-th neuron in Γ, h ∈ {1, . . . , H}, has the form Γ (h) (x) = σ x B i=1 B Mi ω Mi,h + b h , σ(•) is a nonpolynomial activation function, and b h ∈ R is a bias parameter. Our optimization objective is then Ω, b, W g = arg min Ω,b,Wg (y (tr) ,x (tr) )∈D (tr) L y (tr) , g(Γ(x (tr) ; Ω, b); W g ) + λR(Ω) where L : Y × ImP (Y |X) → R ≥0 is a nonnegative loss function, and λ > 0 is a regularization strength. The regularization penalty R(Ω) is given by, R(Ω) = |{M i : |M i | > l, 1 ≤ i ≤ B}| + i:|Mi|=l 1≤i≤B 1{ ω Mi,• 2 2 > 0} , where l = min{|M i | • 1{ ω Mi,• 2 2 > 0}, 1 ≤ i ≤ B}. Intuition behind the penalty in Equation (12): A subspace B Mi is said to be used in the computation of neuron h (Equation ( 10)) if the corresponding parameter ω Mi,h is nonzero. Then, let B M k be the least invariant subspace used by any neuron (i.e., |M k | is the lowest among all used subspaces) and |M k | = l. The first term in the penalty counts the number of subspaces B Mi (used or unused) that are invariant to more groups than B M k (i.e., |M i | > |M k |). This term ensures that the optimization tries to use subspaces that are higher in the partial order with invariance to more groups. The second term in the penalty counts the number of subspaces B Mi that have the same level of invariance as B M k (i.e., |M i | = |M k |), and also have the corresponding coefficients ω Mi,h nonzero (i.e., the subspace B Mi is used). The larger the second term, farther away the optimization is from increasing the least level of invariance from l to l + 1. We present a differentiable approximation of the penalty in Appendix F along with an example computation of Equation ( 12) in Figure 8 . Limitations of Equation ( 12): Recall that we stop the algorithm in Theorem 3 once the basis for vec(X ) is found. In such cases, there could be parameters Ω and Ω that assign positive weights corresponding to the same subspace bases, but with Ω invariant to more groups than Ω . The penalty in Equation ( 12) however cannot distinguish between these two sets of weights as they use the same subspaces and thus, R(Ω ) = R(Ω ). We provide an example in the case of sequence inputs in Appendix F.3 and leave the solution as future work. Selecting regularization strength λ: We use a held-out training set to find the best validation accuracy achieved by any value of λ. Then, among all the values of λ that achieve validation accuracy within 5% of the best validation accuracy, we choose the largest λ (i.e., we opt for maximum invariance without significantly affecting validation performance).

5. CG-INVARIANT NEURAL ARCHITECTURES

For image tasks: We can apply the CG-regularization of Equations ( 10) and ( 11) in the convolutional layers of a CNN architecture like VGG (Simonyan and Zisserman, 2014) . Mostly, the VGG architecture remains the same with the exception that the convolutional filters are obtained using the subspaces from Theorem 3 for the given groups. Once the filter is obtained as a linear combination of the bases, it is convolved with the image or the feature maps. This will ensure that the model is CG-invariant to the transformations of smaller patches in the image. A sum-pooling layer over the entire channel is applied after all the convolutional layers to ensure that the model can be CG-invariant to the transformations on the whole image. See Appendix E.1 for an example architecture. For sequence and array tasks (sets, graph & tensor tasks), the architecture is more direct: One can simply apply a feedforward network with as many hidden layers as needed. Each neuron of the first layer is as given by Equation ( 10), ensuring that the first layer can be CG-invariant to the given groups if needed. Other layers can have regular neurons since stacking dense layers after a CG-invariant layer does not undo the CG-invariance. See Appendix E.2 for an example architecture.

6. EMPIRICAL RESULTS

We now provide empirical results of 12 different tasks to showcase the properties and advantages of our frameworkfoot_0 . Due to space limitations, our results are only briefly summarized here, with most of the details described in Appendix G. Appendix A also shows a task where CG-invariance is stronger than G-invariance, showing the practical relevance of Theorem 1.

Validation of our learning framework (CGreg):

In 12 different image and sequence tasks, we confirmed that our CG-regularization of Equation ( 11) is able to selectively learn to be invariant to the largest overgroup that doesn't contradict the training data, all of this without any evidence in the data supporting the invariance. The results are summarized in Table 1 , which also shows that both standard neural networks and forced G-invariant networks do not extrapolate to new environments when I = ∅ and I {1, . . . , m}, respectively. X (hid) and Transformation groups: X (hid) is the canonically ordered input (e.g., upright images, sorted sequences). Our task considers m linear automorphism groups G 1 , . . . , G m . We generate G I from a subset I ⊆ {1, . . . , m} of the groups, i.e., G I = G i∈I . We construct G D using a subset of {1, . . . , m} \ I, while ensuring that G I G D∪I in order to fulfill the conditions in Theorem 2. For image tasks, X (hid) is an upright MNIST image and the m = 3 groups are G rot , G color , G vertical-flip . For sequence tasks, we sample X (hid) as a sequence of n sorted integers from a fixed vocabulary and consider m = n 2 permutation groups for all the pair-wise permutations: G 1,2 , G 2,3 , G 1,3 , . . . , G n-1,n , where G i,j := {T identity , T i,j } and T i,j swaps positions i and j in the sequence. Training data: The training data is sampled via the SCM equations using an economical data generation process. We decompose the transformation T U I ,U D into a transformation T U D ∈ G D followed by another transformation hid) . This decomposition is made possible from our assumption that G I is a normal subgroup of G D∪I (Theorem 2). Under the assumption of economic sampling of the training data, in all our experiments we simply set T U I |U D ∈ G I to obtain X = T U I |U D • T U D • X ( T U I |U D = T identity ∈ G I , whereas T U D is randomly sampled from G D . Finally, following Equation ( 3), the label Y is a combination of the original label of X (hid) and the transformation T U D . Example (Table 1,  Extrapolation task: The extrapolated test data consists of samples from the coupled random variable hid) (Definition 1). As before, we decompose Example (Table 1 , row: rot,vflip): For image tasks, if G I = G rot, vertical-flip and G D = G color , then the extrapolation test data consists of images randomly rotated, flipped and color permuted, while the task is the same: predict the digit and its color. X U I ← U I = T U I ,U D • X ( T U I ,U D = T U I |U D • T U D with T U D ∈ G D and T U I |U D ∈ G I . Results: Standard neural networks such as CNNs (e.g., VGG (Simonyan and Zisserman, 2014 )) (for images) and GRUs/Transformers (Cho et al., 2014; Vaswani et al., 2017 ) (for sequences) fail whenever the extrapolation task requires some invariance (I = ∅), but excel at the interpolation task (I = ∅). Adding forced G D∪I -invariances via G-CNNs (Cohen and Welling, 2016) (for images) and permutation-invariant models (Lee et al., 2019; Murphy et al., 2019a; Zaheer et al., 2017 ) (for sequences) clearly fails when D = ∅ but succeeds when D = ∅. Our CG-regularized neural network representations, on the other hand, achieve high extrapolation accuracy across all tasks for all choices of I ⊆ {1, . . . , m} and D ⊆ {1, . . . , m} \ I. These results plainly show that our approach is able to selectively learn to be invariant only to the appropriate groups. Furthermore, this G I -invariance is achieved without any evidence in the training data, thanks to our novel learning paradigm that considers all G-invariances mandatory unless contradicted by the training data.

7. CONCLUSION

This work studied the task of learning representations that can extrapolate beyond the training data distribution (environment), even when presented with a single training environment. We considered the case of (counterfactual) extrapolation from linear automorphism groups and described a framework where all G-invariances (and CG-invariances via Theorem 2) are mandatory, except the ones deemed inconsistent with the training data (i.e., rather than learning G-invariances, we unlearn them). Our framework reframes the standard statistical learning hypothesis that unseen-data means underspecified-models with a learning hypothesis that forces models to have all (known) G-invariances (symmetries) that do not contradict the data, with our empirical results supporting the proposed approach. Finally, this learning paradigm offers a promising novel research direction for neural network extrapolations. Test data:  X (cf) = T ( ) T (+h) X (hid) , Y = T ( ) ∈ G D ∪ G I is such that T = T 1 • T 2 • T 3 • . . ., where T i is in either G I or G D . Then, if T 1 ∈ G D , we can write T 1 = T (1) I • T (1) D with T (1) I = T identity ∈ G I and T (1) D = T 1 ∈ G D . Continuing in a similar fashion, we can find two sequences of transformations, one from G I and the other from G D , such that interleaving and composing the resultant sequence of transformations gives us any transformation from G D∪I . This property of the noises to appropriately index any T ∈ G D∪I will be used in the proof of Theorems 1 and 2.

B.2 PROOF OF THEOREMS 1 AND 2

Theorem 1 (CG-invariance is stronger than G-invariance). Let the vector (X, X U I ← U I ) denote the counterfactual coupling of the observed variable X given in Definition 1. For a representation Γ : X → R d , d ≥ 1, let G-inv : ∀T I ∈ G I , Γ(X) = Γ(T I • X) , CG-inv : Γ(X) = Γ(X U I ← U I ) , denote the conditions on Γ for G I -invariance and CG-invariance respectively. Then, CG-inv =⇒ Ginv, but G-inv =⇒ CG-inv.

Label: Upright

Label: Flat Label: Flat Label: Upright Figure 3 : Counterexample to show that G I -invariance does not imply CG-invariance. Given images of a rod (shown in brown), we wish to predict the orientation of the rod, i.e., whether the rod is upright or flat. In this example, we have G D = G rot and G I = G h-translate as any horizontal translation does not affect the orientation of the rod. Γ : X → R sums the pixel values across the green shaded region, and is clearly G-invariant to horizontal translations. However, Γ is not CG-invariant. Proof. First, we will show that CG-invariance =⇒ G-invariance, i.e., for any CG-invariant representation Γ : X → R d , we will show that Γ is also G-invariant to G I . Consider any u ∈ supp(U I ) and say the input was generated as X U I ←u = T U D ,U I ←u • X (hid) . In other words, U I took the value u in the structural causal equation for generating the observed input (Equation ( 2)). We will prove G-invariance for this input ∈ G D and T ( * ) depends upon which of the respective sequences before interleaving is longer. Then, * ) . Further, if we write T X U I ←u , i.e., Γ(T † I • X U I ←u ) = Γ(X U I ←u ) for any T † I ∈ G I . Recall that T U D ,U T † I • T U D ,U I ←u = T † I • T (1) I • T (1) D • T (2) I • . . . • T ( (1) I = T † I • T I , then we have T † I • T U D ,U I ←u = T (1) I • T (1) D • T (2) I • . . . • T ( * ) . Now we can find a u such that U I ← u generates the sequence of transformations T (1) I , T I , . . .. Interleaving this sequence with the sequence generated by U D , we get T U D ,U I ← u = T (1) I • T (1) D • T (2) I • . . . • T ( * ) . Denote X U I ← u = T U D ,U I ← u • X (hid) . Since Γ is CG-invariant, we have from Definition 2 that Γ(X U I ←u ) = Γ(X U I ← u ) = Γ(T U D ,U I ← u • X (hid) ) = Γ(T † I • T U D ,U I ←u • X (hid) ) (from construction of u) = Γ(T † I • X U I ←u ) . Since this holds for all u ∈ supp(U I ), we have that Γ(X) = Γ(T † I • X). Next, we will show G-invariance =⇒ CG-invariance by constructing a counterexample. Let X (hid) ∈ R (2n+1)×(2n+1) be the (2n + 1) × (2n + 1) grayscale image of an upright rod as shown in Figure 3 . Consider two groups that act on this image: the rotation group G rot = {T (k) } k∈{0 • ,90 • ,180 • ,270 • } and the cyclic horizontal-translation group G h-translate = {T (+u) } u∈Zn . Let G D = G rot and G I = G h-translate and the label of the image Y deterministically given by the orientation of the rod: upright (Y = 0) or flat (Y = 1). The top row of Figure 3 depicts the data in training which is transformed by G h-translate only via the identity T (+0) (i.e., no translation). Now consider a representation Γ : R (2n+1)×(2n+1) → R such that Γ(X) = 2n+1 i=1 X n,i finds the sum of the middle row of the image. Note that (a) Γ is able to distinguish between the labels for the training data, and (b) Γ is G h-translate -invariant. We can define the random variables U I and U I such that X = T (90 • ) • X (hid) and hid) . Then, as shown in Figure 3, Γ  X U I ← U I = T (90 • ) • T (+5) • X ( (X U I ← U I ) = Γ(T (90 • ) • T (+5) • X (hid) ) = Γ(T (90 • ) • X (hid) ), thus showing that Γ is not CG-invariant. Theorem 2. If G I is a normal subgroup of G D∪I , then CG-inv ⇐⇒ G-inv. Proof. The proof that CG-invariance =⇒ G-invariance (from Theorem 1) still holds here. We only need to prove the converse: G-invariance =⇒ CG-invariance when G I is a normal subgroup of G D∪I . We begin with a representation Γ that is G I -invariant and consider the simpler case when U D generates a transformation sequence of length 1 (from G D ). In other words, X is obtained by: X = T (1) hid) for arbitrary transformations T D ∈ G D and T I • T D • T (2) I • X ( (1) I , T (2) I ∈ G I . Then for any U I , we have that X U I ← U I = T (1) I • T D • T (2) I • X (hid) with T (1) I , T I ∈ G I . Note that U I only affects the transformations from G I . The condition for CG-invariance with respect to G I requires that requirement: Γ(X) = Γ(T (1) I • T D • T (2) I • X (hid) ) = Γ( T (1) I • T D • T (2) I • X (hid) ) = Γ(X U I ← U I ) . (13) Since G I is a normal subgroup of G D∪I and G D ≤ G D∪I , we have ∀T D ∈ G D , ∀T I ∈ G I , T D • T I • T -1 D ∈ G I , or equivalently, ∀T D ∈ G D , ∀T I ∈ G I , ∃T I , s.t., T D • T I • T -1 D = T I =⇒ T D • T I = T I • T D (A special case is when the groups G D and G I commute, as then T D • T I = T I • T D .) Then,

Γ(X) = Γ(T

(1) I • T D • T (2) I • X (hid) ) = Γ(T D • T (2) I • X (hid) ) (Γ is invariant to G I ) = Γ(T I • T D • X (hid) ) (there exists such a T I ∈ G I ) = Γ(T D • X (hid) ) (Γ is invariant to G I ) Similarly, we can prove for the coupled variable that Γ(X U I ← U I ) = Γ( T (1) hid) ), thus satisfying the requirement of CG-invariance in Equation ( 13). I • T D • T (2) I • X (hid) ) = Γ(T D • X ( Extension to the case when U D generates transformation sequences of length greater than one is trivial. Any transformation T U D ,U I = T (1) I • T (1) D • T (2) I • • • • T ( * ) can be written in the form T † I • T (1) D • T (2) D • • • • by repeatedly applying the normal subgroup property in Equation ( 14). Then Γ(T U D ,U I • X (hid) ) = Γ(T † I • T (1) D • T (2) D • • • • • X (hid) ) = Γ(T (1) D • T (2) D • • • • • X (hid) ) as Γ is G I - invariant. Using a similar argument, we can show for the coupled variable that Γ(T U D , U I • X (hid) ) = Γ(T (1) D • T (2) D • • • • • X (hid) ), thus proving that Γ is CG-invariant, i.e., Γ(X) = Γ(X U I ← U I ).

B.3 PROOFS OF LEMMA 1, LEMMA 2 AND THEOREM 3

Lemma 1 (Reynolds operator (Mumford et al. (1994) , Definition 1.5)). Let G be a (finite) linear automorphism group over vec(X ). Then, T = 1 |G| T ∈G T (7) is a G-invariant linear automorphism, i.e., ∀T † ∈ G and ∀x ∈ vec(X ), it must be that T (T † x) = T x. Proof. Consider an arbitrary transformation T † ∈ G. Then T • T † = 1 |G| T ∈G T • T † = 1 |G| T ∈G † T , where we define Proof. Sufficiency: Let {w T i } d W i=1 be the set of left eigenvectors of T with eigenvalue 1 and constitute the orthogonal basis for W. Consider any non-zero w ∈ W, then G † = {T • T † : ∀T ∈ G}. (w ) T = d W i=1 α i w T i = d W i=1 α i w T i T for some coefficients {α i } d W i=1 , where we used the fact that w T i T = w T i , 1 ≤ i ≤ d W . For any x ∈ vec(X ) and any T ∈ G we have, We are representing these eigenvectors in R 3×3×3 instead to emphasize that these are rotation-invariant. (b) 1-eigenspace of the Reynolds operator for the color-permutation group. The eigenspace again has nine basis vectors v ∈ R 27 but we represent them in R 3×3×3 to emphasize that these are invariant to permutations of color channels. γ(T x; w , b) = (w ) T (T x) + b = d W i=1 α i w T i T (T x) + b (using Equation (15)) = d W i=1 α i w T i T x + b (from Lemma 1) = γ(x; w , b) R channel G channel B channel any nonzero w is G M -invariant iff it is G i -invariant for all i ∈ M . It is possible to have B M = {0} implying that there is no nonzero w ∈ vec(X ) that is G M -invariant. Next note that for all N M , we have B N ⊆ B M (using the definition of B M ). Then, their direct sum is the smallest subspace containing all such B N and thus, N M B N ⊆ B M . From our claim earlier, this implies that B M = N M B N ⊆ B M . Finally, we have B M = orth B M ( B M ) ⊆ B M for all M . Thus, we have proved that any nonzero w ∈ B M also lies in B M and hence is invariant to G M . In the sequel, we will prove that any w ∈ B M is not G j -invariant for any j ∈ {1, . . . , m} \ M . Let P M . Then it is clear that B M = N M B N ⊇ B P , which implies from the first part of our proof that any w ∈ vec(X ) that is G P -invariant lies inside B M . The orthogonalization step ensures that B M ⊥ B M and thus, B M ⊥ B P and B M ∩ B P = {0}. Hence there is no nonzero w ∈ B M such that w is G P -invariant. This applies for all supersets P M . Finally, we consider supersets of M of the form P = M ∪ {j} for j ∈ {1, . . . , m} \ M . If a nonzero w ∈ B M is invariant to G j , then it will hold that w is invariant to G P , P M , resulting in a contradiction. Hence, we have that if B M = {0}, any w ∈ B M \ {0} is G M -invariant but not G j -invariant for any j ∈ {1, . . . , m} \ M .

C EXAMPLE CONSTRUCTION OF CG-INVARIANT NEURONS

In this section, we will present a detailed example of the construction of CG-invariant neurons. Consider a 3 × 3 image with 3 channels, thus X = R 3×3×3 . Then, a convolutional filter w ∈ X = R 3×3×3 multiplies elementwise with the image x ∈ X . Consider m = 2 groups G rot and G col , the former rotates the image patch by 90-degree multiples and the latter permutes the color channels of the image. Our goal is to enforce invariance to rotation and color channel unless contradicted by training data. Note that vec(X ) = R 27 . Step 1: Construct 1-eigenspace of Reynolds operator for each group. Since we only consider linear automorphism groups, each transformation T in the group can be written as T (x) = T x, where T is a matrix of size R 27×27 and x ∈ vec(X ) = R 27 . Given a group, we can directly use Lemma 1 to construct the Reynolds operator by averaging over all the linear transformations (or corresponding matrices) in the group. Then, we can use standard methods in linear algebra to find the 1-eigenspace of the Reynolds operator (i.e., find the eigenvectors with corresponding eigenvalues equal to 1). Let W rot and W col be the 1-eigenspaces of the Reynolds operator of the groups G rot and G col respectively. Figure 4 shows these eigenspaces with the eigenvectors arranged in R 3×3×3 instead of R 27 . The figure shows that the eigenvectors in W rot are invariant to rotations of 90-degree multiples whereas the eigenvectors in W col have the same values across the RGB channels, and thus are invariant to permutation of these channels. Lemma 2 proves this invariance-property for the 1-eigenspaces of the Reynolds operator of any finite linear automorphism group. Step 2: Construct B M for all M ⊆ {rot, col}. Now, given W rot and W col , we will construct basis for the subspaces B M for all M ⊆ {rot, col} using Theorem 3. R channel G channel B channel R channel R channel G channel B channel 6 basis vectors G channel B channel R channel G channel B channel 1. Set M = {rot, col}. B {rot,col} = W rot ∩ W col B {rot,col} = B {rot,col} . (because B {rot,col} = {0}) The intersection of subspaces W rot ∩ W col can be computed using standard methods in linear algebra. The subspace B {rot,col} with 3 basis vectors is visualized in the topmost level of Figure 5 . As before the basis vectors of the subspace are represented in R 3×3×3 . It is clear that the basis vectors are invariant to both rotation and permutation of the channels. This property will hold for any linear combination of the basis vectors, i.e., for any w ∈ B {rot,col} . 2. Set M = {rot}. B {rot} = W rot B {rot} = orth B {rot} ( B {rot} ) = orth B {rot,col} ( B {rot} ) (because B {rot} = B {rot,col} ) The subspace B {rot} consists of all vectors that are invariant to rotation but also includes vectors that are invariant to both rotation and channel-permutation. Thus, we need to remove from B {rot} the projection of B {rot} on B {rot,col} . their basis vectors when the groups are just G rot and G color , and the kernel size is 3 × 3 applied over an input with 3 channels. One can similarly obtain these subspaces for other groups, different kernel sizes and different number of input channels. Then, the filter is obtained as a linear combination of these basis vectors, where the coefficients form the learnable parameters. The G-invariance of the filter then depends upon which of these coefficients are nonzero. Once the filter is obtained, it is convolved with the image or the feature maps. This will ensure that the model can be CG-invariant to transformations of smaller patches in the image if needed. Max-pooling layers function in the standard way. After all the convolutional and max-pooling layers, we use a sum-pooling layer over the entire channel to ensure that the model can be invariant to the transformations (e.g., rotations) on the whole image if needed. Finally, any number of dense layers can be added after the sum-pooling layer. In our experiments, we use the three groups G rot , G color and G vertical-flip to construct the subspaces for the filters of the first convolutional layer, but remove G color in the further layers as we do not wish to be invariant to channel permutation after the first layer.

E.2 SEQUENCES

A CG-invariant architecture for sequences is depicted in Figure 7 . Consider a sequence X = [x 1 , . . . , x n ] ∈ R p×n of length n and groups G 1 , . . . , G m as before. In the following discussion, we will assume that the groups are permutation groups over the sequence elements. However, one could also consider other groups over X. First, each element of the sequence is passed through a shared feedforward network φ that returns a representation Z ∈ R p ×n . Then, Theorem 3 finds the bases for B M , M ⊆ {1, . . . , m} until all the p n basis vectors are found covering the space R p ×n . The weight vectors for the h-th neuron of the CG-invariant layer is obtained as a linear combination of these basis vectors via the learnable parameters Ω (Equation ( 10)). Finally, any number of dense layers can be stacked after the CG-invariant layer for the final output.

F REGULARIZATION

F.1 EXAMPLE Figure 8 shows an example computation of the penalty in Equation (12). The example considers an image task with m = 3 groups: G rot , G col , G vflip . Each cell in the figure shows one subset M ⊆ {rot, col, vflip}. The subsets are arranged according to their levels of invariance, i.e., by the size of |M |. For example, the topmost cell {rot, col, vflip} denotes the subspace with all the invariances whereas the bottommost cell ∅ denotes the subspace with no invariance. The colors indicate the state of the parameters Ω at a single point in the optimization. The cells are colored green or red depending on whether the subspace is used or unused respectively, i.e., whether the parameters corresponding to the subspace are nonzero or not. The least invariant subspaces used Red colored cells denote that the parameters corresponding to these subspaces are zero (i.e., the subspaces are unused) and the green colored cells denote otherwise (i.e., the subspaces are used). In this example, the least invariant subspaces used are in Level 1. The penalty counts all the subspaces (used or unused) that are in higher levels (i.e., with |M | > 1) and adds it to the number of subspaces of the same level that are used. at this point are in Level 1 (i.e., invariant to a single group). The penalty counts (a) all subspaces with higher levels of invariance irrespective of whether the subspace is used or not, and (b) counts all the used subspaces with the same level of invariance. The former penalizes the use of subspaces lower in the partial order and ensures that subspaces with higher levels of invariance are used. The latter approximates the effort to reach a higher level of invariance.

F.2 DIFFERENTIABLE APPROXIMATION

Recall the regularization penalty R(Ω) in Equation ( 12) is given by, R(Ω) = f l (Ω) := |{M i : |M i | > l, 1 ≤ i ≤ B}| + i:|Mi|=l 1≤i≤B 1{ ω Mi,• 2 2 > 0} , where l = min{|M i | • 1{ ω Mi,• 2 2 > 0}, 1 ≤ i ≤ B}. R Ω) is clearly discrete but can be approximated by a differentiable formula. First, we replace the indicator function 1{z > 0} in Equation ( 16) with the approximation 1{z > 0} = τ z/(τ z + 1), where τ ≥ 1 is a temperature hyperparameter. Then, in order to obtain R(Ω) = f l (Ω) for the minimum l defined in Equation ( 16), we use the following recursion: R(Ω) = R m (Ω), and R l (Ω) = (1 -β l (Ω)) • R l-1 (Ω) + f l (Ω)β l (Ω) l = 1, . . . , m , with the base case R 0 (Ω) = 0, and β l (Ω) = 1{ Ni:|Ni|=l, 1≤i≤B ω Ni,• 2 2 > 0}. β l (Ω) is approximately one if at least one neuron h has nonzero ω Ni,h parameters for some N i ⊆ {1, . . . , m} of size l (i.e., with l groups). Then the recursion finds f l (Ω) with l defined as the size of the least invariant subspace used.

F.3 LIMITATION OF R(Ω)

As explained in Section 4.2, there could be overgroups (out of the total 2 m groups considered) with different levels of invariance, but penalized similarly by Equation ( 12). This scenario arises only in cases when Theorem 3 does not construct subspace basis for all the 2 m overgroups, i.e., the basis for vec(X ) is found prior to that. In this section, we provide such an example scenario with sequence inputs and the transposition groups considered in Section 6.

1. basis vector

Levels of zero subspaces 1 basis vector 1 basis vector 1 basis vector 2 transposition groups G i,j over sequences of length n = 5 and dimension d = 1. Each of the subspaces B M is of dimension 1. For each basis vector shown above, elements sharing the same color have the same value. At the topmost level, we have the subspace with most invariance, i.e., invariant to the full permutation group S n . Following many levels with empty subspaces, we have subspaces B Mp for M p = {(i, j) | i, j ∈ [n] \ {p}, i < j}, where [n] = {1, . . . , n}. In other words, the subspace B Mp is invariant to all transpositions except those that move index p. Note that we have covered the entire space R n with these n independent subspaces of dimension 1. Let X ∈ X = R n be a 1-dimensional sequence of length n. The transposition groups are {G i,j } 1≤i<j≤n , where G i,j = {T identity , T i,j } and T i,j swaps positions i and j in the sequence. Given these m = n 2 groups, we can use Lemmas 1 and 2, and Theorem 3 to find the invariant subspaces B M for subsets M ⊆ {(i, j) | 1 ≤ i < j ≤ n} indexing the transposition groups. The basis vectors for these subspaces constructed for sequence length n = 5 are visualized in Figure 9 . There are n 1-dimensional subspaces. Let the vectors b \∅ , b \{1} . . . b \{n-1} denote these n basis vectors. The notation \A means that the vector has the same value for all positions k ∈ {1, . . . , n}\A (cf. Figure 9 ). Let n = 5 and note that any weight vector ω ∈ R 5 can be written as, ω = α 1 b \∅ + α 2 b \{1} + α 3 b \{2} + α 4 b \{3} + α 5 b \{4} . ( ) where α ∈ R 5 . Let α = (1, 0, 1, 0, 1) T . From a quick read of Figure 9 , we see that the weight ω obtained by substituting α in Equation ( 17) is such that ω 1 = ω 3 = ω 5 and ω 2 = w 4 . For any input x ∈ R 5 , the neuron σ(ω T x + b) is invariant to any permutation of x 1 , x 3 and x 5 , and, transposition of x 2 and x 4 . The penalty R(ω ) = 3 as there are 2 subspaces used at the lowest level and there is 1 subspace above the lowest level (see Equation ( 12)). Now let α = (1, 0, 1, 0, 1.5) T . The weight ω obtained by substituting α in Equation ( 17) is such that ω 1 = ω 3 = ω 5 but ω 2 = ω 4 . For input x ∈ R 5 , the neuron σ(ω T x + b) is invariant to any permutation of x 1 , x 3 and x 5 , but sensitive to the transposition of x 2 and x 4 . The penalty R(ω ) = 3 as the same subspaces are used as before. In the first case, with all the parameters being equal (especially α 3 = α 5 ), ω lies in a smaller (more invariant) subspace of span(b \∅ , b \{2} , b \{4} ). In the second case, since α 3 = α 5 , the same does not hold for ω . The penalty R(•), which only counts the subspaces used (in this case, b \∅ , b \{2} and b \{4} ), is unable to distinguish between these two weight vectors ω and ω , one clearly more invariant than the other. In conclusion, we obtain the observed image X in the test data by applying a random transformation from G D to X (hid) and then applying a random transformation from G I to the result. The task is the same as in the training data: to predict the original label of the image (i.e., the digit) and the transformation T U D that was applied to obtain X. Note that the label does not depend on the transformation T U I |U D ∈ G I that was applied. Once again, if G I = G rot, vertical-flip and G D = G color , then the extrapolated test data consists of images randomly rotated, flipped and channel permuted, while the task is the same: predict the digit and its color. In order to evaluate the models, we use 5-fold cross-validation procedure as follows. We divide the training and test datasets that are pre-split in MNIST and MNIST-34 datasets into 5 folds each. We use the above procedure to transform the training data and the test data. Then in each iteration i of the cross-validation procedure, we leave out i-th fold of the transformed training data and i-th fold of the extrapolated test data. Further, we use 20% of the training data as validation data for hyperparameter tuning and early stopping. Baselines and Architecture. For all methods, we use a VGG architecture (Simonyan and Zisserman, 2014) with 8 convolutional layers each having 128 channels except the first layer which has 64 channels. All convolutional layers have a receptive field of size 3 × 3, stride 1 and padding 1. A max-pooling layer is added after every two convolutional layers. Two feedforward layers at the end give the final output. We compare our approach with the standard CNNs and Group-equivariant CNNs (G-CNNs) (Cohen and Welling, 2016) with the p4m group. We modified G-CNN such that it has invariances to all the 3 groups strictly enforced via a) coset-pooling (Cohen and Welling, 2016) after each layer and b) adding together the 3 input RGB channels. For our approach, we replace the standard convolutional layer in the VGG architecture by CG-invariant layers with bases constructed from G rot , G color and G vertical-flip . An example architecture with only 2 convolutional layers is shown in Figure 6 . We optimize all models using SGD with momentum with learning rate in {10 -2 , 10 -3 , 10 -4 } and a batch size of 64. We use early stopping on validation loss to select the best model. Further, we use validation loss to select the best set of hyperparameters for each model. We choose the maximum value of λ with validation accuracy within a 5% threshold of the maximum validation accuracy Table 4 : Sequence tasks. The first column defines the target Y for a given sequence (X i ) 10 i=1 . The second column denotes G I , the group of transformations to which Y is invariant. Recall that G I is constructed as the join of a subset of 10 2 transposition groups. Target Y G I G D Y task-1 = 10 i=1 X i {G i,j } 1≤i<j≤n {Id} Y task-2 = 10 i=2 X i {G i,j } 2≤i<j≤n {Id} Y task-3 = 5 i=1 (X 2i -X 2i-1 ) {G i,i+2k } 1≤i<i+2k≤n {Id} Y task-4 = 10 i=1 i j=1 1(X j ≥ 20) {Id} {G i,j } 1≤i<j≤n obtained from any value of λ. Tables 2 and 3 show the effect of regularization strength on the performance of the model. We observe that λ = 10 performs considerably well across all tasks. G.2 SEQUENCES Datasets. For sequence tasks, we generate X (hid) = (X i ) 10 i=1 as a sequence of n = 10 canonically ordered integers uniformly sampled with replacement from a fixed vocabulary set {1, . . . , 99}. The canonical ordering is fixed for a given set of integers sampled: the corresponding sequence X (hid) is always either in an increasing order or a decreasing order. Groups. We consider m = n 2 permutation groups for all the pair-wise permutations: G 1,2 , G 2,3 , G 1,3 , . . . , G n-1,n , where G i,j := {T identity , T i,j } and T i,j swaps positions i and j in the sequence. For I ⊆ {(1, 2), (2, 3), (1, 3), . . . , (n -1, n)}, G I is defined as before as the join ∪ (i,j)∈I G i,j . We choose 4 different subsets I of the given m groups indicated by the second column of Table 4 . For our choices of I = ∅, we set D = ∅ to ensure that G I G D∪I , i.e., G I is a normal subgroup of G D∪I . Tasks. The label for the sequence X (hid) is obtained by applying an arithmetic function to X (hid) that is invariant to the chosen group G I . The arithmetic functions are given in the first column of Table 4 . Y task-1 is invariant to any permutation of the input elements X i , 1 ≤ i ≤ n. Y task-2 is invariant to any permutation of input elements X i with indices i > 1 but sensitive to permutations that move X 1 . Y task-3 is invariant to permutations that move elements at even indices to even indices and elements at odd indices to odd indices respectively. Finally, Y task-4 is sensitive to all permutations (i.e., no invariance). Training data: Recall that X (hid) is in a sorted order. Since the training data is sampled economically, it consists only of sequences under transformations that have an effect on the label, i.e., transformations from G D . The observed input is obtained as X = T U I ,U D • X (hid) , a transformation of the sorted input X (hid) . Since In conclusion, we obtain the observed sequence X in the training data by applying a random transformation T U D ∈ G D to X (hid) and then applying a constant transformation (e.g., T identity ) from G I to the result. The target Y is computed by applying the arithmetic function corresponding to the task (see Table 4 ) to T U D • X (hid) (recall from Equation (3) that Y is a function of both X (hid) and U D ). Extrapolation task: The extrapolated test data consists of samples from the coupled random variable X U I ← U I (Definition 1). Unlike the training data that was economically sampled (i.e., with a single transformation from G I ), the extrapolated test data is obtained via the full range of transformations in G I . Recall from Definition 1 that X U I ← U I = T U I ,U D • X (hid) . As before, we decompose T U I ,U D = T U I |U D • T U D . However, there is no economic sampling for the test data: T U I |U D and T U D are sampled randomly from G I and G D respectively. Table 5 : (Sequence tasks) Extrapolation test accuracies (%) with 95% confidence intervals for all the models (bold means p < 0.05 significant). The standard sequence models cannot extrapolate when I = ∅ whereas the forced G-invariant models cannot unlearn the invariances and fail when I {1, . . . , m}. G I {G i,j } 1≤i<j≤n {G i,j } 2≤i<j≤n {G i,i+2k } 1≤i<i+2k≤n {Id} Model DeepSets (Zaheer et al., 2017) 100.00 ( 0.00) 2.36 ( 2.37) 0.97 ( 0.60) 16.12 ( 8.21) Janossy pooling (Murphy et al., 2018) 96 In conclusion, we obtain the observed sequence X in the test data by applying a random transformation T U D ∈ G D to X (hid) and then applying a random transformation from G I to the result. The target Y is computed in a similar fashion as in the training data by applying the appropriate arithmetic function to T U D • X (hid) . Note that Y is invariant to G I . Example: Consider the the first row of Table 4 with I = {(i, j)} 1≤i<j≤n , i.e., it contains all the m = n 2 groups. Then, the group G I is simply the full permutation group over n elements. The target is defined as the sum of elements (which is fully permutation-invariant). The sequences in the training data are always sorted (because of the economic sampling of training data), whereas the sequences in test data have arbitrarily different permutations (by sampling random transformations from G I ). The task is simply to compute the sum of the elements of the sequence. Sizes of the training data and the extrapolated test data are fixed at 8000 and 2000 respectively. We repeat all the experiments for 5 different random seeds. Baselines and Architecture. We compare our approach with a) standard sequence models, specifically Transformers (Vaswani et al., 2017) and GRUs (Cho et al., 2014) , and b) forced permutationinvariant set models, specifically DeepSets (Zaheer et al., 2017) , SetTransformer (Lee et al., 2019) and Janossy Pooling (Murphy et al., 2018) . An example of the proposed CG-invariant feedforward architecture is depicted in Figure 7 . We optimize all models using Adam (Kingma and Ba, 2014) with an initial learning rate in {10 -2 , 10 -3 , 10 -4 } and a batch size of 128. We use validation loss for early-stopping and to select the best hyperparameters for all models. Once again, we choose the best value for the CGregularization strength λ by choosing the maximum value of λ with validation accuracy within 5% of the maximum validation accuracy obtained from any λ. Table 6 shows the effect of regularization strength on the performance of the model. We observe that although λ = 10 performs comparably to the rest in validation accuracy and is chosen consistently, it does not achieve the best possible extrapolation accuracy. Table 5 shows the complete set of results for all the models. The table clearly shows the issue with standard sequence models (cannot extrapolate when I = ∅) and the issue with forced G-invariant models (fail when I {1, . . . , m}). In Table 4 of the main text, we show the results for the best model out of all the permutation-invariant models in the column Best FF+G-inv.



Public code available at: https://github.com/PurdueMINDS/NN_CGInvariance



where T U D ,U I is a transformation in G D∪I indexed by two independent hidden environment background random variables U D , U I . The reader can roughly interpret U D and U I as the random seeds of a random number generator that gives ordered sequences of transformations from G D and G I respectively. If these ordered sequences are, say, T , then T U D ,U I is the transformation obtained after interleaving the two sequences of transformations and composing them in order: T U D ,U I = T(1)

row: rot,vflip): For image tasks, if G I = G rot, vertical-flip and G D = G color , then the training data consists of upright and unflipped images (as T U I |U D = T identity ) with different permutations of the color channels (random transformations T U D ∈ G color are chosen). The task is to predict the original label of the image (i.e., the digit) and the transformation T D (i.e., the color).

However, there is no economic sampling for the test data: T U I |U D and T U D are sampled randomly from G I and G D respectively. The task is the same as in the training data.

X (obs) = T ( ) X(hid) , Y = T (

Figure 2: An example task where CG-invariance is stronger than G-invariance. The task is to predict the orientation of the image while being CG-invariant to horizontal translations.

I ←u was generated by interleaving two separate sequences of transformations obtained via the background variables U D and U I respectively (Appendix B.1). In other words, we can write T U D ,U I ←u = T

Now, in order to prove T • T † = T , we only need to show that G † = G. Since groups are closed under compositions, we have ∀T ∈ G, T • T † ∈ G, and thus G † ⊆ G. Finally, since T † is a bijection and T a • T † = T b • T † only if T a = T b for any T a , T b ∈ G, it must be that |G † | = |G|. Hence, G † = G. Lemma 2. If W denotes the left eigenspace corresponding to the eigenvalue 1 of the Reynolds operator T for the group G, then ∀b ∈ R, the linear transformation γ(x; w, b) = w T x + b is invariant to all transformations T ∈ G, i.e., γ(T x; w, b) = γ(x; w, b), if and only if w ∈ W.

Figure4: (a) 1-eigenspace of the Reynolds operator for the rotation group. The eigenspace has nine basis vectors v ∈ R 27 (stacked). We are representing these eigenvectors in R 3×3×3 instead to emphasize that these are rotation-invariant. (b) 1-eigenspace of the Reynolds operator for the color-permutation group. The eigenspace again has nine basis vectors v ∈ R 27 but we represent them in R 3×3×3 to emphasize that these are invariant to permutations of color channels.

Figure5: The subspaces B M for all M ⊆ {rot, col}. For instance, B {rot,col} on the top has 3 basis vectors (represented in R 3×3×3 ) and each of these vectors are both rotation-invariant and channel-permutation invariant. On the other hand, B {rot} (of dimension 6) is rotation invariant but strictly not channel-permutation invariant. Finally, the vectors in B ∅ are neither rotation-invariant nor channel-permutation invariant. All the basis vectors together cover the entire space R 27 (i.e., dim(B {rot,col} ) + dim(B {rot} ) + dim(B {col} ) + dim(B ∅ ) = 3 + 6 + 6 + 12 = 27).

Figure 7: An example architecture of CG-invariant feedforward network.

Figure8: (Best viewed in color) Describing the computation of the penalty. The cells denote different subsets M ⊆ {rot, col, vflip}. Red colored cells denote that the parameters corresponding to these subspaces are zero (i.e., the subspaces are unused) and the green colored cells denote otherwise (i.e., the subspaces are used). In this example, the least invariant subspaces used are in Level 1. The penalty counts all the subspaces (used or unused) that are in higher levels (i.e., with |M | > 1) and adds it to the number of subspaces of the same level that are used.

Figure 9: (Best viewed in color) The subspaces B M for different M ⊆ {(i, j)} 1≤i<j≤n indexing the m = n2 transposition groups G i,j over sequences of length n = 5 and dimension d = 1. Each of the subspaces B M is of dimension 1. For each basis vector shown above, elements sharing the same color have the same value. At the topmost level, we have the subspace with most invariance, i.e., invariant to the full permutation group S n . Following many levels with empty subspaces, we have subspacesB Mp for M p = {(i, j) | i, j ∈ [n] \ {p}, i < j}, where [n] = {1, . . . , n}.In other words, the subspace B Mp is invariant to all transpositions except those that move index p. Note that we have covered the entire space R n with these n independent subspaces of dimension 1.

G I G D (by construction), we have that any T U I ,U D = T U I |U D • T U D , i.e., the transformation can be decomposed into one transformation from G D followed by another transformation from G I . U I | U D in the subscript indicates that the transformation T U I |U D ∈ G I also depends on U D . Under the assumption of economic sampling of training data, in all our experiments we sample a single value for T U I |U D ∈ G I : we simply use T U I |U D = T identity .

∀u ∈ supp(U I ), ∀u ∈ supp( U I ) and supp(A) is the support of random variable A. X U I ← U I ), for some appropriately defined U I ∼ P ( U I ), be the random variables describing the training and test data, respectively. We do not have access to test data at training time. Let Γ true : X → R d , d ≥ 1, be a representation of the input data. Consider a function g true : R d → Im P (Y = y|X(tr) ) -where Im P (•) is the image of P (•)-(e.g., g true could be a feedforward network with softmax output) and

Extrapolation accuracy (± 95% confidence interval, bold means p < 0.05 significant) Image transformation groups {Grot, Gvertical-flip, Gcolor} Sequences {G1,2, . . . , Gn-1,n} Task: Predict digit & which transformations of GD was applied to image Tasks depend on I (see Appendix G)

Kun Zhang, Mingming Gong, and Bernhard Schölkopf. Multi-source domain adaptation: A causal view. In AAAI, volume 1, pages 3150-3157, 2015. Allan Zhou, Tom Knowles, and Chelsea Finn. Meta-learning symmetries by reparameterization. In International Conference on Learning Representations, 2021. URL https://openreview. net/forum?id=-QxT4mJdijq.

Validation and Extrapolation test accuracies (%) with 95% confidence intervals for different CG-regularization strength λ in Equation (11). λ is chosen only based on the validation accuracy: maximum λ with validation accuracy within 5% of the best validation accuracy (bold values indicate the performance of this choice of λ).

(MNIST.)  Validation and Extrapolation test accuracies (%) with 95% confidence intervals for different CG-regularization strength λ in Equation (11). λ is chosen only based on the validation accuracy: maximum λ with validation accuracy within 5% of the best validation accuracy (bold values indicate the performance of this choice of λ).

(Sequence Tasks) Validation and Extrapolation test accuracies (%) with 95% confidence intervals for different CG-regularization strength λ in Equation (11). λ is chosen only based on the validation accuracy: maximum λ with validation accuracy within 5% of the best validation accuracy (bold values indicate the performance of this choice of λ). (87.19) 100.00 ( 0.00) 80.36 (74.04) 99.95 ( 0.10) 22.78 (14.52) 99.92 ( 0.27) 99.86 ( 0.15) 0.1 80.83 (84.04) 59.18 (91.12) 100.00 ( 0.00) 65.13 (80.43) 99.99 ( 0.05) 60.66 (49.83) 99.95 ( 0.10) 99.98 ( 0.05) 1.0 80.72 (84.48) 80.03 (87.52) 99.04 ( 4.22) 61.81 (66.24) 100.00 ( 0.00) 68.34 (36.98) 99.81 ( 0.55) 99.76 ( 0.65) 2.0 82.56 (72.85) 63.16 (99.44) 100.00 ( 0.00) 77.97 (48.87) 99.99 ( 0.05) 69.20 (31.40) 99.46 ( 0.53) 99.37 ( 0.53) 10.0 80.97 (74.10) 62.83 (100.17) 98.14 ( 2.71) 42.08 (18.99) 100.00 ( 0.00) 71.85 (26.61) 95.56 ( 3.34) 95.70 ( 3.05) 100.0 100.00 ( 0.00) 100.00 ( 0.00) 15.65 ( 3.63) 2.29 ( 0.96) 93.42 (14.90) 27.64 (24.85) 65.92 (10.38) 65.42 (10.30)

ACKNOWLEDGMENTS

This work was funded in part by the National Science Foundation (NSF) Awards CAREER IIS-1943364 and CCF-1918483, the Purdue Integrative Data Science Initiative, and the Wabash Heartland Innovation Network. Any opinions, findings and conclusions or recommendations expressed in this material are those of the authors and do not necessarily reflect the views of the sponsors.

annex

Necessity: Given a non-zero w ∈ W and b ∈ R, let γ(T x; w, b) = γ(x; w, b) for all x ∈ vec(X ) and all T ∈ G. Then,(summing over all T ∈ G)Hence proved that w T is a left eigenvector of T with eigenvalue 1.Theorem 3 (G-invariant subspace bases can be partially ordered by invariance strength). Let W i ⊆ vec(X ) be the left eigenspace corresponding to the eigenvalue 1 of the Reynolds operator T i for group G i , i = 1, . . . , m. We construct the invariant subspace partitionswhere ℘ is the power set, B M = N M B N , orth A1 (A 2 ) removes from the subspace A 2 its orthogonal projection onto the subspace A 1 , and is the direct sum operator. Then, the linear transformation γ(x;Proof. Throughout this proof, we will slightly abuse notation by calling a w ∈ vec(X ) as G-invariant for some group G, where we mean the transformation γ(is the direct sum operator. Essentially, B M is the direct sum of all the subspaces corresponding to the strict supersets of M . Using induction on the size of M , we first show that B M = N M B N . The statement trivially holds for B {1,...,m} .Then the induction hypothesis is: for all sets M such that |M | > k, we have B M = N M B N . We prove that the statement holds for any set M with |M | = k as follows,This proves our claim thatNow we are ready to prove the theorem. We begin by showing that any nonzerowe have from Lemma 2 that any w ∈ B M is G i -invariant for all i ∈ M . Then it is easy to see that Published as a conference paper at ICLR 2021The subspace B {rot} with 6 basis vectors is visualized in middle level of Figure 5 . It is clear that the basis vectors are invariant to rotation but not invariant to channel-permutations.Again, this property holds for any linear combination of the basis vectors.3. Set M = {col}.The subspace B {col} is obtained in a similar fashion. B {col} has 6 basis vectors and is visualized in middle level of Figure 5 . It is clear that the basis vectors are invariant to channel-permutations but not invariant to rotation. This property holds for any linear combination of the basis vectors.4. Set M = ∅.whereThe subspace B ∅ represents the rest of the space that is neither rotation-invariant nor channel-permutation-invariant. This subspace has 12 basis vectors and is visualized in the bottommost level of Figure 5 .Finally, we have B = 4 subspaces (enumerated above) with a total of 27 basis vectors covering the entire space vec(X ) = R 27 .Step 3: Neuron construction. For each subspace B M , M ⊆ {rot, col}, we denote B M as the corresponding matrix with columns as the basis vectors of the subspace B M . As described above any linear combination of the basis vectors of B M are invariant to all groups indexed by M and nothing more (e.g., B {rot} consists of vectors invariant to rotation but not invariant to channel-permutation).In the following, we consider a single neuron and drop the subscript h from ω M,h (where h represented the h-th neuron in Equation ( 10)). Recall that ω M ∈ R d M are the learnable parameters of the neuron corresponding to each basis vector of the subspace B M , and d M is the dimension of the subspace B M . Then, ω {rot,col} ∈ R 3 represents the coefficients in the linear combination of the basis vectors in B {rot,col} . The linear combination is given by the matrix-vector product B {rot,col} ω {rot,col} .Similarly, ω {rot} ∈ R 6 , ω {col} ∈ R 6 , ω ∅ ∈ R 12 represent the coefficients of the basis vectors in the columns of B {rot} , B {col} and B ∅ respectively.Then, a CG-invariant neuron is given by,and ω {rot,col} , ω {rot} , ω {col} , ω ∅ , b ∈ R are the only learnable parameters. The total number of parameters is 28, same as that of the standard neuron with input x ∈ R 27 . Now, if for example the optimization finds ω {rot,col} = 0, ω {rot} = 0, ω {col} = 0 and ω ∅ = 0, then the neuron Γ(•) is invariant to both rotation and channel-permutation.Our regularization in Equation ( 11) forces the optimization to find maximum invariance as long as training performance is unaffected. A more comprehensive example of the computation of the penalty is given in Appendix F.

D PSEUDOCODE FOR THEOREM 3

We present the algorithm for Theorem 3 in Algorithm 1. The loops in the algorithm iterate over the different subsets M ⊆ {1, . . . , m} in descending order of their sizes. The worst-case complexity of the algorithm is exponential in m (to iterate over all subsets). However, since the algorithm stops after finding all the basis for the space vec(X ), it is unclear if the worst-case runtime occurs in practice. E ARCHITECTURES E.1 IMAGES An example CG-invariant CNN architecture is depicted in Figure 6 . Majority of the CNN architecture remains the same with the exception that the filters are obtained using the bases of the subspaces obtained in Theorem 3 for the given set of groups. Figure 5 shows example subspaces along with In this specific case with transposition groups over sequences, one could add another penalty term that regularizes the parameters α i to share the same value (e.g., entropy regularization of the parameters). We leave further investigation into the general scenario with other groups for future work.

G DATASETS AND EMPIRICAL RESULTS

G.1 IMAGES Datasets. We consider the standard MNIST dataset and its subset MNIST-34 that contains only the digits 3 & 4 alone. We chose to experiment on the MNIST-34 dataset since it does not have digits that can be confused with a rotation transformation (e.g., 6 and 9) or are invariant to some rotations (e.g., 0, 1 and 8), thus avoiding any confounding factors while testing our hypothesis. We also experiment on the full MNIST dataset to depict the scenario when the data does contain these contradictions. First, we modify all the images in the dataset to have three RGB color channels and color each digit red initially, i.e., all active pixels in the digit are set to (255, 0, 0). We sample X (hid) from this dataset with the target digit as its original label.Groups. We consider m = 3 linear automorphism groups on images: the rotation group G rot = {T (0 • ) , T (90 • ) , T (180 • ) , T (270 • ) } that rotates the entire image by multiples of 90 • , the channelpermutation group G color = {T α } α∈S3 that permutes the three RGB channels of the image, and the vertical flip group G vertical-flip = {T (0) , T (v) } that vertically flips the image.Tasks. For both MNIST and MNIST-34 datasets, we consider 4 classification tasks where each task represents the case when the target Y is invariant to a different subset of {G rot , G vertical-flip , G color }, i.e., invariant to all three groups, to two, to one, invariant to none (and sensitive to the remaining groups). We consider the following subsets I: i) {rot, color, vertical-flip}, ii) {rot, vertical-flip}, iii) {color}, iv) ∅, and generate G I = ∪ i∈I G i as the join of the respective groups. Setting D = {rot, color, vertical-flip} \ I, we generate G D = ∪ j∈D G j from the join of groups in the complement set (our choices ensure that G I G D∪I , thus satisfying the conditions of Theorem 2).Training data: X (hid) is the canonically ordered (standard) image in the MNIST datasets. Recall that the training data is sampled via an economical data generation process. Thus the training data consists only of images under transformations that have an effect on the label, i.e., transformations from G D .Recall from Equation (2) that the observed input is obtained as X = T U I ,U D • X (hid) , a transformation of the canonical input X (hid) . Since G I G D (by construction), we have that any In conclusion, we obtain the observed image X in the training data by applying a random transformation from G D to X (hid) and then applying a constant transformation (e.g., T identity ) from G I to the result. The task is to predict the original label of the image (i.e., the digit) and the transformation T U D that was applied to obtain X (recall from Equation (3) that Y is a function of both X (hid) and U D ).For instance, if G I = G rot, vertical-flip and G D = G color , then the training data consists of upright and unflipped images (as T U I |U D is chosen to be identity transformation) with different permutations of the color channels (since random transformations are sampled from G D ) resulting in digits with different colors. Then, the task is to predict the digit and its color.Extrapolation task: The extrapolated test data consists of samples from the coupled random variable X U I ← U I (Definition 1). Unlike the training data that was economically sampled (i.e., with a single transformation from G I ), the extrapolated test data is obtained via the full range of transformations in G I . Recall from Definition 1 that X U I ← U I = T U I ,U D • X (hid) . As before, we decompose T U I ,U D = T U I |U D • T U D . However, there is no economic sampling for the test data: T U I |U D and T U D are sampled randomly from G I and G D respectively.

