WHY RESAMPLING OUTPERFORMS REWEIGHTING FOR CORRECTING SAMPLING BIAS WITH STOCHASTIC GRA-DIENTS

Abstract

A data set sampled from a certain population is biased if the subgroups of the population are sampled at proportions that are significantly different from their underlying proportions. Training machine learning models on biased data sets requires correction techniques to compensate for the bias. We consider two commonlyused techniques, resampling and reweighting, that rebalance the proportions of the subgroups to maintain the desired objective function. Though statistically equivalent, it has been observed that resampling outperforms reweighting when combined with stochastic gradient algorithms. By analyzing illustrative examples, we explain the reason behind this phenomenon using tools from dynamical stability and stochastic asymptotics. We also present experiments from regression, classification, and off-policy prediction to demonstrate that this is a general phenomenon. We argue that it is imperative to consider the objective function design and the optimization algorithm together while addressing the sampling bias. where i r ∈ {1, 2} denotes which group an individual r belongs to. When N grows, the empirical loss in (2) is consistent with the population loss in (1) as there are approximately a 1 fraction of samples from the first group and a 2 fraction of samples from the second.

1. INTRODUCTION

A data set sampled from a certain population is called biased if the subgroups of the population are sampled at proportions that are significantly different from their underlying population proportions. Applying machine learning algorithms naively to biased training data can raise serious concerns and lead to controversial results (Sweeney, 2013; Kay et al., 2015; Menon et al., 2020) . In many domains such as demographic surveys, fraud detection, identification of rare diseases, and natural disasters prediction, a model trained from biased data tends to favor oversampled subgroups by achieving high accuracy there while sacrificing the performance on undersampled subgroups. Although one can improve by diversifying and balancing during the data collection process, it is often hard or impossible to eliminate the sampling bias due to historical and operational issues. In order to mitigate the biases and discriminations against the undersampled subgroups, a common technique is to preprocess the data set by compensating the mismatch between population proportion and the sampling proportion. Among various approaches, two commonly-used choices are reweighting and resampling. In reweighting, one multiplies each sample with a ratio equal to its population proportion over its sampling proportion. In resampling, on the other hand, one corrects the proportion mismatch by either generating new samples for the undersampled subgroups or selecting a subset of samples for the oversampled subgroups. Both methods result in statistically equivalent models in terms of the loss function (see details in Section 2). However, it has been observed in practice that resampling often outperforms reweighting significantly, such as boosting algorithms in classification (Galar et al., 2011; Seiffert et al., 2008) , off-policy prediction in reinforcement learning (Schlegel et al., 2019) and so on. The obvious question is why. Main contributions. Our main contribution is to provide an answer to this question: resampling outperforms reweighting because of the stochastic gradient-type algorithms used for training. To the best of our knowledge, our explanation is the first theoretical quantitative analysis for this phenomenon. With stochastic gradient descent (SGD) being the dominant method for model training, our analysis is based on some recent developments for understanding SGD. We show via simple and explicitly analyzable examples why resampling generates expected results while reweighting performs undesirably. Our theoretical analysis is based on two points of view, one from the dynamical stability perspective and the other from stochastic asymptotics. In addition to the theoretical analysis, we present experimental examples from three distinct categories (classification, regression, and off-policy prediction) to demonstrate that resampling outperforms reweighting in practice. This empirical study illustrates that this is a quite general phenomenon when models are trained using stochastic gradient type algorithms. Our theoretical analysis and experiments show clearly that adjusting only the loss functions is not sufficient for fixing the biased data problem. The output can be disastrous if one overlooks the optimization algorithm used in the training. In fact, recent understanding has shown that objective function design and optimization algorithm are closely related, for example optimization algorithms such as SGD play a key role in the generalizability of deep neural networks. Therefore in order to address the biased data issue, we advocate for considering data, model, and optimization as an integrated system. Related work. In a broader scope, resampling and reweighting can be considered as instances of preprocessing the training data to tackle biases of machine learning algorithms. Though there are many well-developed resampling (Mani & Zhang, 2003; He & Garcia, 2009; Maciejewski & Stefanowski, 2011) and reweighting (Kumar et al., 2010; Malisiewicz et al., 2011; Chang et al., 2017) techniques, we only focus on the reweighting approaches that do not change the optimization problem. It has been well-known that training algorithms using disparate data can lead to algorithmic discrimination (Bolukbasi et al., 2016; Caliskan et al., 2017) , and over the years there have been growing efforts to mitigate such biases, for example see (Amini et al., 2019; Kamiran & Calders, 2012; Calmon et al., 2017; Zhao et al., 2019; López et al., 2013) . We also refer to (Guo et al., 2017; He & Ma, 2013; Krawczyk, 2016) for a comprehensive review of this growing research field. Our approaches for understanding the dynamics of resampling and reweighting under SGD are based on tools from numerical analysis for stochastic systems. Connections between numerical analysis and stochastic algorithms have been rapidly developing in recent years. The dynamical stability perspective has been used in (Wu et al., 2018) to show the impact of learning rate and batch size in minima selection. The stochastic differential equations (SDE) approach for approximating stochastic optimization methods can be traced in the line of work (Li et al., 2017; 2019; Rotskoff & Vanden-Eijnden, 2018; Shi et al., 2019) , just to mention a few.

2. PROBLEM SETUP

Let us consider a population that is comprised of two different groups, where a proportion a 1 of the population belongs to the first group, and the rest with the proportion a 2 = 1 -a 1 belongs to the second (i.e., a 1 , a 2 > 0 and a 1 + a 2 = 1). In what follows, we shall call a 1 and a 2 the population proportions. Consider an optimization problem for this population over a parameter θ. For simplicity, we assume that each individual from the first group experiences a loss function V 1 (θ), while each individual from the second group has a loss function of type V 2 (θ). Here the loss function V 1 (θ) is assumed to be identical across all members of the first group and the same for V 2 (θ) across the second group, however it is possible to extend the formulation to allow for loss function variation within each group. Based on this setup, a minimization problem over the whole population is to find θ * = arg min θ V (θ), where V (θ) ≡ a 1 V 1 (θ) + a 2 V 2 (θ). (1) For a given set Ω of N individuals sampled uniformly from the population, the empirical minimization problem is θ * = arg min θ 1 N However, the sampling can be far from uniformly random in reality. Let n 1 and n 2 with n 1 +n 2 = N denote the number of samples from the first and the second group, respectively. It is convenient to define f i , i = 1, 2 as the sampling proportions for each group, i.e., f 1 = n 1 /N and f 2 = n 2 /N with f 1 + f 2 = 1. The data set is biased when the sampling proportions f 1 and f 2 are different from the population proportions a 1 and a 2 . In such a case, the empirical loss is f 1 V 1 (θ) + f 2 V 2 (θ) , which is clearly wrong when compared with (1). Let us consider two basic strategies to adjust the model: reweighting and resampling. In reweighting, one assigns to each sample r ∈ Ω a weight a ir /f ir and the reweighting loss function is V w (θ) ≡ 1 N r∈Ω a ir f ir V ir (θ) = a 1 V 1 (θ) + a 2 V 2 (θ). In resampling, one either adds samples to the minority group (i.e., oversampling) or removing samples from the majority group (i.e., undersampling). Although the actual implementation of oversampling and undersampling could be quite sophisticated in order to avoid overfitting or loss of information, mathematically we interpret the resampling as constructing a new set of samples of size M , among which a 1 M samples are of the first group and a 2 M samples of the second. The resampling loss function is V s (θ) ≡ 1 M s V is (θ) = a 1 V 1 (θ) + a 2 V 2 (θ). Notice that both V w (θ) and V s (θ) are consistent with the population loss function V (θ). This means that, under mild conditions on V 1 (θ) and V 2 (θ), a deterministic gradient descent algorithm from a generic initial condition converges to similar solutions for V w (θ) and V s (θ). For a stochastic gradient descent algorithm, the expectations of the stochastic gradients of V w (θ) and V s (θ) also agree at any θ value. However, as we shall explain below, the training behavior can be drastically different for a stochastic gradient algorithm. The key reason is that the variances experienced for V w (θ) and V s (θ) can be drastically different: computing the variances of gradients for resampling and reweighting reveals that V ∇ Vs (θ) = a 1 ∇V 1 (θ)∇V 1 (θ) T + a 2 ∇V 2 (θ)∇V 2 (θ) T -(E[∇ Vs (θ)]) 2 , V ∇ Vw (θ) = a 2 1 f 1 ∇V 1 (θ)∇V 1 (θ) T + a 2 2 f 2 ∇V 2 (θ)∇V 2 (θ) T -(E[∇ Vw (θ)]) 2 . (5) These formulas indicate that, when f 1 /f 2 is significantly misaligned with a 1 /a 2 , the variance of reweighting can be much larger. Without knowing the optimal learning rates a priori, it is difficult to select an efficient learning rate for reliable and stable performance for stiff problems, when only reweighting is used. In comparison, resampling is more favorable especially when the choice of learning rates is restrictive.

3. STABILITY ANALYSIS

Let us use a simple example to illustrate why resampling outperforms reweighting under SGD, from the viewpoint of stability. Consider two loss functions V 1 and V 2 with disjoint supports, V 1 (θ) = 1 2 (θ + 1) 2 -1 2 , θ ≤ 0 0, θ > 0, V 2 (θ) = 0, θ ≤ 0 1 2 (θ -1) 2 -1 2 , θ > 0, each of which is quadratic on its support. The population loss function is V (θ) = a 1 V 1 (θ)+a 2 V 2 (θ), with two local minima at θ = -1 and θ = 1. The gradients for V 1 and V 2 are ∇V 1 (θ) = θ + 1, θ ≤ 0 0, θ > 0. , ∇V 2 (θ) = 0, θ ≤ 0 θ -1, θ > 0. Suppose that the population proportions satisfy a 2 > a 1 , then θ = 1 is the global minimizer and it is desired that SGD should be stable near it. However, as shown in Figure 1 , when the sampling proportion f 2 is significantly less than the population proportion a 2 , for reweighting θ = 1 can easily become unstable: even if one starts near the global minimizer θ = 1, the trajectories for reweighting -3 -2 -1 0 1 2 3 -0.4 -0.2 0 0.2 0.4 0.6 0.8 1 V( ) -3 -2 -1 0 1 2 3 -0.4 -0.2 0 0.2 0.4 0.6 0.8 1 V( ) (1) Reweighting (2) Resampling Figure 1 : Comparison of reweighting and resampling with a 1 /a 2 = 0.4/0.6 and f 1 /f 2 = 0.9/0.1 at the learning rate η = 0.5. The resampling strategy here is to randomly select the sub-population i with the probability a i with replacement in each iteration. (1) For reweighting, the trajectory starting from θ 0 = 1.1 can end up at θ = -1 after a few iterations, but θ = -1 is not the global minimizer. (2) For resampling, the trajectory starting from θ 0 = 2.0 stays close to the desired minimizer θ = 1. Hence resampling is more reliable than reweighting. We include more comparisons with various learning rates in Appendix D to show that resampling is stable for a wider range of η. always gear towards θ = -1 after a few steps (see Figure 1 (1)). On the other hand, for resampling θ = 1 is quite stable (see Figure 1 (2)). The expectations of the stochastic gradient are the same for both methods. It is the difference in the second moment that explains why trajectories near the two minima exhibit different behaviors. Our explanation is based on the stability analysis framework used in (Wu et al., 2018) . By definition, a stationary point θ * is stochastically stable if there exists a uniform constant 0 < C ≤ 1 such that E[ θ k -θ * 2 ] ≤ C θ 0 -θ * 2 , where θ k is the k-th iterate of SGD. The stability conditions for resampling and reweighting are stated in the following two lemmas, in which we use η to denote the learning rate. Lemma 1. For resampling, the conditions for the SGD to be stochastically stable around θ = -1 and θ = 1 are respectively (1 -ηa 1 )foot_0 + η 2 a 1 a 2 ≤ 1, (1 -ηa 2 ) 2 + η 2 a 1 a 2 ≤ 1. Lemma 2. For reweighting, the condition for the SGD to be stochastically stable around θ = -1 and θ = 1 are respectively (1 -ηa 1 ) 2 + η 2 f 1 f 2 a 1 f 1 2 ≤ 1, (1 -ηa 2 ) 2 + η 2 f 1 f 2 a 2 f 2 2 ≤ 1. Note that the stability conditions for resampling are independent of the sampling proportions (f 1 , f 2 ), while the ones for reweighting clearly depend on (f 1 , f 2 ). We defer the detailed computations to Appendix A. Lemma 2 shows that reweighting can incur a more stringent stability criterion. Let us consider the case a 1 = 1 Let us again use a simple example to illustrate the main idea. Consider the following two loss functions, V 1 (θ) = |θ + 1| -1, θ ≤ 0 θ, θ > 0 , V 2 (θ) = -θ, θ ≤ 0 |θ -1| -1, θ > 0 , with 0 < 1. The population loss function is V (θ) = a 1 V 1 (θ) + a 2 V 2 (θ) with local minimizers θ = -1 and θ = 1. Note that the O( ) terms are necessary. Without it, if the SGD starts in (-∞, 0), all iterates will stay in this region because there is no drift from V 2 (θ). Similarly, if the SGD starts in (0, ∞), no iterates will move to (-∞, 0). That means the result of SGD only depends on the initialization when O( ) term is absent. In Figure 2 , we present numerical simulations of the resampling and reweighting methods for the designed loss function V (θ). If a 2 > a 1 , then the global minimizer of V (θ) is θ = 1 (see the Figure 2 (1)). Consider a setup with population proportions a 1 /a 2 = 0.4/0.6 along sampling proportions f 1 /f 2 = 0.9/0.1, which are quite different. Figures 2( 2) and ( 3) show the dynamics under the reweighting and resampling methods, respectively. The plots show that, while the trajectory for resampling is stable across time, the trajectory for reweighting quickly escapes to the (non-global) local minimizer θ = -1 even when it starts near the global minimizer θ = 1. -3 -2 -1 0 1 2 3 -0.6 -0.4 -0.2 0 0.2 0.4 0.6 0.8 V( ) (1) Loss function V (θ) (2) Reweighting (3) Resampling Figure 2 : Comparison of reweighting and resampling with learning rate η = 0.12. We set a 1 /a 2 = 0.4/0.6, f 1 /f 2 = 0.9/0.1 and = 0.1. Both experiments start at θ 0 = 0.9. The resampling strategy here is to randomly select the sub-population i with the probability a i with replacement in each iteration. In (2) where reweighting is used, the trajectory skips to the local minimizer θ = -1 later. In (3) where resampling is used, it stabilizes at the global minimizer θ = 1 all the time. We include more comparisons with various learning rates in Appendix D to show that resampling is more reliable for a wider range of η. When the learning rate is sufficiently small, one can approximate the SGD by an SDE, which in this piece-wise linear loss example is approximately a Langevin dynamics with a piecewise constant mobility. In particular when the dynamics reaches equilibrium, the stationary distribution of the stochastic process is approximated by a Gibbs distribution, which gives the probability densities at the stationary points. Let us denote p s (θ) and p w (θ) as the stationary distribution over θ under resampling and reweighting, respectively. Following lemmas quantitatively summarize the results. Lemma 3. When a 2 > a 1 , V (1) < V (-1). The stationary distribution for resampling satisfies the relationship p s (1) p s (-1) = exp - 2 a 1 a 2 η (V (1) -V (-1)) + O ( ) > 1. Lemma 4. With a 2 > a 1 , V (1) < V (-1) < 0. Under the condition f2 f1 ≤ a2 a1 V (-1) V (1) for the sampling proportions, the stationary distribution for reweighting satisfies the relationship p w (1) p w (-1) = a 2 1 /f 2 1 a 2 2 /f 2 2 exp - 2f 2 /f 1 a 2 2 η V (1) + 2f 1 /f 2 a 2 1 η V (-1) + O( ) < 1. The proofs of the above two lemmas can be found in Appendix B. Lemma 3 shows that for resampling it is always more likely to find θ at the global minimizer 1 than at the local minimizer -1. Lemma 4 states that for reweighting it is more likely to find θ at the local minimizer -1 when f2 f1 ≤ a2 a1 V (-1) V (1) . Together, they explain the phenomenon shown in Figure 2 . To better understand the condition in Lemma 4, let us consider the case a 1 =foot_1 2 -, a 2 = 1 2 + with a small constant > 0. Under this setup, V (-1)/V (1) ≈ 1. Whenever the ratio of the sampling proportions f 2 /f 1 is significantly less than the ratio of the population proportions a 2 /a 1 ≈ 1, reweighting will lead to the undesired behavior. The smaller the ratio f 2 /f 1 is, the less likely the global minimizer will be visited. The reason for constructing the above piecewise linear loss function is to obtain an approximately explicitly solvable SDE with a constant coefficient for the noise. One can further extend the results in 1D for piecewise strictly convex function with two local minima (See Lemmas 9 and 10 in Appendix B.3). Here we present the most general results in 1D, that is, piecewise strictly convex function with finite number of local minima. One may consider the population loss function V (θ) = k i=1 a i V i (θ) with V i (θ) = h i (θ) for θ i-1 < θ ≤ θ i and V i (θ) = O( ) otherwise , where h i (θ) are strictly convex functions and continuously differentiable, O( ) term is sufficiently small and smooth. Here {θ i } k-1 i=1 are k -1 disjoint points, and θ 0 = -∞, θ k = ∞. We assume that V (θ) has k local minimizers θ * i for θ * i ∈ (θ i-1 , θ i ). We present the following two lemmas with suitable assumptions (See Appendix B.3 for details of assumptions, the proof and follow-up discussions). Lemma 5. The stationary distribution for resampling at any two local minizers θ * p , θ * q with p > q satisfies the relationship p s (θ * p ) p s (θ * q ) = exp 2 η θp θ * p 1 h p (θ) dθ 1 1 -a p - 1 1 -a q + O( ) = > 1, if a p > a q ; < 1, if a p < a q , Lemma 6. The stationary distribution for reweighting at any two local minizers θ * p , θ * q with p > q satisfies the relationship p w (θ * p ) p w (θ * q ) = exp 2 η θp θ * p 1 h p (θ) dθ f p a p (1 -f p ) - f q a q (1 -f q ) + O( ). Multi-dimensional results. Let us now consider the minimization of V (θ) = a 1 V 1 (θ) + a 2 V 2 (θ) for more general V 1 , V 2 and also θ in high dimensions. It is in fact not clear how to extend the above stochastic analysis to more general functions V (θ). Instead we focus on the transition time from one stationary point to another in order to understand the behavior of resampling and reweighting. For this purpose, we again resort to the SDE approximation of the SGD in the continuous time limit. Such a SDE approximation, first introduced in (Li et al., 2017) , involves a data-dependent covariance coefficient for the diffusion term and is justified in the weak sense with an error of order O( √ η). More specifically, the dynamics can be approximated by dΘ = -∇V (Θ)dt + √ ηΣ(Θ) 1/2 dB, where Θ(t = kη) ≈ θ k for the step k parameter θ k , η is the learning rate, and Σ(Θ) is the covariance of the stochastic gradient at location Θ. In the SDE theory, the drift term ∇V (•) is usually assumed to be Lipschitz. However, in machine learning (for example neural network training with non-smooth activation functions), it is common to encounter non-Lipschitz gradients of loss functions (as in the example presented in Section 3). To fill this gap, we provide in Appendix C a justification of SDE approximation for the drift with jump discontinuities, based on the proof presented in (Müller-Gronbach et al., 2020) . The following two lemmas summarize the transition times between the two local minimizers. Lemma 7. Assume that there are only two local minimizers θ * 1 , θ * 2 for the objective function V (θ). Let τ θ * 1 →θ * 2 be the transition time for Θ(t) in (7) from the -neighborhood of θ * 1 (a closed ball of radius centered at θ * 1 ) to the -neighborhood of θ * 2 and τ θ * 2 →θ * 1 be the transition time in the opposite direction. Then E[τ θ * 1 →θ * 2 ] E[τ θ * 2 →θ * 1 ] = det(∇ 2 L(θ * 2 )) det(∇ 2 L(θ * 1 )) exp 2 η δV (θ * 1 ) Σ(θ * 1 ) - δV (θ * 2 ) Σ(θ * 2 ) + O( √ ). Here det(∇foot_2 L(θ * 1 )) and det(∇ 2 L(θ * 2 )) are the determinants of the Hessians at θ * 1 and θ * 2 , respectively. δV (θ * k ) ≡ V (θ • ) -V (θ * k ) for k = 1, 2 where θ • is the saddle point between θ * 1 , θ * 2 . 1 This lemma is known in the diffusion process literature as the Eyring-Kramers formula; see, e.g., (Berglund, 2011; Bovier et al., 2004; 2005) . Using the above lemma, we obtain the following result for the transition times for resampling and reweighting. Lemma 8. Assume that there are only two local minimizers θ * 1 , θ * 2 for the objective function V (θ). Also assume that the loss function V 1 (•) for the first group is O( ) in the -neighborhood of θ * 2 and the loss function V 2 (•) for the second group is O( ) in the -neighborhood of θ * 1 . In addition, assume that the determinants of the Hessian at two local minimizers are the same. Then the ratio of the transition times between the two local minimizers for resampling is E[τ s θ * 1 →θ * 2 ] E[τ s θ * 2 →θ * 1 ] = exp 2 η δV (θ * 1 ) a 1 ∇V 1 (θ * 1 )∇V 1 (θ * 1 ) - δV (θ * 2 ) a 2 ∇V 2 (θ * 2 )∇V 2 (θ * 2 ) + O( √ ) and the ratio for reweighting is E[τ w θ * 1 →θ * 2 ] E[τ w θ * 2 →θ * 1 ] = exp 2 η f 1 δV (θ * 1 ) a 2 1 ∇V 1 (θ * 1 )∇V 1 (θ * 1 ) - f 2 δV (θ * 2 ) a 2 2 ∇V 2 (θ * 2 )∇V 2 (θ * 2 ) + O( √ ). See Appendix B for the proof. When the ratio is larger than 1, it means that θ * 1 is more stable than θ * 2 . This result shows that for reweighting the relative stability of the two minimizers highly depends on the sampling proportions (f 1 , f 2 ). On the other hand, for resampling it is independent of (f 1 , f 2 ), which is precisely the desired result for a bias correction procedure. To see how the sampling proportions affect the behavior of reweighting, let us consider a simple case where  and f 1 f 2 . This ensures that δV (θ * 1 ) > δV (θ * 2 ) and the above ratio for resampling is larger than 1, which is the desired result. However, θ * 1 is the global minimizer, ∇V 1 (θ * 1 )∇V 1 (θ * 1 ) = ∇V 2 (θ * 2 )∇V 2 (θ * 2 ) , a 1 = 1 2 + , a 2 = 1 2 -, f 1 f 2 implies that f1 a 2 1 1, f2 a 2 2 1, and the above ratio for reweighting is much smaller than 1, which means that the local minimizer θ * 2 is more stable than the global minimizer θ * 1 .

5. EXPERIMENTS

This section examines the empirical performance of resampling and reweighting for problems from classification, regression, and reinforcement learning. As mentioned in the previous sections, the noise of stochastic gradient algorithms makes optimal learning rate selections much more restrictive for reweighting, when the data sampling is highly biased. In order to achieve good learning efficiency and reasonable performance in a neural network training, adaptive stochastic gradient methods such as Adam (Kingma & Ba, 2014) are applied in the first two experiments. We observe that resampling consistently outperforms reweighting with various sampling ratios when combined with these adaptive learning methods. Classification. This experiment uses the Bank Marketing data set from (Moro et al., 2014) to predict if a client will subscribe a term deposit. After preprocessing, the provided data distribution over the variable "y" that indicates the subscription, is highly skewed: the ratio of "yes" and "no" is f 1 /f 2 = 4640/36548 ≈ 1/7.88. We assume that the underlying population distribution is a 1 /a 2 = 1. We setup a 3-layer neural network with the binary cross-entropy loss function and train with the default Adam optimizer. The training and testing data set is obtained using train test split provided in sklearnfoot_3 . The training takes 5 epochs with the batch-size equal to 100. The performance is compared among the baseline (i.e. trained without using either resampling or reweighting), resampling (oversample the minority group uses the sample with replacement), and reweighting. We run the experiments 10 times for each case, and then compute and plot results by averaging. To estimate the performance, rather than using the classification accuracy that can be misleading for biased data, we use the metric that computes the area under the receiver operating characteristic curve (ROC-AUC) from the prediction scores. The ROC curves plots the true positive rate on the y-axis versus the false positive rate on the x-axis. As a result, a larger area under the curve indicates a better performance of a classifier. From both Table 1 and Figure 3 , we see that the oversampling has the best performance compared to others. We choose oversampling rather than undersampling for the resampling method, because if we naively down sample the majority group, we throw away many information that could be useful for the prediction. Off-policy prediction. In the off-policy prediction problem in reinforcement learning, the objective is to find the value function of policy π using the trajectory {(a t , s t , s t+1 )} T t=1 generated by a behavior policy µ. To achieve this, the standard approach is to update the value function based on the behavior policy's temporal difference (TD) error δ(s t ) = R(s t ) + γV (s t+1 ) -V (s t ) with an importance weight E π [δ|s t = s] = a∈A π(a|s) µ(a|s) E[δ|s t = s, a t = a]µ(a|s), where the summation is taken over the action space A. The resulting reweighting TD learning for policy π is V t+1 (s t ) = V t (s t ) + η π(a t |s t ) µ(a t |s t ) (R(s t ) + γV t (s t+1 ) -V t (s t )), where η is the learning rate. This update rule is an example of reweighting. On the other hand, the expected TD error can also be written in the resampling form, E π [δ|s t = s] = a∈A E[δ|s t = s, a t = a]π(a|s) = a∈A π(a|s)N j=1 E[δ j |s t = s, a t = a], where N is the total number of samples for s t = s. This results to a resampling TD learning algorithm: at step t, V t+1 (s t ) = V t (s t ) + η(R(s k ) + γV t (s k+1 ) -V t (s k )), where (a k , s k , s k+1 ) is randomly chosen from the data set {(a j , s j , s j+1 )} sj =st with probability π(a k |s t ). Consider a simple example with discrete state space S = {i} n-1 i=0 , action space A = {±1}, discount factor γ = 0.9 and transition dynamics s t+1 = mod(s t + a t , n), where the operator mod (m, n) gives the remainder of m divided by n. Figure 4 shows the results of the off-policy TD learning by these two approaches, with the choice of n = 32 and r(s) = 1 + sin(2πs/n) and learning rate η = 0.1. The target policy is π(a i |s) = 1 2 while the behavior policy is µ(a i |s) = 1 2 + ca i . The difference between the two policies becomes larger as the constant c ∈ [0, 1/2] increases. From the previous analysis, if one group has much fewer samples as it should have, then the minimizer of the reweighting method is highly affected by the sampling bias. This is verified in the plots: as c becomes larger, the performance of reweighting deteriorates, while resampling is rather stable and almost experiences no difference with the on-policy prediction in this example. ), where the absolute error e t = V t (s) -V π (s) 2 2 . RW and RS in the upright corner represent reweighting and resampling, respectively. c determines the behavior policy µ(a i |s) = 1 2 + ca i . The value function is trained on a trajectory with length 10 5 generated by the behavior policy. The value function obtained from resampling is fairly close to the exact value function, while the results of reweighting gets worse as the behavior policy gets further from the target policy.

6. DISCUSSIONS

This paper examines the different behaviors of reweighting and resampling for training on biasedly sampled data with the stochastic gradient descent. From both the dynamical stability and stochastic asymptotics viewpoints, we explain why resampling is numerically more stable and robust than reweighting. Based on this theoretical understanding, we advocate for considering data, model, and optimization as an integrated system, while addressing the bias. An immediate direction for future work is to apply the analysis to more sophisticated stochastic training algorithms and understand their impact on resampling and reweighting. Another direction is to extend our analysis to unsupervised learning problems. For example, in the principal component analysis one computes the dominant eigenvectors of the covariance matrix of a data set. When the data set consists of multiple subgroups sampled with biases and a stochastic algorithm is applied to compute the eigenvectors, then an interesting question is how resampling or reweighting would affect the result. variable with mean E(W ) = a 1 and variance V(W ) = a 1 a 2 . At the learning rate η, the iteration can be written as (θ k+1 + 1) = (1 -ηW )(θ k + 1). The first and second moments of the iterates are E[θ k + 1] = (1 -ηa 1 ) k (θ 0 + 1), E[(θ k + 1) 2 ] = ((1 -ηa 1 ) 2 + η 2 a 1 a 2 ) k (θ 0 + 1) 2 . ( ) According to the definition of the stochastic stability, SGD is stable around θ = -1 if the multiplicative factor of the second equation is bounded by 1, i.e. (1 -ηa 1 ) 2 + η 2 a 1 a 2 ≤ 1. ( ) Consider now the stability around θ = 1, the iteration can be written as (θ k+1 -1) = (1 -ηW )(θ k -1), where W is again a Bernoulli random variable with E(W ) = a 2 and V(W ) = a 1 a 2 . The same computation shows that the second moment follows E[(θ k -1) 2 ] = ((1 -ηa 2 ) 2 + η 2 a 1 a 2 ) k (θ 0 -1) 2 . Therefore, the condition for the SGD to be stable around θ = 1 is (1 -ηa 2 ) 2 + η 2 a 1 a 2 ≤ 1. A.2 PROOF OF LEMMA 2 Proof. In reweighting, near θ = -1 the gradient is a1 f1 (θ + 1) with probability f 1 and 0 with probability f 2 . Let us denote the random gradient at each step by W (θ + 1), where W is a Bernoulli random variable with E(W ) = a 1 and V(W ) = f 1 f 2 a1 f1

2

. At the learning rate η, the iteration can be written as (θ k+1 + 1) ← (1 -ηW )(θ k + 1). Hence the second moments of the iterates are given by E[(θ k + 1) 2 ] = ((1 -ηa 1 ) 2 + η 2 f 1 f 2 (a 1 /f 1 ) 2 ) k (θ 0 + 1) 2 . Therefore, the condition for the SGD to be stable around θ = -1 is (1 -ηa 1 ) 2 + η 2 f 1 f 2 a 1 f 1 2 ≤ 1. Consider now the stability around θ = 1, the gradient is 0 with probability f 1 and a2 f2 (θ -1) with probability f 2 . An analysis similar to the case θ = -1 shows that the condition for the SGD to be stable around θ = 1 is (1 -ηa 2 ) 2 + η 2 f 1 f 2 a 2 f 2 2 ≤ 1.

B PROOFS IN SECTION 4

B.1 PROOF OF LEMMA 3 Proof. In resampling, with probability a 1 the gradients over the four intervals (-∞, -1), (-1, 0), (0, 1), and (1, ∞) are -1, 1, , and . With probability a 2 , they are -, -, -1, and 1 across these four intervals. The variances of the gradients are a 1 a 2 (1 -) 2 , a 1 a 2 (1 + ) 2 , a 1 a 2 (1 + ) 2 , a 1 a 2 (1 -) 2 , respectively, across the same intervals. Since 1, the variance can be written as a 1 a 2 +O( ) across all intervals. Then the SGD dynamics with learning rate η can be approximated by θ k+1 ← θ k -η V (θ k ) + a 1 a 2 + O( )W , where W ∼ N (0, 1) is a normal random variable. When η is small, one can approximate the dynamics by a stochastic differential equation of form dΘ = -V (Θ)dt + √ η a 1 a 2 + O( )dB by identifying θ k ≈ Θ(t = kη) (see Appendix C for details). The stationary distribution of this stochastic process is p s (θ) = 1 Z exp - 2 (a 1 a 2 + O( ))η V (θ) , where Z is a normalization constant. Plugging in θ = -1, 1 results in p s (1) p s (-1) = exp - 2 (a 1 a 2 + O( ))η (V (1) -V (-1)) = exp - 2 a 1 a 2 η (V (1) -V (-1)) + O( ) = exp - 2 a 1 a 2 η (V (1) -V (-1)) + O( ). Under the assumption that 1, the last term is negligible. When a 2 > a 1 , V (θ) is minimized at θ = 1, which implies -(V (1) -V (-1)) > 0. Hence, this ratio is larger than 1.

B.2 PROOF OF LEMMA 4

Proof. In reweighting, with probability f 1 the gradients are -a1 f1 , a1 f1 , a1 f1 , and a1 f1 over the four intervals (-∞, -1), (-1, 0), (0, 1), and (1, ∞), respectively. With probability f 2 , they are -a2 f2 , -a2 f2 , -a2 f2 , and a2 f2 . The variances of the gradients are f 1 f 2 ( a1 f1 -a2 f2 ) 2 , f 1 f 2 ( a1 f1 + a2 f2 ) 2 , f 1 f 2 ( a1 f1 + a2 f2 ) 2 , and f 1 f 2 ( a1 f1 -a2 f2 ) 2 , respectively, across the same intervals. Since 1, the variance can be written as f 1 f 2 a 2 1 f 2 1 + O( ) for θ < 0 and f 1 f 2 a 2 2 f 2 2 + O( ) for θ > 0. With θ k ≈ Θ(kη), the approximate SDE for θ < 0 is given by dΘ = -V (Θ)dt + √ η f 1 f 2 a 2 1 f 2 1 + O( )dB while the one for θ > 0 is dΘ = -V (Θ)dt + √ η f 1 f 2 a 2 2 f 2 2 + O( )dB (see Appendix C for the SDE derivations). The stationary distributions for θ < 0 and θ > 0 are, respectively, 1 Z 1 exp   - 2 f 1 f 2 a 2 1 f 2 1 + O( ) η V (θ)   , 1 Z 2 exp   - 2 f 1 f 2 a 2 2 f 2 2 + O( ) η V (θ)   . Plugging in θ = -1, 1 results in p w (1) p w (-1) = Z 1 Z 2 exp   - 2 f 1 f 2 a 2 2 f 2 2 + O( ) η V (1) + 2 f 1 f 2 a 2 1 f 2 1 + O( ) η V (-1)   = Z 1 Z 2 exp - 2f 2 /f 1 a 2 2 η V (1) + 2f 1 /f 2 a 2 1 η V (-1) + O( ) . The next step is to figure out the relationship between Z 1 and Z 2 . Consider an SDE with non-smooth diffusion dΘ = -V (Θ)dt + σdB. The Kolmogorov equation for the stationary distribution is 0 = p t = V (θ)p + σ 2 2 p θ θ . This suggests that σ 2 p is continuous at the discontinuity θ = 0. In our setting, since V (0) = 0, this simplifies to f 1 f 2 a 2 1 f 2 1 + O( ) η • 1 Z 1 = f 1 f 2 a 2 2 f 2 2 + O( ) η • 1 Z 2 . This simplifies to Z 1 Z 2 = f 1 f 2 a 2 1 f 2 1 + O( ) f 1 f 2 a 2 2 f 2 2 + O( ) = a 2 1 /f 2 1 a 2 2 /f 2 2 + O( ). Inserting this into (11) results in p w (1) p w (-1) = a 2 1 /f 2 1 a 2 2 /f 2 2 + O( ) exp - 2f 2 /f 1 a 2 2 η V (1) + 2f 1 /f 2 a 2 1 η V (-1) + O( ) = a 2 1 /f 2 1 a 2 2 /f 2 2 exp - 2f 2 /f 1 a 2 2 η V (1) + 2f 1 /f 2 a 2 1 η V (-1) + O( ). By the assumption f2 f1 ≤ a2 a1 V (-1) V (1) and V (1) < V (-1) < 0, one has a1 a2 2 f2 f1 2 ≤ V (-1) V (1) < 1 and -f2/f1 a 2 2 V (1) ≤ -f1/f2 a 2 1 V (-1). Hence the above ratio is less than 1.

Let us consider the population loss function

V (θ) = a 1 V 1 (θ) + a 2 V 2 (θ) with, V 1 (θ) = h 1 (θ), θ ≤ 0 θ, θ > 0 , V 2 (θ) = -θ, θ ≤ 0 h 2 (θ), θ > 0 , where h 1 , h 2 are strictly convex functions and continuously differentiable. We assume V (θ) has two local minimizers θ 1 < 0, θ 2 > 0 and the values are negative at local minima. Therefore, when a 2 > a 1 , θ 2 should be the global minimizer. In addition, we assume that the geometries of h 1 , h 2 at two local minimizers are similar, i.e., h 1 (θ 1 ) = h 2 (θ 2 ), h 1 (θ 1 ) = h 2 (θ 2 ); if we set g i (θ) to be the anti-derivative of 1/h i (θ), then g 1 (θ 1 ) = g 2 (θ 2 ). Moreover, we assume that the two disjoint convex functions are smooth at the disjoint point, i.e., h 1 (0) = h 2 (0) and g 1 (0) = g 2 (0). The following two lemmas extend Lemmas 3 and 4 to piecewise strictly convex function based on the above assumptions. Lemma 9. When a 2 > a 1 , V (θ 2 ) < V (θ 1 ). The stationary distribution for resampling satisfies the relationship p s (θ 2 ) p s (θ 1 ) = exp 2 η 1 a 1 - 1 a 2 0 θ1 1 h 1 (θ) dθ + O( ) > 1. Proof. In resampling, with probability a 1 the gradients in the two intervals (-∞, 0), (0, ∞) are h 1 (θ), respectively; with probability a 2 the gradients are -, h 2 (θ) respectively. Therefore, the expectation of the gradients µ (θ) is a 1 h 1 (θ) + O( ) in (-∞, 0) and a 2 h 2 (θ) + O( ) in (0, ∞). The variance of the gradients σ(θ) is a 1 a 2 h 1 (θ) 2 + O( ) in (-∞, 0) and a 1 a 2 h 2 (θ) 2 + O( ) in (0, ∞). The p.d.f p s (t, θ) satisfies ∂ t p s = ∂ θ µp s + η 2 ∂ θ (σp s ) , therefore, the stationary distribution p s (θ) satisfies θ) , where µ + η 2 ∂ θ σ p s + ησ 2 ∂ θ p s = 0, or equivalently, 2µ ησ + ∂ θ σ σ p s + ∂ θ p s = 0, which implies p s (θ) = 1 Z e -F (θ) with normalization constant Z = ∞ -∞ e -F ( F (θ) = θ -∞ 2µ(ξ) ησ(ξ) + ∂ ξ σ(ξ) σ(ξ) dξ = F 1 (θ) -F 1 (-∞), θ ≤ 0, F 2 (θ) -F 2 (0) + F 1 (0) -F 1 (-∞), θ > 0. By inserting µ, σ in different intervals, one has        F 1 (θ) = 2 ηa 2 1 h 1 dθ + log(a 1 a 2 (h 1 ) 2 ) + O( ); F 2 (θ) = 2 ηa 1 1 h 2 dθ + log(a 1 a 2 (h 2 ) 2 ) + O( ). Hence, the ratio of the stationary probabiliy at two local minimizers θ 1 < 0, θ 2 > 0 is p s (θ 1 ) p s (θ 2 ) = exp(-F (θ 1 ) + F (θ 2 )) = exp(-F 1 (θ 1 ) + F 2 (θ 2 ) + (F 1 (0) -F 2 (0))) = exp - 2 ηa 2 g 1 (θ 1 ) + 2 ηa 1 g 2 (θ 2 ) + log h 2 (θ 2 ) 2 h 1 (θ 1 ) 2 • exp 2 ηa 2 g 1 (0) - 2 ηa 1 g 2 (0) + log h 1 (0) 2 h 2 (0) 2 + O( ), where g i (θ) = 1 h i dθ, i = 1, 2. By the assumption that g 1 (θ 1 ) = g 2 (θ 2 ) and h 1 (θ 1 ) = h 2 (θ 2 ), g 1 (0) = g 2 (0) and h 1 (0) = h 2 (0) one has, s (θ 1 ) p s (θ 2 ) = exp 2 η (g 1 (0) -g 1 (θ 1 )) 1 a 2 - 1 a 1 + O( ), Since a 2 > a 1 > 0, 1 a2 -1 a1 < 0. Because of the strictly convexity of h 1 , h 1 (θ) > 0 in (θ 1 , 0), therefore, one has g 1 (0) -g 1 (θ 1 ) = 0 θ1 1 h 1 (θ) dθ > 0. Therefore p s (θ 1 ) p s (θ 2 ) = exp 2 η (g 1 (0) -g 1 (θ 1 )) 1 a 2 - 1 a 1 + O( ) < 1, Lemma 10. When a 2 > a 1 , V (θ 2 ) < V (θ 1 ). Under the condition f1 f2 > a1 a2 , the stationary distribution for resampling satisfies the relationship p w (θ 2 ) p w (θ 1 ) = exp 2 η f 2 f 1 a 2 - f 1 f 2 a 1 0 θ1 1 h 1 (θ) dθ + O( ) < 1. One sufficient condition such that f1 f2 > a1 a2 is when f 1 , f 2 is significantly different from a 1 , a 2 in the sense that f 1 > f 2 when the actually population proportion a 1 < a 2 . Proof. In reweighting, with probability f 1 the gradients over the two intervals (-∞, 0), (0, ∞) are a1 f1 h 1 (θ), a1 f1 respectively; with probability f 2 the gradients are -a2 f2 , a2 f2 h 2 (θ) respectively. Therefore, the expectation of the gradients µ(θ) is a 1 h 1 (θ)+O( ) in (-∞, 0) and a 2 h 2 (θ)+O( ) in (0, ∞). The variance of the gradients σ(θ) is f2 f1 a 2 1 h 1 (θ) 2 +O( ) in (-∞, 0) and f1 f2 a 2 2 h 2 (θ) 2 +O( ) in (0, ∞). From the similar analysis as in Lemma 9, the stationary distribution is p w (θ) = 1 Z e -F (θ) with the same F (θ) defined in equation 13, but F 1 , F 2 are defined as follows        F 1 (θ) = 2f 1 ηf 2 a 1 1 h 1 dθ + log f 2 a 2 1 f 1 (h 1 ) 2 + O( ); F 2 (θ) = 2f 2 ηf 1 a 2 1 h 2 dθ + log f 1 a 2 2 f 2 (h 2 ) 2 + O( ). Hence, the ratio of the stationary probabiliy at two local minimizers θ 1 < 0, θ 2 > 0 is p w (θ 1 ) p w (θ 2 ) = exp(-F 1 (θ 1 ) + F 2 (θ 2 ) + (F 1 (0) -F 2 (0))) = exp - 2f 1 ηf 2 a 1 g 1 (θ 1 ) + 2f 2 ηf 1 a 2 g 2 (θ 2 ) + log f 2 1 a 2 2 f 2 2 a 2 1 h 2 (θ 2 ) 2 h 1 (θ 1 ) 2 • exp 2f 1 ηf 2 a 1 g 1 (0) - 2f 2 ηf 1 a 2 g 2 (0) + log f 2 2 a 2 1 f 2 1 a 2 2 h 1 (0) 2 h 2 (0) 2 + O( ), where g i (θ) = 1 f i dθ, i = 1, 2. By the assumption that g 1 (θ 1 ) = g 2 (θ 2 ) and h 1 (θ 1 ) = h 2 (θ 2 ), g 1 (0) = g 2 (0) and h 1 (0) = h 2 (0) one has, p w (θ 1 ) p w (θ 2 ) = exp 2 η (g 1 (0) -g 1 (θ 1 )) f 1 f 2 a 1 - f 2 f 1 a 2 + O( ). Because of the strictly convexity of h 1 , one has g 1 (0) -g 1 (θ 1 ) > 0. By the assumption f1 f2 > a1 a2 , then f1 f2a1 -f2 f1a2 > 0, which gives ps(θ1) ps(θ2) > 1. Proof of Lemmas 5 and 6 We can further extend results in 1D for a finite number of local minima as presented in Lemmas 5 and 6. In the same way as in the two local minima case, we assume that h i (θ) has a similar geometry at the minimizers and h i (θ), h i+1 (θ) are smooth enough at the disjoint point θ i . In order to obtain the ratio of the stationary distribution at two arbitrary local minimizes, we take an additional assumption that g i (θ i-1 ) = g i (θ i ) for all i, where g i (θ) is the antiderivative of 1/h i (θ). Intuitively, this assumption requires that each local minimum has an equal barrier on both sides. To be more specific, the assumptions we mentioned above are the following: at all the local minimizers, h i (θ * i ) = h j (θ * j ) < 0, h i (θ * i ) = h j (θ * j ), let g i (θ) = 1 h i (θ) dθ, then g i (θ * i ) = g j (θ * j ) for any i = j; at all the disjoint points, h i (θ i ) = h i+1 (θ i ), g i (θ i-1 ) = g i (θ i ) = g i+1 (θ i ) for all i. Lemmas 5 and 6 are under the above assumptions. Proof of Lemma 5. For resampling, with probability a i , the gradient is h i (θ) for θ ∈ (θ i-1 , θ i ), and O( ) for θ / ∈ (θ i-1 , θ i ). Therefore, the expectation and variance in (θ i-1 , θ i ) are µ = a i h i (θ)+O( ) and σ = a i (1 -a i )h i (θ) 2 + O( ). The stationary solution is p s (θ) = 1 Z e -F (θ) , with F (θ) = F i (θ) -F i (θ i-1 ) + i-1 j=1 F j (θ j ) -F j (θ j-1 ), for θ ∈ (θ i-1 , θ i ), where Z = ∞ -∞ e -F (θ) is a normalization constant and F i (θ) = 2 η 1 h i (θ) dθ + log a i (1 -a i )h i (θ) 2 + O( ). Therefore, the ratio of the stationary probability at any two local minimizers θ * p , θ * q is p s (θ * p ) p s (θ * q ) = exp   -   F p (θ * p ) -F p (θ p-1 ) + p-1 j=1 F j (θ j ) -F j (θ j-1 )   +   F q (θ * q ) -F q (θ q-1 ) + q-1 j=1 F j (θ j ) -F j (θ j-1 )     = exp   -F p (θ * p ) + F q (θ * q ) + q-1 j=p F j (θ j ) -F j+1 (θ j )   = exp - 2 η(1 -a p ) g p (θ * p ) + 2 η(1 -a q ) g q (θ * q ) + log a q (1 -a q )h q (θ * q ) 2 a q (1 -a p )h p (θ * p ) 2 • exp   q-1 j=p 2 η(1 -a j ) g j (θ j ) - 2 η(1 -a j+1 ) g j+1 (θ * j ) + log a j (1 -a j )h j (θ j ) 2 a j+1 (1 -a j+1 )h j+1 (θ j ) 2   + O( ). Published as a conference paper at ICLR 2021 By the assumption that g p (θ * p ) = g q (θ * q ), h p (θ * p ) = h q (θ * q ) and g i (θ i-1 ) = g i (θ i ) = g i+1 (θ i ), h i (θ i ) = h i+1 (θ i ) for all i, then the above ratio can be simplified to p s (θ * p ) p s (θ * q ) = exp 2 η g p (θ p ) -g p (θ * p ) 1 1 -a p - 1 1 -a q + O( ) = > 1, if a p > a q ; < 1, if a p < a q , where the last inequality can be easily derived from that g p (θ p )-g p (θ θ) dθ > 0 because of the strictly convexity of h p . * p ) = θp θ * p 1 h p ( Proof of Lemma 6. For reweighting, with probability f i , the gradient is ai fi h i (θ) for θ ∈ (θ i-1 , θ i ), and O( ) for θ / ∈ (θ i-1 , θ i ). Therefore, the expectation and variance in θ) is a normalization constant and (θ i-1 , θ i ) are µ = a i h i (θ) + O( ) and σ = (1-fi)a 2 i fi h i (θ) 2 + O( ). The stationary solution p w (θ) = 1 Z e -F (θ) , with F (θ) = F i (θ) -F i (θ i-1 ) + i-1 j=1 F j (θ j ) -F j (θ j-1 ), for θ ∈ (θ i-1 , θ i ), where Z = ∞ -∞ e -F ( F i (θ) = 2f i ηa i (1 -f i ) 1 h i (θ) dθ + log (1 -f i )a 2 i f i h i (θ) 2 + O( ) Therefore, the ratio of the stationary probability at any two local minimizers θ * p , θ * q is p w (θ * p ) p w (θ * q ) = exp   -F p (θ * p ) + F q (θ * q ) + q-1 j=p F j (θ j ) -F j+1 (θ j )   = exp - 2f p ηa p (1 -f p ) g p (θ * p ) + 2f q ηa q (1 -f q ) g q (θ * q ) + log f p (1 -f q )a 2 q h q (θ * q ) 2 f q (1 -f p )a 2 p h p (θ * p ) 2 • exp   q-1 j=p 2 η(1 -a j ) g j (θ j ) - 2 η(1 -a j+1 ) g j+1 (θ * j ) + log f j (1 -f j )a 2 j h j (θ j ) 2 f j+1 (1 -f j+1 )a 2 j+1 h j+1 (θ j ) 2   + O( ) By the assumption that g p (θ * p ) = g q (θ * q ), h p (θ * p ) = h q (θ * q ) and g i (θ i-1 ) = g i (θ i ) = g i+1 (θ i ), h i (θ i ) = h i+1 (θ i ) for all i, then the above ratio can be simplified to p w (θ * p ) p w (θ * q ) = exp 2 η g p (θ p ) -g p (θ * p ) f p a p (1 -f p ) - f q a q (1 -f q ) + O( ). Follow-up discussions of Lemma 5 and 6 We first note that θ) dθ > 0 due to the strictly convexity of h p . Therefore, one can see from Lemma 5 that for resampling, the stationary solution always has the highest probability at the global minimizer. On the other hand, for the stationary solution of reweighting in Lemma 6, let us consider the case when a p > a q . In this case, V (θ * p ) < V (θ * q ), therefore, one expects the above ratio larger than 1, which implies that fp ap(1-fp) -fq aq(1-fq) > 0. Note that if f p = a p , f q = a q , then this term is always larger than 0, but when f p , f q are significantly different from a p , a q in the sense that f p < f q and f p < a p , f q > a q , then fp ap(1-fp) -fq aq(1-fq) < 0, which will lead to ps(θ * p ) ps(θ * q ) < 1, i.e., higher probability of converging to θ * q , which is not desirable. To sum up, Lemma 6 shows that for reweighting, the stationary solution won't have the highest probability at the global minimizer if the empirical proportion is significantly different fron the population proportion. θp θ * p 1 h p (

B.4 PROOF OF LEMMA 8

Proof. By the variance of the gradients for resampling and reweighting in (5), and given that at the stationary point E[∇V (θ * 1 )] = E[∇V (θ * 2 )] = 0, one can omit the last term in the variance. In addition, since ∇V 1 (θ * 2 ), ∇V 2 (θ * 1 ) = O( ) ∇V 1 (θ * 1 ), ∇V 2 (θ * 2 ) by assumption, all the higher order terms are included in an O( √ ) term. One can then derive Lemma 8 from Lemma 7.

C A JUSTIFICATION OF THE SDE APPROXIMATION

The stochastic differential equation approximation of SGD involving data-dependent covariance coefficient Gaussian noise was first introduced in (Li et al., 2017) and justified in the weak sense. Consider the SDE dΘ = b(Θ)dt + σ(Θ)dB. The Euler-Maruyama discretization with time step η results in Θ k+1 = Θ k + ηb(Θ k ) + √ ησ(Θ k )Z k , Z k ∼ N (0, 1), Θ 0 = θ 0 . In our case, b(•) = -V (•). When b satisfies Lipschitz continuity and some technical smoothness conditions, according to (Li et al., 2017) for any function g from a smooth class M, there exists C > 0 and α > 0 such that for all k = 0, 1, 2, • • • , N , |E[g(Θ kη )] -E[g(θ k )]| ≤ Cη α . However, as the loss function considered in this paper has jump discontinuous in the first derivative, the classical approximation error results for SDE do not apply. In fact, the problem V / ∈ C 1 (R n ) is a common issue in machine learning and deep neural networks, as many loss functions involves non-smooth activation functions such as ReLU and leaky ReLU. In our case, we need to justify the SDE approximation adopted in Section 3. It turns out that strong approximation error can be obtained if  with b = (G • b + 1 2 G • σ 2 ) • G -1 and σ = (G • σ) • G -1 . ( ) As the SGD updates can essentially be viewed as data from the Euler-Maruyama scheme, considering Z k as updates from Euler-Maruyama scheme leads to E[|Θ kη -θ k |] ≤ c 1 E[|Z kη -G • θ k |] = c 1 E[|Z kη -Z k + Z k -G • θ k |] ≤ c 2 √ η + c 1 E[|Z k -G • θ k |]. To control the second item, we introduce θ t := θ k + b(θ k )(t -kη) + t -kησ(θ k )Z k , where t ∈ [0, kη]. Then as shown in (Müller-Gronbach et al., 2020) , Figure 5 : A comparison of reweighting (upper row) and resampling (lower row) with a 1 /a 2 = 0.4/0.6 and f 1 /f 2 = 0.9/0.1 at various learning rates η. All experiments start at θ 0 = 1.6. We can see that unless the learning rate η < 0.4, resampling is more stable near the minimizer θ = 1. Figure 6 : A comparison of reweighting (upper row) and resampling (lower row) with a 1 /a 2 = 0.4/0.6, f 1 /f 2 = 0.9/0.1 and = 0.1 at various learning rates η. All experiments start at θ 0 = 0.9. We can see that unless the learning rate η < 0.12, resampling is more reliable in the sense that its trajectory stays around the desired minimizer. E[|Z k -G • θ k |] ≤ c √ η + cE



-, a 2 = 1 2 + with a small constant > 0 and f 2 /f 1 1. For reweighting, the global minimum θ = 1 is stochastically stable only if η(1 + f 1 /f 2 ) ≤ 4 + O( ). This condition becomes rather stringent in terms of the learning rate η since f 1 /f 2 1. On the other hand, the local minimizer θ = -1 is stable if η(1 + f 2 /f 1 ) ≤ 4 + O( ), which could be satisfied for a broader range of η because f 2 /f 1 1. In other words, for a fixed learning rate η, when the ratio f 2 /f 1 between the sampling proportions is sufficiently small, the desired minimizer θ = 1 is no longer statistically stable with respect to SGD.4 SDE ANALYSISThe stability analysis can only be carried for a learning rate η of a finite size. However, even for a small learning rate η, one can show that the reweighting method is still unreliable from a different perspective. This section applies stochastic differential equation analysis to demonstrate it. The formal definition of θ • : Let θ(t) be a path with θ(0) = θ * 1 , θ(1) = θ * , then θ(t) = arg min θ(t) sup t∈[0,1] V (θ(t)) is the path with minimal saddle point height among all continuous paths. θ • = sup t∈(0,1) θ(t) is the saddle point of this path. https://scikit-learn.org/stable https://www.kaggle.com/camnugent/california-housing-prices



Figure 4: The left plot shows the approximate value function obtained by the two methods. The right plot is the evolution of the relative error log( et e0), where the absolute error e t = V t (s) -V π (s) 2 2 . RW and RS in the upright corner represent reweighting and resampling, respectively. c determines the behavior policy µ(a i |s) = 1 2 + ca i . The value function is trained on a trajectory with length 10 5 generated by the behavior policy. The value function obtained from resampling is fairly close to the exact value function, while the results of reweighting gets worse as the behavior policy gets further from the target policy.

the noise coefficient σ is Lipschitz continuous and non-degenerate, and • the drift coefficient b is piece-wise Lipschitz continuous, in the sense that b has finitely many discontinuity points-∞ = ξ 0 < ξ 1 < • • • < ξ m < ξ m+1 = ∞ and in each interval (ξ i-1 , ξ i ), b is Lipschitz continuous.Under these conditions, the following approximation result holds: for all k = 0, 1, 2, • • • , N , there exists C > 0 such thatE[|Θ kη -θ k |] ≤ C √ η. (16)Here Θ kη is the solution to SDE at time kη. The proof strategy closely follows from(Müller- Gronbach et al., 2020). The key is to construct a bijective mapping G : R → R that transforms (14) to SDE with Lipschitz continuous coefficients. With such a bijection G, one can define a stochastic process Z : [0, T ] × Ω → R by Z t = G(Θ t ) and the transformed SDE isdZ t = b(Z t )dt + σdB t , t ∈ [0, T ], Z 0 = G(Θ 0 ),

B (θ t , θ k )dt , with B being the set of pairs (y 1 , y 2 ) ∈ R 2 where the joint Lipschitz estimate |b(y 1 ) -b(y 2 )| does not apply due to at least one discontinuity. In(Müller-Gronbach et al., 2020), it is estimated byE kη 0 1 B (θ t , θ k )dt ≤ c √ η,which leads us to (16).D NUMERICAL COMPARISONS WITH DIFFERENT LEARNING RATESIn this section, we present extensive numerical results to show the effect of learning rates in our toy examples. The Figure5corresponds to the example in Section 3, and Figure6corresponds to the example in Section 4.

The loss takes the binary cross-entropy with a 3-layer neural network. We see that in average of 10 trials, the resampling method (oversampling) achieves the lowest training loss and highest ROC-AUC score over testing data among all tested cases.Nonlinear Regression. This experiment uses the California Housing Prices data set 3 to predict the median house values. The target median house values, ranging from 15k to 500k, are distributed quite non-uniformly. We select subgroups with median house values > 400k (1726 in total) and < 200k (11767 in total) and combine them to make our dataset. In the preprocessing step, we drop the "ocean proximity" feature and randomly set 30% of the data to be the test data. The remaining training data set with 8 features is fed into a 3-layer neural network. The population proportion of two subgroups is assumed to be a 1 /a 2 ≈ 1, while resampling and reweighting are tested with various sampling ratios f 1 /f 2 near 11767/1726. Their performance of is compared also with the baseline. In each test, the mean squared error (MSE) is chosen as the loss function and Adam is used as the optimizer in the model. The batch-size is 32 and the number of epochs is 400 for each case. As shown in Table2, resampling significantly outperforms reweighting for all sampling ratios in terms of a lower averaged MSE, and its good stability is reflected in its lowest standard deviation for multiple runs.

Mean squared errors (MSE) for nonlinear regression problems. RS stands for resampling and RW for reweighting. The weights used in reweighting are a1 f1 and a2 f2 , respectively. For each case, we run experiments for 10 times and compute the corresponding mean and standard deviation. Resampling (oversampling the minor group) achieves the lowest mean and standard deviation of MSE among all tested cases.

ACKNOWLEDGEMENTS

The work of L.Y. and Y.Z. is partially supported by the U.S. Department of Energy via Scientific Discovery through Advanced Computing (SciDAC) program and also by the National Science Foundation under award DMS-1818449. J.A. is supported by Joe Oliger Fellowship from Stanford University.

