OUT-OF-DISTRIBUTION GENERALIZATION ANALYSIS VIA INFLUENCE FUNCTION Anonymous

Abstract

The mismatch between training and target data is one major challenge for current machine learning systems. When training data is collected from multiple domains and the target domains include all training domains and other new domains, we are facing an Out-of-Distribution (OOD) generalization problem that aims to find a model with the best OOD accuracy. One of the definitions of OOD accuracy is worst-domain accuracy. In general, the set of target domains is unknown, and the worst over target domains may be unseen when the number of observed domains is limited. In this paper, we show that the worst accuracy over the observed domains may dramatically fail to identify the OOD accuracy. To this end, we introduce Influence Function, a classical tool from robust statistics, into the OOD generalization problem and suggest the variance of influence function to monitor the stability of a model on training domains. We show that the accuracy on test domains and the proposed index together can help us discern whether OOD algorithms are needed and whether a model achieves good OOD generalization.

1. INTRODUCTION

Most machine learning systems assume both training and test data are independently and identically distributed, which does not always hold in practice (Bengio et al. (2019) ). Consequently, its performance is often greatly degraded when the test data is from a different domain (distribution). A classical example is the problem to identify cows and camels (Beery et al. (2018) ), where the empirical risk minimization (ERM, Vapnik (1992) ) may classify images by background color instead of object shape. As a result, when the test domain is "out-of-distribution" (OOD), e.g. when the background color is changed, its performance will drop significantly. The OOD generalization is to obtain a robust predictor against this distribution shift. Suppose that we have training data collected from m domains: S = {S e : e ∈ E tr , |E tr | = m}, S e = {z e 1 , z e 2 , . . . , z e n e } with z e i ∼ P e , where P e is the distribution corresponding to domain e, E tr is the set of all available domains, including validation domains, and z e i is a data point. The OOD problem we considered is to find a model f OOD such that f OOD = arg min f sup P e ∈E all (f, P e ), where E all is the set of all target domains and (f, P e ) is the expected loss of f on the domain P e . Recent algorithms address this OOD problem by recovering invariant (causal) features and build the optimal model on top of these features, such as Invariant Risk Minimization (IRM, Arjovsky et al. 2017)). This is not surprising, since these tasks only require high performance on certain domains, while an OOD algorithm is expected to learn truly invariant features and be excellent on a large set of target domains E all . This phenomenon is described as "accuracy-vs-invariance trade-off" in Akuzawa et al. (2019) . Two questions arise in the min-max problem (2). First, previous works assume that there is sufficient diversity among the domains in E all . Thus the supremacy of (f, P e ) may be much larger than the average, which implies that ERM may fail to discover f OOD . But in reality, we do not know whether it is true. If not, the distribution of (f, P e ) is concentrated on the expectation of (f, P e ), and ERM is sufficient to find an invariant model for E all . Therefore, we call for a method to judge whether an OOD algorithm is needed. Second, how to judge a model's OOD performance? Traditionally, we consider test domains E test ⊂ E tr and use the worst-domain accuracy over E test (which we call test accuracy) to approximate the OOD accuracy. However, test accuracy is a biased estimate of the OOD accuracy unless E tr is closed to E all . More seriously, It may be irrelevant or even negatively correlated to the OOD accuracy. This phenomenon is not uncommon, especially when there are features virtually spurious in E all but show a strong correlation to the target in E tr . We give a toy example in Colored MNIST when the test accuracy fails to approximate the OOD accuracy. For more details, please refer to Section 5.1 and Appendix A.4. We choose three domains from Colored MNIST and use cross-validation (Gulrajani & Lopez-Paz ( 2020)) to select models, i.e. we take turns to select a domain S ∈ E tr as the test domain and train on the rest, and select the model with max average test accuracy. Figure 1 shows the comparison between ERM and IRM. One can find that no matter which domain is the test domain, ERM model uniformly outperforms IRM model on the test domain. However, IRM model achieves consistently better OOD accuracy. Shortcomings of the test accuracy here are obvious, regardless of whether cross-validation is used. In short, the naive use of the test accuracy may result in a non-OOD model. To address this obstacle, we hope to find a metric that correlates better with model's OOD property, even when E tr is much smaller than E all and the "worst" domain remains unknown. Without any assumption to E all , our goal is unrealistic. Therefore, we assume that features that are invariant across E tr should also be across E all . This assumption is necessary. Otherwise, the only thing we can do is to collect more domains. Therefore, we need to focus on what features the model has learnt. Specifically, we want to check whether the model learns invariant features and avoid varying features. The influence function (Cook & Weisberg (1980) ) can serve our purpose. Influence function was proposed to measures the parameter change when a data point is removed or upweighted by a small perturbation (details in 3.2). When modified it to domain-level, it measures the influence of a domain instead of a data point on the model. Note that we are not emulating the changes of the parameter when a domain is removed. Instead, we are exactly caring about upweighting the domain by δ → 0 + (will be specified later). Base on this, the variance of influence function allows us to measure OOD property and solve the obstacle. Contributions we summarize our contributions here: (i) We introduce influence function to domain-level and propose index V γ|θ (formula 6) based on influence function of the model f θ . Our index can measure the OOD extent of available domains, i.e. how different these domains (distributions) are. This measurement provides a basis for whether to adopt an OOD algorithm and to collect more diverse domains. See Section 4.1 and Section 5.1.1 for details. (ii) We point out that the



(2019)), Risk Extrapolation (REx, Krueger et al. (2020)), Group Distributionally Robust Optimization (gDRO, Sagawa et al. (2019)) and Inter-domain Mixup (Mixup, Xu et al. (2020); Yan et al. (2020); Wang et al. (2020)). Most works evaluate on Colored MNIST (see 5.1 for details) where we can directly obtain the worst domain accuracy over E all . Gulrajani & Lopez-Paz (2020) has assembled many algorithms and multi-domain datasets, and finds that OOD algorithms can't outperform ERM in some domain generalization tasks (Gulrajani & Lopez-Paz (2020)), e.g. VLCS (Torralba & Efros (2011)) and PACS (Li et al. (

Figure 1: Experiments in Colored MNIST to show test accuracy is not enough to reflect a model's OOD accuracy. The top left penal shows the test accuracy of ERM and IRM. The other three panels present the relationship between test accuracy (x-axis) and OOD accuracy (y-axis) in three setups.

