EQUIVARIANT DISENTANGLED TRANSFORMATION FOR DOMAIN GENERALIZATION UNDER COMBINATION SHIFT

Abstract

Machine learning systems may encounter unexpected problems when the data distribution changes in the deployment environment. A major reason is that certain combinations of domains and labels are not observed during training but appear in the test environment. Although various invariance-based algorithms can be applied, we find that the performance gain is often marginal. To formally analyze this issue, we provide a unique algebraic formulation of the combination shift problem based on the concepts of homomorphism, equivariance, and a refined definition of disentanglement. The algebraic requirements naturally derive a simple yet effective method, referred to as equivariant disentangled transformation (EDT), which augments the data based on the algebraic structures of labels and makes the transformation satisfy the equivariance and disentanglement requirements. Experimental results demonstrate that invariance may be insufficient, and it is important to exploit the equivariance structure in the combination shift problem.

1. INTRODUCTION

The way we humans perceive the world is combinatorial -we tend to cognize a complex object or phenomenon as a combination of simpler factors of variation. Further, we have the ability to recognize, imagine, and process novel combinations of factors that we have never observed so that we can survive in this rapidly changing world. Such ability is usually referred to as generalization. However, despite recent super-human performance on certain tasks, machine learning systems still lack this generalization ability, especially when only a limited subset of all combinations of factors are observable (Sagawa et al., 2020; Träuble et al., 2021; Goel et al., 2021; Wiles et al., 2022) . In risk-sensitive applications such as driver-assistance systems (Alcorn et al., 2019; Volk et al., 2019) and computer-aided medical diagnosis (Castro et al., 2020; Bissoto et al., 2020) , performing well only on a given subset of combinations but not on unobserved combinations may cause unexpected and catastrophic failures in a deployment environment.

Domain generalization (Wang et al., 2021a

) is a problem where we need to deal with combinations of two factors: domains and labels. Recently, Gulrajani & Lopez-Paz (2021) questioned the progress of the domain generalization research, claiming that several algorithms are not significantly superior to an empirical risk minimization (ERM) baseline. In addition to the model selection issue raised by Gulrajani & Lopez-Paz (2021), we conjecture that this is due to the ambitious goal of the usual domain generalization setting: generalizing to a completely unknown domain. Is it really possible to understand art if we have only seen photographs (Li et al., 2017) ? Besides, those datasets used for evaluation usually have almost uniformly distributed domains and classes for training, which may be unrealistic to expect in real-world applications. A more practical but still challenging learning problem is to learn all domains and labels, but only given a limited subset of the domain-label combinations for training. We refer to the usual setting of domain generalization as domain shift and this new setting as combination shift. An illustration is given in Fig. 1 . Combination shift is more feasible because all domains are at least partially observable during training but is also more challenging because the distribution of labels can vary significantly across domains. The learning goal is to improve generalization with as few combinations as possible. To solve the combination shift problem, a straightforward way is to apply the methods designed for domain shift. One approach is based on the idea that the prediction of labels should be invariant to the change of domains (Ganin et al., 2016; Sun & Saenko, 2016; Arjovsky et al., 2019; Creager et al., 2021) . However, we find that the performance improvement is often marginal. Recent works (Wiles et al., 2022; Schott et al., 2022 ) also provided empirical evidence showing that invariance-based domain generalization methods offer limited improvement. On the other hand, they also showed that data augmentation and pre-training could be more effective. To analyze this phenomenon, a unified perspective on different methods is desired. In this work, we provide an algebraic formulation for both invariance-based methods and data augmentation methods to investigate why invariance may be insufficient and how we should learn data augmentations. We also derive a simple yet effective method from the algebraic requirements, referred to as equivariant disentangled transformation (EDT), to demonstrate its usefulness. Our main contributions are as follows: We provide an algebraic formulation for the combination shift problem. We show that invariance is only half the story and it is important to exploit the equivariance structure. We present a refined definition of disentanglement beyond the one based on group action (Higgins et al., 2018) , which may be interesting in its own right. As a proof of concept, we demonstrate that learning data augmentations based on the algebraic structures of labels is a promising approach for the combination shift problem.

2. PROBLEM: DOMAIN GENERALIZATION UNDER COMBINATION SHIFT

Throughout the following sections, we study the problem of transforming a set of features X to a set of targets Y via a function f : X → Y . Here, X can be a set of images, texts, audios, or more structured data, while Y is the space of outputs. Further, the target Y may have multiple components. For example, Y 1 is the set of domain indices and Y 2 is the set of target labels. Ideally, all combinations of domains and target labels would be uniformly observable. However, in reality, it may not be the case because of selection bias, uncontrolled variables, or changing environments (Sagawa et al., 2020; Träuble et al., 2021) . Let Y train 



Figure 1: Domain generalization under domain shift (an unseen domain) and combination shift (unseen combinations of domains and labels). Domain: color, label: digit, training: , test: .

Based on this algebraic formulation, we derive (a) what combinations are needed to effectively learn augmentations; (b) what augmentations are useful for improving generalization; and (c) what regularization can be derived from the algebraic constraints, which can serve as a guidance for designing data augmentation methods.

and Y test i denote the sets of i-th components (the support of the marginal distributions) observed in the training and test data. In the usual domain generalization setting (Wang et al., 2021a; Gulrajani & Lopez-Paz, 2021), the goal is to generalize to a completely unseen domain, i.e., domain shift. We have Y train 2 = Y test 2 but

