TOPOLOGY-AWARE ROBUST OPTIMIZATION FOR OUT-OF-DISTRIBUTION GENERALIZATION

Abstract

Out-of-distribution (OOD) generalization is a challenging machine learning problem yet highly desirable in many high-stake applications. Existing methods suffer from overly pessimistic modeling with low generalization confidence. As generalizing to arbitrary test distributions is impossible, we hypothesize that further structure on the topology of distributions is crucial in developing strong OOD resilience. To this end, we propose topology-aware robust optimization (TRO) that seamlessly integrates distributional topology in a principled optimization framework. More specifically, TRO solves two optimization objectives: (1) Topology Learning which explores data manifold to uncover the distributional topology; (2) Learning on Topology which exploits the topology to constrain robust optimization for tightlybounded generalization risks. We theoretically demonstrate the effectiveness of our approach, and empirically show that it significantly outperforms the state of the arts in a wide range of tasks including classification, regression, and semantic segmentation. Moreover, we empirically find the data-driven distributional topology is consistent with domain knowledge, enhancing the explainability of our approach. Published as a conference paper at ICLR 2023 for implausible distributions would fundamentally damage the OOD resilience by yielding overlypessimistic models with low prediction confidence. (2) The worst-case groups are not necessarily the influential ones that are truly connected to unseen distributions; optimizing over worst-case rather than influential groups would yield compromised OOD resilience. We propose a new principled optimization method (TRO) to develop OOD resilience, which integrates topology and optimization via a two-phase scheme: Topology Learning and Learning on Topology. 3.

1. INTRODUCTION

Recent years have witnessed a surge of applying machine learning (ML) in high-stake and safetycritical applications. Such applications pose an unprecedented out-of-distribution (OOD) generalization challenge: ML models are constantly exposed to unseen distributions that lie outside their training space. Despite well-documented success for interpolation, modern ML models (e.g., deep neural networks) are notoriously weak for extrapolation; a highly accurate model on average can fail catastrophically when presented with rare or unseen distributions (Arjovsky et al., 2019) . For example, a flood predictor, trained with data of all 89 major flood events in the U.S. from 2000 to 2020, would erroneously predict on event "Hurricane Ida" in 2021. Without addressing this challenge, it is unclear when and where a model can be applied and how much risk is associated with its use. A promising solution for out-of-distribution generalization is to conduct distributionally robust optimization (DRO) (Namkoong & Duchi, 2016; Staib & Jegelka, 2019; Levy et al., 2020) . DRO minimizes the worst-case expected risk over an uncertainty set of potential test distributions. The uncertainty set is typically formulated as a divergence ball surrounding the training distribution endowed with a certain distance metric such as f -divergence (Namkoong & Duchi, 2016) and Wasserstein distance (Shafieezadeh Abadeh et al., 2018) . Compared to empirical risk minimization (ERM) (Vapnik, 1998) that minimizes the average loss, DRO is more robust against distributional drifts from spurious correlations, adversarial attacks, subpopulations, or naturally-occurring variation (Robey et al., 2021) . However, it is non-trivial to build a realistic uncertainty set that truly approximates unseen distributions. On the one hand, to confer robustness against extensive distributional drifts, the uncertainty set has to be sufficiently large, which increases the risks of conferring implausible distributions, e.g., outliers, and thus yielding overly pessimistic models with low prediction confidence (Hu et al., 2018; Frogner et al., 2021) . On the other hand, the worst-case distributions are not necessarily the influential ones that are truly connected to unseen distributions; optimizing over worst-case rather than influential distributions would yield compromised OOD resilience. As generalizing to arbitrary test distributions is impossible, we hypothesize further structure on the topology of distributions is crucial in constructing a realistic uncertainty set. More specifically, we propose topology-aware robust optimization (TRO) by integrating two optimization objectives: (1) Topology learning: We model the data distributions as many discrete groups lying on a common low-dimensional manifold, where we can explore the distributional topology by either using physical priors or measuring multiscale Earth Mover's Distance (EMD) among distributions. (2) Learning on topology: The acquired distributional topology is then exploited to construct a realistic uncertainty set, where robust optimization is constrained to bound the generalization risk within a topology graph, rather than blindly generalizing to unseen distributions. Our contributions include: 1. A new principled optimization method that seamlessly integrates topological information to develop strong OOD resilience. 2. Theoretical analysis that proves our method enjoys fast convergence for both convex and non-convex loss functions while the generalization risk is tightly bounded. 3. Empirical results in a wide range of tasks including classification, regression, and semantic segmentation that demonstrate the superior performance of our method over SOTA. 4. Data-driven distributional topology that is consistent with domain knowledge and facilitates the explainability of our approach.

2. PROBLEM FORMULATION AND PRELIMINARY WORKS

The problem of out-of-distribution (OOD) generalization is defined by a pair of random variables (X, Y ) over instances x ∈ X ⊆ R d and corresponding labels y ∈ Y, following an unknown joint probability distribution P (X, Y ). The objective is to learn a predictor f ∈ F such that f (x) → y for any (x, y) ∼ P (X, Y ). Here F is a function class that is model-agnostic for a prediction task. However, unlike typical supervised learning, the OOD generalization is complicated since one cannot sample directly from P (X, Y ). Instead, it is assumed that we can only measure (X, Y ) under different environmental conditions e so that data is drawn from a set of groups E all such that (x, y) ∼ P e (X, Y ). For example, in flood prediction, these environmental conditions denote the latent factors (e.g., stressors, precipitation, terrain, etc) that underlie different flood events. Let E train ⊊ E all be a finite subset of training groups (distributions), given the loss function ℓ, an OOD-resilient model f can be learned by solving a minimax optimization: min f ∈F R(f ) := sup e∈Eall E (x,y)∼Pe(X,Y ) [ℓ(f (x), y)] . (1) Intuitively, Eq. 1 aims to learn a model that minimizes the worst-case risk over the entire family E all . It is nontrivial since we do not have access to data from any unseen distributions E test = E all \E train . Empirical Risk Minimization (ERM). Typically, classic supervised learning employs ERM (Vapnik, 1998) to find a model f that minimizes the average risk under the training distribution P tr : min f ∈F {R(f ) := E (x,y)∼Ptr [ℓ(f (x), y)]}. Though proved to be effective in i.i.d. settings, models trained via ERM heavily rely on spurious correlations that do not always hold under distributional drifts (Arjovsky et al., 2019) . Distributionally Robust Optimization (DRO). To develop OOD resilience, DRO (Namkoong & Duchi, 2016) minimizes the worst-case risk over an uncertainty set Q by solving: min f ∈F {R(f ) := sup Q∈P(Ptr) E (x,y)∼Q [ℓ(f (x), y)]}. Here the uncertainty set Q approximates potential test distributions. It is usually formulated as a divergence ball with a radius of ρ surrounding the training distribution P (P tr ) = {Q : D (Q, P tr ) ≤ ρ} endowed with a certain distance metric D(•, •) such as f -divergence (Namkoong & Duchi, 2016) or Wasserstein distance (Shafieezadeh Abadeh et al., 2018) . To construct a realistic uncertainty set without being overly conservative, Group DRO is further developed to formulate the uncertainty set as the mixture of training groups (Hu et al., 2018; Sagawa et al., 2019) . Despite the well-documented success, existing DRO methods suffer from critical limitations. (1) To endow robustness against a wide range of potential test distributions, the radius of the divergence ball has to be sufficiently large with high risks of containing implausible distributions; optimizing  EVENT 2 EVENT 3 EVENT N Label < l a t e x i t s h a 1 _ b a s e 6 4 = " W Z 4 E L 4 d D p W z y j 5 / c r h T q c P S 1 R z Q = " > A A A B 6 n i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 L B b B U 0 i k q B e h 4 M V j R f s B b S i b 7 a Z d u t m E 3 Y l Q S n + C F w + K e P U X e f P f u G 1 z 0 N Y H A 4 / 3 Z p i Z F 6 Z S G P S 8 b 6 e w t r 6 x u V X c L u 3 s 7 u 0 f l A + P m i b J N O M N l s h E t 0 N q u B S K N 1 C g 5 O 1 U c x q H k r f C 0 e 3 M b z 1 x b U S i H n G c 8 i C m A y U i w S h a 6 Q F v v F 6 5 4 r n e H G S V + D m p Q I 5 6 r / z V 7 S c s i 7 l C J q k x H d 9 L M Z h Q j Y J J P i 1 1 M 8 N T y k Z 0 w D u W K h p z E 0 z m p 0 7 J m V X 6 J E q 0 L Y V k r v 6 e m N D Y m H E c 2 s 6 Y 4 t A s e z P x P 6 + T Y X Q d T I R K M + S K L R Z F m S S Y k N n f p C 8 0 Z y j H l l C m h b 2 V s C H V l K F N p 2 R D 8 J d f X i X N C 9 e / d K v 3 1 U r N z e M o w g m c w j n 4 c A U 1 u I M 6 N I D B A J 7 h F d 4 c 6 b w 4 7 8 7 H o r X g 5 D P H 8 A f O 5 w / N R Y 1 u < / l a t e x i t > t = 0 < l a t e x i t s h a 1 _ b a s e 6 4 = " o 6 X 1 J o m Q k 5 x B v K E i p R N N S U 2 g Q N k = " > A A A B 7 H i c b V B N S 8 N A E J 3 4 W e t X 1 a O X Y B E 8 h a Q U 9 S I U v H i s Y N p C G 8 t m u 2 m X b j Z h d y K U 0 N / g x Y M i X v 1 B 3 v w 3 b t s c t P X B w O O 9 G W b m h a n g G l 3 3 2 1 p b 3 9 j c 2 i 7 t l H f 3 9 g 8 O K 0 f H L Z 1 k i j K f J i J R n Z B o J r h k P n I U r J M q R u J Q s H Y 4 v p 3 5 7 S e m N E / k A 0 5 S F s R k K H n E K U E j + X h T e 3 T 7 l a r r u H P Y q 8 Q r S B U K N P u V r 9 4 g o V n M J F J B t O 5 6 b o p B T h R y K t i 0 3 M s 0 S w k d k y H r G i p J z H S Q z 4 + d 2 u d G G d h R o k x J t O f q 7 4 m c x F p P 4 t B 0 x g R H e t m b i f 9 5 3 Q y j 6 y D n M s 2 Q S b p Y F G X C x s S e f W 4 P u G I U x c Q Q Q h U 3 t 9 p 0 R B S h a P I p m x C 8 5 Z d X S a v m e J d O / b 5 e b T h F H C U 4 h T O 4 A A + u o A F 3 0 A Q f K H B 4 h l d 4 s 6 T 1 Y r 1 b H 4 v W N a u Y O Y E / s D 5 / A P O K j h I = < / l a t e x i t > t = 2 0 < l a t e x i t s h a 1 _ b a s e 6 4 = " L 4 o h d d Y u L L E e z D w B I 0 N C 0 y x e i 9 4 = " > A A A B 7 H i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 B I v g K S S l q B e h 4 E X w U s G 0 h T a W z X b T L t 1 s w u 5 E K K W / w Y s H R b z 6 g 7 z 5 b 9 y 2 O W j r g 4 H H e z P M z A t T w T W 6 7 r d V W F v f 2 N w q b p d 2 d v f 2 D 8 q H R 0 2 d Z I o y n y Y i U e 2 Q a C a 4 Z D 5 y F K y d K k b i U L B W O L q Z + a 0 n p j R P 5 A O O U x b E Z C B 5 x C l B I / l 4 X X 2 8 6 5 U r r u P O Y a 8 S L y c V y N H o l b + 6 / Y R m M Z N I B d G 6 4 7 k p B h O i k F P B p q V u p l l K 6 I g M W M d Q S W K m g 8 n 8 2 K l 9 Z p S + H S X K l E R 7 r v 6 e m J B Y 6 3 E c m s 6 Y 4 F A v e z P x P 6 + T Y X Q V T L h M M 2 S S L h Z F m b A x s W e f 2 3 2 u G E U x N o R Q x c 2 t N h 0 S R S i a f E o m B G / 5 5 V X S r D r e h V O 7 r 1 X q T h 5 H E U 7 g F M 7 B g 0 u o w y 0 0 w A c K H J 7 h F d 4 s a b 1 Y 7 9 b H o r V g 5 T P H 8 A f W 5 w 8 c h Y 4 t < / l a t e x i t > t = 2 K < l a t e x i t s h a 1 _ b a s e 6 4 = " Y v 4 T D / e Y k g 1 i g U A J 9 M 3 W U o U l N 5 s = " > A A A B 8 3 i c b V D L S g N B E O z 1 G e M U z D G / 8 X M R p h j x m 8 0 V h J g k m Z B o A 6 Q v F G c q x I Z Q p Y W 4 l b E g V Z W h i K p s Q 3 M W X l 0 n r w n a v 7 N p 9 r V q 3 i z h K c A w n c A Y u X E M d 7 q A B T W C Q w j O 8 w p u V W S / W u / U U z D G / 8 X M R p h j x m 8 0 V h J g k m Z B o A 6 Q v F G c q x I Z Q p Y W 4 l b E g V Z W h i K p s Q 3 M W X l 0 n r w n a v 7 N p 9 r V q 3 i z h K c A w n c A Y u X E M d 7 q A B T W C Q w j O 8 w p u V W S / W u / U U z D G / 8 X M R p h j x m 8 0 V h J g k m Z B o A 6 Q v F G c q x I Z Q p Y W 4 l b E g V Z W h i K p s Q 3 M W X l 0 n r w n a v 7 N p 9 r V q 3 i z h K c A w n c A Y u X E M d 7 q A B T W C Q w j O 8 w p u V W S / W u / U U z D G / 8 X M R p h j x m 8 0 V h J g k m Z B o A 6 Q v F G c q x I Z Q p Y W 4 l b E g V Z W h i K p s Q 3 M W X l 0 n r w n a v 7 N p 9 r V q 3 i z h K c A w n c A Y u X E M d 7 q A B T W C Q w j O 8 w p u V W S / W u / U U z D G / 8 X M R p h j x m 8 0 V h J g k m Z B o A 6 Q v F G c q x I Z Q p Y W 4 l b E g V Z W h i K p s Q 3 M W X l 0 n r w n a v 7 N p 9 r V q 3 i z h K c A w n c A Y u X E M d 7 q A B T W C Q w j O 8 w p u V W S / W u / U x b 1 2 x i p k j + A P r 8 w e e F J F b < / l a t e x i t > 1/3 < l a t e x i t s h a 1 _ b a s e 6 4 = " Y v 4 T D / e Y k g 1 i g U A J 9 M 3 W U o U l N 5 s = " > A A A B 8 3 i c b V D L S g N B E O z 1 G e M r 6 t H L Y B A 8 r b s a 1 G P A i 8 c I 5 g H Z J c x O Z p M h s w 9 m e s W w 5 D e 8 e F D E q z / j z b 9 x k u x B E w s a i q p u u r u C V A q N j v N t r a y u r W 9 s l r b K 2 z u 7 e / u V g 8 O W T j L F e J M l M l G d g G o u R c y b K F D y T q o 4 j Q L J 2 8 H o d u q 3 H 7 n S I o k f c J x y P 6 K D W I S C U T S S 5 y F / w i D M 3 f P L S a 9 S d W x n B r J M 3 I J U o U C j V / n y + g n L I h 4 j k 1 T r r u u k 6 O d U o W C S T 8 p e p n l K 2 Y g O e N f Q m E Z c + / n s 5 g k 5 N U q f h I k y F S O Z q b 8 n c h p p P Y 4 C 0 x l R H O p F b y r + 5 3 U z D G / 8 X M R p h j x m 8 0 V h J g k m Z B o A 6 Q v F G c q x I Z Q p Y W 4 l b E g V Z W h i K p s Q 3 M W X l 0 n r w n a v 7 N p 9 r V q 3 i z h K c A w n c A Y u X E M d 7 q A B T W C Q w j O 8 w p u V W S / W u / U x b 1 2 x i p k j + A P r 8 w e e F J F b < / l a t e x i t > 1/3 … … EVENT 1 EVENT N-1 Multiscale Diffusion Density Estimates Co Physical Graph Physical-based Topology TRO Prior Prior … … < l a t e x i t s h a 1 _ b a s e 6 4 = " F Q W 6 6 T v d z 8 O p u l + N a Y F h J e C I p C Y = " > A A A C F H i c b Z D L S s N A F I Y n 9 V b r L e r S z W A R K k p J p K j L g h u X F e x F m h A m k 0 k 7 d C Z J Z y Z C C X 0 I N 7 6 K G x e K u H X h z r d x k n a h 1 R 8 G P v 5 z D n P O 7 y e M S m V Z X 0 Z p a X l l d a 2 8 X t n Y 3 N r e M X f 3 O j J O B S Z t H L N Y 9 H w k C a M R a S u q G O k l g i D u M 9 L 1 R 1 d 5 v X t P h K R x d K s m C X E 5 G k Q 0 p B g p b X n m i Z M 5 4 3 G K A g c H s Z J O g V N H U u 5 w p I a + n 7 W m H q / 1 T u H d s W d W r b p V C P 4 F e w 5 V M F f L M z + d I M Y p J 5 H C D E n Z t 6 1 E u R k S i m J G p h U n l S R B e I Q G p K 8 x Q p x I N y u O m s I j 7 Q Q w j I V + k Y K F + 3 M i Q 1 z K C f d 1 Z 7 6 p X K z l 5 n + 1 f q r C S z e j U Z I q E u H Z R 2 H K o I p h n h A M q C B Y s Y k G h A X V u 0 I 8 R A J h p X O s 6 B D s x Z P / Q u e s b p / X G z e N a r M x j 6 M M D s A h q A E b X I A m u A Y t 0 A Y Y P I A n 8 A J e j U f j 2 X g z 3 m e t J W M + s w 9 + y f j 4 B i c 7 n s 0 = < / l a t e x i t > { • • • } ⇠ Pm(X, Y ) < l a t e x i t s h a 1 _ b a s e 6 4 = " v F g a n P Q 5 g h 8 C W 2 + a K x S p U Z M f r X M = " > A A A C F H i c b Z D L S s N A F I Y n 9 V b r L e r S z W A R K k p J S l G X B T c u K 9 i L N C F M J p N 2 6 O T S m Y l Q Q h 7 C j a / i x o U i b l 2 4 8 2 2 c p F 1 o 6 w 8 D H / 8 5 h z n n d 2 N G h T S M b 6 2 0 s r q 2 v l H e r G x t 7 + z u 6 f s H X R E l H J M O j l j E + y 4 S h N G Q d C S V j P R j T l D g M t J z x 9 d 5 v f d A u K B R e C e n M b E D N A y p T z G S y n L 0 M y u 1 J p M E e R b 2 I i m s A j N L 0 M A K k B y 5 b t r O n E a t f w 7 v T x 2 9 a t S N Q n A Z z D l U w V x t R / + y v A g n A Q k l Z k i I g W n E 0 k 4 R l x Q z k l W s R J A Y 4 T E a k o H C E A V E 2 G l x V A Z P l O N B P + L q h R I W 7 u + J F A V C T A N X d e a b i s V a b v 5 X G y T S v 7 J T G s a J J C G e f e Q n D M o I 5 g l B j 3 K C J Z s q Q J h T t S v E I 8 Q R l i r H i g r B X D x 5 G b q N u n l R b 9 4 2 q 6 3 m P I 4 y O A L H o A Z M c A l a 4 A a 0 Q Q d g 8 A i e w S t 4 0 5 6 0 F + 1 d + 5 i 1 l r T 5 z C H 4 I + 3 z B 8 x e n p I = < / l a t e x i t > { • • • } ⇠ P2(X, Y ) < l a t e x i t s h a 1 _ b a s e 6 4 = " n 8 A 4 K C S x A y 7 t / y K 8 X u y o q P f g Z M w = " > A A A C F H i c b Z D L S s N A F I Y n 9 V b r L e r S z W A R K k p J p K j L g h u X F e x F m h A m k 0 k 7 d H L p z E Q o I Q / h x l d x 4 0 I R t y 7 c + T Z O 0 i y 0 + s P A x 3 / O Y c 7 5 3 Z h R I Q 3 j S 6 s s L a + s r l X X a x u b W 9 s 7 + u 5 e T 0 Q J x 6 S L I x b x g Y s E Y T Q k X U k l I 4 O Y E x S 4 j P T d y V V e 7 9 8 T L m g U 3 s p Z T O w A j U L q U 4 y k s h z 9 x E q t 6 T R B n o W 9 S A q r w M w S N L A C J M e u m 3 Y y x 2 w M T u H d s a P X j a Z R C P 4 F s 4 Q 6 K N V x 9 E / L i 3 A S k F B i h o Q Y m k Y s 7 R R x S T E j W c 1 K B I k R n q A R G S o M U U C E n R Z H Z f B I O R 7 0 I 6 5 e K G H h / p x I U S D E L H B V Z 7 6 p W K z l 5 n + 1 Y S L 9 S z u l Y Z x I E u L 5 R 3 7 C o I x g n h D 0 K C d Y s p k C h D l V u 0 I 8 R h x h q X K s q R D M x Z P / Q u + s a Z 4 3 W z e t e r t V x l E F B + A Q N I A J L k A b X I M O 6 A I M H s A T e A G v 2 q P 2 r L 1 p 7 / P W i l b O 7 I N f 0 j 6 + A c r U n p E = < / l a t e x i t > { • • • } ⇠ P1(X, Y ) … < l a t e x i t s h a 1 _ b a s e 6 4 = " F Q W 6 6 T v d z 8 O p u l + N a Y F h J e C I p C Y = " > A A A C F H i c b Z D L S s N A F I Y n 9 V b r L e r S z W A R K k p J p K j L g h u X F e x F m h A m k 0 k 7 d C Z J Z y Z C C X 0 I N 7 6 K G x e K u H X h z r d x k n a h 1 R 8 G P v 5 z D n P O 7 y e M S m V Z X 0 Z p a X l l d a 2 8 X t n Y 3 N r e M X f 3 O j J O B S Z t H L N Y 9 H w k C a M R a S u q G O k l g i D u M 9 L 1 R 1 d 5 v X t P h K R x d K s m C X E 5 G k Q 0 p B g p b X n m i Z M 5 4 3 G K A g c H s Z J O g V N H U u 5 w p I a + n 7 W m H q / 1 T u H d s W d W r b p V C P 4 F e w 5 V M F f L M z + d I M Y p J 5 H C D E n Z t 6 1 E u R k S i m J G p h U n l S R B e I Q G p K 8 x Q p x I N y u O m s I j 7 Q Q w j I V + k Y K F + 3 M i Q 1 z K C f d 1 Z 7 6 p X K z l 5 n + 1 f q r C S z e j U Z I q E u H Z R 2 H K o I p h n h A M q C B Y s Y k G h A X V u 0 I 8 R A J h p X O s 6 B D s x Z P / Q u e s b p / X G z e N a r M x j 6 M M D s A h q A E b X I A m u A Y t 0 A Y Y P I A n 8 A J e j U f j 2 X g z 3 m e t J W M + s w 9 + y f j 4 B i c 7 n s 0 = < / l a t e x i t > { ••• }⇠ Pm(X, Y ) < l a t e x i t s h a 1 _ b a s e 6 4 = " v F g a n P Q 5 g h 8 C W 2 + a K x S p U Z M f r X M = " > A A A C F H i c b Z D L S s N A F I Y n 9 V b r L e r S z W A R K k p J S l G X B T c u K 9 i L N C F M J p N 2 6 O T S m Y l Q Q h 7 C j a / i x o U i b l 2 4 8 2 2 c p F 1 o 6 w 8 D H / 8 5 h z n n d 2 N G h T S M b 6 2 0 s r q 2 v l H e r G x t 7 + z u 6 f s H X R E l H J M O j l j E + y 4 S h N G Q d C S V j P R j T l D g M t J z x 9 d 5 v f d A u K B R e C e n M b E D N A y p T z G S y n L 0 M y u 1 J p M E e R b 2 I i m s A j N L 0 M A K k B y 5 b t r O n E a t f w 7 v T x 2 9 a t S N Q n A Z z D l U w V x t R / + y v A g n A Q k l Z k i I g W n E 0 k 4 R l x Q z k l W s R J A Y 4 T E a k o H C E A V E 2 G l x V A Z P l O N B P + L q h R I W 7 u + J F A V C T A N X d e a b i s V a b v 5 X G y T S v 7 J T G s a J J C G e f e Q n D M o I 5 g l B j 3 K C J Z s q Q J h T t S v E I 8 Q R l i r H i g r B X D x 5 G b q N u n l R b 9 4 2 q 6 3 m P I 4 y O A L H o A Z M c A l a 4 A a 0 Q Q d g 8 A i e w S t 4 0 5 6 0 F + 1 d + 5 i 1 l r T 5 z C H 4 I + 3 z B 8 x e n p I = < / l a t e x i t > We model the data groups E all as many discrete distributions lying on a common low-dimensional manifold in a high-dimensional data measurement space. In such case their structure, i.e. distributional topology, can be naturally captured by a graph G = (V, E), where the entities V = ∪ e∈Eall X e symbolize the groups and the edges E represent interactions among groups. The topology graph is constructed by: (1) identifying entity: we assume the entities are defined by the given group identities; and (2) uncovering interactions: we consider two scenarios to measure the connectivity between discrete distributions as illustrated in Fig 1 . { ••• }⇠ P2(X, Y ) < l a t e x i t s h a 1 _ b a s e 6 4 = " n 8 A 4 K C S x A y 7 t / y K 8 X u y o q P f g Z M w = " > A A A C F H i c b Z D L S s N A F I Y n 9 V b r L e r S z W A R K k p J p K j L g h u X F e x F m h A m k 0 k 7 d H L p z E Q o I Q / h x l d x 4 0 I R t y 7 c + T Z O 0 i y 0 + s P A x 3 / O Y c 7 5 3 Z h R I Q 3 j S 6 s s L a + s r l X X a x u b W 9 s 7 + u 5 e T 0 Q J x 6 S L I x b x g Y s E Y T Q k X U k l I 4 O Y E x S 4 j P T d y V V e 7 9 8 T L m g U 3 s p Z T O w A j U L q U 4 y k s h z 9 x E q t 6 T R B n o W 9 S A q r w M w S N L A C J M e u m 3 Y y x 2 w M T u H d s a P X j a Z R C P 4 F s 4 Q 6 K N V x 9 E / L i 3 A S k F B i h o Q Y m k Y s 7 R R x S T E j W c 1 K B I k R n q A R G S o M U U C E n R Z H Z f B I O R 7 0 I 6 5 e K G H h / p x I U S D E L H B V Z 7 6 p W K z l 5 n + 1 Y S L 9 S z u l Y Z x I E u L 5 R 3 7 C o I x g n h D 0 K C d Y s p k C h D l V u 0 I 8 R h x h q X K s q R D M x Z P / Q u + s a Z 4 3 W z e t e r t V x l E F B + A Q N I A J L k A b X I M O 6 A I M H s A T e A G v 2 q P 2 r L 1 p 7 / P W i l b O 7 I N f 0 j 6 + A c r U n p E = < / l a t e x i t > { ••• }⇠ P1(X, Y ) … Diffusion EMD EVENT 1 EVENT 2 EVENT N-1 Physical-based distributional topology. In the scenario where the distributional adjacency information is available, we can instantly acquire the topology G physic by simply imposing the predefined neighborhood information. For example, to capture the similarity of weather events in the U.S., one can construct a graph where each state realizes an entity, and the physical adjacency between two states results in an edge (see Fig. 1 ). In this case, G physic functions as a physical prior to constrain the robust optimization introduced in Sec. 3.2. We empirically find G physic yields an improvement of 9.56% over the state of the art regarding OOD generalization reported in Sec. 5.1. Data-driven distributional topology. In the absence of G physic , we propose a data-driven approach to learn the topology G data from training data. Specifically, we embed the individual groups onto a shared data graph based on an affinity matrix of the combined data. Inspired by Leeb & Coifman (2016) , such a data graph can be viewed as a discretization of an underlying Riemann closed manifold. By simulating a time-dependent diffusion process over the graph, we will obtain density estimates at multiple scales for each group, which will be used to calculate ℓ 1 distances between two groups. Such multiscale ℓ 1 distance has been proved to be topologically equivalent to the Earth Mover's Distance (EMD) on the manifold geodesic, but cutting down the computational complexity from O m 2 n 3 to Õ(mn) between m distributions over n data points (Tong et al., 2021) . We obtain the data-driven topology through three steps: (1) Data graph construction: we construct a data graph through an affinity matrix K of the combined data. K can be implemented through kernel functions (e.g., RBF kernel) which capture the similarity of data. Instead of calculating the similarity between raw data, we calculate the similarity between features extracted from an ERM-trained model as it captures spurious correlations which preserve group identity (Creager et al., 2021) . Specifically, we define the affinity matrix as: K i,j = exp -∥f (x i ) -f (x j )∥ 2 /σ 2 , where σ 2 is the kernel scale. (2) Multiscale diffusion density estimation: to simulate the diffusion process over the graph, we obtain a Markov diffusion operator P from K. Following Coifman & Lafon (2006) , we normalize the affinity matrix: M = Q -1 KQ -1 , where Q is a diagonal matrix and Q i,i = j K i,j . The diffusion operator is defined as P = D -1 M, where D is a diagonal matrix and D i,i = j M i,j . The operator P will be used to approximate the multiscale density estimates µ e for each data group X e : µ t e = 1 ne P t 1 Xe , where t is the diffusion time, P t denotes the t-th power of P, and 1 Xe is the indicator function for group e. Intuitively, P t i,j sums the probabilities of all possible paths of length t between x i and x j . By taking multiple powers of P, µ e reveals the topological structure of X e at multiple scales. (3) Diffusion EMD measurement: we follow Tong et al. ( 2021) to measure the geodesic distance W α,K (X e , X e ′ ) between X e and X e ′ by aggregating the ℓ 1 distances between the multiscale density estimates: W α,K (X e , X e ′ ) = K k=0 ∥T α,k (X e ) -T α,k (X e ′ )∥ 1 , where α is used to balance long-and short-range distances, K is the maximum scale, and T α,k (X e ) =      2 -(K-k-1)α µ (2 k+1 ) e -µ (2 k ) e , k < K µ (2 K ) e , k = K Although G data is computationally more expensive than G physic , our experimental results in Sec. 5.2 indicate that optimizing with G data can yield improved OOD resilience. Besides, the ablation study in Sec. 5.3 also indicates that G data is consistent with domain knowledge and enhances the explainability of TRO. Last but not least, the data-driven method is fully differentiable, making it amenable to jointly conducting topology learning and learning on topology in an end-to-end manner. We leave this as future work.

3.2. LEARNING ON TOPOLOGY: EXPLOIT TOPOLOGY FOR ROBUST OPTIMIZATION

Algorithm 1: TRO Algorithm Input: Data of E train , Step sizes η θ and η q Output: Learned model f Topology Learning: if G physic exists then G ← G physic else Obtain the affinity matrix K from data Q ← Diag j K ij M ← Q -1 KQ -1 D ← Diag j M ij P ← D -1 M Obtain G data via Eq. 3 G ← G data end Learning on Topology: Calculate topological prior p from G while not converged do Sample (x, y) ∼ P e (X, Y ) ∀e ∈ E train Calculate R(f, q) via Eq. 5 Update θ and q via Eq. 6 end Next, we propose a principled method that integrates distributional topology to develop TRO. The key challenge is how to leverage G to construct a uncertainty set which can approximate unseen distributions with bounded generalization risk. Our main idea is to assess the group centrality of training distributions. Graph centrality is widely used in social network analysis (Newman, 2005) to measure how much information is propagated through each entity. Here we introduce group centrality to identify influential groups that are truly connected to unseen distributions, which can be calculated using graph measurements (Tian et al., 2019) such as degree, betweenness, and closeness. More specifically, we first calculate the centrality of each entity in G as a topological prior p to identify influential groups. Then, we construct the uncertainty set as an arbitrary mixture of training groups Q := { e∈Etrain q e P e | q ∈ ∆ m } where q e denotes the weight of group e, P e is the distribution of group e, and ∆ m is a (m -1)-dimensional probability simplex. Finally, we use the prior p to constrain the uncertainty set Q by solving the minimax optimization problem as: min f ∈F {R(f, q) := max q∈∆m e∈Etrain q e E (x,y)∼Pe(X,Y ) [ℓ(f (x), y)]}, s.t. D(q∥p) ≤ τ. Intuitively, groups with high training loss and centrality will be assigned with large weights; this can tightly bound the OOD generalization risk within a topological graph. D is an arbitrary distributional distance metric. We use ℓ 2 distance to implement D due to its strong convexity and simplicity. However, solving Eq. 4 often leads to a non-convex problem, wherein methods such as stochastic gradient descent (SGD) cannot guarantee constraint satisfaction (Robey et al., 2021) . To address this issue, we leverage Karush-Kuhn-Tucker conditions (Boyd et al., 2004) and introduce a Lagrange multiplier to convert the constrained problem into its unconstrained counterpart: min f ∈F {R(f, q) := max q∈∆m e∈Etrain q e E (x,y)∼Pe(X,Y ) [ℓ(f (x), y)] -λD(q∥p)}, where λ is the dual variable. Let θ ∈ Θ be the model parameters of f , we can solve the primal-dual problem effectively by alternatively updating: θ t+1 = θ t -η t θ ∇ θ R(f, q), q t+1 = P ∆m (q t + η t q ∇ q R(f, q)), where η t θ (η t q ) is gradient descent (ascent) step size. P ∆m (q) projects q onto simplex ∆ m for regularization. The overall algorithm of TRO is shown in Alg. 1. In Sec. 4, we show TRO enjoys fast convergence for both convex and non-convex loss functions, while the generalization risk is tightly bounded with topological constraints. We empirically demonstrate TRO achieves strong OOD resilience by striking a good balance between the worst-case and influential groups (see Sec. 5.2). Calculation of group centrality. We use betweenness centrality to measure the centrality of groups. Betweenness centrality measures how often an entity is on the shortest path between two other entities in the topology. Freeman (1977) reveals that entities with higher betweenness centrality would have more control over the topology as more information will pass through them. For physical-based topology G physic , we define the centrality of group e by computing the fraction of shortest paths that pass through it: c physic e = s∈Etrain,t∈Etest σ(s,t|e) σ(s,t) , where σ(s, t) is the number of shortest paths between groups s and t in the graph ((s, t)-paths), and σ(s, t | e) is the number of (s, t)-paths that go through group e. Intuitively, c physic e measures how much information is propagated through e from the start (training) to the end (test). For data-driven topology G data , the underlying assumption is that training groups with high centrality also exert strong influence on unseen groups. Instead of sampling group pairs from two separate sets, we sample (s, t) from E train . The centrality is modified as: c data e = s,t∈Etrain σ(s,t|e) σ(s,t) . We use softmax function to normalize c e and the prior probability for group e ∈ E train is defined as: p e = exp(c e )/ e∈Etrain exp(c e ).

4.1. CONVERGENCE ANALYSIS

In this section, we show that by choosing appropriate step sizes η t θ and η t q , TRO yields fast convergence rates for both convex and non-convex loss functions. We first state the assumptions of the theorems. Next, we give the convergence rate for convex loss functions in Theorem 1 and the convergence rate for non-convex loss functions in Theorem 2. Definition 1. (Lipschitz continuity) A mapping f : X → R m is G-Lipschitz continuous if for any x, y ∈ X , ∥f (x) -f (y)∥ ≤ G∥x -y∥. Definition 2. (Smoothness) A function f : X → R is L-smooth if it is differentiable on X and the gradient ∇f is L-Lipschitz continuous, i.e., ∥∇f (x) -∇f (y)∥ ≤ L∥x -y∥ for all x, y ∈ X . Assumption 1. We make the following assumptions throughout the paper: Given θ, the loss function ℓ(f θ (x), y) is G-Lipschitz continuous and L-smooth with respect to x. Convex Loss. The expected number of stochastic gradient computations is utilized to estimate the convergence rate. To reach a duality gap of ϵ (Nemirovski et al., 2009) , the optimal rate of convergence for solving the stochastic min-max problems is O 1/ϵ 2 if it is convex-concave. The duality gap of the pair ( θ, q) is defined as max q∈∆m R( θ, q)min θ∈Θ R(θ, q). In the case of strong duality, ( θ, q) is optimal iif the duality gap is zero. We show TRO achieves the optimal O 1/ϵ 2 rate. Theorem 1. Consider the dual problem in Eq. 5 when the loss function is convex and Assumption 1 holds. Let Θ bounded by R Θ , E ∥∇ θ R(θ, q)∥ 2 2 ≤ Ĝ2 θ , and E ∥∇ q R(θ, q)∥ 2 2 ≤ Ĝ2 q . With step sizes η θ = 2R Θ / Ĝθ √ T and η q = 2/ Ĝq √ T , the output of TRO satisfies: E max q∈∆m R (θ T , q) -min θ∈Θ R (θ, q T ) ≤ 3R Θ Ĝθ + 3 Ĝq √ T . Theorem 1 shows that our method requires T = O 1/ϵ 2 iterations to reach a duality gap within ϵ. To derive the convergence rate for non-convex functions., we define ϵ-stationary points as follows: Definition 3. (ϵ-stationary point) For a differentiable function f : X → R, a point x ∈ X is said to be first-order ϵ-stationary if ∥∇f (x)∥ ≤ ϵ. Nonconvex Loss. The loss function ℓ(f θ (x), y) can be nonconvex and as a result, R(θ, q) is nonconvex in θ. Following Collins et al. (2020) , we define ( θ, q) is an (ϵ, δ)-stationary point of R if: ∇ θ R( θ, q) 2 ≤ ϵ and R( θ, q) ≥ max q∈∆m R( θ, q)δ, where ϵ, δ > 0. Theorem 2. If Assumption 1 holds and the loss function is bounded by B and is M-smooth, the output of Alg. 1 satisfies: E ∥∇ θ R (θ T , q T )∥ 2 2 ≤ R θ 0 , q 0 + B T η θ + 2η q √ nB Ĝq η θ + η θ M Ĝ2 θ 2 , E [R (θ T , q T )] ≥ max q∈∆m {E [R (θ T , q)]} - 1 η q T - η q Ĝ2 q 2 . ( ) Theorem 2 shows that our method converges in expectation to an (ϵ, δ)-stationary point of R in O(1/ϵ 4 ) stochastic gradient evaluations.

4.2. GENERALIZATION BOUNDS

In this section, we provide learning guarantees for TRO. Compared to DRO, TRO achieves a lower upper bound on the generalization risks over unseen distributions with the topological constraint. Let H denote the family of losses associated with a hypothesis set F : Let the distribution of some test group be P . The learning guarantee for P is shown in Theorem 3. Theorem 3. Assume the loss function ℓ is bounded by B. For any ϵ ≥ 0 and δ > 0, with probability at least 1δ, the following inequality holds for all f ∈ F : H = {(x, y) → ℓ(f (x), y) : f ∈ F }, R P (f ) ≤ R PΛ (f, q) + 2R n (H, Λ) + BD (P ∥P Λ ) + B 1 2m log |Λ| δ . The upper bound of the generalization risk on P is mainly determined by its distance to P Λ : D (P ∥P Λ ). With the topological prior, risks on P can be tightly bounded by minimizing D (P ∥P Λ ), as long as P falls into the convex hull of training groups. We empirically verify the effectiveness of the topological prior in minimizing the generalization risks over unseen distributions (see Sec. 5). DRO assigns the highest weight to the worst-case group "1" which is the furthest group to the test groups while TRO focuses on the influential groups "2", "5", and "6", which are truly connected to test groups. Results. The results of DG-15 and DG-60 are summarized in Tab. 1. In both datasets, our method yields the highest accuracy. For DG-15, we show the detailed results of all groups in Fig. 8 . We visualize the decision boundary of DG-15 and DG-60 in Appendix 7.3.

5. EXPERIMENTS

Ablations study. TRO significantly improves the generalization performance by discovering influential groups. To investigate the reason why TRO outperforms DRO, we show group weights q of DRO and TRO on DG-15 in Fig. 3 . DRO assigns the highest weight to group "1" which is the furthest group to test groups. Instead, TRO prioritizes influential groups "2", "5", and "6" which are truly connected to the test ones, yielding improved performance on unseen distributions.

5.2. REGRESSION

Datasets. TPT-48 (Vose et al., 2014) contains the monthly average temperature for the 48 contiguous states in the US from 2008 to 2019. We focus on the regression task to predict the next 6 months' temperature based on the previous first 6 months' temperature. We consider two generalization tasks: (1) E(24) → W(24): we use the 24 eastern states as training groups and the 24 western states as test groups; (2) N(24) → S(24): we use the 24 northern states as training groups and the 24 southern states as test groups. Test groups one hop away from the closest training group are defined as Hop-1 test groups, those two hops away are Hop-2 test groups, and the remaining groups are Hop-3 test groups. The visualization of N(24) → S(24) on TPT-48 is shown in Fig. 4 (left). Results. We show the results of TPT-48 in Tab. 2. TRO yields the lowest average MSE on both tasks. We also report the average MSE of Hop-1, Hop-2, and Hop-3 test groups for both tasks. Although REx yields the lowest error on Hop-1 and Hop-2 groups in N (24) → S (24), it yields the highest prediction error on Hop-3 groups. The results indicate that REx may yield compromised performance under large distributional drifts. TRO yields the best performance on Hop-3 groups, indicating its strong generalization capability under large distributional drifts. Ablations study. (1) Data-driven topology yields better performance than physical-based topology. We show group centrality of both physical and data topology on the task of North → South in Fig. 4 . "PA" is identified by TRO as the influential group in physical-based topology; "NY", "PA", and "MA" are identified by TRO as influential groups in data-driven topology. The results prove that the influential groups in data topology are more effective in minimizing the generalization error. (2) Strong OOD resilience of TRO comes from the synergy of the worst-case and influential groups. To investigate which components contribute to the superior performance of TRO. We build a simple baseline based on ERM: we directly use the group importance acquired from the topology to weight training groups and the weights are fixed during the training. We name this baseline as importance weighted ERM (IW-ERM). We show the results of "N(24)→S( 24)" on TPT-48 in Tab. 3. The results of IW-ERM are inferior to ERM and DRO, possibly because IW-ERM merely considers influential and "NGA" are identified by TRO as the most influential groups. A possible explanation is that both "IND" and "NGA" are aroused by heavy rainfall, the most prevalent disaster that causes floods. (2) "GHA" and "KHM" are identified by TRO as the least influential groups. A possible explanation is that both "GHA" and "KHM" are aroused edge cases such as dam collapse. The data-driven distributional topology is consistent with domain knowledge and facilitates the explainability of TRO. Data Label Data ? Data Label Data Label 2019 USA FLOOD 2018 BOL FLOOD 2018 SOM FLOOD BOL PRY GHA LKA SOM KHM NGA PRY IND GHA LKA PAK SOM KHM NGA groups. We further show the group importance of DRO and TRO in Fig. 5 . TRO significantly reduces the generalization risks by not only prioritizing the worst-case groups but also the influential ones.

5.3. SEMANTIC SEGMENTATION

Datasets. Sen1Floods11 (Bonafilia et al., 2020) is a public dataset for flood mapping at the global scale. The dataset provides global coverage of 4,831 chips of 512 x 512 10m satellite images across 11 distinct flood events, covering 120,406 km 2 . Each image is associated with its pixel-wise label. Locations of the 11 flood events are shown in Fig. 6 (left). Flood events vary in boundary conditions, terrain, and other latent factors, posing significant OOD challenges to existing models in terms of reliability and explainability. Following Bonafilia et al. (2020) , event "BOL" is held out for testing, and data of other events are split into training and validation sets with a random 80-20 split. ERM IRM REx SD DRO TRO (data)

Val

. 489 .387 .484 .449 .480 .485 Test .430 .338 .357 .400 .433 .450 Table 4 : Segmentation results (IoU) on Sen1Floods11. TRO yields better performance than other baselines on unseen flood events. Results. We show the results of Sen1Floods11 in Tab. 4. ERM achieves the highest IoU on the validation set while TRO achieves the highest IoU on the test set. The results prove that TRO yields better performance than other baselines on unseen flood events. Ablations study. (1) Data-driven distributional topology is consistent with domain knowledge. We visualize the distributional topology as well as group centrality in Fig. 6 (right). The learned distributional topology is consistent with domain knowledge, enhancing the explainability of TRO. (2) Ablation study on λ. We report IoU under different λ on Sen1Floods11 in Fig. 7 . IoU remains stable for a wide range of λ, and λ = 0.01 yields the best performance.

6. CONCLUSION

In this paper, we proposed a new principled optimization method that seamlessly integrates topological information to develop strong OOD resilience. Empirical results in a wide range of tasks including classification, regression, and semantic segmentation demonstrate the superior performance of our method over SOTA. Moreover, the data-driven distributional topology is consistent with domain knowledge and facilitates the explainability of our approach. Shalev-Shwartz & Wexler (2016) argued that minimizing the maximal loss over a set of possible distributions can provide better generalization performance than minimizing the average loss. The robustness guarantee of DRO heavily relies on the quality of the uncertainty set which is typically constructed by moment constraints (Delage & Ye, 2010) , f -divergence (Namkoong & Duchi, 2016) or Wasserstein distance (Shafieezadeh Abadeh et al., 2018) . To avoid yielding overly pessimistic models, group DRO (Hu et al., 2018; Sagawa et al., 2019) is proposed to leverage pre-defined data groups to formulate the uncertainty set as the mixture of these groups. Although the uncertainty set of Group DRO is of a wider radius while not being too conservative, our preliminary results show that Group DRO recklessly prioritizes the worst-case groups that incur higher losses than others. Such worst-case groups are not necessarily the influential ones that are truly connected to unseen distributions; optimizing over the worst-case rather than influential groups would yield mediocre OOD generalization performance. Out-of-Distribution Generalization. The goal of OOD generalization is to generalize the model from source distributions to unseen target distributions. There are mainly two branches of methods to tackle OOD generalization: group-invariant learning (Arjovsky et al., 2019; Koyama & Yamaguchi, 2020; Liu et al., 2021) and distributionally robust optimization. The goal of group-invariant learning is to exploit the causally invariant correlations across multiple distributions. Invariant Risk Minimization (IRM) is one of the most representative methods which learns the optimal classifier across source distributions. However, recent work (Rosenfeld et al., 2021) shows that IRM methods can fail catastrophically unless the test data are sufficiently similar to the training distribution. . Therefore, we only tune the learning rate of the mixture distribution η q and the dual variable λ. All results are reported over 3 random seed runs, which is consistent with Koh et al. ( 2021) and Shi et al. (2022) . We select λ from {1e-3, 1e-2, 1e-1, 1, 10, 100} and select η q from {1e-4, 1e-3, 1e-2, 1e-1, 1}.

7.3. ADDITIONAL RESULTS

DG-15 and DG-60. We provide detailed classification results for each group. The results are shown in Fig. 8 . We can see that, compared to other baselines, TRO significantly improves the generalization performance on groups that are far from the training groups, such as group "13". We further visualize the decision boundary of DG-15 and DG-60 in Fig. 9 We assume the reason why "Art" is the most influential group is that "Art" may contain more information than "Photo" and "Sketch" as "Art" is the combination of photos and various kinds of styles. Fang et al., 2013) . Image samples of the three datasets are shown in Fig. 11 (left) . (1) PACS is one of the most popular dataset for out-of-distribution generalization. It consists of images from four groups: "Art", "Cartoon", "Photo" and "Sketch". Results on PACS are shown in Tab. 5. Results of other baselines are from Appendix B.4 of Gulrajani & Lopez-Paz (2021) . In average accuracy, TRO outperforms the SOTA by 0.5%. To further investigate the results, we visualize the learned topology in Fig. 11 (right). As observed, when "Cartoon" is the test group, the topology is a chain graph consisting of three nodes where "Art" is the most influential group. A possible explanation is that "Art" may contain more information than "Photo" and "Sketch" as "Art" can be viewed as the combination of photos and various kinds of styles. Even though the topology is so simple, it enables our method to significantly outperforms ERM and DRO by 0.8% and 2.4% on average. The results empirically demonstrate the strong explainability of our method when the number of training groups is quite limited, i.e., 3. We would like to point out that when the distributional shift across different groups is small (see explanation on the results of VLCS), the influential group may not exist and all groups share the same centrality. In this special case, TRO aims to strike a good balance between the average (ERM) risk and the worst-case (DRO) risk. Table 7 : Accuracy (%) on VLCS. The average accuracy of DRO and TRO is the same. We assume the reason is that the distributional shift across different groups is small (Li et al., 2017) , and therefore the influential group may not exist and all groups share the same centrality. Gulrajani & Lopez-Paz (2021) . The average accuracy of DRO and TRO is the same. We assume the reason is that the distributional shift across different groups is small (Li et al., 2017) , and therefore the influential group may not exist and all groups share the same centrality. In this special case, TRO aims to strike a good balance between the average (ERM) risk and the worst-case (DRO) risk. The images of VLCS are all photos and the distributional shift is not as significant as PACS (e.g., Photo vs. Sketch). As stated in Sec. 2.1 of Li et al. (2017) , "despite the famous analysis of dataset bias that motivated the creation of the VLCS benchmark, it was later shown that the domain shift is much smaller with recent deep features", and PACS (Li et al., 2017) was proposed to address this limitation.

7.4. PROOF OF THEOREM 1

Proof. By using the property of convex loss functions, we can obtain: max q∈∆m R (θ T , q)min θ∈Θ R (θ, q T ) ≤ 1 T max q∈∆,θ∈Θ T t=1 R θ t , q -R θ, q t , where ∀t ≥ 1: R θ t , q -R θ, q t =R θ t , q -R θ t , q t + R θ t , q t -R θ, q t ≤ qq t , ∇ q R θ t , q t + θ tθ , ∇ θ R θ t , q t . By rearranging the terms in the above equation, we obtain: max q∈∆,θ∈Θ T t=1 R θ t , q -R θ, q t ≤ E max θ∈Θ T t=1 θ t -θ , ĝt θ + E max q∈∆m T t=1 qq t , ĝt q + E max q∈∆m T t=1 q, ∇ q R θ t , q tĝt q + E max θ∈Θ T t=1 θ, ĝt θ -∇ θ R θ t , q t . Following Collins et al. (2020) , we will derive the combined bound by bounding the expectation of each term in the above equation. For the first term, by utilizing the telescoping sum, we can obtain: Similarly, for the second term: E max q∈∆m T t=1 q -q t , ĝt q ≤ 2 η q + η q T Ĝ2 q 2 . The third term and the last term are bounded by √ T σq and R Θ √ T σθ , respectively. To this end, we can derive the overall bound as: E max q∈∆m R (θ T , q) -min θ∈Θ R (θ, q T ) ≤ 2R 2 Θ η θ T + η θ Ĝ2 θ 2 + 2 η q T + η q Ĝ2 q 2 + R Θ σθ √ T + σq √ T . The above bound can be minimized by setting the step sizes η θ = 2R Θ / Ĝθ √ T and η q = 2/ Ĝq √ T .

7.5. PROOF OF THEOREM 2

Proof. Inspired by Qian et al. ( 2019) and Collins et al. (2020) , we utilize the property of M-smooth to start the proof: E n i=1 q t i ℓ i θ t+1 ≤ n i=1 q t i ℓ i θ t - n i=1 q t+1 i ℓ i θ t+1 +E n i=1 q t+1 i ℓ i θ t+1 - n i=1 q t i ℓ i θ t+1 + η 2 θ M σ 2 θ 2 . The second term of the above equation can be bounded by: E n i=1 q t+1 i ℓ i θ t+1 - n i=1 q t i ℓ i θ t+1 = E n i=1 q t+1 i -q t i ℓ i θ t+1 ≤ E q t+1 -q t 2 n i=1 ℓ i θ t+1 1/2 ≤ η q √ n B Ĝq . By using the Law of Iterated Expectations, we can obtain: η θ - η 2 θ M 2 T t=1 E g t θ 2 ≤ E n i=1 q 1 i ℓ i θ 1 -E n i=1 q T +1 i ℓ i θ T +1 + 2T η q √ n B Ĝq + T M η 2 θ σ 2 θ 2 ≤ R θ 1 , q 1 + B + 2T η q √ n B Ĝq + T η 2 θ M σ 2 θ 2 . ( ) Next, we investigate the convergence of q: E R θ t , q -R θ t , q t = E 1 2η q qq t 2 2 + (η q ) 2 ĝt q 2 2qq t + η q ĝt q 2 2 ≤ E 1 2η q qq t 2 2 + (η q ) 2 ĝt q 2 2qq t+1 2 2 ≤ E 1 2η q qq t 2 2 + (η q ) 2 Ĝ2 qqq t+1 2 2 . Published as a conference paper at ICLR 2023 By aggregating the difference at all time steps, we obtain: T t=1 E R θ t , q -R θ t , q t ≤ T t=1 1 2η q E qq t 2 2 -1 2η q E qq t+1 2 2 + η q 2 Ĝ2 q = 1 2η q E qq 1 2 2 + η q 2 T Ĝ2 q ≤ 1 η q + η q T Ĝ2 q 2 . Since the above equation holds for all q ∈ ∆ m , we maximize the right hand side over q ∈ ∆ m : 1 T T t=1 E R θ t , q t ≥ max q∈∆m [R (θ T , q)] - 1 η q T + η q Ĝ2 q 2 . ( ) Eqs. 9 and 10 show that TRO converges in expectation to an (ϵ, δ)-stationary point of R in O(1/ϵ 4 ) stochastic gradient evaluations.



The source code and pre-trained models are available at: https://github.com/joffery/TRO. https://github.com/KrishnaswamyLab/DiffusionEMD



r 6 t H L Y B A 8 r b s a 1 G P A i 8 c I 5 g H Z J c x O Z p M h s w 9 m e s W w 5 D e 8 e F D E q z / j z b 9 x k u x B E w s a i q p u u r u C V A q N j v N t r a y u r W 9 s l r b K 2 z u 7 e / u V g 8 O W T j L F e J M l M l G d g G o u R c y b K F D y T q o 4 j Q L J 2 8 H o d u q 3 H 7 n S I o k f c J x y P 6 K D W I S C U T S S 5 y F / w i D M 3 f P L S a 9 S d W x n B r J M 3 I J U o U C j V / n y + g n L I h 4 j k 1 T r r u u k 6 O d U o W C S T 8 p e p n l K 2 Y g O e N f Q m E Z c + / n s 5 g k 5 N U q f h I k y F S O Z q b 8 n c h p p P Y 4 C 0 x l R H O p F b y r + 5 3

x b 1 2 x i p k j + A P r 8 w e e F J F b < / l a t e x i t > 1/3 < l a t e x i t s h a 1 _ b a s e 6 4 = " Y v 4 T D / e Y k g 1 i g U A J 9 M 3 W U o U l N 5 s = " > A A A B 8 3 i c b V D L S g N B E O z 1 G e M r 6 t H L Y B A 8 r b s a 1 G P A i 8 c I 5 g H Z J c x O Z p M h s w 9 m e s W w 5 D e 8 e F D E q z / j z b 9 x k u x B E w s a i q p u u r u C V A q N j v N t r a y u r W 9 s l r b K 2 z u 7 e / u V g 8 O W T j L F e J M l M l G d g G o u R c y b K F D y T q o 4 j Q L J 2 8 H o d u q 3 H 7 n S I o k f c J x y P 6 K D W I S C U T S S 5 y F / w i D M 3 f P L S a 9 S d W x n B r J M 3 I J U o U C j V / n y + g n L I h 4 j k 1 T r r u u k 6 O d U o W C S T 8 p e p n l K 2 Y g O e N f Q m E Z c + / n s 5 g k 5 N U q f h I k y F S O Z q b 8 n c h p p P Y 4 C 0 x l R H O p F b y r + 5 3

x b 1 2 x i p k j + A P r 8 w e e F J F b < / l a t e x i t > 1/3 < l a t e x i t s h a 1 _ b a s e 6 4 = " Y v 4 T D / e Y k g 1 i g U A J 9 M 3 W U o U l N 5 s = " > A A A B 8 3 i c b V D L S g N B E O z 1 G e M r 6 t H L Y B A 8 r b s a 1 G P A i 8 c I 5 g H Z J c x O Z p M h s w 9 m e s W w 5 D e 8 e F D E q z / j z b 9 x k u x B E w s a i q p u u r u C V A q N j v N t r a y u r W 9 s l r b K 2 z u 7 e / u V g 8 O W T j L F e J M l M l G d g G o u R c y b K F D y T q o 4 j Q L J 2 8 H o d u q 3 H 7 n S I o k f c J x y P 6 K D W I S C U T S S 5 y F / w i D M 3 f P L S a 9 S d W x n B r J M 3 I J U o U C j V / n y + g n L I h 4 j k 1 T r r u u k 6 O d U o W C S T 8 p e p n l K 2 Y g O e N f Q m E Z c + / n s 5 g k 5 N U q f h I k y F S O Z q b 8 n c h p p P Y 4 C 0 x l R H O p F b y r + 5 3

x b 1 2 x i p k j + A P r 8 w e e F J F b < / l a t e x i t > 1/3 < l a t e x i t s h a 1 _ b a s e 6 4 = " Y v 4 T D / e Y k g 1 i g U A J 9 M 3 W U o U l N 5 s = " > A A A B 8 3 i c b V D L S g N B E O z 1 G e M r 6 t H L Y B A 8 r b s a 1 G P A i 8 c I 5 g H Z J c x O Z p M h s w 9 m e s W w 5 D e 8 e F D E q z / j z b 9 x k u x B E w s a i q p u u r u C V A q N j v N t r a y u r W 9 s l r b K 2 z u 7 e / u V g 8 O W T j L F e J M l M l G d g G o u R c y b K F D y T q o 4 j Q L J 2 8 H o d u q 3 H 7 n S I o k f c J x y P 6 K D W I S C U T S S 5 y F / w i D M 3 f P L S a 9 S d W x n B r J M 3 I J U o U C j V / n y + g n L I h 4 j k 1 T r r u u k 6 O d U o W C S T 8 p e p n l K 2 Y g O e N f Q m E Z c + / n s 5 g k 5 N U q f h I k y F S O Z q b 8 n c h p p P Y 4 C 0 x l R H O p F b y r + 5 3

x b 1 2 x i p k j + A P r 8 w e e F J F b < / l a t e x i t > 1/3 < l a t e x i t s h a 1 _ b a s e 6 4 = " Y v 4 T D / e Y k g 1 i g U A J 9 M 3 W U o U l N 5 s = " > A A A B 8 3 i c b V D L S g N B E O z 1 G e M r 6 t H L Y B A 8 r b s a 1 G P A i 8 c I 5 g H Z J c x O Z p M h s w 9 m e s W w 5 D e 8 e F D E q z / j z b 9 x k u x B E w s a i q p u u r u C V A q N j v N t r a y u r W 9 s l r b K 2 z u 7 e / u V g 8 O W T j L F e J M l M l G d g G o u R c y b K F D y T q o 4 j Q L J 2 8 H o d u q 3 H 7 n S I o k f c J x y P 6 K D W I S C U T S S 5 y F / w i D M 3 f P L S a 9 S d W x n B r J M 3 I J U o U C j V / n y + g n L I h 4 j k 1 T r r u u k 6 O d U o W C S T 8 p e p n l K 2 Y g O e N f Q m E Z c + / n s 5 g k 5 N U q f h I k y F S O Z q b 8 n c h p p P Y 4 C 0 x l R H O p F b y r + 5 3

Figure 1: Overview of topology-aware distributionally robust optimization (TRO).

and n = (n 1 , . . . , n m ) denote the vector of sample sizes for all training groups. FollowingMohri et al. (2019), we define weighted Rademacher complexity for any F as:R n (H, q) = E i ℓ (f (x e,i ) , y e,i ) ,where e denotes group index, S e a sample of size n e , P e the distribution of group e, and σ = (σ e,i ) e∈[m],i∈[ne] a collection of Rademacher variables. The minimax weighted Rademacher complexity for a subset Λ ⊆ ∆ m is defined by R n (H, Λ) = max q∈Λ R n (H, q) where n = m e=1 n e . Let P Λ be the distribution over the mixture of training groups and PΛ be its empirical counterpart.

Figure 2: Illustration of data groups in (a) DG-15 and (b) DG-60 datasets.

CLASSIFICATION Datasets. DG-15 (Xu et al., 2022) is a synthetic binary classification dataset with 15 groups. Each group contains 100 data points. In this dataset, adjacent groups have similar decision boundaries. Following Xu et al. (2022), we use six connected groups as the training groups, and use others as test groups. Note that, different from Xu et al. (2022) which focuses on domain adaptation, the data of test groups are unseen in OOD generalization. DG-60 (Xu et al., 2022) is another synthetic dataset generated using the same procedure as DG-15, except that it contains 60 groups, with 6,000 data points in total. We randomly select six groups as the training groups and use others as test groups. Visualization of DG-15 and DG-60 are shown in Fig. 2 (a) and (b), respectively.

Figure 4: Left: Generalization task of North → South on TPT-48. Middle: Group centrality of physical-based topology. Right: Group centrality of data-driven topology. "PA" is identified by TRO as the influential group in physical-based topology; "NY", "PA", and "MA" are identified by TRO as influential groups in data-driven topology. Data topology yields lower MSE than physical topology.

Figure6: Left: Location of the 11 flood events on Sen1Floods11. We use the event "BOL" for testing and other events for training. Right: Data-driven distributional topology on Sen1Floods11. (1) "IND" and "NGA" are identified by TRO as the most influential groups. A possible explanation is that both "IND" and "NGA" are aroused by heavy rainfall, the most prevalent disaster that causes floods. (2) "GHA" and "KHM" are identified by TRO as the least influential groups. A possible explanation is that both "GHA" and "KHM" are aroused edge cases such as dam collapse. The data-driven distributional topology is consistent with domain knowledge and facilitates the explainability of TRO.

Figure 7: Ablation study on λ. IoU remains stable for a wide range of λ.

IMPLEMENTATION DETAILS In Sec. 3.1, for all hyperparameters such as the kernel scale σ 2 and the maximum scale K, we use the default values from the official implementation 1 of Tong et al. (2021). In Sec. 3.2, for learning rate of model parameters η θ , we use default values from Xu et al. (2022) (DG-15/-60 and TPT-48) and Bonafilia et al. (2020) (Sen1Floods11)

Figure11: Left: Image samples of DomainBed. Right: the data-driven topology of PACS when "Cartoon" is the test group while the other three are training groups. We assume the reason why "Art" is the most influential group is that "Art" may contain more information than "Photo" and "Sketch" as "Art" is the combination of photos and various kinds of styles.

Accuracy (%) on DG-15 and DG-60. TRO sets the new SOTA on both DG-15 and DG-60.

Mean Squared Error (MSE) for both tasks E (24) → W (24) and N (24) → S (24) on TPT-48. TRO (data-driven topology) consistently outperforms TRO (physical-based topology) in both tasks, indicating the data-driven topology captures the distributional relation more accurately.

MSE on TPT-48. Ignoring either the worst-case (IW-ERM) or influential (DRO) groups would yield compromised performance.

Alexander Y Tong, Guillaume Huguet, Amine Natik, Kincaid MacDonald, Manik Kuchroo, Ronald Coifman, Guy Wolf, and Smita Krishnaswamy. Diffusion earth mover's distance and distribution embeddings. In International Conference on Machine Learning, pp. 10336-10346. PMLR, 2021.



Accuracy (%) on Terra. TRO achieves comparable results with the SOTA and outperforms ERM and DRO by 1.8% and 2.0%. DomainBed. Following the instructions of the official implementation of DomainBed Gulrajani & Lopez-Paz (2021), we have conducted experiments on PACS(Li et al., 2017), Terra(Beery et al., 2018), and VLCS (

Terra consists of images of wild animals captured by camera traps under four locations. Results on Terra are shown in Tab. 6. Results of other baselines are from Appendix B.6 of Gulrajani & Lopez-Paz (2021). As observed, in average accuracy, TRO achieves comparable results with the SOTA and outperforms ERM and DRO by 1.8% and 2.0%. (3) Results on VLCS are shown in Tab. 7. Results of other baselines are from Appendix B.3 of

acknowledgement

We evaluate TRO in a wide range of tasks including classification, regression, and semantic segmentation. We compare TRO with SOTA baselines on OOD generalization and conduct ablation study on the key components of TRO. Following Gulrajani & Lopez-Paz (2021) , we perform model selection based on a validation set constructed from training groups only. We provide implementation details in Appendix 7.2 and results on DomainBed (Gulrajani & Lopez-Paz, 2021 ) in Appendix 7.3.Baselines. We compare TRO with the following methods: (1) Empirical Risk Minimization (ERM) (Vapnik, 1998); (2) Group distributionally robust optimization (DRO) (Sagawa et al., 2019) ;(3) Invariant Risk Minimization (IRM) (Arjovsky et al., 2019) ; (4) Risk Extrapolation (REx) (Krueger et al., 2021) ; (5) Spectral Decoupling (SD) (Pezeshki et al., 2021) .

ACKNOWLEDGEMENTS

This work is partially supported by National Science Foundation (NSF) CMMI-2039857, General University Research (GUR), and University of Delaware Research Foundation (UDRF). The authors would like to thank Kien X. Nguyen for helping with the experiments on DomainBed.

