PROVABLE ROBUST LEARNING FOR DEEP NEURAL NETWORKS UNDER AGNOSTIC CORRUPTED SUPERVI-SION

Abstract

Training deep neural models in the presence of corrupted supervisions is challenging as the corrupted data points may significantly impact the generalization performance. To alleviate this problem, we present an efficient robust algorithm that achieves strong guarantees without any assumption on the type of corruption, and provides a unified framework for both classification and regression problems. Different from many existing approaches that quantify the quality of individual data points (e.g., loss values) and filter out data points accordingly, the proposed algorithm focuses on controlling the collective impact of data points on the averaged gradient. Even when a corrupted data point failed to be excluded by the proposed algorithm, the data point will have very limited impact on the overall loss, as compared with state-of-the-art filtering data points based on loss values. Extensive empirical results on multiple benchmark datasets have demonstrated the robustness of the proposed method under different types of corruptions.

1. INTRODUCTION

Corrupted supervision is a common issue in real-world learning tasks, where the learning targets are not accurate due to various factors in the data collection process. In deep learning models, such corruptions are especially severe, whose degree-of-freedom makes them easily memorize corrected examples and susceptible to overfitting (Zhang et al., 2016) . There are extensive efforts to achieve robustness against corrupted supervisions. A natural approach to deal with corrupted supervision in deep neural networks (DNNs) is to reduce the model exposure to corrupted data points during training. By detecting and filtering (or re-weighting) the possible corrupted samples, the learning is expected to deliver a model that is similar to the one trained on clean data (without corruption) (Kumar et al., 2010; Han et al., 2018; Zheng et al., 2020) . There are different criteria designed to identify the corrupted data points in training. For example, Kumar et al. 2014) utilized the prediction consistency of neighboring iterations. The success of these methods highly depends on the effectiveness of the detection criteria in correctly identifying the corrupted data points. Since the corrupted labels remain unknown throughout the learning, such "unsupervised" detection approaches may not be effective, either lack theoretical guarantees of robustness (Han et al., 2018; Reed et al., 2014; Malach & Shalev-Shwartz, 2017; Li et al., 2017) or provide guarantees under assumptions of the availability of prior knowledge about the type of corruption (Zheng et al., 2020; Shah et al., 2020; Patrini et al., 2017; Yi & Wu, 2019) . Besides, another limitation of many existing approaches is that, they are exclusively designed for classification problems (e.g., Malach & Shalev-Shwartz ( 2017 To tackle these challenges, this paper presents a unified optimization framework with robustness guarantees without any assumptions on how supervisions are corrupted, and is applicable to both classification and regression problems. Instead of developing an accurate criterion for detection corrupted samples, we adopt a novel perspective and focus on limiting the collective impact of corrupted samples during the learning process through robust mean estimation of gradients. Specifically, if our estimated average gradient is close to the gradient from the clean data during the learning iterations, then the final model will be close to the model trained on clean data. As such, a corrupted data point can still be used during the training when it does not considerably alter the averaged gradient. This observation has remarkably impact on our algorithm design: instead of explicitly quantifying (and identifying) individual corrupted data points, which is a hard problem in itself, we are now dealing with an easier task, i.e., eliminating training data points that significantly distort the mean gradient estimation. One immediate consequence of this design is that, even when a corrupted data point failed to be excluded by the proposed algorithm, data point is likely to have very limited impact on the overall loss, as compared with state-of-the-art filtering data points based on loss values. We perform experiments on both regression and classification with corrupted supervision on multiple benchmark datasets. The results show that the proposed method outperforms state-of-the-art.

2. BACKGROUND

Learning from corrupted data (Huber, 1992) has attracted considerable attention in the machine learning community (Natarajan et al., 2013) . Many recent studies have investigated robustness of classification tasks with noisy labels. For example, Kumar et al. ( 2010) proposed a self-paced learning (SPL) approach, which assigns higher weights to examples with smaller loss. A similar idea was used in curriculum learning (Bengio et al., 2009) , in which the model learns easy samples first before learning harder ones. Alternative methods inspired by SPL include learning the data weights (Jiang et al., 2018) and collaborative learning (Han et al., 2018; Yu et al., 2019) . Label correction (Patrini et al., 2017; Li et al., 2017; Yi & Wu, 2019) is another approach, which revises original labels in data with a goal to recover clean labels from corrupt ones. However, since we do not have access to which data points are corrupted, it is hard to get provable guarantees for label correction without strong assumptions regarding the corruption type. Accurate estimation of gradients is a key step for successful optimization. The relationship between gradient estimation and its final convergence has been widely studied in the optimization community. Since computing an approximated (and potentially biased) gradient is often more efficient than computing the exact gradient, many studies used approximated gradients to optimize their models and showed that they suffer from the biased estimation problem if there is no assumptions on the gradient estimation (d'Aspremont, 2008; Schmidt et al., 2011; Bernstein et al., 2018; Hu et al., 2020; Ajalloeian & Stich, 2020) . A closely related topic is robust estimation of the mean. Given corrupted data, robust mean estimation aims at generating an estimated mean μ such that the difference between the estimated mean on corrupted data and the mean of clean data μ -µ 2 is minimized. It was showed that median or trimmed-mean are the optimal statistics for mean estimation in one-dimensional data (Huber, 1992) . However, robustness in high dimension is quite challenging since applying the coordinate-wise optimal robust estimator would lead to an error factor O( √ d) that scales with the data dimension. Although some classical work, such as Tukey median (Tukey, 1975) 2016) successfully designed polynomial-time algorithms with dimension-free error bounds. The results have been widely applied to improve algorithmic efficiency in various scenarios (Dong et al., 2019; Cheng et al., 2020) . Robust optimization aims to optimize the model given corrupted data. Many previous studies improve the robustness of the optimization in different problem settings. However, most of them either study linear regression and its variantes (Bhatia et al., 2015; 2017; Shen & Sanghavi, 2019) or study the convex optimization (Prasad et al., 2018) . Thus, those results cannot be directly generalized to deep neural networks. Diakonikolas et al. ( 2019) is a very generalized non-convex optimization method with the agnostic corruption guarantee. However, the space complexity of the algorithm is high, thus cannot be applied to deep neural networks given current hardware limitations.

3. METHODOLOGY

Before introducing our algorithm, we first discuss the corrupted supervision. To characterize agnostic corruptions, we make use of an adversary that tries to corrupt the supervision of a clean data. There is no limitation on how the adversary corrupts the supervision, which can either be randomly permuting the target, or in a way that maximizes the negative impact (i.e., lower performance).



(2010); Han et al. (2018); Jiang et al. (2018) leveraged the loss function values of data points; Zheng et al. (2020) tapped prediction uncertainty for filtering data; Malach & Shalev-Shwartz (2017) used the disagreement between two deep networks; Reed et al. (

); Reed et al. (2014); Menon et al. (2019); Zheng et al. (2020)) and are not straightforward to extend to solve regression problems.

, successfully designed algorithms to get rid of the O( √ d) error, the algorithms themselves are not polynomial-time algorithm. More recently, Diakonikolas et al. (2016); Lai et al. (

