TEST-TIME ADAPTATION VIA SELF-TRAINING WITH NEAREST NEIGHBOR INFORMATION

Abstract

Test-time adaptation (TTA) aims to adapt a trained classifier using online unlabeled test data only, without any information related to the training procedure. Most existing TTA methods adapt the trained classifier using the classifier's prediction on the test data as pseudo-label. However, under test-time domain shift, accuracy of the pseudo labels cannot be guaranteed, and thus the TTA methods often encounter performance degradation at the adapted classifier. To overcome this limitation, we propose a novel test-time adaptation method, called Test-time Adaptation via Self-Training with nearest neighbor information (TAST), which is composed of the following procedures: (1) adds trainable adaptation modules on top of the trained feature extractor; (2) newly defines a pseudo-label distribution for the test data by using the nearest neighbor information; (3) trains these modules only a few times during test time to match the nearest neighbor-based pseudo label distribution and a prototype-based class distribution for the test data; and (4) predicts the label of test data using the average predicted class distribution from these modules. The pseudo-label generation is based on the basic intuition that a test data and its nearest neighbor in the embedding space are likely to share the same label under the domain shift. By utilizing multiple randomly initialized adaptation modules, TAST extracts useful information for the classification of the test data under the domain shift, using the nearest neighbor information. TAST showed better performance than the state-of-the-art TTA methods on two standard benchmark tasks, domain generalization, namely VLCS, PACS, OfficeHome, and TerraIncognita, and image corruption, particularly CIFAR-10/100C.

1. INTRODUCTION

Deep neural networks often encounter significant performance degradations under domain shift (i.e., distribution shift). This phenomenon has been observed in various tasks including classification (Taori et al., 2020; Wang et al., 2021b ), visual recognition (Saenko et al., 2010; Csurka, 2017) , and reinforcement learning (Cobbe et al., 2019; Mendonca et al., 2020; Lee and Chung, 2021b) . There are two broad classes of domain adaptation methods that attempt to solve this problem: supervised domain adaptation (SDA) (Tzeng et al., 2015; Motiian et al., 2017) and unsupervised domain adaptation (UDA) (Ganin and Lempitsky, 2015; Long et al., 2016; Sener et al., 2016) . Both SDA and UDA methods aim to obtain domain-invariant representations by aligning the representations of training and test data closely in the embedding space. While testing, UDA methods require the training dataset and SDA methods additionally require labeled data of the test domain. However, in practice, it is often difficult to access training datasets or labeled data in the test domain during test time, due to data security or labeling cost. Test-time adaptation (TTA) (Iwasawa and Matsuo, 2021; Wang et al., 2021a ) is a prominent approach to alleviate the problems caused by the domain shift. TTA methods aim to adapt the trained model to the test domain without a labeled dataset in the test domain and any information related to the training procedure (e.g., training dataset, feature statistics of training domain (Sun et al., 2020; Liu et al., 2021; Eastwood et al., 2022) ). TTA methods have access to the online unlabeled test data only, whereas domain adaptation methods assume access to the whole (i.e., offline) test data. There are three popular categories for TTA: normalization-based method (Schneider et al., 2020) , entropy minimization (Liang et al., 2020; Wang et al., 2021a) and prototype-based methods (Iwasawa and Matsuo, 2021) . Normalization method replaces the batch normalization (BN) statistics of the trained model with the BN statistics estimated on test data, and does not update model parameters except for the BN layers. Entropy minimization methods fine-tune the trained feature extractor, which is the trained classifier except the last linear layer, by minimizing the prediction entropy of test data. These methods force the classifier to have over-confident predictions for the test data, and thus have a risk of degrading model calibration (Guo et al., 2017; Mukhoti et al., 2020) , a measure of model interpretability and reliability. One form of entropy minimization is self-training (Rosenberg et al., 2005; Lee, 2013; Xie et al., 2020) . Self-training methods use predictions from the classifier as pseudo labels for the test data and fine-tune the classifier to make it fit to the pseudo labels. These methods have a limitation that the fine-tuned classifier can overfit to the inaccurate pseudo labels, resulting in confirmation bias (Arazo et al., 2020) . This limitation can be harmful when the performance of the trained classifier is significantly degraded due to the domain shift. On the other hand, Iwasawa and Matsuo (2021) proposed a prototype-based TTA method, named T3A, that simply modifies a trained linear classifier (the last layer) by using the pseudo-prototype representations of each class and the prototype-based classification for test data, where the prototypes are constructed by previous test data and the prediction for the data from trained classifier. T3A does not update the trained feature extractor at test time. T3A is simple but it brings a marginal performance gain (Table 1 and 3 ). In this work, we propose a new test-time adaptation method, which is simple yet effective in mitigating the confirmation bias problem of self-training, by adding adaptation modules on top of the feature extractor, which are simply trainable during test time. We use the prototype-based classifier as in T3A, but not in the embedding space of the original feature extractor but in the embedding space of the adaptation modules, trained with nearest neighbor information, to achieve higher performance gains than the original simple prototype-based classifier method. Our method, named Test-time Adaptation via Self-Training with nearest neighbor information (TAST), is composed of the following procedures: (1) adds randomly initialized adaptation modules on top of the feature extractor at the beginning of test time (Figure 1 ); (2) generates pseudo label distribution for a test data considering the nearest neighbor information; (3) trains the adaptation modules only a few times during test time to match the nearest neighbor-based pseudo label distribution and a prototype-based class distribution for the test data; and (4) predicts the label of test data using the average predicted class distribution from the adaptation modules. Specifically, in (1), we add the trainable adaptation modules to obtain new feature embeddings that are useful for classification in the test domain. In (2), TAST assigns the mean of the labels of the nearby examples in the embedding space as the pseudo label distribution for the test data based on the idea that a test data and its nearest neighbors are more likely to have the same label. In (3), TAST trains the adaptation modules to output the pseudo label distribution when the test data is fed into (Figure 1 Right). And in (4), we average the predicted class distributions from adaptation modules for the prediction of test data (Figure 1 Left). We investigate the effectiveness of TAST on two standard benchmarks, domain generalization and image corruption. We demonstrate that TAST outperforms the current state-of-the-art test-time adaptation methods such as Tent (Wang et al., 2021a), T3A, and TTT++ (Liu et al., 2021) on the two benchmarks. For example, TAST surpasses the current state-of-the-art algorithm by 1.01% on average with ResNet-18 learned by Empirical Risk Minimization (ERM) on the domain generalization benchmarks. Extensive ablation studies show that both the nearest neighbor information and the adaptation module utilization contribute to the performance increase. Moreover, we experimentally found that the adaptation modules adapt feature extractor outputs effectively although the adaptation modules are randomly initialized at the beginning of test time and trained with a few gradient steps per test batch during test time.

2. PRELIMINARIES

Test-time domain shift Consider a labeled dataset D train = {(x i , y i )} ntrain i=1 drawn from a distribution P train , where x ∈ R d and y ∈ Y := {1, 2, • • • , K} for a K-class classification problem.

availability

Our code is available at https://github.com/mingukjang/

