SGD THROUGH THE LENS OF KOLMOGOROV COM-PLEXITY

Abstract

We initiate a thorough study of the dynamics of stochastic gradient descent (SGD) under minimal assumptions using the tools of entropy compression. Specifically, we characterize a quantity of interest which we refer to as the accuracy discrepancy. Roughly speaking, this measures the average discrepancy between the model accuracy on batches and large subsets of the entire dataset. We show that if this quantity is sufficiently large, then SGD finds a model which achieves perfect accuracy on the data in O(1) epochs. On the contrary, if the model cannot perfectly fit the data, this quantity must remain below a global threshold, which only depends on the size of the dataset and batch. We use the above framework to lower bound the amount of randomness required to allow (non-stochastic) gradient descent to escape from local minima using perturbations. We show that even if the model is extremely overparameterized, at least a linear (in the size of the dataset) number of random bits are required to guarantee that GD escapes local minima in subexponential time.

1. INTRODUCTION

Stochastic gradient descent (SGD) is at the heart of modern machine learning. However, we are still lacking a theoretical framework that explains its performance for general, non-convex functions. Current results make significant assumptions regarding the model. Global convergence guarantees only hold under specific architectures, activation units, and when models are extremely overparameterized (Du et al., 2019; Allen-Zhu et al., 2019; Zou et al., 2018; Zou and Gu, 2019) . In this paper, we take a step back and explore what can be said about SGD under the most minimal assumptions. We only assume that the loss function is differentiable and L-smooth, the learning rate is sufficiently small and that models are initialized randomly. Clearly, we cannot prove general convergence to a global minimum under these assumptions. However, we can try and understand the dynamics of SGD -what types of execution patterns can and cannot happen. Motivating example: Suppose hypothetically, that for every batch, the accuracy of the model after the Gradient Descent (GD) step on the batch is 100%. However, its accuracy on the set of previously seen batches (including the current batch) remains at 80%. Can this process go on forever? At first glance, this might seem like a possible scenario. However, we show that this cannot be the case. That is, if the above scenario repeats sufficiently often the model must eventually achieve 100% accuracy on the entire dataset. To show the above, we identify a quantity of interest which we call the accuracy discrepancy (formally defined in Section 3). Roughly speaking, this is how much the model accuracy on a batch differs from the model accuracy on all previous batches in the epoch. We show that when this quantity (averaged over epochs) is higher than a certain threshold, we can guarantee that SGD convergence to 100% accuracy on the dataset within O(1) epochs w.h.pfoot_0 . We note that this threshold is global, that is, it only depends on the size of the dataset and the size of the batch. In doing so, we provide a sufficient condition for SGD convergence. The above result is especially interesting when applied to weak models that cannot achieve perfect accuracy on the data. Imagine a dataset of size n with random labels, a model with n 0.99 parameters, and a batch of size log n. The above implies that the accuracy discrepancy must eventually go below the global threshold. In other words, the model cannot consistently make significant progress on batches. This is surprising because even though the model is underparameterized with respect to the entire dataset, it is extremely overparameterized with respect to the batch. We verify this observation experimentally (Appendix B). This holds for a single GD step, but what if we were to allow many GD steps per batch, would this mean that we still cannot make significant progress on the batch? This leads us to consider the role of randomness in (non-stochastic) gradient descent. It is well known that overparameterized models trained using SGD can perfectly fit datasets with random labels (Zhang et al., 2017) . It is also known that when models are sufficiently overparameterized (and wide) GD with random initialization convergences to a near global minimum (Du et al., 2019) . This leads to an interesting question: how much randomness does GD require to escape local minima efficiently (in polynomial time)? It is obvious that without randomness we could initialize GD next to a local minimum, and it will never escape it. However, what about the case where we are provided an adversarial input and we can perturb that input (for example, by adding a random vector to it), how many bits of randomness are required to guarantee that after the perturbation GD achieves good accuracy on the input in polynomial time? In Section 4 we show that if the amount of randomness is sublinear in the size of the dataset, then for any differentiable and L-smooth model class (e.g., a neural network architecture), there are datasets that require an exponential running time to achieve any non-trivial accuracy (i.e., better than 1/2 + o(1) for a two-class classification task), even if the model is extremely overparameterized. This result highlights the importance of randomness for the convergence of gradient methods. Specifically, it provides an indication of why SGD converges in certain situations and GD does not. We hope this result opens the door to the design of randomness in other versions of GD.

Outline of our techniques

We consider batch SGD, where the dataset is shuffled once at the beginning of each epoch and then divided into batches. We do not deal with the generalization abilities of the model. Thus, the dataset is always the training set. In each epoch, the algorithm goes over the batches one by one, and performs gradient descent to update the model. This is the "vanilla" version of SGD, without any acceleration or regularization (for a formal definition, see Section 2). For the sake of analysis, we add a termination condition after every GD step: if the accuracy on the entire dataset is 100% we terminate. Thus, in our case, termination implies 100% accuracy. To achieve our results, we make use of entropy compression, first considered by Moser and Tardos (2010) to prove a constructive version of the Lovász local lemma. Roughly speaking, the entropy compression argument allows one to bound the running time of a randomized algorithmfoot_1 by leveraging the fact that a random string of bits (the randomness used by the algorithm) is computationally incompressible (has high Kolmogorov complexity) w.h.p. If one can show that throughout the execution of the algorithm, it (implicitly) compresses the randomness it uses, then one can bound the number of iterations the algorithm may execute without terminating. To show that the algorithm has such a property, one would usually consider the algorithm after executing t iterations, and would try to show that just by looking at an "execution log" of the algorithm and some set of "hints", whose size together is considerably smaller than the number of random bits used by the algorithm, it is possible to reconstruct all of the random bits used by the algorithm. We apply this approach to SGD with an added termination condition when the accuracy over the entire dataset is 100%. Thus, termination in our case guarantees perfect accuracy. The randomness we compress is the bits required to represent the random permutation of the data at every epoch. So indeed the longer SGD executes, the more random bits are generated. We show that under our assumptions it is possible to reconstruct these bits efficiently starting from the dataset X and the model after executing t epochs. The first step in allowing us to reconstruct the random bits of the permutation in each epoch is to show that under the L-smoothness assumption and a sufficiently small step size, SGD is reversible. That is, if we are given a model W i+1 and a batch B i such that W i+1 results from taking a gradient step with model W i where the loss is calculated with respect to B i , then we can uniquely retrieve W i using only B i and W i+1 . This means that if we can efficiently encode the batches used in every epoch (i.e., using less bits than encoding the entire permutation of the data), we can also retrieve all intermediate models in that epoch (at no additional cost). We prove this claim in Section 2. The crux of this paper is to show that when the accuracy discrepancy is high for a certain epoch, the batches can indeed be compressed. To exemplify our techniques let us consider the scenario where, in every epoch, just after a single GD step on a batch we consistently achieve perfect accuracy on the batch. Let us consider some epoch of our execution, assume we have access to X, and let W f be the model at the end of the epoch. If the algorithm did not terminate, then W f has accuracy at most 1on the entire dataset (assume for simplicity that is a constant). Our goal is to retrieve the last batch of the epoch, B f ⊂ X (without knowing the permutation of the data for the epoch). A naive approach would be to simply encode the indices in X of the elements in the batch. However, we can use W f to achieve a more efficient encoding. Specifically, we know that W f achieves 1.0 accuracy on B f but only 1accuracy on X. Thus it is sufficient to encode the elements of B f using a smaller subset of X (the elements classified correctly by W f , which has size at most (1 -) |X|). This allows us to significantly compress B f . Next, we can use B f and W f together with the reversibility of SGD to retrieve W f -1 . We can now repeat the above argument to compress B f -1 and so on, until we are able to reconstruct all of the random bits used to generate the permutation of X in the epoch. This will result in a linear reduction in the number of bits required for the encoding. In our analysis, we show a generalized version of the scenario above. We show that high accuracy discrepancy implies that entropy compression occurs. For our second result, we consider a modified SGD algorithm that instead of performing a single GD step per batch, first perturbs the batch with a limited amount of randomness and then performs GD until a desired accuracy on the batch is reached. We assume towards contradiction that GD can always reach the desired accuracy on the batch in subexponential time. This forces the accuracy discrepancy to be high, which guarantees that we always find a model with good accuracy. Applying this reasoning to models of sublinear size and data with random labels we arrive at a contradiction, as such models cannot achieve good accuracy on the data. This implies that when we limit the amount of randomness GD can use for perturbations, there must exist instances where GD requires exponential running time to achieve good accuracy.

Related work

There has been a long line of research proving convergence bounds for SGD under various simplifying assumptions such as: linear networks (Arora et al., 2019; 2018) , shallow networks (Safran and Shamir, 2018; Du and Lee, 2018; Oymak and Soltanolkotabi, 2019) , etc. However, the most general results are the ones dealing with deep, overparameterized networks (Du et al., 2019; Allen-Zhu et al., 2019; Zou et al., 2018; Zou and Gu, 2019) . All of these works make use of NTK (Neural Tangent Kernel) (Jacot et al., 2018) and show global convergence guarantees for SGD when the hidden layers have width at least poly(n, L) where n is the size of the dataset and L is the depth of the network. We note that the exponents of the polynomials are quite large. A recent line of work by Zhang et al. (2022) notes that in many real world scenarios models do not converge to stationary points. They instead take a different approach which, similar to us, studies the dynamics of neural networks. They show that under certain assumptions (e.g., considering a fully connected architecture with sub-differentiable and coordinate-wise Lipschitz activations and weights laying on a compact set) the change in training loss gradually converges to 0, even if the full gradient norms do not vanish. In (Du et al., 2017) it was shown that GD can take exponential time to escape saddle points, even under random initialization. They provide a highly engineered instance, while our results hold for many model classes of interest. Jin et al. (2017) show that adding perturbations during the executions of GD guarantees that it escapes saddle points. This is done by occasionally perturbing the parameters within a ball of radius r, where r depends on the properties of the function to be optimized. Therefore, a single perturbation must require an amount of randomness linear in the number of parameters.

2. PRELIMINARIES

We consider the following optimization problem. We are given an input (dataset) of size n. Let us denote X = {x i } n i=1 (Our inputs contain both data and labels, we do not need to distinguish them for this work). We also associate every x ∈ X with a unique id of log n bits. We often consider batches of the input B ⊂ X. The size of the batch is denoted by b (all batches have the same size). We have some model whose parameters are denoted by W ∈ R d , where d is the model dimension. We aim to optimize a goal function of the following type: f (W ) = 1 n x∈X f x (W ) , where the functions f x : R d → R are completely determined by x ∈ X. We also define for every set A ⊆ X: f A (W ) = 1 |A| x∈A f x (W ). Note that f X = f . We denote by acc(W, A) : R d × 2 X → [0, 1] the accuracy of model W on the set A ⊆ X (where we use W to classify elements from X). Note that for x ∈ X it holds that acc(W, x) is a binary value indicating whether x is classified correctly or not. We require that every f x is differentiable and L-smooth: ∀W 1 , W 2 ∈ R d , ∇f x (W 1 ) -∇f x (W 2 ) ≤ L W 1 -W 2 . This implies that every f A is also differentiable and L-smooth. To see this consider the following: ∇f A (W 1 ) -∇f A (W 2 ) = 1 |A| x∈A ∇f x (W 1 ) - 1 |A| x∈A ∇f x (W 2 ) = 1 |A| x∈A ∇f x (W 1 ) -∇f x (W 2 ) ≤ 1 |A| x∈A ∇f x (W 1 ) -∇f x (W 2 ) ≤ L W 1 -W 2 We state another useful property of f A : Lemma 2.1. Let W 1 , W 2 ∈ R d and α < 1/L. For any A ⊆ X, if it holds that W 1 -α∇f A (W 1 ) = W 2 -α∇f A (W 2 ) then W 1 = W 2 . Proof. Rearranging the terms we get that W 1 -W 2 = α∇f A (W 1 )-α∇f A (W 2 ). Now let us consider the norm of both sides: W 1 -W 2 = α∇f A (W 1 )-α∇f A (W 2 ) ≤ α•L W 1 -W 2 < W 1 -W 2 Unless W 1 = W 2 , the final strict inequality holds which leads to a contradiction. The above means that for a sufficiently small gradient step, the gradient descent process is reversible. That is, we can always recover the previous model parameters given the current ones, assuming that the batch is fixed. We use the notion of reversibility throughout this paper. However, in practice we only have finite precision, thus instead of R we work with the finite set F ⊂ R. Furthermore, due to numerical stability issues, we do not have access to exact gradients, but only to approximate values ∇f A . For the rest of this paper, we assume these values are L-smooth on all elements in F d . That is, ∀W 1 , W 2 ∈ F d , A ⊆ X, ∇f A (W 1 ) -∇f A (W 2 ) ≤ L W 1 -W 2 This immediately implies that Lemma 2.1 holds even when precision is limited. Let us state the following theorem: Theorem 2.2. Let W 1 , W 2 , ..., W k ∈ F d ⊂ R d , A 1 , A 2 , ..., A k ⊆ X and α < 1/L. If it holds that W i = W i-1 -α ∇f Ai-1 (W i-1 ), then given A 1 , A 2 , ..., A k-1 and W k we can retrieve W 1 . Proof. Given W k we iterate over all W ∈ F d until we find W such that W k = W -α ∇f Ai-1 (W ). Using Lemma 2.1, there is only a single element such that this equality holds, and thus W = W k-1 . We repeat this process until we retrieve W 1 . SGD We analyze the classic SGD algorithm presented in Algorithm 1. One difference to note in our algorithm, compared to the standard implementation, is the termination condition when the accuracy on the dataset is 100%. In practice the termination condition is not used, however, we only use it to prove that at some point in time the accuracy of the model is 100%. Algorithm 1: SGD i ← 1 // epoch counter W 1,1 is an initial model while True do Take a random permutation of X, divided into batches {B i,j } n/b j=1 for j from 1 to n/b do if acc(W i,j , X) = 1 then Return W i,j W i,j+1 ← W i,j -α∇f Bi,j (W i,j ) i ← i + 1, W i,1 ← W i-1,n/b+1 Kolmogorov complexity The Kolmogorov complexity of a string x ∈ {0, 1} * , denoted by K(x), is defined as the size of the smallest prefix Turing machine which outputs this string. We note that this definition depends on which encoding of Turing machines we use. However, one can show that this will only change the Kolmogorov complexity by a constant factor (Li and Vitányi, 2019) . We also use the notion of conditional Kolmogorov complexity, denoted by K(x | y). This is the length of the shortest prefix Turing machine which gets y as an auxiliary input and prints x. Note that the length of y does not count towards the size of the machine which outputs x. So it can be the case that |x| |y| but it holds that K(x | y) < K(x). We can also consider the Kolmogorov complexity of functions. Let g : {0, 1} * → {0, 1} * then K(g) is the size of the smallest Turing machine which computes the function g. The following properties of Kolmogorov complexity will be of use. Let x, y, z be three strings: • Extra information: K(x | y, z) ≤ K(x | z) + O(1) ≤ K(x, y | z) + O(1) • Subadditivity: K(xy | z) ≤ K(x | z, y) + K(y | z) + O(1) ≤ K(x | z) + K(y | z) + O(1) Random strings have the following useful property (Li and Vitányi, 2019 ): Theorem 2.3. For an n bit string x chosen uniformly at random, and some string y independent of x (i.e., y is fixed before x is chosen) and any c ∈ N it holds that P r [K(x | y) ≥ n -c] ≥ 1 -1/2 c . Entropy and KL-divergence Our proofs make extensive use of binary entropy and KL-divergence. In what follows we define these concepts and provide some useful properties. Entropy: For p ∈ [0, 1] we denote by h(p) = -p log p -(1 -p) log(1 -p) the entropy of p. Note that h(0) = h(1) = 0. KL-divergence: For p, q ∈ (0, 1) let D KL (p q) = p log p q + (1 -p) log 1-p 1-q be the Kullback Leibler divergence (KL-divergence) between two Bernoulli distributions with parameters p, q. We also extend the above for the case where q, p ∈ {0, 1} as follows: D KL (1 q) = D KL (0 q) = 0, D KL (p 1) = log(1/p), D KL (p 0) = log(1/(1 -p)). This is just notation that agrees with Lemma 2.4. We also state the following result of Pinsker's inequality applied to Bernoulli random variables: D KL (p q) ≥ 2(p -q) 2 . Representing sets Let us state some useful bounds on the Kolmogorov complexity of sets. A more detailed explanation regarding the Kolmogorov complexity of sets and permutations together with the proof to the lemma below appears in Appendix A. First, let us define some useful notation (W i,j , B i,j are formally defined in Algorithm 1): • λ i,j = acc(W i,j , X). This is the accuracy of the model in epoch i on the entire dataset X, before performing the GD step on batch j. • ϕ i,j = acc(W i,j , B i,j-1 ). This is the accuracy of the model on the (j -1)-th batch in the i-th epoch after performing the GD step on the batch. • X i,j = j k=1 B i,k (note that ∀i, X i,0 = ∅, X i,n/b = X). This is the set of elements in the first j batches of epoch i. Let us also denote n j = |X i,j | = jb (Note that ∀j, i 1 , i 2 , |X i1,j | = |X i2,j |, thus i need not appear in the subscript). • λ i,j = acc(W i,j , X i,j-1 ), λ i,j = acc(W i,j , X \ X i,j-1 ), where λ i,j is the accuracy of the model on the set of all previously seen batch elements, after performing the GD step on the (j -1)-th batch and λ i,j is the accuracy of the same model, on all remaining elements (j-th batch onward). To avoid computing the accuracy on empty sets, λ i,j is defined for j ∈ [2, n/b + 1] and λ i,j is defined for j ∈ [1, n/b]. • ρ i,j = D KL (λ i,j ϕ i,j ) is the accuracy discrepancy for the j-th batch in iteration i and ρ i = n/b+1 j=2 ρ i,j is the accuracy discrepancy at iteration i. In our analysis, we consider t epochs of the SGD algorithm. Our goal for this section is to derive a connection between t i=1 ρ i and t. Bounding t: Our goal is to use the entropy compression argument to show that if t i=1 ρ i is sufficiently large we can bound t. Let us start by formally defining the random bits which the algorithm uses. Let r i be the string of random bits representing the random permutation of X at epoch i. As we consider t epochs, let r = r 1 r 2 . . . r t . Note that the number of bits required to represent an arbitrary permutation of [n] is given by: log(n!) = n log n -n log e + O(log n) = n log(n/e) + O(log n), where in the above we used Stirling's approximation. Thus, it holds that |r| = t(n log(n/e) + O(log n)) and according to Theorem 2.3, with probability at least 1 -1/n 2 it holds that K(r) ≥ tn log(n/e) -O(log n). In the following lemma we show how to use the model at every iteration to efficiently reconstruct the batch at that iteration, where the efficiency of reconstruction is expressed via ρ i . Lemma 3.1. It holds w.h.p that ∀i ∈ [t] that: K(r i | W i+1,1 , X) ≤ n log n e -bρ i + n b • O(log n) Proof. Recall that B i,j is the j-th batch in the i-th epoch, and let P i,j be a permutation of B i,j such that the order of the elements in B i,j under P i,j is the same as under r i . Note that given X, if we know the partition into batches and all permutations, we can reconstruct r i . According to Theorem 2.2, given W i,j and B i,j-1 we can compute W i,j-1 . Let us denote by Y the encoding of this procedure. To implement Y we need to iterate over all possible vectors in F d and over batch elements to compute the gradients. To express this program we require auxiliary variables of size at most O(log min {d, b}) = O(log n). Thus it holds that K(Y ) = O(log n). Let us abbreviate B i,1 , B i,2 , ..., B i,j as (B i,k ) j k=1 . We write the following. K(r i | X, W i+1,1 ) ≤ K(r i , Y | X, W i+1,1 ) + O(1) ≤ K(r i | X, W i+1,1 , Y ) + K(Y | X, W i+1,1 ) + O(1) ≤ O(log n) + K((B i,k , P i,k ) n/b k=1 | X, W i+1,1 , Y ) ≤ O(log n) + K((B i,k ) n/b k=1 | X, W i+1,1 , Y ) + K((P i,k ) n/b k=1 | X, W i+1,1 , Y ) ≤ O(log n) + K((B i,k ) n/b k=1 | X, W i+1,1 , Y ) + n/b j=1 K(P i,j ) Let us bound K((B i,k ) n/b k=1 | X, W i+1,1 Y ) by repeatedly using the subadditivity and extra information properties of Kolmogorov complexity.

K((B

i,k ) n/b k=1 | X, Y, W i+1,1 ) ≤ K(B i,n/b | X, W i+1,1 ) + K((B i,k ) n/b-1 k=1 | X, Y, W i+1,1 , B i,n/b ) + O(1) ≤ K(B i,n/b | X, W i+1,1 ) + K((B i,k ) n/b-1 k=1 | X, Y, W i,n/b , B i,n/b ) + O(1) ≤ K(B i,n/b | X, W i+1,1 ) + K(B i,n/b-1 | X, W i,n/b , B i,n/b ) + K((B i,k ) n/b-2 k=1 | X, Y, W i,n/b-1 , B i,n/b , B i,n/b-1 ) + O(1) ≤ ... ≤ O( n b ) + n/b j=1 K(B i,j | X, W i,j+1 , (B i,k ) n/b k=j+1 ) ≤ O( n b ) + n/b j=1 K(B i,j | X i,j , W i,j+1 ) where in the transitions we used the fact that given W i,j , B i,j-1 and Y we can retrieve W i,j-1 . That is, we can always bound K(... | Y, W i,j , B i,j-1 , ...) by K(... | Y, W i,j-1 , B i,j-1 , ...) + O(1). To encode the order P i,j inside each batch, b log(b/e) + O(log b) bits are sufficient. Finally we get that: K (r i | X, W i+1,1 ) ≤ O( n b ) + n/b j=1 [K(B i,j | X i,j , W i,j+1 ) + b log(b/e) + O(log b)]. Let us now bound K(B i,j-1 | X i,j-1 , W i,j ). Knowing X i,j-1 we know that B i,j-1 ⊆ X i,j-1 . Thus we need to use W i,j to compress B i,j-1 . Applying Lemma 2.4 with parameters A = B i,j-1 , B = X i,j-1 , γ = b/n j-1 , κ A = ϕ i,j , κ B = λ i,j and g(x) = acc(W i,j , x). We get the following: K(B i,j-1 | X i,j-1 , W i,j ) ≤ b(log( e • n j-1 b ) -ρ i,j ) + O(log n j-1 ) Adding b log(b/e) + O(log b) to the above, we get the following bound on every element in the sum: b(log( e • n j-1 b ) -ρ i,j ) + b log(b/e) + O(log b) + O(log n j-1 ) ≤ b log n j-1 -bρ i,j + O(log n j-1 ) Note that the most important term in the sum is -bρ i,j . That is, the more the accuracy of W i,j on the batch, B i,j-1 , differs from the accuracy of W i,j on the set of elements containing the batch, X i,j-1 , we can represent the batch more efficiently. Let us now bound the sum: n/b+1 j=2 [b log n j-1 -bρ i,j + O(log n j-1 )]. Let us first bound the sum over b log n j-1 : n/b+1 j=2 b log n j-1 = n/b j=1 b log jb = n/b j=1 b(log b + log j) = n log b + b log(n/b)! = n log b + n log n b • e + O(log n) = n log n e + O(log n) Finally, we can write that: K(r i | X, W i+1,1 ) ≤ O( n b ) + n/b+1 j=2 [b log n j-1 -bρ i,j + O(log n)] ≤ n log n e -bρ i + n b • O(log n) Using the above we know that when the value ρ i is sufficiently high, the random permutation of the epoch can be compressed. We use the fact that random strings are incompressible to bound 1 t t i=1 ρ i . Theorem 3.2. If the algorithm does not terminate by the t-th iteration, then it holds w.h.p that ∀t, 1 t t i=1 ρ i ≤ O( n log n b 2 ). Proof. Using arguments similar to Lemma 3.1, we can show that K(r, W 1,1 | X) ≤ K(W t+1,1 ) + O(t)+ t k=1 K(r k | X, W k+1,1 ) (formally proved in Lemma A.3). Combining this with Lemma 3.1, we get that K(r, W 1,1 | X) ≤ K(W t+1,1 ) + t[n(log(n/e) + n•O(log n) b -bρ i + O(log n)]. Our proof implies that we can reconstruct not only r, but also W 1,1 using X, W t+1,1 . Due to the incompressibility of random strings, we get that w.h.p K(r, W 1,1 | X) ≥ d + tn log(n/e) -O(log n). Combining the lower and upper bound for K(r, W 1,1 | X) we can get the following inequality: d + tn log(n/e) -O(log n) ≤ d + t[n(log(n/e) + n • O(log n) b + O(log n)] - t i=1 bρ i (1) =⇒ 1 t t i=1 ρ i ≤ n • O(log n) b 2 + O(log n) b β(n,b) + O(log n) bt = O( n log n b 2 ) Let β(n, b) be the exact value of the asymptotic expression in Inequality 1. Theorem 3.2 says that as long as SGD does not terminate the average accuracy discrepeancy cannot be too high. Using the contra-positive we get the following useful corollary (proof is deferred to Appendix A.3). Corollary 3.3. If ∀k, 1 k k i=1 ρ i > β(n, b) + γ, for γ = Ω(b -1 log n), then w.h.p SGD terminates within O(1) epochs. The case for weak models Using the above we can also derive some interesting negative results when the model is not expressive enough to get perfect accuracy on the data. It must be the case that the average accuracy discrepancy tends below β(n, b) over time. We verify this experimentally on the MNIST dataset (Appendix B), showing that the average accuracy indeed drops over time when the model is weak compared to the dataset. We also confirm that the dependence of the threshold in b is indeed inversely quadratic.

4. THE ROLE OF RANDOMNESS IN GD INITIALIZATION

Our goal for this section is to show that when the amount of randomness in the perturbation is too small, for any model architecture which is differentiable and L-smooth there are inputs for which Algorithm 2 requires exponential time to terminate, even for extremely overparameterized models. Perturbation families Let us consider a family of 2 functions indexed by length real valued vectors Ψ = {ψ z } z∈R . Recall that throughout this paper we assume finite precision, thus every z can be represented using O( ) bits. We say that Ψ is a reversible perturbation family if it holds that ∀z ∈ R , ψ z is one-to-one. We often use the notation Ψ (W ), which means pick z ∈ R uniformly at random, and apply ψ z (W ). We often refer to Ψ as simply a perturbation. We note that the above captures a wide range of natural perturbations. For example ψ z (W ) = W +W z where W z [i] = z[i mod ]. Clearly ψ z (W ) is reversible.

Gradient descent

The GD algorithm we analyze is formally given in Algorithm 2. Algorithm 2: GD(W, Y, δ) Input: initial model W , dataset Y , desired accuracy δ i = 1, T = o(2 m ) + poly(d) W = Ψ (W ) while acc(W, Y ) < δ and i < T do W ← W -α∇f Y (W ) i ← i + 1 Return W Let us denote by m the number of elements in Y . We make the following 2 assumptions for the rest of this section: (1) = o(m). (2) There exists T = o(2 m ) + poly(d) and a perturbation family Ψ such that for every input W, Y within T iterations GD terminates and returns a solution that has at least δ accuracy on Y with constant probability. We show that the above two assumptions cannot hold together. That is, if the amount of randomness is sublinear in m, there must be instances with exponential running time, even when d m. To show the above, we define a variant of SGD, which uses GD as a sub procedure (Algorithm 3). Assume that our data set is a binary classification task (it is easy to generalize our results to any number of classes), and that elements in X are assigned random labels. Furthermore, let us assume that d = o(n), e.g., d = n 0.99 . It holds that w.h.p we cannot train a model with d parameters that achieves any accuracy better than 1/2 + o(1) on X (Lemma A.4). Let us take to be a small constant. We show that if assumptions 1 and 2 hold, then Algorithm 3 must terminate and return a model with 1/2 + Θ(1) accuracy on X, leading to a contradiction. Our analysis follows the same line as the previous section, and uses the same notation. Reversibility First, we must show that Algorithm 3 is still reversible. Note that we can take the same approach as before, where the only difference is that in order to get W i,j from W i,j+1 we must now get all the intermediate values from the call to GD. As the GD steps are applied to the same batch, this amounts to applying Lemma 2.1 several times instead of once per iteration. More specifically, we must encode for every batch a number T i,j = o(2 b ) + poly(d) = o(2 b ) + poly(n) (recall that d = o(n)) and apply Lemma 2.1 T i,j times. This results in ψ z (W i,j ). If we know z, Ψ then we can retrieve ψ z and efficiently retrieve W i,j using only O(log d) = O(log n) additional bits (by iterating over all values in F d ). Therefore, in every  (W i,j , X) ≥ 1/2(1 -) then Return W i,j W i,j+1 ← GD(W i,j , B i,j , 1 2(1-2 ) ) i ← i + 1, W i,1 ← W i-1,n/b+1 (r i | W i+1,1 , X, Ψ ) ≤ n log n e - bρ i + β(n, b) + o(n). We show that under our assumptions, Algorithm 3 must terminate, leading to a contradiction. Lemma 4.2. Algorithm 3 with b = Ω(log n) terminates within O(T ) iterations w.h.p. Proof. Our goal is to lower bound ρ i = n/b+1 j=2 D KL (λ i,j ϕ i,j ). Let us first upper bound λ i,j . Using the fact that λ i,j ≤ nλi,j (j-1)b (Lemma A.5) combined with the fact that λ i,j ≤ 1/2(1 -) as long as the algorithm does not terminate, we get that ∀j ∈ [2, n/b + 1] it holds that λ i,j ≤ n 2(1-)(j-1)b . Using the above we conclude that as long as we do not terminate it must hold that λ i,j ≤ 1 2(1-) 2 whenever j ∈ I = [(1 -)n/b + 1, n/b + 1]. That is, λ i,j must be close to λ i,j towards the end of the epoch, and therefore must be sufficiently small. Note that |I| ≥ n/b. We know that as long as the algorithm does not terminate it holds that ϕ i,j > 1/2(1 -2 ) with some constant probability. Furthermore, this probability is taken over the randomness used in the call to GD (the randomness of the perturbation). This fact allows us to use Hoeffding-type bounds for the ϕ i,j variables. If ϕ i,j > 1/2(1 -2 ) we say that it is good. Therefore in expectation a constant fraction of ϕ i,j , j ∈ I are good. Applying a Hoeffding type bound we get that w.h.p a constant fraction of ϕ i,j , j ∈ I are good. Denote these good indices by I g ⊆ I. We are now ready to bound ρ i . ρ i = n/b+1 j=2 D KL (λ i,j ϕ i,j ) ≥ j∈Ig D KL (λ i,j ϕ i,j ) ≥ j∈Ig D KL ( 1 2(1 -) 2 1 2(1 -2 ) ) ≥ Θ( n b ) • ( 1 2(1 -2 ) - 1 2(1 -) 2 ) 2 = Θ( n b ) • 5 = Θ( n b ) Where in the transitions we used the fact that KL-divergence is non-negative, and Pinsker's inequality. Finally, requiring that b = Ω(log n) we get that bρ  i -β(n, b) -o(n) = Θ(n) -Θ( n log n log 2 n ) -o(n) = Θ(n). ( pγ q ) + (1 -q)h( (1 -p)γ (1 -q) ) = -q( pγ q log pγ q + (1 - pγ q ) log(1 - pγ q )) -(1 -q)( (1 -p)γ (1 -q) log (1 -p)γ (1 -q) + (1 - (1 -p)γ (1 -q) ) log(1 - (1 -p)γ (1 -q) )) = -(pγ log pγ q + (q -pγ) log q -pγ q ) -((1 -p)γ log (1 -p)γ (1 -q) + ((1 -q) -(1 -p)γ) log (1 -q) -(1 -p)γ 1 -q ) = -γ log γ -γD KL (p q) -(q -pγ)(log q -pγ q ) -((1 -q) -(1 -p)γ) log (1 -q) -(1 -p)γ 1 -q ) Where in the last equality we simply sum the first terms on both lines. To complete the proof we use the log-sum inequality for the last expression. The log-sum inequality states that: Let {a k } m k=1 , {b k } m k=1 be non-negative numbers and let a =  = q -pγ, a 2 = (1 -q) -(1 -p)γ, a = 1 -γ and b 1 = q, b 2 = 1 -q, b = 1, getting that: (q -pγ)(log q -pγ q ) + ((1 -q) -(1 -p)γ) log (1 -q) -(1 -p)γ 1 -q ) ≥ (1 -γ) log(1 -γ) Putting everything together we get that -γ log γ -γD KL (p q) -(q -pγ)(log q -pγ q ) -((1 -q) -(1 -p)γ) log (1 -q) -(1 -p)γ 1 -q ) ≤ -γ log γ -(1 -γ) log(1 -γ) -γD KL (p q) = h(γ) -γD KL (p q) Lemma 2.4. Let A ⊆ B, |B| = m, |A| = γm, and let g : B → {0, 1}. For any set Y ⊆ B let Y 1 = {x | x ∈ Y, g(x) = 1} , Y 0 = Y \ Y 1 and κ Y = |Y1| |Y | . It holds that K(A | B, g) ≤ mγ(log(e/γ) -D KL (κ B κ A )) + O(log m) Proof. The algorithm is very similar to Algorithm 4, the main difference is that we must first compute B 1 , B 0 from B using g, and select A 1 , A 0 from B 1 , B 0 , respectively, using two indices i A1 , i A0 . Finally we print A = A 1 ∪ A 0 . We can now bound the number of bits required to represent i A1 , i A0 . Note that |B 1 | = κ B m, |B 0 | = (1 -κ B )m. Note that for A 1 we pick γκ A m elements from κ B m elements and for A 0 we pick γ(1 -κ A )m elements from (1 -κ B )m elements. The number of bits required to represent this selection is: log κ B m γκ A m + log (1 -κ B )m γ(1 -κ A )m ≤ κ B mh( γκ A κ B ) + (1 -κ B )mh( γ(1 -κ A ) (1 -κ B ) ) ≤ m(h(γ) -γD KL (κ B κ A )) ≤ mγ(log(e/γ) -D KL (κ B κ A )) Where in the first inequality we used the fact that ∀0 ≤ k ≤ n, log n k ≤ nh(k/n), Lemma A.2 in the second transition, and Lemma A.1 in the third transition. Note that when κ A = 0, 1 We only have one term of the initial sum. For example, for κ A = 1 we get: log κ B m γκ A m = log κ B m γm ≤ κ B mh( γ κ B ) ≤ mγ log(eκ B /γ) = mγ(log(e/γ) -log(1/κ B )) And similar computation yields mγ(log(e/γ) -log(1/(1 -κ B ))) for κ A = 0. Finally, the additional O(log m) factor is due to various counters and variables, similarly to Algorithm 4. Proof. Let us assume that g agrees with f on all except n elements in X and bound . Using Theorem 2.3, it holds w.h.p that K(f | X) > n -O(log n). We show that if is sufficiently far from 1/2, we can use g to compress f below its Kolmogorov complexity, arriving at a contradiction. We can construct f using g and the set of values on which they do not agree, which we denote by D. This set is of size n and therefore can be encoded using log n n ≤ nh( ) bits (recall that ∀0 ≤ k ≤ n, log n k ≤ nh(k/n)) given X (i.e., K(D | X) ≤ nh( )). To compute f (x) using D, g we simply check if x ∈ D and output g(x) or 1 -g(x) accordingly. The total number of bits required for the above is K(g, D | X) ≤ o(n) + nh( ) (where auxiliary variables are subsumed in the o(n) term). We conclude that K(f | X) ≤ o(n) + nh( ). Combining the upper and lower bounds on K(f | X), it must hold that o(n) + nh( ) ≥ n -O(log n) =⇒ h( ) ≥ 1 -o(1). This inequality only holds when = 1/2 + o(1). 

B EXPERIMENTS

Experimental setup We perform experiments on MNIST dataset and the same data set with random labels (MNIST-RAND). We use SGD with learning rate 0.01 without momentum or regularization. We use a simple fully connected architecture with a single hidden layer, GELU activation units (a differentiable alternative to ReLU) and cross entropy loss. We run experiments with a hidden layer of size 2, 5, 10. We consider batches of size 50, 100, 200. For each of the datasets we run experiments for all configurations of architecture sizes and batch sizes for 300 epochs. Results Figure 2 and Figure 3 show the accuracy discrepancy and accuracy over epochs for all configurations for MNIST and MNIST-RAND respectively. Figure 4 and Figure 5 show for every batch size the accuracy discrepancy of all three model sizes on the same plot. All of the values displayed are averaged over epochs, i.e., the value for epoch t is 1 t i x i . First, we indeed observe that the scale of the accuracy discrepancy is inversely quadratic in the batch size, as our analysis suggests. Second, for MNIST-RAND we can clearly see that the average accuracy discrepancy tends below a certain threshold over time, where the threshold appears to be independent of the number of model parameters. We see similar results for MNIST when the model is small, but not when it is large. This is because the model does not reach its capacity within the timeframe of our experiment. 



With high probability means a probability of at least 1 -1/n, where n is the size of the dataset. We require that the number of the random bits used is proportional to the execution time of the algorithm. That is, the algorithm flips coins for every iteration of a loop, rather than just a constant number at the beginning of the execution.



Figure 1: A visual summary of our notations.

SGD' i ← 1 // epoch counter W 1,1 is an initial model while True do Take a random permutation of X, divided into batches {B i,j } n/b j=1 for j from 1 to n/b do if acc

Following the same calculation as in Corollary 3.3, this guarantees termination within O(log n n ) epochs, or O(T • n b • log n n ) = O(T ) iterations (gradient descent steps).The above leads to a contradiction. It is critical to note that the above does not hold if T = 2 m = 2 b or if = Θ(n), as both would imply that the o(n) term becomes Θ(n). We state our main theorem: Theorem 4.3. For any differentiable and L-smooth model class with d parameters and a perturbation class Ψ such that = o(m) there exist an input data set Y of size m such that GD requires Ω(2 m ) iterations to achieve δ accuracy on Y , even if δ = 1/2 + Θ(1) and d m. A OMITTED PROOFS AND EXPLENATIONS A.1 REPRESENTING SETS AND PERMUTATIONS Throughout this paper, we often consider the value K(A) where A is a set. Here the program computing A need only output the elements of A (in any order). When considering K(A | B) such that A ⊆ B, it holds that K(A | B) ≤ log |B| |A| + O(log |B|). To see why, consider Algorithm 4. In the algorithm i A is the index of A when considering some ordering of all subsets of B of size |A|. Thus log |B| |A| bits are sufficient to represent i A . The remaining variables i, m A , m B and any Proof. Let us expand the left hand side using the definition of entropy: qh

m k=1 a k , b = m k=1 b k , then m k=1 a i log ai bi ≥ a log a b .We apply the log-sum inequality with m = 2, a 1

It holds that 1 -n(1-λi,j ) (j-1)b ≤ λ i,j ≤ nλi,j (j-1)b .Proof. We can write the following for j ∈ [2, n/b + 1]:nλ i,j = x∈X acc(W i,j , x) = x∈Xi,j-1 acc(W i,j , x) + x∈X\Xi,j-1 acc(W i,j , x) = (j -1)bλ i,j + (n -(j -1)b)λ i,j =⇒ λ i,j = nλ i,j -(n -(j -1)b)λ i,j (j -1)bSetting λ i,j = 0 we getλ i,j = nλ i,j -(n -(j -1)b)λ i,j (j -1)b ≤ nλ i,j (j -1)bAnd setting λ i,j = 1 we get λ i,j = nλ i,j -(n -(j -1)b)λ i,j (j -1)b ≥ 1 -n(1 -λ i,j ) (j -1)b

Figure 2: Full results for the MNIST dataset.

Figure 3: Full results for the MNIST-RAND dataset.

iteration we have the following additional terms: log T + O(log n) + = o(b) + O(log n). Summing over n/b iterations we get o(n) per epoch. We state the following Lemma analogous to Lemma 3.1. Lemma 4.1. For Algorithm 3 it holds w.h.p that ∀i ∈ [t] that: K

annex

Algorithm 4: Compute A given B as inputadditional variables required to construct the set C are all of size at most O(log |B|) and there is at most a constant number of them.During our analysis, we often bound the Kolmogorov complexity of tuples of objects. For example, K(A, P | B) where A ⊆ B is a set and P : A → [|A|] is a permutation of A (note that A, P together form an ordered tuple of the elements of A). Instead of explicitly presenting a program such as Algorithm 4, we say that if K(A | B) ≤ c 1 and c 2 bits are sufficient to represent P , thus. This just means that we directly have a variable encoding P into the program that computes A given B and uses it in the code. For example, we can add a permutation to Algorithm 4 and output an ordered tuple of elements rather than a set. Note that when representing a permutation of A, |A| = k, instead of using functions, we can just talk about values in log k! . That is, we can decide on some predetermined ordering of all permutations of k elements, and represent a permutation as its number in this ordering.

A.2 OMITTED PROOFS FOR SECTION 2

Lemma A.1. For p ∈ [0, 1] it holds that h(p) ≤ p log(e/p).Proof. Let us write our lemma as:Rearranging we get:Where in the final transition we use the fact that 1 (1-x) is monotonically increasing on [0, 1]. This completes the proof.Lemma A.2. For p, γ, q ∈ [0, 1] where pγ ≤ q, (1 -p)γ ≤ (1 -q) it holds thatProof. Similarly to the definition of Y in Lemma 3.1, let Y be the program which receives X, r i , W i+1,1 as input and repeatedly applies Theorem 2.2 to retrieve W i,1 . As Y just needs to reconstruct all batches from X, r i and call Y for n/b times, it holds thatUsing the subadditivity and extra information properties of K(), together with the fact that W 1,1 can be reconstructed given X, W t+1,1 , Y , we write the following:Where in the last inequality we simply execute Y on X, W i+2,1 , r i+1 to get W i+1,1 . Let us write:Combining everything together we get that: Proof. Let us simplify Inequality 1.Our condition implies that t i=1 ρ i > t(β(n, b) + γ). This allows us to rewrite the above inequality as:

A.4 OMITTED PROOFS FOR SECTION 4

Lemma A.4. Let X be some set of size n and let f : X → {0, 1} be a random binary function. It holds w.h.p that there exists no function g : X → {0, 1} such that K(g | X) = o(n) and g agrees with f on n(1/2 + Θ(1)) elements in X. 

