TOPOLOGY-AWARE ROBUST OPTIMIZATION FOR OUT-OF-DISTRIBUTION GENERALIZATION

Abstract

Out-of-distribution (OOD) generalization is a challenging machine learning problem yet highly desirable in many high-stake applications. Existing methods suffer from overly pessimistic modeling with low generalization confidence. As generalizing to arbitrary test distributions is impossible, we hypothesize that further structure on the topology of distributions is crucial in developing strong OOD resilience. To this end, we propose topology-aware robust optimization (TRO) that seamlessly integrates distributional topology in a principled optimization framework. More specifically, TRO solves two optimization objectives: (1) Topology Learning which explores data manifold to uncover the distributional topology; (2) Learning on Topology which exploits the topology to constrain robust optimization for tightlybounded generalization risks. We theoretically demonstrate the effectiveness of our approach, and empirically show that it significantly outperforms the state of the arts in a wide range of tasks including classification, regression, and semantic segmentation. Moreover, we empirically find the data-driven distributional topology is consistent with domain knowledge, enhancing the explainability of our approach.

1. INTRODUCTION

Recent years have witnessed a surge of applying machine learning (ML) in high-stake and safetycritical applications. Such applications pose an unprecedented out-of-distribution (OOD) generalization challenge: ML models are constantly exposed to unseen distributions that lie outside their training space. Despite well-documented success for interpolation, modern ML models (e.g., deep neural networks) are notoriously weak for extrapolation; a highly accurate model on average can fail catastrophically when presented with rare or unseen distributions (Arjovsky et al., 2019) . For example, a flood predictor, trained with data of all 89 major flood events in the U.S. from 2000 to 2020, would erroneously predict on event "Hurricane Ida" in 2021. Without addressing this challenge, it is unclear when and where a model can be applied and how much risk is associated with its use. A promising solution for out-of-distribution generalization is to conduct distributionally robust optimization (DRO) (Namkoong & Duchi, 2016; Staib & Jegelka, 2019; Levy et al., 2020) . DRO minimizes the worst-case expected risk over an uncertainty set of potential test distributions. The uncertainty set is typically formulated as a divergence ball surrounding the training distribution endowed with a certain distance metric such as f -divergence (Namkoong & Duchi, 2016) and Wasserstein distance (Shafieezadeh Abadeh et al., 2018) . Compared to empirical risk minimization (ERM) (Vapnik, 1998) that minimizes the average loss, DRO is more robust against distributional drifts from spurious correlations, adversarial attacks, subpopulations, or naturally-occurring variation (Robey et al., 2021) . However, it is non-trivial to build a realistic uncertainty set that truly approximates unseen distributions. On the one hand, to confer robustness against extensive distributional drifts, the uncertainty set has to be sufficiently large, which increases the risks of conferring implausible distributions, e.g., outliers, and thus yielding overly pessimistic models with low prediction confidence (Hu et al., 2018; Frogner et al., 2021) . On the other hand, the worst-case distributions are not necessarily the influential ones that are truly connected to unseen distributions; optimizing over worst-case rather than influential distributions would yield compromised OOD resilience. As generalizing to arbitrary test distributions is impossible, we hypothesize further structure on the topology of distributions is crucial in constructing a realistic uncertainty set. More specifically, we propose topology-aware robust optimization (TRO) by integrating two optimization objectives: (1) Topology learning: We model the data distributions as many discrete groups lying on a common low-dimensional manifold, where we can explore the distributional topology by either using physical priors or measuring multiscale Earth Mover's Distance (EMD) among distributions. (2) Learning on topology: The acquired distributional topology is then exploited to construct a realistic uncertainty set, where robust optimization is constrained to bound the generalization risk within a topology graph, rather than blindly generalizing to unseen distributions. Our contributions include: 1. A new principled optimization method that seamlessly integrates topological information to develop strong OOD resilience. 2. Theoretical analysis that proves our method enjoys fast convergence for both convex and non-convex loss functions while the generalization risk is tightly bounded. 3. Empirical results in a wide range of tasks including classification, regression, and semantic segmentation that demonstrate the superior performance of our method over SOTA. 4. Data-driven distributional topology that is consistent with domain knowledge and facilitates the explainability of our approach.

2. PROBLEM FORMULATION AND PRELIMINARY WORKS

The problem of out-of-distribution (OOD) generalization is defined by a pair of random variables (X, Y ) over instances x ∈ X ⊆ R d and corresponding labels y ∈ Y, following an unknown joint probability distribution P (X, Y ). The objective is to learn a predictor f ∈ F such that f (x) → y for any (x, y) ∼ P (X, Y ). Here F is a function class that is model-agnostic for a prediction task. However, unlike typical supervised learning, the OOD generalization is complicated since one cannot sample directly from P (X, Y ). Instead, it is assumed that we can only measure (X, Y ) under different environmental conditions e so that data is drawn from a set of groups E all such that (x, y) ∼ P e (X, Y ). For example, in flood prediction, these environmental conditions denote the latent factors (e.g., stressors, precipitation, terrain, etc) that underlie different flood events. Let E train ⊊ E all be a finite subset of training groups (distributions), given the loss function ℓ, an OOD-resilient model f can be learned by solving a minimax optimization: min f ∈F R(f ) := sup e∈Eall E (x,y)∼Pe(X,Y ) [ℓ(f (x), y)] . Intuitively, Eq. 1 aims to learn a model that minimizes the worst-case risk over the entire family E all . It is nontrivial since we do not have access to data from any unseen distributions E test = E all \E train . Empirical Risk Minimization (ERM). Typically, classic supervised learning employs ERM (Vapnik, 1998) to find a model f that minimizes the average risk under the training distribution P tr : min f ∈F {R(f ) := E (x,y)∼Ptr [ℓ(f (x), y)]}. Though proved to be effective in i.i.d. settings, models trained via ERM heavily rely on spurious correlations that do not always hold under distributional drifts (Arjovsky et al., 2019) . Distributionally Robust Optimization (DRO). To develop OOD resilience, DRO (Namkoong & Duchi, 2016) minimizes the worst-case risk over an uncertainty set Q by solving: min f ∈F {R(f ) := sup Q∈P(Ptr) E (x,y)∼Q [ℓ(f (x), y)]}. Here the uncertainty set Q approximates potential test distributions. It is usually formulated as a divergence ball with a radius of ρ surrounding the training distribution P (P tr ) = {Q : D (Q, P tr ) ≤ ρ} endowed with a certain distance metric D(•, •) such as f -divergence (Namkoong & Duchi, 2016) or Wasserstein distance (Shafieezadeh Abadeh et al., 2018) . To construct a realistic uncertainty set without being overly conservative, Group DRO is further developed to formulate the uncertainty set as the mixture of training groups (Hu et al., 2018; Sagawa et al., 2019) . Despite the well-documented success, existing DRO methods suffer from critical limitations. (1) To endow robustness against a wide range of potential test distributions, the radius of the divergence ball has to be sufficiently large with high risks of containing implausible distributions; optimizing



The source code and pre-trained models are available at: https://github.com/joffery/TRO.

