PGR A D : LEARNING PRINCIPAL GRADIENTS FOR DO-MAIN GENERALIZATION

Abstract

Machine learning models fail to perform when facing out-of-distribution (OOD) domains, a challenging task known as domain generalization (DG). In this work, we develop a novel DG training strategy, we call PGrad , to learn a robust gradient direction, improving models' generalization ability on unseen domains. The proposed gradient aggregates the principal directions of a sampled roll-out optimization trajectory that measures the training dynamics across all training domains. PGrad 's gradient design forces the DG training to ignore domain-dependent noise signals and updates all training domains with a robust direction covering main components of parameter dynamics. We further improve PGrad via bijection-based computational refinement and directional plus length-based calibrations. Our theoretical proof connects PGrad to the spectral analysis of Hessian in training neural networks. Experiments on DomainBed and WILDS benchmarks demonstrate that our approach effectively enables robust DG optimization and leads to smoothly decreased loss curves. Empirically, PGrad achieves competitive results across seven datasets, demonstrating its efficacy across both synthetic and real-world distributional shifts.

1. INTRODUCTION

Deep neural networks have shown remarkable generalization ability on test data following the same distribution as their training data. Yet, high-capacity models are incentivized to exploit any correlation in the training data that will lead to more accurate predictions. As a result, these models risk becoming overly reliant on "domain-specific" correlations that may harm model performance on test cases from out-of-distribution (OOD). A typical example is a camel-and-cows classification task (Beery et al., 2018; Shi et al., 2021) , where camel pictures in training are almost always shown in a desert environment while cow pictures mostly have green grassland backgrounds. A typical machine learning model trained on this dataset will perform worse than random guessing on those test pictures with cows in deserts or camels in pastures. The network has learned to use the background texture as one deciding factor when we want it to learn to recognize animal shapes. Unfortunately, the model overfits to specific traps that are highly predictive of some training domains but fail on OOD target domains. Recent domain generalization (DG) research efforts deal with such a challenge. They are concerned with how to learn a machine learning model that can generalize to an unseen test distribution when given multiple different but related training domains. 1Recent literature covers a wide spectrum of DG methods, including invariant representation learning, meta-learning, data augmentation, ensemble learning, and gradient manipulation (more details in Section 2.4). Despite the large body of recent DG literature, the authors of (Gulrajani & Lopez-Paz, 2021) showed that empirical risk minimization (ERM) provides a competitive baseline on many real-world DG benchmarks. ERM does not explicitly address distributional shifts during training. Instead, ERM calculates the gradient from each training domain and updates a model with the average gradient. However, one caveat of ERM is its average gradient-based model update will preserve domain-specific noise during optimization. This observation motivates the core design of our method.



In the rest of this paper, we use the terms "domain" and "distribution" interchangeably.1

