DYNAMIC BATCH NORM STATISTICS UPDATE FOR NATURAL ROBUSTNESS

Abstract

DNNs trained on natural clean samples have been shown to perform poorly on corrupted samples, such as noisy or blurry images. Various data augmentation methods have been recently proposed to improve DNN's robustness against common corruptions. Despite their success, they require computationally expensive training and cannot be applied to off-the-shelf trained models. Recently, updating only BatchNorm (BN) statistics of a model on a single corruption has been shown to improve its accuracy on that corruption significantly. However, adopting the idea at inference time when the type of corruption changes decreases the effectiveness of this method. In this paper, we harness the Fourier domain to detect the corruption type, a challenging task in the image domain. We propose a unified framework consisting of a corruption-detection model and BN statistics update that can improve the corruption accuracy of any off-the-shelf trained model. We benchmark our framework on different models and datasets. Our results demonstrate about 8% and 4% accuracy improvement on CIFAR10-C and ImageNet-C, respectively. Furthermore, our framework can further improve the accuracy of state-of-the-art robust models, such as AugMix and DeepAug.

1. INTRODUCTION

Deep neural networks (DNNs) have been successfully applied to solve various vision tasks in recent years. At inference time, DNNs generally perform well on data points sampled from the same distribution as the training data. However, they often perform poorly on data points of different distribution, including corrupted data, such as noisy or blurred images. These corruptions often appear naturally at inference time in many real-world applications, such as cameras in autonomous cars, x-ray images, etc. Not only DNNs' accuracy drops across shifts in the data distribution, but also the well-known overconfidence problem of DNNs impedes the detection of domain shift. One straightforward approach to improve the robustness against various corruptions is to augment the training data to cover various corruptions. Recently, many more advanced data augmentation schemes have also been proposed and shown to improve the model robustness on corrupted data, such as SIN Geirhos et al. (2018a) , ANT Rusak et al. (2020a) , AugMix Hendrycks et al. (2019) , and DeepAug Hendrycks et al. (2021) . Despite their effectiveness, these approaches require computationally expensive training or re-training process. Two recent works (Benz et al., 2021; Schneider et al., 2020) proposed a simple batch normalization (BN) statistics update to improve the robustness of a pre-trained model against various corruptions with minimal computational overhead. The idea is to only update the BN statistics of a pre-trained model on a target corruption. If the corruption type is unknown beforehand, the model can keep BNs updating at inference time to adapt to the ongoing corruption. Despite its effectiveness, this approach is only suitable when a constant flow of inputs with the same type of corruption is fed to the model so that it can adjust the BN stats accordingly. In this work, we first investigate how complex the corruption type detection task itself would be. Although corruption type detection is challenging in the image domain, employing the Fourier domain can make it much more manageable because each corruption has a relatively unique frequency profile. We show that a very simple DNN can modestly detect corruption types when fed with a specifically normalized frequency spectrum. Given the ability to detect corruption types in the Fourier domain, we adopt the BN statistic update method such that it can change the BN values dynamically based on the detected corruption type. The overall architecture of our approach is depicted in Figure 1 . First, we calculate the Fourier transform of the input image, and after applying a specifically designed normalization, it is fed to the corruption type detection DNN. Based on the detected corruption, we fetch the corresponding BN statistics from the BN stat lookup table, and the pre-trained network BNs are updated accordingly. Finally, the dynamically updated pre-trained network processes the original input image. In summary, our contributions are as follows: • We harness the frequency spectrum of an image to identify the corruption type. On ImageNet-C, a shallow 3-layer fully connected neural network can identify 16 different corruption types with 65.88% accuracy. The majority of the misclassifications occur between similar corruptions, such as different types of noise, for which the BN stat updates are similar nevertheless. • Our framework can be used on any off-the-shelf pre-trained model, even robustly trained models, such as AugMix Hendrycks et al. (2019) and DeepAug Hendrycks et al. (2021) , and further improves the robustness. • We demonstrate that updating BN statistics at inference time as suggested in (Benz et al., 2021; Schneider et al., 2020) does not achieve good performance when the corruption type does not continue to be the same for a long time. On the other hand, our framework is insensitive to the rate of corruption changes and outperforms these methods when dealing with dynamic corruption changes.

2. METHOD 2.1 OVERALL FRAMEWORK

The overview of our framework is depicted in Figure 1 . It consists of three main modules: A) a pre-trained model on the original task, such as object detection, B) a DNN trained to detect corruption type, and C) a lookup table storing BN statistics corresponding to each type of corruption. This paper mainly focuses on improving the natural robustness of trained DNNs. However, the framework can be easily extended to domain generalization and circumstances where the lookup table may update the entire model weights or even the model architecture itself. In (Benz et al., 2021; Schneider et al., 2020) , a simple BN statistic update has significantly improved the natural robustness of trained DNNs. Figure 2 show the effectiveness of their approach on various corruption types. The drawback of their approach is that the BN statistics obtained for one type of corruption often significantly degrades the accuracy for other types of corruption, except for similar corruption, such as different types of noise. The authors claim that in many applications, such as autonomous vehicles, the corruption type will remain the same for a considerable amount of time. Consequently, the BN statistics can be updated at inference time. However, neither of those papers has shown the performance of BN statistic update when the corruption type changes. We conduct an experiment in Section 3.4 to show that detecting corruption types and utilizing appropriate BN stats provides better results when the corruption type is not fixed.

2.3. CORRUPTION DETECTION

The average Fourier spectrum of different corruptions has been shown to have different visual appearances Yin et al. (2019) . However, conducting a corruption classification on the Fourier spectrum of individual images is not a trivial task. Feeding a DNN with the raw Fourier spectrum leads to poor results and unstable training. Here, we first visually investigate the Fourier spectrum of various corruption types. Then, we propose a tailored normalization technique and a shallow DNN to detect corruption types. We denote an image of size (d 1 , d 2 ) by x ∈ R d1×d2 . We omit the channel dimension here because the Fourier spectrum of all channels turns out to be similar, when the average is taken over all samples. We only show the results of the first channel here. We denote natural and corrupted data distribution by D n and D c , respectively. We denote 2D discrete Fourier transform operation by F . In this paper, we only consider the amplitude component of F since the phase component does not help much with corruption detection. Moreover, we shift the low-frequency component to the center for better visualization. Figure 3 shows the normalized Fourier spectrum of different corruption types in CIFAR10-C. The results on ImageNet-C is presented in Figure 4 . We explain the normalization process in the next paragraph. For visualization purposes, we clamp the values above one. However, we do not clamp pixel values of the input when fed to the corruption detection model. As shown in Figure 3 , most corruption types have a distinguishable average Fourier spectrum. The almost identical ones, i.e., different types of noise, are not needed to be distinguished accurately because the BN stat updates for one of them can improve the accuracy for others nevertheless, as shown in Figure 2 . + 1) for each corruption type, separately. For corruption detection purpose, we substitute the expected value over the entire corruption type dataset by an individual image, i.e., log( |F (x)| ϵn + 1). We empirically find this specific normalization to outperform others significantly. The intuition behind this normalization is twofold: First, natural images have a higher concentration in low frequencies Yin et al. (2019) . Although corrupted images also have large values on low-frequency components, they may also have large concentration on high-frequency components, depending on the corruption. Hence, we divide the values by ϵ n to ensure that model does not exclusively focus on low-frequency components during training. Second, the range of values from one pixel to another may vary multiple orders of magnitude, which causes instability during training. Typical normalization techniques on unbounded data, such as tanh or sigmoid transforms, leads to poor accuracy because values larger than a certain point converge to 1 and become indistinguishable. We employ a three-layer fully connected (FC) neural network for corruption-type detection. Despite having an image-like structure, we avoid using convolutional neural networks (CNNs) here because of the apparent absence of shift-invariance in the Fourier spectrum. Due to the symmetry in the Fourier spectrum, we only feed half of the Fourier spectrum to the model. For CIFAR10, we flatten the 2D data and feed it to a three-layers FC model with 1024, 512, and 16 neurons. Note that this We train the model with stochastic gradient descent (SGD) for 50 epochs. We decrease the learning rate by a factor of 10 at epochs 20 and 35. We only use a small number of samples for training, i.e., 100 samples per corruption/intensity, and we keep the rest for validation. Using the Fourier spectrum and the proposed normalization method, we achieve validation accuracy of 49.21% and 65.88% on CIFAR10-C and ImageNet-C, respectively. The same architecture and model capacity only yields 7.64% and 6.32% accuracy in the image domain. We also could not achieve good accuracy with CNNs in the image domain. The confusion matrix of the corruption detection is presented in Figure 5 . BN Statistics. In this paper, we specifically adopted BN stat update from Schneider et al. (2020) with parameters N = 1 and n = 1. For a corruption c, this choice of parameters indicates that we take an average of a natural BN stats and the BN stats of the corruption c. We compute BN stats from the same samples we use to train the corruption-type detection model. Due to the small sample size for BN stat adoption, we find that taking an average with natural BN stats leads to better results than only using the target corruption BN stats.

3.1. EVALUATION ON CIFAR10-C

Table 1 presents the results of CIFAR10-C over several models. Our approach improves the accuracy over all corruptions by around 8%. However, the accuracy over natural samples is dropped by less than 1%. Because the base model is trained on natural samples, any misclassification of natural samples in the corruption detection model negatively affects the model performance, while any correct classification of corruptions positively affects the accuracy. As shown in Table 2 , our approach significantly improves the accuracy over all the corruption types, except for brightness and JPEG corruption, in which the accuracy barely changes. Note that these two corruptions have the least improvement when BN stat is applied, as shown in Figure 2 .

3.2. EVALUATION ON IMAGENET-C

Evaluation results on ImageNet-C is shown in Table 3 and 4 . We observe a similar pattern as CIFAR10 with a slightly smaller improvement. Here, accuracy improvement is around 4%. Similarly, improvement occurs over all corruptions except for brightness and JPEG. 

3.3. EVALUATION ON ROBUST MODELS

In this section, we investigate if our approach can further improve the accuracy of state-of-the-art models on ImageNet-C. Table 5 presents the evaluation of five state-of-the-art models. Our approach consistently improves the performance of robust approaches even further. Note that here we exclude the data we use to train the corruption type detection model from the validation set. That explains the small discrepancy between the base accuracy reported in the paper and those in previous work.

3.4. INFERENCE TIME ADAPTATION

Two recent papers (Benz et al., 2021; Schneider et al., 2020) that investigated BN statistics update suggested that the idea can be used at inference time, and the model will adopt to a new corruption eventually. However, they have never empirically evaluated their performance for inference time adaptation. Here, we start with the original model trained on clean samples. Then, during evaluation, after a certain number of batches, we randomly pick another corruption and then continue evaluating the model. The samples within one batch come from only a single corruption, and there are 16 samples in each batch. We let the model BN stats be updated from the last ten batches at the beginning of each batch. Because our approach does not update the BN stat lookup table, it is insensitive to how the inference time evaluation is conducted, and consequently, the performance is similar. The results of the experiment are shown in Figure 6 . In CIFAR10, only in VGG-19 and only when we let the corruption stay the same for 32 consecutive batches our approach is underperformed. In ImageNet, both VGG-19 and ResNet18 outperforms our approach only after 32 successive batches. This experiment reveals that the original BN stat update mechanism in (Benz et al., 2021; Schneider et al., 2020) only works when input corruption remains the same for a considerable number of consecutive samples. Although this assumption is reasonable for some applications, such as autonomous vehicles with continuous stream input, it does not hold for many applications, particularly for non-stream inputs common in healthcare applications.

4. LIMITATIONS & DISCUSSION

One major limitation of the current framework is that it needs data samples from all corruption types to train the corruption type detection model. Although using the Fourier spectrum allows us to train the corruption detector easily with a small number of samples, it still limits the generalizability of the framework to unseen corruptions. One solution to this problem is to attach an off-the-shelf outlier detection mechanism or an uncertainty mechanism to discover new types of corruption at inference time. Then, we can make a new entry in the BN stat lookup table, and the model can gradually learn BN statistics at inference time by observing multiple samples from the new class. Hence, we can prevent the need to collect image samples from all corruptions during training. Another related perspective is to frame the supervised corruption type detection as an unsupervised problem. This reformulation is possible because the corruption labels themselves are nonessential in our framework. For example, we can use a clustering algorithm to cluster different corruption and then associate each cluster with an entry in the BN stats table. This strategy can also be extended to detect new clusters at inference time for better generalization. We will investigate this idea in future work. In this paper, Our framework is only evaluated on natural and corrupted images. We can employ the same corruption detection idea for domain detection. Since the pre-trained model does not need to be re-trained in our framework, it might be interesting to adopt our framework for domain generalization. For instance, a natural image and cartoon have distinguishable features, such as color distributions, Fourier spectrum, etc. Accurate domain detection might be a simple task if proper features are found. Currently, our framework accuracy is bounded by the BN statistics update proposed in (Benz et al., 2021; Schneider et al., 2020) . As a result, with the presence of perfect corruption/domain detection, the accuracy may not be improved if the BN statistic update does not work for the target corruption/domain. In the future, we will investigate other approaches to eliminate this limitation. 5 RELATED WORK Dodge et al. Dodge & Karam (2017) revealed that deep models' accuracy significantly drops with corrupted images despite having similar performance to humans on clean data. Several studies (Geirhos et al., 2018b; Vasiljevic et al., 2016) verified that training with some corruptions does not improve the accuracy for unseen corruptions. However, Rusak et al. (2020b) later challenged this notion by showing that Gaussian data augmentation can enhance the accuracy of some other corruptions as well. In (Benz et al., 2021; Schneider et al., 2020) , authors have shown that corruption accuracy can be significantly increased by only updating the BN statistics of a trained model on a specific corruption. Although it is claimed that it can be easily adopted at inference time by updating the model BN stats using a batch of most recent samples, the performance of the models has not been evaluated in a situation where the corruption type changes. There (2019) . In Stylized-ImageNet, the idea of using style-transfer were adopted for data augmentation Geirhos et al. (2018a) . Using adversarially learned noise distribution has been proposed in Rusak et al. (2020a) . In DeepAug Hendrycks et al. (2021) , images are passed through image-to-image models while being distorted to create new images leading to large improvements in robustness. The adoption of adversarially training to improve corruption robustness has not been consistent. For instance, Rusak et al. (2020b) has shown that adversarial training does not improve corruption robustness while Shen et al. (2021) and Ford et al. (2019) have reported otherwise, using l∞ adversarial training.

6. CONCLUSION

In this paper, we propose a framework where an off-the-shelf naturally trained vision model can be adapted to perform better against corrupted inputs. Our framework consists of three main components: 1) corruption type detector, 2) BN stats lookup table, and 3) an off-the-shelf trained model. Upon detecting the corruption type with the first component, our framework pulls the corresponding BN stats from the lookup table and substitutes the BN stats of the trained model. Then, the original image is fed to the updated trained model. Even though detecting the corruption type is a very challenging task in the image domain, we can use the Fourier spectrum of an image to detect the type of corruption. We use a shallow three-layer FC neural network that detects the corruption type based on Fourier amplitudes of the input. We show that this model can achieve significant accuracy by training on minimal samples. The same small sample size is shown to be also enough to obtain the BN stats stored in the BN stat lookup table.



https://github.com/bearpaw/pytorch-classification



Figure 1: Overall Framework

Figure 2: ResNet18 (ImageNet-C): The y-axis shows the corruption with which the model BN stats are updated. The x-axis shows the corruption on which the model performance is evaluated. The numbers in the cells are accuracy gain compared to the original model, the model with BN stats obtained from the natural dataset.

Figure 3: Normalized Fourier spectrum of CIFAR10-C dataset

Figure 5: Corruption type detection model's confusion matrix

s=1 E c,s / 5 s=1 E AlexN et c,s for ImageNet-C. Corruption error averaged over all 15 corruptions is denoted by mCE. Models. Our framework consists of two DNNs, namely the corruption type detector and a pre-trained model on the original task. The details of the corruption type detector model are explained in Section 2.3. For CIFAR10, we consider ResNet-20, ResNet-110 He et al. (2016), VGG-19 Simonyan & Zisserman (2014), WideResNet-28-10 Zagoruyko & Komodakis (2016), and DenseNet (L=100, k=12)

Figure6: BN stat update vs. our framework at inference time. The x-axis shows the number of consecutive batches for which the corruption remains the same during the evaluation. It takes many consecutive batches of the same corruption for models to catch up to the corruption change when inference time BN stat update is deployed. Our framework, however, is insensitive to the corruption changes and can adapt instantly.

are numerous data augmentation methods shown to improve corruption robustness. AutoAugment Cubuk et al. (2019) automatically searches for improved data augmentation policies but has been shown later to improve corruption error Yin et al. (2019). AugMix Hendrycks et al. (2019) combines a set of transforms with a regularization term based on the Jensen-Shannon divergence. It has been shown that applying Gaussian noise to image patches can also improve accuracy Lopes et al.

Datasets & Metrics. CIFAR10 datasetKrizhevsky et al. (2009) contains 32 × 32 color images of 10 classes, with 50,000 training samples and 10,000 test samples. ImageNet dataset Deng et al. (2009) contains around 1.2 millions images of 1000 classes. For ImageNet, we resize images to 256 × 256 and take the center 224 × 224 as input. CIFAR10-C and ImageNet-C datasets Hendrycks & Dietterich (2019) contain corrupted test samples of the original CIFAR10 and ImageNet. There are 15 test corruptions and 4 hold-out corruptions. For a fair comparison with previous work, we only use the 15 test corruptions as in(Benz et al., 2021;Schneider et al., 2020). Each corruption type, c, contains 5 different intensities or severity level, denoted by s. Similar toHendrycks et al. (2019), we use unnormalized corruption error uCE =

Evaluation results on CIFAR10-C

Per corruption accuracy on CIFAR10-C

Evaluation results on ImageNet-C

Per corruption accuracy on ImageNet-C

Accuracy of ResNet50 on ImageNet-C

