UNDERSTANDING THE ROLE OF IMPORTANCE WEIGHT-ING FOR DEEP LEARNING

Abstract

The recent paper by Byrd & Lipton (2019), based on empirical observations, raises a major concern on the impact of importance weighting for the over-parameterized deep learning models. They observe that as long as the model can separate the training data, the impact of importance weighting diminishes as the training proceeds. Nevertheless, there lacks a rigorous characterization of this phenomenon. In this paper, we provide formal characterizations and theoretical justifications on the role of importance weighting with respect to the implicit bias of gradient descent and margin-based learning theory. We reveal both the optimization dynamics and generalization performance under deep learning models. Our work not only explains the various novel phenomenons observed for importance weighting in deep learning, but also extends to the studies where the weights are being optimized as part of the model, which applies to a number of topics under active research.

1. INTRODUCTION

Importance weighting is a standard tool for estimating a quantity under a target distribution while only the samples from some source distribution is accessible. It has been drawing extensive attention in the communities of statistics and machine learning. Causal inference for deep learning investigates heavily on the propensity score weighting method that applies the off-policy optimization with counterfactual estimator (Gilotte et al., 2018; Jiang & Li, 2016) , modelling with observational feedback (Schnabel et al., 2016; Xu et al., 2020) and learning from controlled intervention (Swaminathan & Joachims, 2015) . The importance weighting methods are also applied to characterize distribution shifts for deep learning models (Fang et al., 2020) , with modern applications in such as the domain adaptation (Azizzadenesheli et al., 2019; Lipton et al., 2018) and learning from noisy labels (Song et al., 2020) . Other usages include curriculum learning (Bengio et al., 2009) and knowledge distillation (Hinton et al., 2015) , where the weights characterize the model confidence on each sample. To reduce the discrepancy between the source and target distribution for model training, a standard routine is to minimize a weighted risk (Rubinstein & Kroese, 2016) . Many techniques have been developed to this end, and the common strategy is re-weighting the classes proportionally to the inverse of their frequencies (Huang et al., 2016; 2019; Wang et al., 2017) . For example, Cui et al. Despite the empirical successes of various re-weighting methods, it is ultimately not clear how importance weighting lays influence from the theoretical standpoint. The recent study of Byrd & Lipton (2019) observes from experiments that there is little impact of importance weights on the converged deep neural network, if the data can be separated by the model using gradient descent. They connect this phenomenon to the implicit bias of gradient descent (Soudry et al., 2018) -a novel topic that studies why over-parameterized models trained on separable data is biased toward solutions that generalize well. Implicit bias of gradient descent has been observed and studied for linear model (Soudry et al., 2018; Ji & Telgarsky, 2018b) , linear neural network (Ji & Telgarsky, 2018a; Gunasekar et al., 2018) , two-layer neural network with homogeneous activation (Chizat & Bach, 2020) and smooth neural networks (Nacson et al., 2019; Lyu & Li, 2019) . To summarize, those work reveals that the direction of the parameters (for linear predictor) and the normalized margin (for nonlinear predictor), regardless of the initialization, respectively converge to those of a max-margin solution. The pivotal role of margin for deep learning models has been explored actively after the long journey of understanding the generalization of over-parameterized neural networks (Bartlett et al., 2017; Golowich et al., 2018; Neyshabur et al., 2018) . For instance, Wei et al. ( 2019) studies the margin of the neural networks for separable data under weak regularization. They show that the normalized margin also converges to the max-margin solution, and provide a generalization bound for a neural network that hinges on its margin. Although there are rich understandings for the implicit bias of gradient descent and the margin-based generalization, very few efforts are dedicated to studying how they adjust to the weighted empiricalrisk minimization (ERM) setting. The established results do not directly transfer since importance weighting can change both the optimization geometry and how the generalization is measured. In this paper, we fill in the gap by showing the impact of importance weighting on the implicit bias of gradient descent as well as the generalization performance. By studying the optimization dynamics of linear models, we first reveal the effect of importance weighting on the convergence speed under linearly separable data. When the data is not linearly separable, we characterize the unique role of importance weighting on defining the intercept term upon the implicit bias. We then investigate the non-linear neural network under a weak regularization as Wei et al. (2019) . We provide a novel generalization bound that reflects how importance weighting leads to the interplay between the empirical risk and a compounding term that consists of the model complexity as well as the deviation between the source target distribution. Based on our theoretical results, we discuss several exploratory developments on importance weighting that are worthy of further investigations. • A good set of weights for learning can be inversely proportional to the hard-to-classify extent. For example, a sample that is close to (far from) the oracle decision boundary should have a large (small) weight. • If the importance weights are jointly trained according to a weighting model, the impact of the weighting model eventually diminishes after showing strong correlation with the hard-to-classify extent such as margin. • The usefulness of explicit regularization on weighted ERM can be studied, via their impact on the margin, on balancing the empirical loss and the distribution divergence. In summary, our contribution are three folds. • We characterize the impact of importance weighting on the implicit bias of gradient descent. • We find a generalization bound that hinges on the importance weights. For finite-step training, the role of importance weighting on the generalization bound is reflected in how the margin is affected, and how it balances the source and target distribution. • We propose several exploratory topics for importance weighting that worth further investigating from both the application and theoretical perspective. The rest of the paper is organized as follows. In Section 2, we introduce the background, preliminary results and the experimental setup. In Section 3 and 4, we demonstrate the influence of the importance weighting for linear and non-linear models in terms of the implicit bias of gradient descent and the generalization performance. We then discuss the extended investigations in Section 5.



2019) proposes re-weighting by the inverse of effective number of samples. The focal loss (Lin et al., 2017) down-weights the well-classified examples, and the work by Li et al. (2019) suggests an improved technique which down-weights examples based on the magnitude of the gradients.

