LEARNING TEST TIME AUGMENTATION WITH CAS-CADE LOSS PREDICTION

Abstract

Data augmentation has been a successful common practice for improving the performance of deep neural network during training stage. In recent years, studies on test time augmentation (TTA) have also been promising due to its effectiveness on improving the robustness against out-of-distribution data at inference. Instead of simply adopting pre-defined handcrafted geometric operations such as cropping and flipping, recent TTA methods learn predictive transformations which are supposed to provide the best performance gain on each test sample. However, the desired iteration number of transformation is proportional to the inference time of the predictor, and the gain by ensembling multiple augmented inputs still requires additional forward pass of the target model. In this paper, we propose a cascade method for test time augmentation prediction. It only requires a single forward pass of the transformation predictor, while can output multiple desirable transformations iteratively. These transformations will then be adopted sequentially on the test sample at once before the target model inference. The experimental results show that our method provides a better trade-off between computational cost and overall performance at test time, and shows significant improvement compared to existing methods.

1. INTRODUCTION

Robustness in artificial intelligence system has been recognized as an important topic in recent years, especially for the application scenario closely related to human life or health, such as biometrics, autonomous driving, medical diagnosis, virtual and augmented reality and so on. Though heavily rely on training data, AI models in real-world will inevitably encounter unforeseen circumstances, which requires not only high performance from the aspect of accuracy, but also high robustness from the aspect of generalization. Data augmentation has been a successful strategy for improving the robustness in many deep learning model training applications. During training stage, various transformations are adopted on the input samples thus to expend the diversity of the training space without truly collecting novel data. In the community of computer vision, some basic augmentation operations are commonly used such as rotation, zoom-in and -out, cropping, flipping, translation, blur, contrast, etc. Some advanced techniques also explore sub-instance level operations such as mixing samples together (Zhang et al., 2018; DeVries & Taylor, 2017; Hendrycks et al., 2019) , or learnable augmentation search strategies (Cubuk et al., 2019; Lim et al., 2019; Hataya et al., 2020; Zheng et al., 2021) . While data augmentation during training time brings much benefit, the challenge would lie in the training cost and the difficulty, given the continually increasing size of the training dataset. On the other hand, training time augmentation cannot solve the issue once for all. In general, we consider that the performance on in-distribution data as the standard accuracy, the performance on the out-of-distribution data as the robustness or the generalization. Specifically, we consider some corruption will occur at test time that is unknown a priori. Consequently, such kind of corruption cannot be explicitly learnt during training stage e.g. by adopting certain data augmentation that attempts to explore the unknown data distribution. Test time augmentation (TTA) is defined as transforming samples before inference at test time. Conventional TTA always requires averaging multiple predictions over different augmented test samples to obtain a final prediction. The major performance gain of conventional TTA methods heavily lies

