ON THE IMPLICIT BIAS OF WEIGHT DECAY IN SHAL-LOW UNIVARIATE RELU NETWORKS

Abstract

We give a complete characterization of the implicit bias of infinitesimal weight decay (i.e. an `2 penalty on network weights) in the modest setting of univariate one layer ReLU networks. Our main result is a surprisingly simple geometric description of all one layer ReLU networks that exactly fit a dataset D = {(x i , y i )} with the minimum value of the `2-norm of the neuron weights. Specifically, we prove that such functions must be either concave or convex between any two consecutive data sites x i and x i+1 . Our description implies that interpolating ReLU networks with weak `2-regularization achieve the best possible `1 generalization error for learning 1d Lipschitz functions, up to universal constants.

1. INTRODUCTION

The ability of overparameterized neural networks to simultaneously fit training data (i.e. interpolate) and generalize to unseen test data (i.e. extrapolate) is a robust empirical finding that underpins the success of deep learning in computer vision He et al. (2016) ; Krizhevsky et al. (2012) , natural language processing Brown et al. ( 2020), and reinforcement learning Jumper et al. (2021) ; Silver et al. (2016) ; Vinyals et al. (2019) . This observation is surprising when viewed from the lens of traditional learning theory Bartlett & Mendelson (2002) ; Vapnik & Chervonenkis (1971) , chiefly because such complexity-based methods are agnostic to the choice of optimizer and seek to predict generalization based solely on the complexity of the overall hypothesis class and how well a learned model fits the training data. In an overparameterized neural network, however, the quality of predictions at test time often varies dramatically across settings of trainable parameters (e.g. weights and biases) that exactly fit all training data Zhang et al. (2017) . Which setting of parameters is learned depends crucially on the optimization procedure, and an insightful analysis of generalization in the presence of overparameterization must therefore combine properties of the model class with the often subtle criteria according to which different minimizers of an empirical risk are selected by different optimizers. This has led to a vibrant sub-field of deep learning theory that analyzes the implicit bias or implicit regularization of optimizers used in practice Arora et al. (2019) ; Blanc et al. (2020) ; Gunasekar et al. (2018) ; Hanin & Sun (2021) ; Jacot et al. (2020) ; Ma et al. (2018) ; Razin & Cohen (2020) ; Smith et al. (2021) . The high level goal of this line of work is to explain how optimization hyperparameters such as initialization scheme, learning rate, batch size, data augmentation scheme, and choice of explicit regularizer influence which of the many global minima of the empirical risk are selected in the course of optimization. A key difficulty in studying implicit bias is that it is unclear how to understand, concretely in terms of the network function, the effect of particular optimization hyperparameters. For example, a wellchosen initialization for gradient-based optimizers is key to ensuring good generalization properties of the resulting learned network He et al. (2015) ; Mishkin & Matas (2015) ; Xiao et al. (2018) . However, the corresponding geometric or analytic properties of the learned network are often hard to pin down, obscuring our understanding of what it is about the learned functions that encourages generalization. In a similar vein, it is standard practice to experiment with explicit regularizers such as an `2 penalty on network weights. While the effect of this choice is easy to describe in terms of model parameters (e.g. it tends to make them smaller), it is typically challenging to translate such a description into 2019) explore and develop the fact that `2 regularization on parameters in this setting is provably equivalent to penalizing the total variation of the derivative of the network function (cf eg Theorem 1.3 from prior work below). These articles apply to networks with any input dimension. In this article, however, we consider the simplest case of input dimension 1 and significantly refine these prior results to give a complete geometric answer to how interpolating ReLU networks with a weak `2 penalty use training data to make predictions on unseen data. Our main results are: 1. We consider a dataset D = {(x i , y i )} with x i , y i 2 R and give a complete description of the space of one layer ReLU networks with a single linear unit which fit the data and, among all such interpolating networks, do so with the minimal `2 norm of the neuron weights. There are infinitely many such networks, and they are described by the constraint that they fit the data with as few inflection points as possible (see Thms. 1.1, 1.2). 2. The above description of the space of interpolants of D gives uniform control of the Lipschitz constant of any such interpolant and immediately yields sharp generalization bounds for learning 1d Lipschitz functions. This is stated in Corollary 1.1. Specifically, if the dataset D is generated by setting y i = f ⇤ (x i ) for f ⇤ : [0, 1] ! R a Lipschitz function, then any one layer ReLU network with a single linear unit which interpolates D but does so with minimal `2-norm of the network parameters will generalize as well as possible to unseen data, up to a small universal multiplicative constant. To the author's knowledge this is the first time such generalization guarantees have been obtained.

Let us denote

[t] + := ReLU(t) = max {0, t} and consider a one layer ReLU network z(x) = z(x; ✓) = z(x; ✓, n) := ax + b + n X j=1 W (2) j h W (1) j x + b (1) i i + (1) with input and output dimensions equal to 1 and a single linear unitfoot_0 and . For a given dataset D = {(x i , y i ) , i = 1, . . . , m} , 1 < x 1 < • • • < x m < 1, y i 2 R, if the number of datapoints m is smaller than the network width n, there are infinitely many choices of the parameter vector ✓ for which z(x; ✓) interpolates (i.e. fits) the data: z(x i ; ✓) = y i , 8 i = 1, . . . , m. Without further information about how ✓ was selected, little can be said about the function x 7 ! z(x; ✓) on intervals (x i , x i+1 ) between two consecutive datapoints when n is much larger than m. This precludes useful generalization guarantees that hold uniformly over all ✓ subject only to the interpolation condition equation 2. In practice interpolants are not chosen arbitrary. Instead, they are typically learned by some variant of gradient descent starting from a random initialization. For a given network architecture, initialization scheme, optimizer, data augmentation scheme, regularizer, and so on, understanding how the learned network uses the known labels 2018), etc that, while perhaps not explicitly regularized, are hopefully small in trained networks. The idea is then that these complexity measures being small gives additional constrains on the capacity of the space of learned networks. {y i , i = 1, . . . , We refer the interested reader to Jiang et al. (2019) for a review and empirically comparison of many such approaches. In this article, we take a different approach to studying generalization. We do not seek general results that are valid for any network architecture. Instead, our goal is to describe completely, in concrete geometrical terms, the properties of one layer ReLU networks z(x; ✓) that interpolate a dataset D in the sense of equation 2 with the minimal possible `2 penalty C(✓) = C(✓, n) = 1 2 n X j=1 ✓ W (1) j 2 + W (2) j 2 ◆ on the neuron weights. More precisely, we study the space of ridgeless ReLU interpolants RidgelessReLU(D) of a dataset D, defined by {f : R ! R | 9✓, n s.t. f (x) = z(x; ✓) 8x 2 R, z(x i ; ✓) = y i 8i = 1, . . . , m, C(✓) = C ⇤ } , ⇤ := inf ✓,n {C(✓, n) | z(x i ; n, ✓) = y i 8(x i , y i ) 2 D} . While we do not prove this directly here, a simple intuition for the elements of RidgelessReLU(D) is that they are all univariate one layer ReLU networks that minimize a weakly penalized loss L(✓; D) + C(✓), ⌧ 1, ( ) where L is an empirical loss, such as the mean squared error over D, and the strength of the weight decay penalty C(✓) is infinitesimal. There is an important subtlety in the definition of RidgelessReLU(D). Namely, given ✓, there exist infinitely many e ✓ such that z(x; ✓) = z(x; e ✓) for every x. Thus, a function f belongs to RidgelessReLU(D) if f interpolates the dataset D and z(x; ✓) = f (x) for some setting of ✓ that achieves the minimal value of C(✓) among all such interpolants. It it plausible but by no means obvious that, with high probability, gradient descent from a random initialization and a weight decay penalty whose strength decreases to zero over training converges to an element in RidgelessReLU(D). This article does not study optimization, and we therefore leave this as an interesting open problem. Our main result is simple description of RidgelessReLU(D) and can informally be stated as follows: Theorem 1.1 (Informal Statement of Theorem 1.2). Fix a dataset D = {(x i , y i ), i = 1, . . . , m} and define ✏ i := sgn (s i s i 1 ) , s i := y i+1 y i x i+1 x i . Note that s i is the slope of the line connecting (x i , y i ) to (x i+1 , y i+1 ) and that ✏ i is an estimate for the sign of the local curvature of the function that generated the data (Figure 1 ). Among all continuous and piecewise linear functions f that interpolate D exactly, the ones in RidgelessReLU(D) are precisely those that: • Are linear (or more precisely affine) on intervals (x i , x i+1 ) when neighboring datapoints disagree on the local curvature in the sense that ✏ i • ✏ i+1 6 = 1. • Are convex (resp. concave) on sequences of intervals (x i , x i+1 ), . . . , (x i+q 1 , x i+q ) on which datapoints x i , . . . , x i+q agree on the local curvature in the sense that ✏ i = • • • = ✏ i+q = 1 (resp. ✏ i = ✏ i+1 = 1 ). On such intervals f lies below (resp. above) the straight line interpolant of the data. See Figures 5 and 7 . Before giving a precise statement our results, we mention that, as described in detail below, the space RidgelessReLU(D) has been considered in a number of prior articles Ongie et al. ( 2019); Parhi & Nowak (2020a); Savarese et al. (2019) . Our starting point will be the useful but abstract characterization of RidgelessReLU(D) they obtained in terms of the total variation of the derivative of z(x; ✓) (see equation 5). Let us also note that the conclusions of Theorem 1.1 (and Theorem 1.2) also hold under seemingly very different hypotheses from ours. Namely, instead of `2-regularization on the parameters, Blanc et al. (2020) considers SGD training for mean squared error with iid noise added to labels. Their Theorem 2 shows (modulo some assumptions about interpreting the derivative of the ReLU) that, among all ReLU networks a linear unit that interpolate a dataset D, the only ones that minimize the implicit regularization induced by adding iid noise to SGD are precisely those that satisfy the conclusions of Theorem 1.1 and hence are exactly the networks in RidgelessReLU(D). This suggests that our results hold under much more general conditions. It would be interesting to characterize them. Further, our characterization of RidgelessReLU(D) in Theorem 1.2 immediately implies strong generalization guarantees uniformly over RidgelessReLU(D). We give a representative example in Corollary 1.1, which shows that such ReLU networks achieve the best possible generalization error of Lipschitz functions, up to constants. Finally, note that we allow networks z(x; ✓) of any width but that if the width n is too small relative to the dataset size m, then the interpolation condition equation 2 cannot be satisfied. Also, we point out that in our formulation of the cost C(✓) we have left both the linear term ax + b and the neuron biases unregularized. This is not standard practice but seems to yield the cleanest results.

1.2. STATEMENT OF RESULTS AND RELATION TO PRIOR WORK

Every ReLU network z(x; ✓) is a continuous and piecewise linear function from R to R with a finite number of affine pieces. Let us denote by PL the space of all such functions and define PL(D) := {f 2 PL| f (x i ) = y i 8i = 1, . . . , m} to be the space of piecewise linear interpolants of D. Perhaps the most natural element in PL(D) is the "connect-the-dots interpolant" f D : R ! R given by f D (x) := 8 < : `1(x), x<x 2 `i(x), x i < x < x i+1 , i = 2, . . . , m 2 `m 1 (x), x>x m 1 , where for i = 1, . . . , m 1, we've set `i(x) := (x x i )s i + y i , s i := y i+1 y i x i+1 x i . See Figure 1 . In addition to f D , there are many other elements in RidgelessReLU(D). Theorem 1.2 gives a complete description of all of them phrased in terms of how they may behave on intervals (x i , x i+1 ) between consecutive datapoints. Our description is based on the signs ✏ i = sgn (s i s i 1 ) , 2  i  m of the (discrete) second derivatives of f D at the inputs x i from our dataset. Theorem 1.2. The space RidgelessReLU(D) consists of those f 2 PL(D) satisfying: 1. f coincides with f D on the following intervals: (1a) Near infinity, i.e. on the intervals ( 1, x 2 ), (x m 1 , 1) (1b) Near datapoints that have zero discrete curvature, i.e. on intervals (x i 1 , x i+1 ) with i = 2, . . . , m 1 such that ✏ i = 0. (1c) Between datapoints with opposite discrete curvature, i.e. on intervals (x i , x i+1 ) with i = 2, . . . , m 1 such that ✏ i • ✏ i+1 = 1. 2. f is convex (resp. concave) and bounded above (resp. below) by f D between any consecutive datapoints at which the discrete curvature is positive (resp. negative). Specifically, suppose for some 3  i  i + q  m 2 that x i and x i+q are consecutive discrete inflection points in the sense that ✏ i 1 6 = ✏ i , ✏ i = • • • = ✏ i+q , ✏ i+q 6 = ✏ i+q+1 . If ✏ i = 1 (resp. ✏ i = 1), then restricted to the interval (x i , x i+q ), f is convex (resp. concave) and lies above (resp. below) the incoming and outgoing support lines and below (resp. above) f D : ✏ i = 1 =) max {`i 1 (x), `i+q (x)}  f (x)  f D (x) ✏ i = 1 = ) min {`i 1 (x), `i+q (x)} f (x) f D (x) for all x 2 (x i , x i+q ). We prove Theorem 1.2 in §A. Before doing so, let us illustrate Theorem 1.2 as an algorithm that, given the dataset D, describes all elements in RidgelessReLU(D) (see Figures 5 and 7 ): Step 1 Linearly interpolate the endpoints: by property (1), f 2 RidgelessReLU(D) must agree with f D on ( 1, x 2 ) and (x m 1 , 1). Step 2 Compute discrete curvature: for i = 2, . . . , m 1 calculate the discrete curvature ✏ i at the data point x i . Step 3 Linearly interpolate on intervals with zero curvature: for all i = 2, . . . , m 1 at which ✏ i = 0 property (1) guarantees that f coincides with the f D on (x i 1 , x i+1 ). Step 4 Linearly interpolate on intervals with ambiguous curvature: for all i = 2, . . . , m 1 at which ✏ i • ✏ i+1 = 1 property (1) guarantees that f coincides with f D on (x i , x i+1 ). Step 5 Determine convexity/concavity on remaining points: all intervals (x i , x i+1 ) on which f has not yet been determined occur in sequences (x i , x i+1 ), . . . , (x i+q 1 , x i+q ) on which ✏ i+j = 1 or ✏ i+j = 1 for all j = 0, . . . , q. If ✏ i = 1 (resp. ✏ i = 1), then f is any convex (resp. concave) function bounded below (resp. above) by f D and above (resp. below) the support lines `i(x), `i+q (x). 2019) and near eq. ( 17) Savarese et al. (2019) ). For any dataset D we have RidgelessReLU(D) = {f 2 PL(D) | ||Df || T V = ||Df D || T V } . Theorem 1.3 says that RidgelessReLU(D) is precisely the space of functions in PL(D) that achieve the minimal possible total variation norm for the derivative. Intuitively, functions in RidgelessReLU(D) are therefore averse to oscillation in their slopes. The proof of this fact uses a simple idea introduced in Theorem 1 of Neyshabur et al. ( 2014) which leverages the homogeneity of the ReLU to translate between the regularizer C(✓), which is positively homogeneous of degree 2 in the network weights, and the penalty ||Df || T V , which is positively homogeneous of degree 1 in the network function. Theorem 1.2 yields strong generalization guarantees uniformly over RidgelessReLU(D). To state a representative example, suppose D is generated by a function f ⇤ : [0, 1] ! R: y j = f ⇤ (x j ). We then find the following Corollary 1.1 (Sharp generalization on Lipschitz functions over a compact set). Fix a dataset D = {(x i , y i ), i = 1, . . . , m} with x i 2 [0, 1]. We have sup f 2RidgelessReLU(D) ||f || Lip  ||f ⇤ || Lip . Hence, if f ⇤ is L Lipschitz and we denote by := max m+1 i=0 min j6 =i {x i x j } the maximal distance between consecutive training points (with x 0 = 0, x m+1 = 1), then sup f 2RidgelessReLU(D) sup x2[0,1] |f (x) f ⇤ (x)|  L, which is the best generalization error possible, up to multiplicative constants. Proof. Observe that for any i = 2, . . . , m 1 and x 2 (x i , x i+1 ) at which Df (x) exists we have ✏ i (s i 1 s i )  ✏ i (Df (x) s i )  ✏ i (s i+1 s i ). Indeed, when ✏ i = 0 the estimate equation 8 follows from property (1b) in Theorem 1.2. Otherwise, equation 8 follows immediately from the local convexity/concavity of f in property (2). Hence, combining equation 8 with property (1a) shows that for each i = 1, . . . , m 1 ||Df || L 1 (xi,xi+1)  max {|s i 1 | , |s i |} . Again using property (1a) and taking the maximum over i = 2, . . . , m we find ||Df || L 1 (R)  max 1im 1 |s i | = ||f D || Lip . To complete the proof of equation 6 observe that for every i = 1, . . . , m 1 |s i | = y i+1 y i x i+1 x i = f ⇤ (x i+1 ) f ⇤ (x i ) x i+1 x i  ||f ⇤ || Lip =) ||f D || Lip  ||f ⇤ || Lip . Given any x 2 [0, 1], let us write x 0 for its nearest neighbor in {x i , i = 0, . . . , m + 1}. We find |f (x) f ⇤ (x)|  |f (x) f (x 0 )| + |f ⇤ (x 0 ) f ⇤ (x)|  ⇣ ||f || Lip + ||f ⇤ || Lip ⌘ |x x 0 |  L . Taking the supremum over f 2 RidgelessReLU(D) and x 2 [0, 1] proves equation 7. Corollary 1.1 gives the best possible generalization error of Lipschitz functions, up to a universal multiplicative constant, in the sense that if all we knew about f ⇤ : [0, 1] ! R was that it was L-Lipschitz and were given its values on {x i , i = 1, . . . , m}, then we cannot recover f ⇤ in L 1 to accuracy that is better than a constant times L . For m uniformly spaced points we have = 1/m + 1, while classical results (e.g. Theorem 2.2. in Holst (1980)) show that if x i ⇠ Unif([0, 1]) are iid, then is bounded above by a constant time log(m)/m with high probability. 1.3 OUTLINE OF PROOF OF THEOREM 1.2 In this section, we briefly outline the main steps in proving Theorem 1.2: • A "local straightening" result given in Proposition A.1. This shows that any element f in RidgelessReLU(D) be either convex or concave on any interval of the form (x i , x i+1 ) between two consecutive inputs in the training data. The main idea is that non-monotonicity of Df on such intervals can only increase ||Df T V . • A "linearity at endpoints" result given in Proposition A.3. This shows that any element f 2 RidgelessReLU(D) agrees with f D to the left of x 2 and to the right of x m 1 . The main idea is that, given f restricted to (x 2 , x m 1 ), a linear extension of f to the complement of this interval can already interpolate the values at x 1 , x m at zero additional cost to ||Df || T V . • A "left-right compatibility" result given in Propositions A.4, A.5, A.6. This gives constraints, by dividing into cases, on the monotonicity of "incoming slopes" s in (x i ) and "outgoing slopes" s out (x i ) of any f 2 RidgelessReLU(D). The main idea is that the slope of f on each interval (x i , x i+1 ) must attain values that are both less than or equal and great than or equal to the slope s i of the f D . This give constraints between s i 1 , s i , s in (x i ), s out (x i ). • Combining the preceding results allows us to conclude that RidgelessReLU(D) is a subset of the set of functions satisfying the conclusions of Theorem 1.2. • Finally, Proposition A.7 shows that the set of functions satisfying the conclusions of Theorem 1.2 are a subset of RidgelessReLU(D).

1.4. DISCUSSION OF LIMITATIONS AND FUTURE WORK

In this article, we completely characterized all possible ReLU networks that interpolate a given dataset D in the simple setting of weakly `2-regularized one layer ReLU networks with a single linear unit and input/output dimension 1. Moreover, our characterization shows that, to assign labels to unseen data such networks simply "look at the curvature of the nearest neighboring datapoints on each side," in a way made precise in Theorem 1.2. This simple geometric description led to sharp generalization results for learning 1d Lipschitz functions in Corollary 1.1. This opens many direction for future investigation. Theorem 1.2 shows, for instance, that there are infinitely many ridgeless ReLU interpolants of a given dataset D. It would be interesting to understand which ones are actually learned by gradient descent from a random initialization and a weak (or even decaying) `2-penalty in time. Further, as already pointed out after the Theorem 1.1, the conclusions of Theorem 1.2 appear to hold under very different kinds of regularization (e.g. Theorem 2 in Blanc et al. (2020) ). This raises the question: what is the most general kind of regularizer that is equivalent to weight decay, at least in our simple setup? Finally, it would also be quite natural to extend the results in this article to ReLU networks with higher input dimension, for which weight decay is known to correspond to regularization of a certain weighted Radon transform of the network function Ongie et al. (2019) ; Parhi & Nowak (2020a; b; 2021) . Finally, extending the results in this article to deeper networks and beyond fully connected architectures are directions left to future work.



The presence of the linear term ax + b is not really standard in practice but is adopted in keeping with prior workOngie et al. (2019);Parhi & Nowak (2020a);Savarese et al. (2019) since it leads a cleaner mathematical formulation of results.



Figure1: A dataset D with m = 8 points. Shown are the "connect the dots" interpolant f D (dashed line), its slopes s i and the "discrete curvature" ✏ i at each x i .

Figure 2: Step 1

Figure 6: Step 4

m} to extrapolate values of z(x; ✓) for x in intervals (x i , x i+1 ) away from the datapoints in D is an important open problem. To obtain nontrivial generalization estimates and make progress on this problem, a fruitful line of inquiry in prior work has been to search for additional complexity measures based on marginsWei et al.

