LEARNING TEST TIME AUGMENTATION WITH CAS-CADE LOSS PREDICTION

Abstract

Data augmentation has been a successful common practice for improving the performance of deep neural network during training stage. In recent years, studies on test time augmentation (TTA) have also been promising due to its effectiveness on improving the robustness against out-of-distribution data at inference. Instead of simply adopting pre-defined handcrafted geometric operations such as cropping and flipping, recent TTA methods learn predictive transformations which are supposed to provide the best performance gain on each test sample. However, the desired iteration number of transformation is proportional to the inference time of the predictor, and the gain by ensembling multiple augmented inputs still requires additional forward pass of the target model. In this paper, we propose a cascade method for test time augmentation prediction. It only requires a single forward pass of the transformation predictor, while can output multiple desirable transformations iteratively. These transformations will then be adopted sequentially on the test sample at once before the target model inference. The experimental results show that our method provides a better trade-off between computational cost and overall performance at test time, and shows significant improvement compared to existing methods.

1. INTRODUCTION

Robustness in artificial intelligence system has been recognized as an important topic in recent years, especially for the application scenario closely related to human life or health, such as biometrics, autonomous driving, medical diagnosis, virtual and augmented reality and so on. Though heavily rely on training data, AI models in real-world will inevitably encounter unforeseen circumstances, which requires not only high performance from the aspect of accuracy, but also high robustness from the aspect of generalization. Data augmentation has been a successful strategy for improving the robustness in many deep learning model training applications. During training stage, various transformations are adopted on the input samples thus to expend the diversity of the training space without truly collecting novel data. In the community of computer vision, some basic augmentation operations are commonly used such as rotation, zoom-in and -out, cropping, flipping, translation, blur, contrast, etc. Some advanced techniques also explore sub-instance level operations such as mixing samples together (Zhang et al., 2018; DeVries & Taylor, 2017; Hendrycks et al., 2019) , or learnable augmentation search strategies (Cubuk et al., 2019; Lim et al., 2019; Hataya et al., 2020; Zheng et al., 2021) . While data augmentation during training time brings much benefit, the challenge would lie in the training cost and the difficulty, given the continually increasing size of the training dataset. On the other hand, training time augmentation cannot solve the issue once for all. In general, we consider that the performance on in-distribution data as the standard accuracy, the performance on the out-of-distribution data as the robustness or the generalization. Specifically, we consider some corruption will occur at test time that is unknown a priori. Consequently, such kind of corruption cannot be explicitly learnt during training stage e.g. by adopting certain data augmentation that attempts to explore the unknown data distribution. Test time augmentation (TTA) is defined as transforming samples before inference at test time. Conventional TTA always requires averaging multiple predictions over different augmented test samples to obtain a final prediction. The major performance gain of conventional TTA methods heavily lies in the ensembling mechanism (Lakshminarayanan et al., 2017) , which inevitably requires multiple forward passes of the inference model. Recent studies on learnable TTA methods put more focus on how to select the best transformation policies at each inference, i.e. the one supposed to provide the largest performance gain compared to no transformation (Kim et al., 2020; Chun et al., 2022) . By adopting instance-level transformation policies, these methods show significant improvement for corrupted (out-of-distribution) data without harm on clean (in-distribution) data. However, there exist still several limitations: (i) most methods still require model ensembling on different predictions to achieve the best performance; (ii) the desired iteration number of transformation before the inference is proportional to the cost of the transformation predictor, which limits the variety of the transformation; (iii) the transformation policy search is still under-explored thus leads to sub-optimal performance. In this paper, we propose a cascade loss prediction method that, for the first time, only requires a single forward pass of the transformation predictor, while can output multiple desirable transformations iteratively. Our contribution can be summarized as follows: • a novel cascade test time augmentation with sequential predictions by a single forward pass. • a better trade-off between target model performance and inference cost with the first compatibility and analysis on various network architectures. • a better exploration on the test data space which leads to state-of-the-art performance against various corruption benchmark.

2. RELATED WORKS

General Data Augmentation Traditional data augmentation aims at enlarging training datasets to improve predictive performance. Recent works explore more diverse strategies of data augmentation such as by mixing up the features and their corresponding labels (Zhang et al., 2018) , by cutting out some random certain area of mixed samples (DeVries & Taylor, 2017), or by cutting out then mixing up those samples with different strategies (Yun et al., 2019; Han et al., 2022) . On the other hand, there are some studies on trainable augmentation policy (Cubuk et al., 2019; Lim et al., 2019; Hataya et al., 2020; Zheng et al., 2021) . They focus rather on the exploration of larger data space and the automatic learning strategy for efficient training. These techniques are commonly used in many state-of-the-art models for their benefit on both accuracy and calibrations, bringing performance gain on standard benchmarks such as CIFAR (Krizhevsky et al., 2009) and ImageNet (Deng et al., 2009) . Out-of-Distribution Robustness Sufficient augmentation is also a successful practice to improve out-of-distribution robustness. Hendrycks & Dietterich (2018) 2020) selected suitable transformations for a test input based on their proposed loss predictor; without high additional computational cost, it carried out instance-level transformation at inference for the first time. However, the proposed method only explores one single trans-



built the first benchmark for evaluating model robustness given different image corruption at test time.Hendrycks et al. (2019)   proposed a simple data processing method to improve robustness; it augments training samples by mixing weighted random transformation operations and learns a distribution similarity between the original samples and the augmented samples. Wen et al. (2020) argued that simple model ensembles on top with such augmentation will degrade the performance, and then proposed a improved variant that dismisses the ones with high uncertainty.Zhang et al. (2021)  proposed to adapt the model parameters by minimizing the entropy of the model's average output distribution across the augmentations, at test time. Whereas the inference becomes expensive due to its augmentation and adaptation procedure, thus limits the usability for other models or tasks.Test Time Augmentation Given a trained model, conventional test time augmentation is often carried out together with model ensembling, that is at inference with different augmented test samples, such as the conventional transformations e.g. cropping or flipping.Lyzhov et al. (2020)  demonstrated that test time augmentation policies can be learned and introduced a greedy method for learning a policy of test time augmentation. Shanmugam et al. (2021) analyzed when and why test time augmentation works and presented a learning-based method for aggregating test time augmentations.Kim et al. (

