SPARSE MIXTURE-OF-EXPERTS ARE DOMAIN GENER-ALIZABLE LEARNERS

Abstract

Human visual perception can easily generalize to out-of-distributed visual data, which is far beyond the capability of modern machine learning models. Domain generalization (DG) aims to close this gap, with existing DG methods mainly focusing on the loss function design. In this paper, we propose to explore an orthogonal direction, i.e., the design of the backbone architecture. It is motivated by an empirical finding that transformer-based models trained with empirical risk minimization (ERM) outperform CNN-based models employing state-ofthe-art (SOTA) DG algorithms on multiple DG datasets. We develop a formal framework to characterize a network's robustness to distribution shifts by studying its architecture's alignment with the correlations in the dataset. This analysis guides us to propose a novel DG model built upon vision transformers, namely Generalizable Mixture-of-Experts (GMoE). Extensive experiments on DomainBed demonstrate that GMoE trained with ERM outperforms SOTA DG baselines by a large margin. Moreover, GMoE is complementary to existing DG methods and its performance is substantially improved when trained with DG algorithms.

1. INTRODUCTION

1.1 MOTIVATIONS Generalizing to out-of-distribution (OOD) data is an innate ability for human vision, but highly challenging for machine learning models (Recht et al., 2019; Geirhos et al., 2021; Ma et al., 2022) . Domain generalization (DG) is one approach to address this problem, which encourages models to be resilient under various distribution shifts such as background, lighting, texture, shape, and geographic/demographic attributes. From the perspective of representation learning, there are several paradigms towards this goal, including domain alignment (Ganin et al., 2016; Hoffman et al., 2018) , invariant causality prediction (Arjovsky et al., 2019; Krueger et al., 2021 ), meta-learning (Bui et al., 2021; Zhang et al., 2021c) , ensemble learning (Mancini et al., 2018; Cha et al., 2021b) , and feature disentanglement (Wang et al., 2021; Zhang et al., 2021b) . The most popular approach to implementing these ideas is to design a specific loss function. For example, DANN (Ganin et al., 2016) aligns domain distributions by adversarial losses. Invariant causal prediction can be enforced by a penalty of gradient norm (Arjovsky et al., 2019) or variance of training risks (Krueger et al., 2021) . Meta-learning and domain-specific loss functions (Bui et al., 2021; Zhang et al., 2021c) have also been employed to enhance the performance. Recent studies have shown that these approaches improve ERM and achieve promising results on large-scale DG datasets (Wiles et al., 2021) . Meanwhile, in various computer vision tasks, the innovations in backbone architectures play a pivotal role in performance boost and have attracted much attention (He et al., 2016; Hu et al., 2018; Liu et al., 2021) . Additionally, it has been empirically demonstrated in Sivaprasad et al. (2021) that different CNN architectures have different performances on DG datasets. Inspired by these pioneering works, we conjecture that backbone architecture design would be promising for DG. To verify this intuition, we evaluate a transformer-based model and compare it with CNN-based architectures of equivalent computational overhead, as shown in Fig. 1 (a). To our surprise, a vanilla ViT-S/16 (Dosovitskiy et al., 2021) trained with empirical risk minimization (ERM) outperforms ResNet-50 trained with SOTA DG algorithms (Cha et al., 2021b; Rame et al., 2021; Shi et al., 2021) on DomainNet, OfficeHome and VLCS datasets, despite the fact that both architectures have a similar number of parameters and enjoy close performance on in-distribution domains. We theoretically validate this effect based on the algorithmic alignment framework (Xu et al., 2020a; Li et al., 2021) . We first prove that a network trained with the ERM loss function is more robust to distribution shifts if its architecture is more similar to the invariant correlation, where the similarity is formally measured by the alignment value defined in Xu et al. (2020a) . On the contrary, a network is less robust if its architecture aligns with the spurious correlation. We then investigate the alignment between backbone architectures (i.e., convolutions and attentions) and the correlations in these datasets, which explains the superior performance of ViT-based methods. To further improve the performance, our analysis indicates that we should exploit properties of invariant correlations in vision tasks and design network architectures to align with these properties. This requires an investigation that sits at the intersection of domain generalization and classic computer vision. In domain generalization, it is widely believed that the data are composed of some sets of attributes and distribution shifts of data are distribution shifts of these attributes (Wiles et al., 2021) . The latent factorization model of these attributes is almost identical to the generative model of visual attributes in classic computer vision (Ferrari & Zisserman, 2007) . To capture these diverse attributes, we propose a Generalizable Mixture-of-Experts (GMoE), which is built upon sparse mixture-of-experts (sparse MoEs) (Shazeer et al., 2017) and vision transformer (Dosovitskiy et al., 2021) . The sparse MoEs were originally proposed as key enablers for extremely large, but efficient models (Fedus et al., 2022) . By theoretical and empirical evidence, we demonstrate that MoEs are experts for processing visual attributes, leading to a better alignment with invariant correlations. Based on our analysis, we modify the architecture of sparse MoEs to enhance their performance in DG. Extensive experiments demonstrate that GMoE achieves superior domain generalization performance both with and without DG algorithms.

1.2. CONTRIBUTIONS

In this paper, we formally investigate the impact of the backbone architecture on DG and propose to develop effective DG methods by backbone architecture design. Specifically, our main contributions are summarized as follows: A Novel View of DG: In contrast to previous works, this paper initiates a formal exploration of the backbone architecture in DG. Based on algorithmic alignment (Xu et al., 2020a) , we prove that a network is more robust to distribution shifts if its architecture aligns with the invariant correlation, whereas less robust if its architecture aligns with spurious correlation. The theorems are verified on synthetic and real datasets. A Novel Model for DG: Based on our theoretical analysis, we propose Generalizable Mixture-of-Experts (GMoE) and prove that it enjoys a better alignment than vision transformers. GMoE is built upon sparse mixture-of-experts (Shazeer et al., 2017) and vision transformer (Dosovitskiy et al., 2021) , with a theory-guided performance enhancement for DG. Excellent Performance: We validate GMoE's performance on all 8 large-scale datasets of Do-mainBed. Remarkably, GMoE trained with ERM achieves SOTA performance on 7 datasets in the train-validation setting and on 8 datasets in the leave-one-domain-out setting. Furthermore, the GMoE trained with DG algorithms achieves better performance than GMoE trained with ERM.

2. PRELIMINARIES

2.1 NOTATIONS Throughout this paper, a, a, A stand for a scalar, a column vector, a matrix, respectively. O(•) and ω(•) are asymptotic notations. We denote the training dataset, training distribution, test dataset, and test distribution as E tr , D tr , E te , and D te , respectively.

2.2. ATTRIBUTE FACTORIZATION

The attribute factorization (Wiles et al., 2021) is a realistic generative model under distribution shifts. Consider a joint distribution of the input x and corresponding attributes a 1 , • • • , a K (denoted as

