TOWARDS UNDERSTANDING GD WITH HARD AND CON-JUGATE PSEUDO-LABELS FOR TEST-TIME ADAPTATION

Abstract

We consider a setting that a model needs to adapt to a new domain under distribution shifts, given that only unlabeled test samples from the new domain are accessible at test time. A common idea in most of the related works is constructing pseudolabels for the unlabeled test samples and applying gradient descent (GD) to a loss function with the pseudo-labels. Recently, Goyal et al. (2022) propose conjugate labels, which is a new kind of pseudo-labels for self-training at test time. They empirically show that the conjugate label outperforms other ways of pseudolabeling on many domain adaptation benchmarks. However, provably showing that GD with conjugate labels learns a good classifier for test-time adaptation remains open. In this work, we aim at theoretically understanding GD with hard and conjugate labels for a binary classification problem. We show that for square loss, GD with conjugate labels converges to an -optimal predictor under a Gaussian model for any arbitrarily small , while GD with hard pseudo-labels fails in this task. We also analyze them under different loss functions for the update. Our results shed lights on understanding when and why GD with hard labels or conjugate labels works in test-time adaptation.

1. INTRODUCTION

Fully test-time adaptation is the task of adapting a model from a source domain so that it fits to a new domain at test time, without accessing the true labels of samples from the new domain nor the data from the source domain (Goyal et al., 2022; Wang et al., 2021a; Li et al., 2020; Rusak et al., 2021; Zhang et al., 2021a; S & Fleuret, 2021; Mummadi et al., 2021; Iwasawa & Matsuo, 2021; Liang et al., 2020; Niu et al., 2022; Thopalli et al., 2022; Wang et al., 2022b; Kurmi et al., 2021) . Its setting is different from many works in domain adaptation or test-time training, where the source data or statistics of the source data are available, e.g., Xie et al. ( 2021 2022). Test-time adaptation has drawn growing interest recently, thanks to its potential in real-world applications where annotating test data from a new domain is costly and distribution shifts arise at test time due to some natural factors, e.g., sensor degradation (Wang et al., 2021a) , evolving road conditions (Gong et al., 2022; Kumar et al., 2020) , weather conditions (Bobu et al., 2018) , or change in demographics, users, and time periods (Koh et al., 2021) . The central idea in many related works is the construction of the pseudo-labels or the proposal of the self-training loss functions for the unlabeled samples, see e.g., Wang et al. (2021a); Goyal et al. (2022) . More precisely, at each test time t, one receives some unlabeled samples from a new domain, and then one constructs some pseudo-labels and applies a GD step to the corresponding self-training loss function, as summarized in Algorithm 1. Recently, Goyal et al. (2022) propose a new type of pseudo-labels called conjugate labels, which is based on an observation that certain loss functions can be naturally connected to conjugate functions, and the pseudo-labels are obtained by exploiting a property of conjugate functions (to be elaborated soon). They provide a modular approach of constructing conjugate labels for some loss functions, e.g., square loss, cross-entropy loss, exponential loss. An interesting finding of Goyal et al. ( 2022 2020) show that when data have spurious features, if projected GD is initialized with sufficiently high accuracy in a new domain, then by minimizing the exponential loss with hard labels, projected GD converges to an approximately Bayes-optimal solution under certain conditions. In this work, we study vanilla GD (without projection) for minimizing the self-training loss derived from square loss, logistic loss, and exponential loss under hard labels and conjugate labels. We prove a performance gap between GD with conjugate labels and GD with hard labels under a simple Gaussian model (Schmidt et al., 2018; Carmon et al., 2019) . Specifically, we show that GD with hard labels for minimizing square loss can not converge to an -optimal predictor (see ( 8) for the definition) for any arbitrarily small , while GD with conjugate labels converge to an -optimal predictor exponentially fast. Our theoretical result champions the work of conjugate labels of Goyal et al. (2022) . We then analyze GD with hard and conjugate labels under logistic loss and exponential loss, and we show that under these scenarios, they converge to an optimal solution at a log(t) rate, where t is the number of test-time iterations. Our results suggest that the performance of GD in test-time adaptation depends crucially on the choice of pseudo-labels and loss functions. Interestingly, the problems of minimizing the associated self-training losses of conjugate labels in this work are non-convex optimization problems. Hence, our theoretical results find an application in non-convex optimization where GD can enjoy some provable guarantees.

2. PRELIMINARIES

We now give an overview of hard labels and conjugate labels. But we note that there are other proposals of pseudo-labels in the literature. We refer the reader to Li et al. ( 2019 Hard labels: Suppose that a model w outputs h w (x) ∈ R K and that each element of h w (x) could be viewed as the predicted score of each class for a multi-class classification problem with K classes. A hard pseudo-label y hard w (x) is a one-hot vector which is 1 on dimension k (and 0 elsewhere) if k = arg max k h w (x)[k], i.e., class k has the largest predicted score by the model w for a sample x (Goyal et al., 2022) . On the other hand, for a binary classification problem by a linear predictor, i.e., h w (x) = w x, a hard pseudo-label is simply defined as: y hard w (x) := sign(w x), (1) see, e.g., Kumar et al. (2020 ), Chen et al. (2020) . GD with hard labels is the case when Algorithm 1 uses a hard label to construct a gradient ∇ w self (w t ; x t ) and update the model w. Conjugate labels (Goyal et al., 2022) : The approach of using conjugate labels as pseudo-labels crucially relies on the assumption that the original loss function is of the following form: (w; (y, x)) := f (h w (x)) -y h w (x), (2) where f (•) : R K → R is a scalar-value function, and y ∈ R K is the label of x, which could be a one-hot encoding vector in multi-class classification. Since the true label y of a sample x is not



); Liu et al. (2021a); Prabhu et al. (2021); Sun et al. (2020); Chen et al. (2022); Hoffman et al. (2018); Eastwood et al. (2022); Kundu et al. (2020); Liu et al. (2021b); Schneider et al. (2020); Gandelsman et al. (2022); Zhang et al. (2021b); Morerio et al. (2020); Su et al. (

) is that a recently proposed self-training loss for test-time adaptation of Wang et al. (2021a) can be recovered from their conjugate-label Algorithm 1: Test-time adaptation via pseudo-labeling 1: Init: w1 = wS , where wS is the model learned from a source domain. 2: Given: Access to samples from the data distribution Dtest of a new domain. 3: for t = 1, 2, . . . , T do 4: Get a sample xt ∼ Dtest from the new domain. 5: Construct a pseudo-label y pseudo w t (xt) and consequently a self-training loss function self (wt; xt). 6: Apply gradient descent (GD): wt+1 = wt -η∇w self (wt; xt). 7: end for framework. They also show that GD with conjugate labels empirically outperforms that of other pseudo-labels like hard labels and robust pseudo-labels (Rusak et al., 2021) across many benchmarks, e.g., ImageNet-C (Hendrycks & Dietterich, 2019), ImageNet-R (Hendrycks et al., 2021), VISDA-C (Peng et al., 2017), MNISTM (Ganin & Lempitsky, 2015). However, certain questions are left open in their work. For example, why does GD with conjugate labels work? Why can it dominate GD with other pseudo-labels? To our knowledge, while pseudo-labels are quite indispensable for self-training in the literature (Li et al., 2019; Zou et al., 2019), works that theoretically understand the dynamic of GD with pseudo-labels are very sparse, and the only work that we are aware is of Chen et al. (2020). Chen et al. (

); Zou et al. (2019); Rusak et al. (2021) and the references therein for details.

