MULTI-DOMAIN LONG-TAILED LEARNING BY AUG-MENTING DISENTANGLED REPRESENTATIONS

Abstract

There is an inescapable long-tailed class-imbalance issue in many real-world classification problems. Existing long-tailed classification methods focus on the single-domain setting, where all examples are drawn from the same distribution. However, real-world scenarios often involve multiple domains with distinct imbalanced class distributions. We study this multi-domain long-tailed learning problem and aim to produce a model that generalizes well across all classes and domains. Towards that goal, we introduce TALLY, which produces invariant predictors by balanced augmenting hidden representations over domains and classes. Built upon a proposed selective balanced sampling strategy, TALLY achieves this by mixing the semantic representation of one example with the domain-associated nuisances of another, producing a new representation for use as data augmentation. To improve the disentanglement of semantic representations, TALLY further utilizes a domain-invariant class prototype that averages out domain-specific effects. We evaluate TALLY on four long-tailed variants of classical domain generalization benchmarks and two real-world imbalanced multi-domain datasets. The results indicate that TALLY consistently outperforms other state-of-the-art methods in both subpopulation shift and domain shift.



. This long-tailed setting arises frequently in practice, such as wildlife recognition (Beery et al., 2020) . Classifiers tend to be biased towards majority classes and perform poorly on classbalanced test distributions, i.e. when there is a shift in the label distribution between training and test. Existing approaches address the long-tailed problem by modifying the data sampling strategy (Chawla et al., 2002; Zhang & Pfister, 2021) , adjusting the loss function for different classses (Cao et al., 2019; Hong et al., 2021) , or augmenting minority classes (Chou et al., 2020; Zhong et al., 2021) . Unlike these works, which focus on single-domain long-tailed learning, we study multi-domain long-tailed learning, where each domain has its own long-tailed distribution. Take wildlife recognition as an example (Figure 1 ): images of wildlife are collected from various locations, and the distribution over species (classes) at each location is typically imbalanced and the class distribution also varies between locations. In multi-domain long-tailed classification, the classifiers need to handle distribution shift amidst class imbalance. Here, we focus on two types of distribution shift: subpopulation shift and domain shift. In subpopulation shift, we train a model on data from multiple domains and evaluate the model on a test set with balanced domain-class pairs. In the wildlife recognition example, species are often concentrated at only a few locations, creating a spurious correlation between the label (species) and the domain (location). A machine learning model trained on the entire population may fail on the test set when this correlation does not hold anymore. In domain shift, we expect the trained model to generalize well to completely new test domains. For example, in wildlife recognition, we train a model on data from a fixed set of training locations and then deploy the model to new test locations. Prior long-tailed classification methods work well in single-domain settings, but may perform poorly when the test data is from underrepresented domains or novel domains. Meanwhile, invariant learning approaches alleviate cross-domain performance gaps by learning representations or predictors that are invariant across different domains (Arjovsky et al., 2019; Li et al., 2018 ). Yet, these approaches are mostly evaluated in class-balanced settings, where models must be trained on plenty of examples from each class even if augmentation strategies are applied (Yao et al., 2022) -see a detailed discussion in Appendix B. With multi-domain long-tailed data, learning a class-unbiased domain-invariant model is not trivial since the imbalance can exist within a domain or across domains. We aim to address these challenges in this work, leading to a novel method named TALLY (mulTi-domAin Long-tailed learning with baLanced representation reassemblY). TALLY empowers augmentation to balance examples over domains and classes by decomposing and reassembling example pairs, combining the class-relevant semantic information of one example with the domain-associated nuisances of another Zhou et al. (2022) . Specifically, TALLY first decouples the representation of each example into semantic information and nuisances with instance normalization. To further mitigate the effects of nuisances, we first average out domain information over examples of the same class and construct class prototype representations. Each semantic representation is then linearly interpolated with a corresponding class prototype, leading to the prototype-enhanced semantic representation. The domain-associated factors are similarly interpolated with class-agnostic domain factors to improve training stability and remove noise. Finally, TALLY produces augmented representations to benefit the training process by reassembling the prototypeenhanced semantic representation and domain-associated nuisances among examples. To further achieve balanced augmentation, we propose a selective balanced sampling strategy to draw example pairs for augmentation. Concretely, for each pair, the label of one example is uniformly sampled from all classes and the domain of another example is uniformly sampled from all domains. In this way, TALLY encourages the model to learn a class-unbiased invariant predictor. In summary, our major contributions are: we investigate and formalize an important yet less explored problem -multi-domain long-tailed learning, and propose an effective augmentation algorithm called TALLY to simultaneously address the class-imbalance issue and learn domain-invariant predictors. We empirically demonstrate the effectiveness of TALLY under subpopulation shift and domain shift. We observe that TALLY outperforms both prior single-domain long-tailed learning and domaininvariant learning approaches, with a 5.18% error decrease over all datasets. Furthermore, TALLY is capable of capturing stronger invariant predictors compared with prior invariant learning approaches.

2. FORMULATIONS AND PRELIMINARIES

Long-Tailed Learning. In this paper, we investigate the setting where one predicts the class label y ∈ C based on the input feature x ∈ X , where C = {1, . . . , C}. Given a machine learning model f parameterized by parameter θ and a loss function ℓ, empirical risk minimization (ERM) trains such a model by minimizing average loss over all training examples as min θ E (x,y)∼P tr [ℓ(f θ (x), y)], which works well when the label distribution is approximately uniform. In long-tailed learning, however, the label distribution is long-tailed, where a small proportion of classes have massive labels 



Figure 1: Illustration of imbalanced class distributions across domains in iWildCam, a wildlife recognition benchmark (Beery et al., 2020). Both subpopulation shift and domain shift settings are illustrated. Deep classification models can struggle when the number of examples per class varies dramatically (Beeryet al., 2020; Zhang et al., 2021). This long-tailed setting arises frequently in practice, such as wildlife recognition(Beery et al., 2020). Classifiers tend to be biased towards majority classes and perform poorly on classbalanced test distributions, i.e. when there is a shift in the label distribution between training and test. Existing approaches address the long-tailed problem by modifying the data sampling strategy(Chawla et al.,  2002; Zhang & Pfister, 2021), adjusting the loss function for different classses(Cao et al., 2019; Hong et al.,  2021), or augmenting minority classes(Chou et al.,  2020; Zhong et al., 2021). Unlike these works, which focus on single-domain long-tailed learning, we study multi-domain long-tailed learning, where each domain has its own long-tailed distribution. Take wildlife recognition as an example (Figure1): images of wildlife are collected from various locations, and the distribution over species (classes) at each location is typically imbalanced and the class distribution also varies between locations.

and the rest of classes are associated with a few examples. Assume {(xi, yi)} N i=1 is a training set sampled from training distribution and the number of examples for each class is {n1, . . . , nC }, where C c=1 nc = N . In long-tailed learning, all classes are sorted according to cardinality (i.e., n 1 ≪ n C ) and the imbalance ratio ρ is defined as ρ = n1/nC > 1. Note that same definitions are used in the set {(xi, yi)} N ts i=1 . Under the class-imbalanced training distribution, vanilla ERM model tends to perform poorly on minority classes, but we expect the model can perform consistently well on all classes. Hence we typically assume the test distribution is class-balanced (i.e., ρ ts = 1). Multi-Domain Imbalanced Learning. Multi-domain long-tailed learning is a natural extension of classical long-tailed learning, where the overall data distribution is drawn from a set of domains D = {1, . . . , D} and each domain d is associated with a class-imbalanced dataset {(xi, yi, d)} N d i=1 drawn from domain-specific distribution p d . Following (Albuquerque et al., 2019; Koh et al., 2021), both training and test distribution can be formulated as a mixture distribution over domain space D, i.e., P tr = D d=1 η tr d P tr d and P ts = D d=1 η ts d P ts d . The corresponding training and test domains are D tr = {d ∈ D|η tr d > 0} and D ts = {d ∈ D|η ts d > 0}, respectively, where η tr d and η ts d represent the

