PARETO INVARIANT RISK MINIMIZATION: TOWARDS MITIGATING THE OPTIMIZATION DILEMMA IN OUT-OF-DISTRIBUTION GENERALIZATION

Abstract

Recently, there has been a growing surge of interest in enabling machine learning systems to generalize well to Out-of-Distribution (OOD) data. Most efforts are devoted to advancing optimization objectives that regularize models to capture the underlying invariance; however, there often are compromises in the optimization process of these OOD objectives: i) Many OOD objectives have to be relaxed as penalty terms of Empirical Risk Minimization (ERM) for the ease of optimization, while the relaxed forms can weaken the robustness of the original objective; ii) The penalty terms also require careful tuning of the penalty weights due to the intrinsic conflicts between ERM and OOD objectives. Consequently, these compromises could easily lead to suboptimal performance of either the ERM or OOD objective. To address these issues, we introduce a multi-objective optimization (MOO) perspective to understand the OOD optimization process, and propose a new optimization scheme called PAreto Invariant Risk Minimization (PAIR). PAIR improves the robustness of OOD objectives by cooperatively optimizing with other OOD objectives, thereby bridging the gaps caused by the relaxations. Then PAIR approaches a Pareto optimal solution that trades off the ERM and OOD objectives properly. Extensive experiments on challenging benchmarks, WILDS, show that PAIR alleviates the compromises and yields top OOD performances.

1. INTRODUCTION

The interplay between optimization and generalization is crucial to the success of deep learning (Zhang et al., 2017; Arora et al., 2019; Allen-Zhu et al., 2019; Jacot et al., 2021; Allen-Zhu & Li, 2021) . Guided by empirical risk minimization (ERM) (Vapnik, 1991) , simple optimization algorithms can find uneventful descent paths in the non-convex loss landscape of deep neural networks (Sagun et al., 2018) . However, when distribution shifts are present, the optimization is usually biased by spurious signals such that the learned models can fail dramatically in Out-of-Distribution (OOD) data (Beery et al., 2018; DeGrave et al., 2021; Geirhos et al., 2020) . Therefore, overcoming the OOD generalization challenge has drawn much attention recently. Most efforts are devoted to proposing better optimization objectives (Rojas-Carulla et al., 2018; Koyama & Yamaguchi, 2020; Parascandolo et al., 2021; Krueger et al., 2021; Creager et al., 2021; Liu et al., 2021; Pezeshki et al., 2021; Ahuja et al., 2021a; Wald et al., 2021; Shi et al., 2022; Rame et al., 2021; Chen et al., 2022b) that regularize the gradient signals produced by ERM, while it has been long neglected that the interplay between optimization and generalization under distribution shifts has already changed its nature. In fact, the optimization process of the OOD objectives turns out to be substantially more challenging than ERM. There are often compromises when applying the OOD objectives in practice. Due to the optimization difficulty, many OOD objectives have to be relaxed as penalty terms of ERM in , which however either weakens the power of OOD objectives or makes them too strong that prevents models from capturing all desirable patterns. Consequently, using the traditional optimization wisdom to train and select models can easily lead to suboptimal performance of either ERM or OOD objectives. Most OOD objectives remain struggling with distribution shifts or even underperform ERM (Gulrajani & Lopez-Paz, 2021; Koh et al., 2021) . This phenomenon calls for a better understanding of the optimization in OOD generalization, and raises a challenging question: How can one obtain a desired OOD solution under the conflicts of ERM and OOD objectives? To answer this question, we take a multi-objective optimization (MOO) perspective of the OOD optimization. Specifically, using the representative OOD objective IRM (Arjovsky et al., 2019) as an example, we find that the failures in OOD optimization can be attributed to two issues. The first one is the compromised robustness of OOD objectives due to the relaxation in the practical variants. In fact, it can even eliminate the desired invariant solution from the Pareto front w.r.t. the ERM and the OOD penalty (Fig. 1(a) ). Therefore, merely optimizing the ERM and the relaxed OOD penalty can hardly approach the desired solution. On the other hand, when the Pareto front contains the desired solution, as shown in Fig. 1 (c), using the traditional linear weighting scheme that linearly reweights the ERM and OOD objectives, cannot reach the solution if it lies in the non-convex part of the front (Boyd & Vandenberghe, 2014) . Even when the OOD solution is reachable (i.e., lies in the convex part), it still requires careful tuning of the OOD penalty weights to approach the solution, as shown in Fig. 1(d) . To et al., 2019) for some group of problem instances (Sec. 3.2). When given robust OOD objectives, PAIR-o finds a descent path with adaptive penalty weights, which leads to a Pareto optimal solution that trades off ERM and OOD performance properly (Sec. 4). In addition, the MOO analysis also motivates PAIR-s, which facilitates the OOD model selection by considering the trade-offs between ERM and OOD objectives. We conducted extensive experiments on challenging OOD benchmarks. Empirical results show that PAIR-o successfully alleviates the objective conflicts and empowers IRMv1 to achieve high perfor-



Figure 1: Optimization issues in OOD algorithms. (a) OOD objectives such as IRM usually require several relaxations for the ease of optimization, which however introduces huge gaps. The ellipsoids denote solutions that satisfy the invariance constraints of practical IRM variant IRMv1. When optimized with ERM, IRMv1 prefers f 1 instead of f IRM (The predictor produced by IRM). (b)The gradient conflicts between ERM and OOD objectives generally exist for different objectives at different penalty weights (x-axis). (c) The typically used linear weighting scheme to combine ERM and OOD objectives requires careful tuning of the weights to approach the solution. However, the scheme cannot reach any solutions in the non-convex part of the Pareto front. In contrast, PAIR finds an adaptive descent direction under gradient conflicts that leads to the desired solution. (d) Due to the optimization dilemma, the best OOD performance (e.g., IRMv1 w.r.t. a modified COLOREDMNIST from Sec. 5) usually requires exhaustive tuning of hyperparameters (y-axis: penalty weights; x-axis: pretraining epochs), while PAIR robustly yields top performances by resolving the compromises.practice(Arjovsky et al., 2019; Koyama & Yamaguchi, 2020; Krueger et al., 2021; Pezeshki et al.,  2021; Ahuja et al., 2021a; Rame et al., 2021), but the relaxed formulations can behave very differently from the original objective (Kamath et al., 2021) (Fig.1(a)). Moreover, due to the generally existing gradient conflicts between ERM and OOD objectives (Fig.1(b)), trade-offs among ERM and OOD performance during the optimization are often needed. Sagawa* et al. (2020); Zhai et al. (2022) suggest that ERM performance usually needs to be sacrificed for better OOD generalization. On the other hand, it usually requires careful tuning of the OOD penalty hyperparameters (Zhang et al., 2022a) (Fig.1(d)), which however either weakens the power of OOD objectives or makes them too strong that prevents models from capturing all desirable patterns. Consequently, using the traditional optimization wisdom to train and select models can easily lead to suboptimal performance of either ERM or OOD objectives. Most OOD objectives remain struggling with distribution shifts or even underperform ERM(Gulrajani & Lopez-Paz, 2021; Koh et al., 2021). This phenomenon calls for a better understanding of the optimization in OOD generalization, and raises a challenging question:

address these issues, we propose a new optimization scheme for OOD generalization, called PAreto Invariant Risk Minimization (PAIR), which includes a new optimizer (PAIR-o) and a new model selection criteria (PAIR-s). Owing to the MOO formulation, PAIR-o allows for cooperative optimization with other OOD objectives to improve the robustness of practical OOD objectives. Despite the huge gaps between IRMv1 and IRM, we show that incorporating VREx (Krueger et al., 2021) into IRMv1 provably recovers the causal invariance (Arjovsky

