PGR A D : LEARNING PRINCIPAL GRADIENTS FOR DO-MAIN GENERALIZATION

Abstract

Machine learning models fail to perform when facing out-of-distribution (OOD) domains, a challenging task known as domain generalization (DG). In this work, we develop a novel DG training strategy, we call PGrad , to learn a robust gradient direction, improving models' generalization ability on unseen domains. The proposed gradient aggregates the principal directions of a sampled roll-out optimization trajectory that measures the training dynamics across all training domains. PGrad 's gradient design forces the DG training to ignore domain-dependent noise signals and updates all training domains with a robust direction covering main components of parameter dynamics. We further improve PGrad via bijection-based computational refinement and directional plus length-based calibrations. Our theoretical proof connects PGrad to the spectral analysis of Hessian in training neural networks. Experiments on DomainBed and WILDS benchmarks demonstrate that our approach effectively enables robust DG optimization and leads to smoothly decreased loss curves. Empirically, PGrad achieves competitive results across seven datasets, demonstrating its efficacy across both synthetic and real-world distributional shifts.

1. INTRODUCTION

Deep neural networks have shown remarkable generalization ability on test data following the same distribution as their training data. Yet, high-capacity models are incentivized to exploit any correlation in the training data that will lead to more accurate predictions. As a result, these models risk becoming overly reliant on "domain-specific" correlations that may harm model performance on test cases from out-of-distribution (OOD). A typical example is a camel-and-cows classification task (Beery et al., 2018; Shi et al., 2021) , where camel pictures in training are almost always shown in a desert environment while cow pictures mostly have green grassland backgrounds. A typical machine learning model trained on this dataset will perform worse than random guessing on those test pictures with cows in deserts or camels in pastures. The network has learned to use the background texture as one deciding factor when we want it to learn to recognize animal shapes. Unfortunately, the model overfits to specific traps that are highly predictive of some training domains but fail on OOD target domains. Recent domain generalization (DG) research efforts deal with such a challenge. They are concerned with how to learn a machine learning model that can generalize to an unseen test distribution when given multiple different but related training domains. 1Recent literature covers a wide spectrum of DG methods, including invariant representation learning, meta-learning, data augmentation, ensemble learning, and gradient manipulation (more details in Section 2.4). Despite the large body of recent DG literature, the authors of (Gulrajani & Lopez-Paz, 2021) showed that empirical risk minimization (ERM) provides a competitive baseline on many real-world DG benchmarks. ERM does not explicitly address distributional shifts during training. Instead, ERM calculates the gradient from each training domain and updates a model with the average gradient. However, one caveat of ERM is its average gradient-based model update will preserve domain-specific noise during optimization. This observation motivates the core design of our method. Θ ! " Θ # " Θ $ " Θ " 𝐷 ! 𝐷 " 𝐷 # Θ " Θ ! " Θ # " Θ $ " Principal Vectors' Construction Direction & Length Calibration Θ " Θ ! " Θ # " Θ $ " Θ " Θ ! " Θ # " Θ $ " Noise Suppression 𝛻 % Eval: We propose a novel training strategy that learns a robust gradient direction for DG optimization, and we call it PGrad . PGrad samples an optimization trajectory in high dimensional parameter space by updating the current model sequentially across training domains. It then constructs a local coordinate system to explain the parameter variations in the trajectory. Via singular value decomposition (SVD), we derive an aggregated vector that covers the main components of parameter dynamics and use it as a new gradient direction to update the target model. This novel vector -that we name the "principal gradient" -reduces domain-specific noise in the DG model update and prevents the model from overfitting to particular training domains. To decrease the computational complexity of SVD, we construct a bijection between the parameter space and a low-dimensional space through transpose mapping. Hence, the computational complexity of the PGrad relates to the number of sampled training domains and does not depend on the size of our model parameters. ℒ(Θ * ; ) 𝑇 $ Θ, 𝐷 ! , 𝐷 # , 𝐷 $ , , •• •• Θ "(! = Θ " -𝛾𝛻 % Train: This paper makes the following contributions: (1) PGrad places no explicit assumption on either the joint or the marginal distributions. (2) PGrad is model-agnostic and is scalable to various model architecture since its computational cost only relates to the number of training domains. (3) We theoretically show the connection between PGrad and Hessian approximation, and also prove that PGrad benefits the training efficiency via learning a gradient in a smaller subspace constructed from learning trajectory. (4) Our empirical results demonstrate the competitive performance of PGrad across seven datasets covering both synthetic and real-world distributional shifts.

2. METHOD

Domain generalization (Wang et al., 2021; Zhou et al., 2021)  In DG setup, any prior about T te , such as inputs or outputs, are unavailable in the training phase. Despite not considering domain discrepancies from training to testing, ERM is still a competitive method for domain generalization tasks (Gulrajani & Lopez-Paz, 2021) . ERM naively groups data



In the rest of this paper, we use the terms "domain" and "distribution" interchangeably.



Figure 1: Overview of our PGrad training strategy. With a current parameter Θ t , we first obtain a rollout trajectory Θ t → Θ t 1 → Θ t 2 → Θ t 3 by sequentially optimizing across all training domains D tr = {D i } 3 i=1 . Then PGrad updates Θ t by extracting the principal gradient direction ∇ p of the trajectory. A target model's generalization is evaluated on unseen (OOD) test domains T j .

assumes no access to instances from future unseen domains. In domain generalization, we are given a set of training domains D tr = {D i } n i=1 and test domains T te = {T j } m j=1 . Each domain D i (or T j ) is associated with a joint distribution P Di X ×Y (or P Tj X ×Y ), where X represents the input space and Y is the output space. Moreover, each training domain D i is characterized by a set of i.i.d samples {x i k , y i k }. For any two different domains sampled from either D tr or T te , their joint distribution varies P Di X ×Y = P Dj X ×Y , and most importantly, P Di X ×Y = P Tj X ×Y . We consider the prediction task from the input x ∈ X to the output y ∈ Y. Provided with a model family whose parameter space is Θ ⊂ R d and the loss function L : Θ × (X × Y) → R + , the goal is to find an optimal Θ * te on test domains so that: Θ * te = arg min Θ∈Θ E Tj ∼Tte E (x,y)∼P T j X ×Y L[Θ, (x, y)].

