UNDERSTANDING NEW TASKS THROUGH THE LENS OF TRAINING DATA VIA EXPONENTIAL TILTING

Abstract

Deploying machine learning models on new tasks is a major challenge due to differences in distributions of the train (source) data and the new (target) data. However, the training data likely captures some of the properties of the new task. We consider the problem of reweighing the training samples to gain insights into the distribution of the target task. Specifically, we formulate a distribution shift model based on the exponential tilt assumption and learn train data importance weights minimizing the KL divergence between labeled train and unlabeled target datasets. The learned train data weights can then be used for downstream tasks such as target performance evaluation, fine-tuning, and model selection. We demonstrate the efficacy of our method on WATERBIRDS and BREEDS benchmarks. 1 

1. INTRODUCTION

Machine learning models are often deployed in a target domain that differs from the domain in which they were trained and validated in. This leads to the practical challenges of adapting and evaluating the performance of models on a new domain without costly labeling of the dataset of interest. For example, in the Inclusive Images challenge (Shankar et al., 2017) , the training data largely consists of images from countries in North America and Western Europe. If a model trained on this data is presented with images from countries in Africa and Asia, then (i) it is likely to perform poorly, and (ii) its performance in the training (source) domain may not mirror its performance in the target domain. However, due to the presence of a small fraction of images from Africa and Asia in the source data, it may be possible to reweigh the source samples to mimic the target domain. In this paper, we consider the problem of learning a set of importance weights so that the reweighted source samples closely mimic the distribution of the target domain. We pose an exponential tilt model of the distribution shift between the train and the target data and an accompanying method that leverages unlabeled target data to fit the model. Although similar methods are widely used in statistics Rosenbaum & Rubin (1983) and machine learning Sugiyama et al. (2012) to train and evaluate models under covariate shift (where the decision function/boundary does not change), one of the main benefits of our approach is it allows concept drift (where the decision boundary/function are expected to differ) (Cai & Wei, 2019; Gama et al., 2014) between the source and the target domains. We summarize our contributions below: • In Section 3 we develop a model and an accompanying method for learning source importance weights to mimic the distribution of the target domain without labeled target samples. • In Section 4 we establish theoretical guarantees on the quality of the weight estimates and their utility in the downstream tasks of fine-tuning and model selection. Out-of-distribution generalization is essential for safe deployment of ML models. There are two prevalent problem settings: domain generalization and subpopulation shift (Koh et al., 2020) . Domain generalization typically assumes access to several datasets during training that are related to the same task, but differ in their domain or environment (Blanchard et al., 2011; Muandet et al., 2013) . The goal is to learn a predictor that can generalize to unseen related datasets via learning invariant representations (Ganin et al., 2016; Sun & Saenko, 2016) Model selection on out-of-distribution (OOD) data is an important and challenging problem as noted by several authors (Gulrajani & Lopez-Paz, 2020; Koh et al., 2020; Zhai et al., 2021; Creager et al., 2021) . Xu & Tibshirani ( 2022 (2021a) train several models and use their ensembles or disagreement. Our importance weighting approach is computationally simpler than the latter and is more flexible in comparison to the former, as it allows for concept drift and can be used in downstream tasks beyond model selection as we demonstrate both theoretically and empirically. Domain adaptation is another closely related problem setting. Domain adaptation (DA) methods require access to labeled source and unlabeled target domains during training and aim to improve target performance via a combination of distribution matching (Ganin et al., 2016; Sun & Saenko, 2016; Shen et al., 2018 ), self-training (Shu et al., 2018; Kumar et al., 2020 ), data augmentation (Cai et al., 2021; Ruan et al., 2021) , and other regularizers. DA methods are typically challenging to train and require retraining for every new target domain. On the other hand, our importance weights are easy to learn for a new domain allowing for efficient fine-tuning, similar to test-time adaptation methods (Sun et al., 2020; Wang et al., 2020; Zhang et al., 2020) , which adjust the model based on the target unlabeled samples. Our importance weights can also be used to define additional regularizers to enhance existing DA methods. Importance weighting has often been used in the domain adaptation literature on label shift (Lipton et al., 2018; Azizzadenesheli et al., 2019; Maity et al., 2022) and covariate shift (Sugiyama et al., 2007; Hashemi & Karimi, 2018) but the application has been lacking in the area of concept drift models (Cai & Wei, 2019; Maity et al., 2021) , due to the reason that it is generally impossible to estimate the weights without seeing labeled data from the target. In this paper, we introduce an exponential tilt model which accommodates concept drift while allowing us to estimate the importance weights for the distribution shift.

3. THE EXPONENTIAL TILT MODEL

Notation We consider a K-class classification problem. Let X ∈ R d and Y ≜ [K] be the space of inputs and set of possible labels, and P and Q be probability distributions on X × Y for the



Codes can be found in https://github.com/smaityumich/exponential-tilting.



); Chen et al. (2021b) propose solutions specific to covariate shift based on parametric bootstrap and reweighing; Garg et al. (2022); Guillory et al. (2021); Yu et al. (2022) align model confidence and accuracy with a threshold; Jiang et al. (2021); Chen et al.

, invariant risk minimization(Arjovsky et al., 2019; Krueger et al., 2021), or meta-learning(Dou et al., 2019). Domain generalization is a very challenging problem and recent benchmark studies demonstrate that corresponding methods rarely improve over vanilla empirical risk minimization (ERM) on the source data unless given access to labeled target data for model selection(Gulrajani & Lopez-Paz, 2020; Koh et al., 2020).Subpopulation shift setting assumes that both train and test data consist of the same groups with different group fractions. This setting is typically approached via distributionally robust optimiza-

