TEST-TIME ADAPTATION VIA SELF-TRAINING WITH NEAREST NEIGHBOR INFORMATION

Abstract

Test-time adaptation (TTA) aims to adapt a trained classifier using online unlabeled test data only, without any information related to the training procedure. Most existing TTA methods adapt the trained classifier using the classifier's prediction on the test data as pseudo-label. However, under test-time domain shift, accuracy of the pseudo labels cannot be guaranteed, and thus the TTA methods often encounter performance degradation at the adapted classifier. To overcome this limitation, we propose a novel test-time adaptation method, called Test-time Adaptation via Self-Training with nearest neighbor information (TAST), which is composed of the following procedures: (1) adds trainable adaptation modules on top of the trained feature extractor; (2) newly defines a pseudo-label distribution for the test data by using the nearest neighbor information; (3) trains these modules only a few times during test time to match the nearest neighbor-based pseudo label distribution and a prototype-based class distribution for the test data; and (4) predicts the label of test data using the average predicted class distribution from these modules. The pseudo-label generation is based on the basic intuition that a test data and its nearest neighbor in the embedding space are likely to share the same label under the domain shift. By utilizing multiple randomly initialized adaptation modules, TAST extracts useful information for the classification of the test data under the domain shift, using the nearest neighbor information. TAST showed better performance than the state-of-the-art TTA methods on two standard benchmark tasks, domain generalization, namely VLCS, PACS, OfficeHome, and TerraIncognita, and image corruption, particularly CIFAR-10/100C. Our code is available at https://github.com/mingukjang/TAST.

1. INTRODUCTION

Deep neural networks often encounter significant performance degradations under domain shift (i.e., distribution shift). This phenomenon has been observed in various tasks including classification (Taori et al., 2020; Wang et al., 2021b) , visual recognition (Saenko et al., 2010; Csurka, 2017) , and reinforcement learning (Cobbe et al., 2019; Mendonca et al., 2020; Lee and Chung, 2021b) . There are two broad classes of domain adaptation methods that attempt to solve this problem: supervised domain adaptation (SDA) (Tzeng et al., 2015; Motiian et al., 2017) and unsupervised domain adaptation (UDA) (Ganin and Lempitsky, 2015; Long et al., 2016; Sener et al., 2016) . Both SDA and UDA methods aim to obtain domain-invariant representations by aligning the representations of training and test data closely in the embedding space. While testing, UDA methods require the training dataset and SDA methods additionally require labeled data of the test domain. However, in practice, it is often difficult to access training datasets or labeled data in the test domain during test time, due to data security or labeling cost. Test-time adaptation (TTA) (Iwasawa and Matsuo, 2021; Wang et al., 2021a ) is a prominent approach to alleviate the problems caused by the domain shift. TTA methods aim to adapt the trained model to the test domain without a labeled dataset in the test domain and any information related to the training procedure (e.g., training dataset, feature statistics of training domain (Sun et al., 2020; Liu et al., 2021; Eastwood et al., 2022) ). TTA methods have access to the online unlabeled test data only, whereas domain adaptation methods assume access to the whole (i.e., offline) test data. There are three popular categories for TTA: normalization-based method (Schneider et al., 2020) , entropy minimization (Liang et al., 2020; Wang et al., 2021a) and prototype-based methods (Iwasawa and Matsuo, 2021) . Normalization method replaces the batch normalization (BN) statistics of the trained model with the BN statistics estimated on test data, and does not update model parameters except for the BN layers. Entropy minimization methods fine-tune the trained feature extractor, which is the trained classifier except the last linear layer, by minimizing the prediction entropy of test data. These methods force the classifier to have over-confident predictions for the test data, and thus have a risk of degrading model calibration (Guo et al., 2017; Mukhoti et al., 2020) , a measure of model interpretability and reliability. One form of entropy minimization is self-training (Rosenberg et al., 2005; Lee, 2013; Xie et al., 2020) . Self-training methods use predictions from the classifier as pseudo labels for the test data and fine-tune the classifier to make it fit to the pseudo labels. These methods have a limitation that the fine-tuned classifier can overfit to the inaccurate pseudo labels, resulting in confirmation bias (Arazo et al., 2020) . This limitation can be harmful when the performance of the trained classifier is significantly degraded due to the domain shift. On the other hand, Iwasawa and Matsuo (2021) proposed a prototype-based TTA method, named T3A, that simply modifies a trained linear classifier (the last layer) by using the pseudo-prototype representations of each class and the prototype-based classification for test data, where the prototypes are constructed by previous test data and the prediction for the data from trained classifier. T3A does not update the trained feature extractor at test time. T3A is simple but it brings a marginal performance gain (Table 1 and 3 ). In this work, we propose a new test-time adaptation method, which is simple yet effective in mitigating the confirmation bias problem of self-training, by adding adaptation modules on top of the feature extractor, which are simply trainable during test time. We use the prototype-based classifier as in T3A, but not in the embedding space of the original feature extractor but in the embedding space of the adaptation modules, trained with nearest neighbor information, to achieve higher performance gains than the original simple prototype-based classifier method. Our method, named Test-time Adaptation via Self-Training with nearest neighbor information (TAST), is composed of the following procedures: (1) adds randomly initialized adaptation modules on top of the feature extractor at the beginning of test time (Figure 1 ); (2) generates pseudo label distribution for a test data considering the nearest neighbor information; (3) trains the adaptation modules only a few times during test time to match the nearest neighbor-based pseudo label distribution and a prototype-based class distribution for the test data; and (4) predicts the label of test data using the average predicted class distribution from the adaptation modules. Specifically, in (1), we add the trainable adaptation modules to obtain new feature embeddings that are useful for classification in the test domain. In (2), TAST assigns the mean of the labels of the nearby examples in the embedding space as the pseudo label distribution for the test data based on the idea that a test data and its nearest neighbors are more likely to have the same label. In (3), TAST trains the adaptation modules to output the pseudo label distribution when the test data is fed into (Figure 1 Right). And in (4), we average the predicted class distributions from adaptation modules for the prediction of test data (Figure 1 Left). We investigate the effectiveness of TAST on two standard benchmarks, domain generalization and image corruption. We demonstrate that TAST outperforms the current state-of-the-art test-time adaptation methods such as Tent (Wang et al., 2021a) , T3A, and TTT++ (Liu et al., 2021) on the two benchmarks. For example, TAST surpasses the current state-of-the-art algorithm by 1.01% on average with ResNet-18 learned by Empirical Risk Minimization (ERM) on the domain generalization benchmarks. Extensive ablation studies show that both the nearest neighbor information and the adaptation module utilization contribute to the performance increase. Moreover, we experimentally found that the adaptation modules adapt feature extractor outputs effectively although the adaptation modules are randomly initialized at the beginning of test time and trained with a few gradient steps per test batch during test time. as the pseudo label of x. We train the adaptation modules to predict the pseudo labels when the test data is fed into. Notice that the feature extractor f θ is frozen during test time.

2. PRELIMINARIES

A number of classifiers have been proposed that easily classify unseen test data under the i.i.d. assumption that unseen test data D test is drawn from the same distribution as training data, i.e., P train = P test . We assume the classifier is a deep neural network composed of two parts: a feature extractor f θ : R d → R dz and a linear classifier g w : R dz → Y, where θ and w are the neural network parameters. ERM optimizes θ and w to obtain a good classifier for future samples in D test by minimizing the objective function L(θ, w) = E (x,y)∈D train [l(g w (f θ (x)), y)], where l is a loss function such as cross-entropy loss. However, under the test-time domain shift (i.e., distribution shift), the i.i.d. assumption between the training and test distributions does not hold, i.e., P train ̸ = P test , and the trained classifiers often show poor classification performance for the test data. Prototype-based classification in test-time adaptation Prototype-based classification refers to a method that obtains prototype representations, which represent each class in the embedding space, and then predicts the label of an input as the class of the nearest prototype. Since labeled data is not available in the TTA setting, T3A (Iwasawa and Matsuo, 2021) utilizes a support set that is composed of previous test data and their predictions for the test data by the trained classifier. T3A does not modify parameters of the classifier. Since the embedding space of the feature extractor is unchanged during the test time, T3A constructs the support set using the feature representations for test data instead of the data itself. Specifically, a support set S t = {S 1 t , S 2 t , • • • , S K t } is a set of test samples until time t. The support set is initialized with the weight of the last linear classifier, i.e., S k 0 = w k ∥w k ∥ , where w k is the parts of w related to k-th class for k = 1, 2, . . . , K. At time t, the support set is updated as S k t = S k t-1 ∪ f θ (xt) ∥f θ (xt)∥ if arg max c p c = k S k t-1 otherwise, where p k represents the likelihood that the classifier assigns x t to the k-th class. Using the support set S k t , one can obtain the class prototype for class k by taking the centroid of the representations in the support set. Formally, the prototype µ k for class k is computed as µ k = 1 |S k t | z∈S k t z for k = 1, 2, • • • , K. Then, the prediction for an input x t is made by comparing the distances between the embedding of x t and the prototypes, i.e., ŷ = arg min c d(f θ (x t ), µ c ) with a predefined metric d such as Euclidean distance or cosine similarity. Update the support set S with eq. ( 1) and the entropy-based filtering Retrieve the nearest neighbors N (x; S) for all x ∈ B with eq. ( 2) for t = 1 : T do for i = 1 : Ne do for x ∈ B do Obtain the nearest neighbor-based pseudo label pTAST i (•|x) of x with eq. ( 4) Compute the prototype-based class distribution p proto i (•|x) of x with eq. ( 6) end for ϕi ← ϕi -α∇ ϕ i 1 |B| x∈B CE(p TAST i (•|x), p proto i (•|x)) end for end for Compute the predictions for all x ∈ B with eq. ( 8) S k t ←{z|z ∈ S k t , H(σ(g w (z))) ≤ α k } , where α k is the M -th largest prediction entropy of the samples from S k t , H is Shannon entropy (Lin, 1991) , and σ is the softmax function. T3A modifies only the support set configuration and does not update the trained model parameters at test time. Thus, T3A cannot effectively mitigate the classification performance degradation caused by test-time domain shift. To address this issue, we extract useful information for classification of the test data by utilizing multiple randomly initialized adaptation modules that are trained using nearest neighbor-based pseudo labels.

3. METHODOLOGY

In this section, we describe two main components of our method TAST: adaptation module utilization (Section 3.1) and pseudo-label generation considering nearest neighbor information (Section 3.2).

3.1. ADAPTATION MODULE

We first discuss the parts to be fine-tuned in the trained classifier before explaining our test-time adaptation method. One possible choice is to fine-tune the whole network parameters in the classifier during test time, but this approach can be unstable and inefficient (Wang et al., 2021a; Kumar et al., 2022) . Another choice is to fine-tune only the parameters of batch normalization (BN) layers in the classifier as in Wang et al. (2021a) . Although it achieves effective test-time adaptation, it has a limitation that it can be utilized only if there are BN layers in the trained classifier. The other choice is to train a new classifier added on top of the frozen feature extractor during test time as in Lee and Chung (2021a) . We construct the new classifier by adding a randomly initialized adaptation module as illustrated in Figure 1 . During the test time, we train the adaptation module and predict the label of the test data using prototype-based class distributions from the adaptation module. The random initialization of the adaptation module may cause performance degradation of trained classifier. Thus, we consider an ensemble scheme (Wen et al., 2020; YM. et al., 2020; Mesbah et al., 2021) to alleviate the issues caused by the random initialization of the adaptation modules to obtain more robust and accurate predictions. We train the adaptation modules independently and predict the label of the test data using the average predicted class distribution from the adaptation modules.

3.2. SELF-TRAINING WITH NEAREST NEIGHBOR INFORMATION

TAST generates pseudo label distributions for unlabeled test data with the nearest neighbor information and fine-tunes the adaptation modules with the pseudo label distributions. The whole adaptation procedure of TAST is described in Algorithm 1. We first update the support set S and filter out the unconfident examples from the support set as in Iwasawa and Matsuo (2021) . Then, we find N s nearby support examples of test data x in the embedding space of f θ . We denote N (x; S) as the set of nearby support examples of x, N (x; S) := {z ∈ S|d(f θ (x), z) ≤ β x }, where β x is the distance between x and the N s -th nearest neighbor of x from S in the embedding space of f θ . Each adaptation module is trained individually during test time. For the i-th adaptation module h ϕifoot_1 , we compute the prototype representations µ i,1 , µ i,2 , . . . , µ i,K in the embedding space of h ϕi • f θ with a support set S = {S 1 , S 2 , • • • , S K }, i.e., µ i,k = 1 |S k | z∈S k h ϕi (z), for k = 1, 2, . . . , K. With the prototypes, we compute the prototype-based predicted class distribution of the nearby support examples in the embedding space of h ϕi • f θ , i.e., for z ∈ N (x; S), the likelihood that the prototype-based classifier assigns z to the k-th class is computed as p proto i (k|z) := exp(-d(h ϕi (z), µ i,k )/τ ) c exp(-d(h ϕi (z), µ i,c )/τ ) , ( ) where τ is the softmax temperaturefoot_2 . With the nearest neighbor information, TAST generates a pseudo label distribution pTAST i of x by aggregating prototype-based predicted class distribution of the nearby support examples in N (x; S) as pTAST i (k|x) := 1 N s z∈N (x;S) 1[arg max c p proto i (c|z) = k], for k = 1, 2, . . . , K. Specifically, we use the one-hot class distributions for pseudo label generation as in Lee (2013) ; Sohn et al. (2020) . Then, we fine-tune the adaptation modules by minimizing the cross-entropy loss between the predicted class distribution of the test example and the nearest neighbor-based pseudo label distribution: L TAST (ϕ i ) = 1 |D test | x∈D test CE(p TAST i (•|x), p proto i (•|x)), p proto i (k|x) := exp(-d(h ϕi (f θ (x)), µ i,k )/τ ) c exp(-d(h ϕi (f θ (x)), µ i,c )/τ ) , k = 1, 2, . . . , K, where CE denotes the standard cross-entropy loss. We iterate the pseudo labeling and fine-tuning processes for T steps per batch. We note that our method does not propagate gradients into the pseudo labels as in Laine and Aila (2017) ; Berthelot et al. (2019) . Finally, we predict the label of x using the average predicted class distribution p TAST i from the adaptation modules, i.e., p TAST i (k|x) := 1 N s z∈N (x;S) p proto i (k|z) (7) ŷTAST = arg max c p TAST (c|x) = arg max c 1 N e Ne i=1 p TAST i (c|x) Additionally, we consider a variant of TAST, named TAST-BN, that fine-tunes the BN layers instead of adaptation modules. The support set stores the test data itself instead of the feature representations since the embedding space of the feature extractor steadily changes during the test time. The pseudocode for TAST-BN is presented in Appendix B.

4. EXPERIMENTS

In this section, we show the effectiveness of our method compared to the state-of-the-art test-time adaptation methods on two standard benchmarks, i.e., domain generalization and image corruption. We compare TAST with the following baseline methods: (1) Pseudo Labeling (PL) (Lee, 2013) fine-tunes only the parameters of the BN layers to minimize the prediction entropy of test data; (4) TentAdapter is a modified version of Tent that adds a BN layer between the feature extractor and the last linear classifier, and fine-tunes only the added BN layer; (5) TentClf is a modified version of Tent that fine-tunes only the last linear classifier instead of the BN layers; (6) SHOTIM (Liang et al., 2020) updates the feature extractor to maximize the mutual information between an input and its prediction; (7) SHOT is a method that adds a pseudo-label loss to SHOTIM; (8) T3A predicts the label of the test data by comparing distances between test data and the generated pseudoprototypes. Originally, SHOT is one of source-free domain adaptation methods which focus on the offline setting, but we compare our method with the online version of SHOT for a fair comparison.

4.1. DOMAIN GENERALIZATION

The domain generalization benchmarks are designed to evaluate the generalization ability of the trained classifiers to the unseen domain. The evaluation is performed by a leave-one-domain-out procedure, which uses a domain as a test domain and the remaining domains as training domains. We use the publicly released codefoot_3 of T3A for the domain generalization benchmarks.

4.1.1. EXPERIMENTAL SETUP

Training setup We test TAST on four domain generalization benchmarks, specifically VLCS (Fang et al., 2013) , PACS (Li et al., 2017) , OfficeHome (Venkateswara et al., 2017) , and TerraIncognita (Beery et al., 2018) . For a fair comparison, we follow the training setup including dataset splits and hyperparameter selection method used in T3A. We use residual networks (He et al., 2016) including batch normalization layers with 18 and 50 layers (hereinafter referred to as ResNet-18 and ResNet-50, respectively), which are widely used for classification tasks. We train the networks with various learning algorithms such as ERM and CORAL (Sun and Saenko, 2016) . Details about the learning algorithms are explained in Appendix A. The backbone networks are trained with the default hyperparameters introduced in Gulrajani and Lopez-Paz (2021). We use a BatchEnsemble (Wen et al., 2020) , which is an efficient ensemble method that reduces the computational cost by weight-sharing, for the adaptation modules of TAST. The output dimension of each adaptation module is set to a quarter of the output dimension of the feature extractorfoot_4 , e.g., 128 for ResNet-18. We use Kaiming normalization (He et al., 2015) for initializing the adaptation modules at the beginning of test time. We run experiments using four different random seeds. More details on the benchmarks and the training setups can be found in Appendix A. Moreover, a discussion on computation complexity such as runtime comparison is summarized in Appendix A. Hyperparameters For a fair comparison, the baseline methods use the same hyperparameters as in Iwasawa and Matsuo (2021) . TAST uses the same set of possible values for each hyperparameter with baseline methods. TAST involves four hyperparameters: the number of gradient steps per adaptation T , the number of support examples per each class M , the number of nearby support examples N s , and the number of adaptation modules N e . We define a finite set of possible values for each hyperparameter, N s ∈ {1, 2, 4, 8}, T ∈ {1, 3}, and M ∈ {1, 5, 20, 50, 100, -1}, where -1 means to storing all samples without filtering. N e is set to 20. We use Adam optimizer with a learning rate of 0.001. More details on the hyperparameters can be found in Appendix A. Moreover, refer to Appendix C for the sensitivity analysis on hyperparameters including the test batch sizes.

4.1.2. EXPERIMENTAL RESULTS

In Table 1 , we summarize the experimental results of test-time adaptation methods using classifiers trained by ERM. Our method consistently improves the performance of the trained classifiers by 2.17% for ResNet-18 and 1.21% for ResNet-50 on average, respectively. TAST also outperforms the baseline methods including the state-of-the-art test-time adaptation method T3A. Compared to T3A, TAST shows better performance by 1.01% for ResNet-18 and 0.69% for ResNet-50 on average, respectively. Especially, we find that our method significantly improves the performance of the trained classifiers in the TerraIncognita benchmark, which is a challenging benchmark in that the trained classifier shows the lowest prediction accuracy. We observe that the performance of the baseline methods, which fine-tune the feature extractors, is lower than that of the classifiers without adaptation, whereas TAST-BN improves the performance of the trained classifiers. Refer to Appendix C for the experimental results of test-time adaptation methods using classifiers trained by different learning algorithms such as CORAL (Sun and Saenko, 2016) and MMD (Li et al., 2018) .

Effect of nearest neighbor information

To understand the effect of nearest neighbor information, we compare Tent and TAST-BN, both of which fine-tine the BN layers. To adjust the BN layers, Tent uses entropy minimization loss, whereas TAST-BN uses the pseudo-label loss using the nearest neighbor information. As shown in Table 1 , the performances of TAST-BN is better than those of Tent by 3.3% for ResNet-18 and 2.89% for ResNet-50, respectively. In addition, we consider an ablated variant of TAST, named TAST-N, that removes adaptation modules from TAST. TAST-N is optimization-free and has the same support set configuration as T3A. T3A uses the prototype-based prediction of the test data itself, whereas TAST-N uses the aggregated predicted class distribution of the nearby support examples. As shown in Table 2 , the prediction using the nearest neighbor information leads to a performance gain of 0.43% on average. Effect of adaptation modules TAST adds randomly initialized adaptation modules on top of the trained feature extractor as illustrated in Figure 1 and trains the adaptation modules during test time. For each test batch, we update the adaptation modules T times using pseudo label distributions considering nearest neighbor information. We set T to 1 or 3 throughout all experiments. To verify that the few step updates are sufficient to train the adaptation modules, we conduct experiments with different T ∈ {0, 1, 2, 4, 8}. We test on PACS using classifiers learned by ERM while M and N s are set to -1 and one of {1, 2, 4, 8}. We summarize the experimental results in Figure 2 . We observe that the performance of the adapted classifier is better than that of the non-adapted classifier (i.e., T = 0) and robust to changes in T . Hence, we conjecture that we can obtain a sufficiently good adaptation module with a few-step updates similar to Lee and Chung (2021a) . In addition, to investigate the effect of adaptation modules, we test TAST with a varying number of adaptation modules, e.g., N e ∈ {1, 5, 10, 20}. In Table 2 , we find that utilizing a single adaptation module leads to degraded performance than TAST-N. However, TAST with multiple adaptation modules shows improvement over TAST-N and T3A on average. 

4.2. IMAGE CORRUPTION

The image corruption benchmark is designed to evaluate the robustness of a classifier to unseen corrupted samples when the classifier is trained using clean samples. We use the publicly released codefoot_5 of TTT++ (Liu et al., 2021) for the image corruption benchmark. For a fair comparison, we compare our method with the online version of TTT++, which fine-tunes the feature extractor using the instance discrimination task along with matching the feature statistics of training and test time.

4.2.1. EXPERIMENTAL SETUP

We test the robustness of TAST to image corruption on CIFAR-10/100 (Krizhevsky and Hinton, 2009) , which is composed of generic images consisting of 10/100 classes, respectively. To make a corrupted test dataset, we apply 15 types of common image corruptions (e.g., Gaussian noise, shot noise) to the test dataset. We call the corrupted dataset CIFAR-10C/100C (Hendrycks and Dietterich, 2019) . We use the highest level (i.e., level-5) of image corruption for this experiment. 

4.2.2. EXPERIMENTAL RESULTS

The overall experimental results on CIFAR-10C/100C are summarized in Table 3 . We note that the best TTA method which achieves effective adaptation in the image corruption benchmarks can be different from that of the domain generalization benchmarks, since the two benchmarks deal with very different types of domain/distribution shifts. From Table 1 and 3 , we can observe that the test-time adaptation algorithms using the frozen feature extractor such as T3A and TAST show poor performance for image corruption benchmarks but better performance for domain generalization benchmarks, compared to those using the adapted feature extractor such as Tent and TAST-BN. Specifically, TAST-BN outperforms all the TTA methods and TTT++, and it achieves performance gains of 1.25% for CIFAR-10C and 4.56% for CIFAR-100C on average, compared to Tent, respectively. Refer to Appendix E for the detailed experimental results on 15 types of image corruptions. 2020) uses a rotation prediction task (Feng et al., 2019) , which predicts the rotation angle of the rotated images. Liu et al. (2021) use an instance discrimination task (Chen et al., 2020) . However, TTA methods, our focus in this paper, have no access to any information related to the training procedure. We empirically demonstrated that our method outperforms the existing test-time training methods on the image corruption benchmark even without the knowledge of the self-supervised learning task.

5. RELATED WORKS

Source-free domain adaptation methods Source-Free Domain Adaptation (SFDA) methods (Liang et al., 2020; Ishii and Sugiyama, 2021; Yeh et al., 2021; Eastwood et al., 2022) 

6. DISCUSSION

We proposed TAST to effectively adapt trained classifiers during test time considering nearest neighbor information. We demonstrated the efficiency and effectiveness of our method by conducting experiments on domain generalization and image corruption benchmarks. To the best of our knowledge, our work is the first one that utilizes an ensemble scheme that is built at test time for test-time adaptation. We expect that adaptation using the ensemble scheme can be combined with the other methods in source-free domain adaptation or test-time training. One of the limitations of TAST is the extension to large-scale benchmarks. TAST and TAST-BN require good prototypes in the embedding space for prediction and pseudo-labeling. To obtain good prototypes, TAST and TAST-BN construct and update the prototypes using the encountered pseudolabeled data during the test time. This prototype construction/update, however, can be ineffective for the large-scale benchmarks especially for too many classes and small batch sizes. Detailed discussion of TAST/TAST-BN on large-scale benchmarks and possible improvement of TAST-BN for large-scale benchmarks is described in Appendix D.

A BENCHMARK AND IMPLEMENTATION DETAILS A.1 DOMAIN GENERALIZATION BENCHMARKS

We test on four domain generalization benchmarks, specifically VLCS (Fang et al., 2013) , PACS (Li et al., 2017) , OfficeHome (Venkateswara et al., 2017) , and TerraIncognita (Beery et al., 2018) . VLCS is composed of photographic images from four different datasets (PASCAL VOC207 (Everingham et al., 2010) , LableMe (Russell et al., 2008) , Caltech 101 (Fei-Fei et al., 2007) All the hyperparameters for training and test-time adaptation are taken from T3A and DomainBed. We train the network with Adam optimizer with default hyperparameters introduced in DomainBed, e.g., a learning rate of 0.00005, a weight decay of 0, a dropout rate of 0, and a batch size of 32. In addition to the hyperparameters for test-time adaptation described in Section 4.1.1 of the manuscript, there is one more hyperparameter β for the baseline methods. The learning rate for test-time adaptation is obtained by multiplying β to the learning rate used in training time. We set the confidence threshold for PL and PLClf to 0.9. The possible values for β are set to 0.1, 1.0, and 10.0. For TAST-BN, we restrict the size of the whole support set to 150 due to effective memory usage and reduced runtime since the test data and the support examples are fed into the classifier for every test batch.

A.3 IMPLEMENTATION DETAILS ON IMAGE CORRUPTION BENCHMARKS

We use the same hyperparameters introduced in TTT++. We train ResNet-50 for 1000 epochs using the classification and instance discrimination tasks jointly. The weight on the instance discrimination task for balancing the two tasks is set to 0.1. For the instance discrimination task, we use the same data augmentation schemes of TTT++, e.g., RandomResizeCrop, RandomHorizontalCrop, HorizontalFlip, ColorJitter, RandomGrayscale, and Normalization. We set the batch size for training the networks to 256. At test time, PL, SHOT, and TTT++ use SGD optimizer with a learning rate of 0.001 and a momentum of 0.9. On the other hand, Tent, TAST, and TAST-BN use Adam optimizer with a learning rate of 0.001. We set the batch size to 128 during the test time due to effective memory usage. We run experiments using four different random seeds: 0, 1, 2, and 3. We set the confidence threshold for PL and PLClf to 0.9. For PL, we adjust only the BN layers in the trained model as in Tent. For TAST-BN, we restrict the size of the whole support set to 200. However, even in CIFAR-100C experiments, we can store only two support examples per class if the support set size is fixed at 200. Thus, we do not restrict the size of support set for TAST-BN on CIFAR-100C. We conduct our experiments on TITAN XP. We report the average runtime spent to adapt classifiers that use ResNet-18 as a backbone network in Table 4 . We note that TAST, which updates the support set and the adaptation modules, requires only 1/3 to 1/4 running time compared to the methods that update the entire feature extractors, e.g. SHOT or SHOTIM. On the other hand, TAST-BN, which updates the support set as well as the BN layer, requires more running time (about 2x) compared to SHOT or SHOTIM. The overhead is not significant though due to the online setting.

A.5 DETAILS ABOUT ADAPTATION MODULES

We use BatchEnsemble (BE) for the adaptation modules of our method. BE is a simple and efficient ensemble method that greatly reduces the computational cost by weight-sharing. Each ensemble member of BE is composed of two layers with a shared weight and rank-one factors. Specifically, the weight matrix of j-th ensemble member is W • r j s T j where W is a shared weight and r j s T j is the rank-one factor of j-th ensemble member. Although the existing deep ensemble (DE) methods do not share any weights, all ensemble members share W , and thus BE reduces the number of parameters compared to DE. Moreover, unlike DE, only the last layer of all ensemble members of BE are different, and thus it can be easily vectorized and trained simultaneously. Therefore, BE greatly reduces the computation cost. The adaptation module structure is used in many fields such as self-supervised learning (which is often called "projection head"). Although the existing methods mainly focus on training time, TAST focuses on test time. For example, SimCLR (Chen et al., 2020) adds a projection head on the top of a feature extractor at the beginning of training time and trains the feature extractor and the projection head with an instance discrimination loss. After the training time, for downstream tasks, SimCLR uses feature extractor outputs rather than projection head ones. However, TAST adds adaptation modules at the beginning of test time and trains the modules with the nearest neighbor-based pseudolabel distribution. To predict the label of test data, we use the averaged predicted class distribution from the adaptation modules.

B PSEUDOCODE FOR TAST-BN

We present the pseudocode for TAST-BN in Algorithm 2. TAST-BN fine-tunes the BN layers in the feature extractor instead of adaptation modules. Since the embedding space of the feature extractor steadily changes, the support set stores the test data itself instead of the feature representations. Formally, a support set S t = {S 1 t , S 2 t , . . . , S K t } is a set of test samples until time t. The support set is initialized as an empty set. At the time t, the support set is updated as Update the support set S with eq. ( 9) in Section B Retrieve the nearest neighbors N (x; S) for all x ∈ B with eq. ( 10) in Section B for t = 1 : T do Compute prototypes {µ k } K k=1 using the support set in the embedding space of f θ for z ∈ N (x; S) do S k t = S k t-1 ∪ {x t } , if arg max c p c = k S k t-1 , otherwise, p proto (k|z) ← exp(-d(f θ (z),µ k )/τ ) c exp(-d(f θ (z),µc)/τ ) , k = 1, 2, . . . , K end for for x ∈ B do pTAST (k|x) ← 1 Ns z∈N (x;S) 1[arg max c p proto (c|z) = k], k = 1, 2, . . . , K p proto (k|x) ← exp(-d(f θ (x),µ k )/τ ) c exp(-d(f θ (x),µc)/τ ) , k = 1, 2, . . . , K end for θ ← θ -α∇ θ 1 |B| x∈B CE(p TAST (•|x), p proto (•|x)) end for for x ∈ B do p TAST (k|x) ← 1 Ns z∈N (x;S) p proto (k|z), k = 1, 2, . . . , K ŷx ← arg max c p TAST (c|x) end for where p k is the likelihood the classifier assigns x t to the class k. Using the support set, we retrieve N s nearby support examples of x in the embedding space of f θ , i.e., N (x; S) := {z ∈ S|d(f θ (x), f θ (z)) ≤ β x }, where β x is the distance between x and the N s -th nearest neighbor of x from S in the embedding space of f θ . Then, we generate a pseudo label distribution for the test data and fine-tune the BN layers to match the nearest neighbor-based pseudo label and a prototype-based class distributions for the test data with the same procedure described in Section 3 of the manuscript.

ALGORITHMS

In Table 5 , we show the results of test-time adaptation methods using classifiers trained by three different learning algorithms, namely CORAL, MMD, and Mixup. TAST consistently enhances the performance of the trained classifiers on the benchmarks by 1.73%, 1.81%, and 2.30% on average using the classifiers trained by CORAL, MMD, and Mixup, respectively. We find that TAST has a minor performance gain compared to the results in Table 1 of manuscript, whereas it surpasses T3A on most of benchmarks. Compared to T3A, TAST shows better performance on the benchmarks by 0.21%, 0.14%, and 0.40% on average with the classifiers trained by CORAL, MMD, and Mixup, respectively. Refer to Section 4 in Appendix E for the experimental results of the other baseline methods.

C.2 FINE-TUNING BOTH ADAPTATION MODULES AND BN LAYERS SIMULTANEOUSLY

We consider a method, named TAST-both, that fine-tunes both the attached adaptation modules and the BN layers in the feature extractor simultaneously. Table 6 reports the experimental results using classifiers learned by ERM on domain generalization benchmarks. We use ResNet-18 as a backbone network. As shown in Table 6 , TAST-both shows worse performance than TAST-BN and TAST. We conjecture that the random initialization of adaptation modules and the changes in feature representation due to BN layer training negatively affect the learning of the other layers. 

C.3 EXPERIMENTAL RESULTS USING DIFFERENT HYPERPARAMETERS ON CIFAR-10C

In Table 4 of the manuscript, we report the experimental results when N s and M are set to 1 and 100 on the CIFAR-10C, respectively. In Table 7 , we summarize the experimental results using different combinations of N s and M on the CIFAR-10C. There are two observations in Table 7 : (1) T3A has shown the best performances when M is set to 100; and (2) TAST and TAST-BN perform better with smaller N s .

C.4 SENSITIVITY ANALYSIS ON HYPERPARAMETERS

We follow the hyperparameter selection method used in T3A. We split the dataset of training domains into training and validation sets. The validation set is used to select hyperparameters that maximize the validation accuracy of the adapted classifier. On the other hand, for the image corruption benchmark, we use manually determined hyperparameters as in Tent. Thus, we summarized experimental results on other combinations of hyperparameters in Table 8 -11. Additionally, we investigate the sensitivity of two hyperparameters which are set manually throughout all experiments, the softmax temperature τ and the output dimension of adaptation modules  d ϕ . We set τ and d ϕ to 0.1 and d z /4, where d z is the output dimension of the feature extractor. In Table 8 -11, we report the average accuracy of the adapted classifier by TAST with the different combinations of τ and d ϕ . In the experiments, we use ResNet-18 as a backbone network trained by ERM on PACS, which is one of the domain generalization benchmarks. We experimentally show that the performance of TAST is robust to changes in τ and d ϕ . We especially think that the classification performance of TAST is not significantly affected by changes in τ because τ affects both the prototype-based predicted class distribution of test data and the new pseudo-label distribution using nearest neighbor information and then we train the adaptation modules with the cross-entropy loss affected by τ only a few times per each test batch during test time. Moreover, we can observe a similar classification performance regardless of the dimension of adaptation modules similar to Chen et al. (2020) . We used the test batch size as in T3A and Tent for domain generalization and image corruption benchmarks, respectively, as described in Appendix A and Section 4 of the manuscript. We summarize experimental results using different test batch size. We conduct experiments using classifiers, which have ResNet-18 backbone networks, learned by ERM on PACS. As shown in Table 12 , we can find that Tent and PL show reduced performance in experiments using smaller test batch size, but T3A, TAST, and TAST-BN are robust to changes in test batch size. 

D TAST ON IMAGENET-C

ImageNet-C is an image corruption benchmark such as CIFAR-10/100C, but it is a large-scale benchmark composed of larger images from more diverse classes. ImageNet-C is challenging for the existing test-time adaptation/training methods including TTT++. Also TAST and TAST-BN may struggle with ImageNet-C, since TAST and TAST-BN require prototypes to represent each class in the embedding space. To obtain good prototypes, a sufficient amount of data per class is required, but we have no access to any labeled data due to TTA settings. Pseudo-labeling alleviates this issue on CIFAR-10/100C, but not on ImageNet-C due to the following concerns: • The prototype updates of TAST and TAST-BN are based on the estimated labels of test data by the classifier, not the ground-truth labels. Under test-time domain shift, classifier bias may occur, which may result in assigning most test data only to a subset of classes. As observed in Chen et al. (2022) , the classifier bias often occurs under the covariate shift such as image corruption and style transfer. Then, even after a large number of batch updates which cover all the ground-truth classes by at least one sample, some prototypes may have not been updated since no previous test data has been classified to those classes. For example, we found that for the experiments with Gaussian noise, it took 768 batches out of 782 batches until all the prototypes were updated at least by once. • Since the number of classes ( 1000) is much larger than the test batch size (64), few prototypes for our method are updated per each test batch while the remaining prototypes remain unupdated. It might affect the performance of the prototype-based classification. To address this issue, it might require a batch size larger than 1000, which is impossible due to the hardware cost. When the number of classes ( 1000) is much larger than the test batch size (64), obtaining good prototypes for TAST-BN can be difficult especially at the early stage of test time as explained above. To alleviate the concerns, we consider a variant of TAST-BN, in which the prototypes are initialized with the weight of the last linear classifier as in TAST and fixed during the test time. We call this variant TAST-BN (w/ fixed prototypes). In Table 13 , we report the experimental results (test accuracy) on ImageNet-C with severity level 5 when we set (N s , M, T ) to (1, -1, 1). Of course, one can still update the prototypes over the test time, but the performance gain from the updating may not be as significant as before. Nonetheless, from the result of Table 13 , we can see that the effective adaptation on ImageNet-C can be achieved with the combination of the prototype approach and self-training (entropy minimization) method of TAST-BN (w/ fixed prototypes). 



We use the cosine similarity as a distance metric d for experiments throughout this paper. Detailed explanation about the adaptation modules is described in Section 4.1.1. and Appendix A. We set τ manually to 0.1 inspired byOreshkin et al. (2018) for experiments throughout this paper. More experimental results with different τ are summarized in Appendix C https://github.com/matsuolab/T3A More experimental results with different output dimensions are summarized in Appendix C. https://github.com/vita-epfl/ttt-plus-plus



Test-time domain shift Consider a labeled dataset D train = {(x i , y i )} ntrain i=1 drawn from a distribution P train , where x ∈ R d and y ∈ Y := {1, 2, • • • , K} for a K-class classification problem.

Figure 1: Overview of TAST. Left: A schematic of TAST compared to T3A. The dashed class indicates the ground-truth class. (a) T3A constructs prototypes that represent classes in the embedding space of feature extractor f θ using a support set S. Then T3A predicts the label of the test data x as the class of the nearest prototype. (b) TAST adds trainable adaptation modules {h ϕi } on top of f θ and computes the estimated class distributions of x by aggregating the pseudo labels of the nearest support examples of x in the embedding space of adaptation modules. Right: Overview of TAST training. Based on the intuition that a test data x and its nearest neighbors N (x; S) are likely to share the same label, we use the mean of prototype-based predictions of the support examples in N (x; S)as the pseudo label of x. We train the adaptation modules to predict the pseudo labels when the test data is fed into. Notice that the feature extractor f θ is frozen during test time.

Test-time training methods Test-time training methods fine-tune trained classifiers by the selfsupervised learning task used at training time. Sun et al. (

1 Since the wrongly pseudo labeled examples can degrade the classification performance, the support examples with unconfident pseudo labels are regarded as unreliable examples and filtered out, i.e., at time stamp t, Algorithm 1 Test-time Adaptation via Self-Training with nearest neighbor information (TAST) Feature extractor f θ , number of adaptation modules Ne, adaptation modules {h ϕ i } Ne i=1 , test batch B, support set S, number of gradient steps per adaptation T , number of support examples per each class M , number of nearby support examples Ns, learning rate α Ensure: Predictions for all x ∈ B

Average accuracy (%) using classifiers learned by ERM on the domain generalization benchmarks. We use ResNet-18 and ResNet-50 as backbone networks. Bold indicates the best performance for each benchmark. Underline indicates the best performance among the baseline methods for each benchmark. Most of the baseline methods degrade the classification performance of the trained classifiers on the benchmarks. However, our method consistently outperforms all the baselines on all of the benchmarks.

Ablation studies to evaluate the effects of the number of adaptation module and the nearest neighbor information. We use ResNet-18 trained by ERM. TAST-N is a method that removes adaptation modules from TAST.



We use ResNet-50 as a backbone network. For a fair comparison, we use the released trained model of Liu et al. (2021) and the same hyperparameters whenever possible. The number of nearby support examples N

Ensemble scheme in test-time adaptation BACS(Zhou and Levine, 2021), which incorporates a Bayesian inference framework into the TTA setting, adapts the trained model to an unseen test domain with a regularization term induced by a posterior approximated at training time. BACS constructs the ensemble of predictive models to obtain diverse labeling for uncertainty estimates at the beginning of training time and trains the models independently during training time. During test time, BACS averages the predictions of the adapted ensemble members. On the other hand, TAST builds an ensemble of adaptation modules to alleviate the issues caused by the random initialization of the modules at the beginning of test time.

, and SUN09(Choi et al., 2010)), consisting of 10,729 examples of 5 categories(bird, car, chair, dog, and person). PACS is composed of images of objects from four different domains (photo, art, cartoon, and sketch), consisting of 9,991 examples of 7 categories (dog, elephant, giraffe, guitar, horse, house, and person). OfficeHome is composed of images of objects in the office and home from 4 different domains (artistic images, clip art, product, and real-world images), consisting of 15,588 examples of 65 categories (e.g., alarm clock, backpack, and batteries). TerraIncognita is composed of wild animal images taken from 4 different locations (L100, L38, L43, and L46), consisting of 24,788 examples of 10 classes.A.2 IMPLEMENTATION DETAILS ON DOMAIN GENERALIZATION BENCHMARKSWe follow the dataset splits and the hyperparameter selection method used in T3A. We split each dataset of training domains into training and validation sets. The training and validation sets are used for network training and hyperparameter selection, respectively. Specifically, we split each dataset into 80% and 20% and use the smaller set as the validation set. We choose the hyperparameters that maximize the validation accuracy of the adapted classifier. This hyperparameter selection method is called the training-domain validation. We train backbone networks using four different learning algorithms: ERM, CORAL, MMD, and Mixup. ERM is explained in Section 2 of the manuscript; CORAL aims to obtain domain-invariant representations by aligning covariance matrices of training data and test data; MMD tries to match the training and test data distributions using the MMD measure; Mixup trains classifiers using mixed images/features and mixed labels created by linear interpolation of examples from the training domains. We run experiments using four different random seeds: 0, 1, 2, and 3.

Mean runtime (sec) to adapt classifiers that use ResNet-18 as a backbone network with a single hyperparameter combination (T = 1, N s = 8, M = -1).

Feature extractor f θ , test batch B, support set S, number of gradient steps per adaptation T , number of support examples per each class M , number of nearby support examples Ns, learning rate α Ensure: Predictions ŷx for all x ∈ B

Average accuracy (%) on domain generalization benchmarks using classifiers trained by different learning algorithms, namely CORAL, MMD, and Mixup. We use ResNet-18 as a backbone network. Bold indicates the best performance for each benchmark. TAST and TAST-BN consistently improve the performance of the trained classifiers and they outperform T3A on most of the benchmarks.

Average accuracy (%) using classifiers trained by ERM on domain generalization benchmarks. We use ResNet-18 as a backbone network. TAST-both is a method fine-tunes both the attached adaptation modules and the BN layers simultaneously. TAST-both shows worse performances than TAST-BN and TAST.

Average error rate (%) in the online setting on CIFAR-10C with different hyperparameters.

Sensitivity analysis about the softmax temperature τ and the output dimension of adaptation modules d ϕ . Average accuracy on test environment A using classifiers learned by ERM on PACS.

Sensitivity analysis about the softmax temperature τ and the output dimension of adaptation modules d ϕ . Average accuracy on test environment C using classifiers learned by ERM on PACS.

Sensitivity analysis about the softmax temperature τ and the output dimension of adaptation modules d ϕ . Average accuracy on test environment P using classifiers learned by ERM on PACS.

Sensitivity analysis about the softmax temperature τ and the output dimension of adaptation modules d ϕ . Average accuracy on test environment S using classifiers learned by ERM on PACS.

Ablation studies to evaluate the effects of the test batch size.

Accuracy of TAST-BN (w/ fixed prototypes) on ImageNet-C

Full results using classifiers trained by ERM for Table1of the manuscript on VLCS. We use ResNet-18 as a backbone network.

Full results using classifiers trained by ERM for Table1of the manuscript on PACS. We use ResNet-18 as a backbone network.

Full results using classifiers trained by ERM for Table1of the manuscript on OfficeHome. We use ResNet-18 as a backbone network.

Full results using classifiers trained by ERM for Table1of the manuscript on TerraIncognita. We use ResNet-18 as a backbone network.

Full results using classifiers trained by ERM for Table1of the manuscript on VLCS. We use ResNet-50 as a backbone network.

Full results using classifiers trained by ERM for Table1of the manuscript on PACS. We use ResNet-50 as a backbone network.

Full results using classifiers trained by ERM for Table1of the manuscript on OfficeHome. We use ResNet-50 as a backbone network.

Full results using classifiers trained by ERM for Table1of the manuscript on TerraIncognita. We use ResNet-50 as a backbone network.

Average accuracy(%) using classifiers trained by CORAL on the domain generalization benchmarks for Table5, namely VLCS, PACS, OfficeHome, and TerraIncognita. We use ResNet-18 and ResNet-50 as backbone networks. Bold indicates the best performance for each benchmark. Our proposed method TAST outperforms all the baselines on most of the benchmarks.

Full results using classifiers trained by CORAL for Table22on VLCS. We use ResNet-18 as a backbone network.

Full results using classifiers trained by CORAL for Table22on PACS. We use ResNet-18 as a backbone network.

Full results using classifiers trained by CORAL for Table22on OfficeHome. We use ResNet-18 as a backbone network.

Full results using classifiers trained by CORAL for Table22on VLCS. We use ResNet-50 as a backbone network.

Full results using classifiers trained by CORAL for Table22on PACS. We use ResNet-50 as a backbone network.

Full results using classifiers trained by CORAL for Table22on OfficeHome. We use ResNet-50 as a backbone network.

Average accuracy(%) using classifiers trained by MMD for Table5on the domain generalization benchmarks, namely VLCS, PACS, OfficeHome, and TerraIncognita. We use ResNet-18 as a backbone network.

Full results using classifiers trained by MMD for Table31on VLCS. We use ResNet-18 as a backbone network.

Full results using classifiers trained by MMD for Table31on PACS. We use ResNet-18 as a backbone network.

Full results using classifiers trained by MMD for Table31on OfficeHome. We use ResNet-18 as a backbone network.

Average accuracy(%) using classifiers trained by Mixup for Table5on the domain generalization benchmarks, namely VLCS, PACS, OfficeHome, and TerraIncognita. We use ResNet-18 as a backbone network.

Full results using classifiers trained by Mixup for Table 36 on VLCS. We use ResNet-18 as a backbone network.

Full results using classifiers trained by Mixup for Table 36 on PACS. We use ResNet-18 as a backbone network.

Full results using classifiers trained by Mixup for Table36on OfficeHome. We use ResNet-18 as a backbone network.

ACKNOWLEDGEMENT

This research was supported by the National Research Foundation of Korea under grant 2021R1C1C11008539, and by the Ministry of Science and ICT, Korea, under the IITP (Institute for Information and Communications Technology Panning and Evaluation) grant No.2020-0-00626.

annex

Published as a conference paper at ICLR 2023 .20 30.74 42.97 41.02 38.19 48.95 51.20 35.70 35.03 33.38 40.01 39.88 29.07 +TTT++ 47.10 29.99 31.10 32.61 47.73 51.74 41.37 57.36 60.40 38.93 39.01 37.17 45.34 44.53 31.31 

