SUB-TASK DECOMPOSITION ENABLES LEARNING IN SEQUENCE TO SEQUENCE TASKS

Abstract

The field of Natural Language Processing (NLP) has experienced a dramatic leap in capabilities with the recent introduction of huge Language Models (LMs). Despite this success, natural language problems that involve several compounded steps are still practically unlearnable, even by the largest LMs. This complies with experimental failures for end-to-end learning of composite problems that were demonstrated in a variety of domains. An effective mitigation is to introduce intermediate supervision for solving sub-tasks of the compounded problem. Recently, several works have demonstrated high gains by taking a straightforward approach for incorporating intermediate supervision in compounded natural language problems: the sequence-to-sequence LM is fed with an augmented input, in which the decomposed tasks' labels are simply concatenated to the original input (see figure 1 ). In this paper, we prove a positive learning result that motivates these recent efforts. We show that when concatenating intermediate supervision to the input and training a sequence-to-sequence model on this modified input, unlearnable composite problems can become learnable. We show that this is true for any family of tasks which on the one hand, are unlearnable, and on the other hand, can be decomposed into a polynomial number of simple sub-tasks, each of which depends only on O(1) previous sub-task results. Beyond motivating contemporary empirical efforts for incorporating intermediate supervision in sequence-to-sequence language models, our positive theoretical result is the first of its kind in the landscape of results on the benefits of intermediate supervision for neural-network learning: Until now, all theoretical results on the subject are negative, i.e., show cases where learning is impossible without intermediate supervision, while our result is positive, showing that learning is facilitated in the presence of intermediate supervision.

1. INTRODUCTION

Large-scale language models such as BERT (Devlin et al., 2019) , T5 (Raffel et al., 2020) , and GPT-3 (Brown et al., 2020) have recently pushed the envelope in many NLP tasks. Nevertheless, there are some problem-families that even the largest models do not seem to be capable of solving. One such family is that of "multi-hop" reasoning problems (see, e.g., Geva et al. (2021) ; Kalyan et al. (2021) ; Press et al. (2022) ) that require compounding operations in order to produce an answer. For example, Gopher (Rae et al., 2021) , one of the largest available language models, achieves 61% accuracy in the StrategyQA benchmark (Geva et al., 2021) that requires implicit decomposition into reasoning steps, while human level performance is around 87% accuracy. The limitations of learning compounded tasks with neural networks in an end-to-end manner have been observed in a variety of non-linguistic domains. A leading experimental approach for addressing these is to first explicitly break the compounded operations into more basic "single-hop" operations and then combine the results. Gülçehre & Bengio (2016) , one of the earliest works on this subject, propose that supervision for the single-hop intermediate steps is crucial for avoiding bad local minima in the optimization of neural networks. Afterward, Glasmachers (2017) demonstrated that gradient-based end-to-end multi-hop learning is inefficient for solving complex problems that are easily solved by a divide-and-conquer strategy. Beyond position papers, specific examples were Figure 1 : An illustrative example of the prominent method for introducing sub-task decomposition and intermediate supervision for math word problems (Ling et al., 2017; Cobbe et al., 2021) . Intermediate sub-tasks and their labels are concatenated to the original task's input to form a new input sequence. At training time, the likelihood of the entire sequence following the original input is maximized conditioned on the input, and at test time only the original input is given to the model. shown, e.g., Chang et al. (2020) showed that SATNet (Wang et al., 2019) could not solve visual Sudoku without using intermediate labels to identify individual Sudoku digit images. Similar limitations were observed in language related compounded tasks, including commonsense reasoning (Liu et al., 2022; Wei et al., 2022; Zelikman et al., 2022) , math word problems (Piękos et al., 2021; Wei et al., 2022) , and programs execution (Nye et al., 2022) . The go-to architectures in this domain are powerful language models, which are trained as sequence-to-sequence models over text. In this setting, a particular form of introducing intermediate supervision for compounded tasks has emerged: intermediate sub-tasks and their labels are concatenated to the original task's input to form a new input sequence, on which the sequence-to-sequence LM is trained. This approach has recently been widely adopted, e.g., by Rajani et al. (2019) ; Cobbe et al. (2021) ; Piękos et al. (2021) ; Recchia (2021) ; Nye et al. (2022) ; Wei et al. (2022) ; Zelikman et al. (2022) . Figure 1 illustrates this approach for math problems, as done in Ling et al. (2017) ; Cobbe et al. (2021) . These works show that training sequence-to-sequence models with concatenated sub-task decomposition supervision significantly improves the results when compared to training the same model without the intermediate supervision. For example, Nye et al. (2022) show > 99% accuracy for 8 digits addition when concatenating intermediate calculations to the input, while the vanilla accuracy without intermediate supervision is around ∼ 35%. While such decomposition based approaches are intuitive, we are not aware of theoretical results that motivate and formulate their benefits for learning composite problems with neural-networks. In this paper, we provide positive theoretical results in this domain, which are in fact the first of their kind (see related work in section 2). We show our results for sequential models, integrating the intermediate supervision in a manner that mimics the above cited successful empirical approaches in the language domain. In this formulation, a learner learns to predict a sequence composed of the task inputs x, followed by the single-hop reasoning steps referred to as the evidence, and finally, the final answer y. We extend provable guarantees for the convergence of overparameterized recurrent neural networks (Wang et al., 2021) and prove that with intermediate sub-task supervision, even a simple sequence-to-sequence model provably learns any task that obeys an efficient decomposition into simpler subtasks that depend only on a small fraction of the input. Importantly, both the sample complexity and the required number of gradient updates are polynomial. In contrast, we rely on existing works (Valiant, 1984; Goldreich et al., 1986; Daniely & Shalev-Shwartz, 2016) to show that in the absence of intermediate supervision, there exist efficiently decomposable tasks that are unlearnable with polynomial time learning algorithms. Our results apply to a broad family of tasks. As a first exemplifying step, we show a positive result for learning bit subset parity, a setting that is notoriously not amenable to gradient-based algorithms in an efficient way without intermediate supervision (Kearns, 1998; Shalev-Shwartz et al., 2017; Abbe & Sandon, 2020; Abbe et al., 2021) . In this setting, the family of target functions consists of parities over subsets of unknown input bits. Specifically, the input is d bits and the task is to predict whether the number of 1's in certain unknown subset of d /2 bits is odd or even. The corresponding sub-tasks we consider are the parities of subsets of the unknown input subset. We prove a theorem guaranteeing that, when intermediate supervision is available, efficient neural network learning is made possible. As a result, we show an exponential gap between the end-to-end and decompositionbased neural network learnability of the bit subset parity problem. Next, we generalize the above result, and show that when sufficient intermediate supervision is available, any family of functions with a polynomial time complexity, i.e., functions that belong to the P time complexity class, are efficiently learnable by neural networks. Accordingly, based on either standard cryptographic assumptions (Valiant, 1984; Goldreich et al., 1986) or computational complexity hardness assumptions (Daniely & Shalev-Shwartz, 2016) we prove that there exist tasks that, on the one hand, cannot be learned by any polynomial time algorithm and, on the other hand, can be efficiently learned by neural networks when intermediate supervision is present. Our main result can be stated as follows: Theorem 1. (Informal) There exists a binary classification problem parameterized by size d, such that the following holds: • On one hand, when equipped with sub-task decomposition supervision, a simple sequenceto-sequence model can get arbitrarily low ϵ > 0 zero-one loss with number of gradient updates that is polynomial in d, ϵ -1 . • On the other hand, when supervision regarding sub-task is missing, then for any polynomial time learning algorithm and (constant) ϵ > 0, the zero-one loss will be higher than 1 /2ϵ. To summarize, the main contributions of this paper are: 1. We show the first positive result that guarantees neural networks learnability in the presence of intermediate supervision for a problem that is unlearnable without it. 2. We do so in the sequence-to-sequence setting that is currently used for applying state-ofthe-art language models on complex multi-hop tasks in NLP. 3. We show that with sufficient intermediate supervision this sequence-to-sequence setting allows learning any function in the P time complexity class. The remainder of this paper is organized as follows. Section 3 presents the sequence to sequence model we analyzed. Section 4 presents a hypothesis class and proves that it can be learned with our sequence to sequence model. Section 5 presents a concrete example, and demonstrates that the task of learning bit-subset parity with sequence-to-sequence models can be learned with sub-task decomposition and the corresponding intermediate supervision. Finally, in section 6 we generalize the positive results to any function in the P time complexity class, thus establishing our main result.

2. RELATED WORK

The concept of how learning can be enhanced by guiding a learner through intermediate tasks is an old one, dating back to animal training by shaping (Skinner, 1958; Peterson, 2004; Krueger & Dayan, 2009) . Since then, a large body of work has shown its practical benefits for various machine learning tasks. For example, there exists a rich line of work on the importance of shaping rewards and adding sub-goals in reinforcement learning tasks. Karlsson (1994) introduced the methodology of using knowledge in the reward function, in order to decompose a holistic task into several subtasks. Ng et al. (1999) established necessary and sufficient conditions for reward shaping to reserved optimal policies. Marthi (2007) investigate the problem of automatically learning a decomposition of the reward function. All these work intuitively rely on benefits of adding intermediate supervision. Recently, Zhai et al. (2022) showed that adding sub-goal rewards provably reduces the complexity of the synchronous value iteration algorithm. However, this reduction is linear in the number of the sub-goals, unlike our work that proves exponential gap in the supervised learning setting. Moreover, several of the notions in their analysis are unique to the reinforcement leaning setup and cannot be easily translated into the supervised learning setting (e.g., One-Way Intermediate States). Negative theoretical results exist, showing that end-to-end learning of multi-hop problems is unfeasible without decomposition. Shalev-Shwartz et al. (2017) explored the theoretical limitations of end-to-end gradients based learning, studying learnability of tasks that are composed of classification and parity tasks, proving that the end-to-end approach does not converge in a polynomial number of gradient updates. They do show that when intermediate supervision is provided, the gradients have a much higher signal-to-noise ratio. However, they provide no guarantees that in this case learning is possible in a polynomial number of gradient updates. In addition, Shalev-Shwartz & Shashua (2016) proved an exponential gap between end-to-end-based verification sample complexity and the decomposition-based verification sample complexity. However, again, an explicit setting in which providing intermediate supervision for training actually improves the situation to a point that learning is feasible, is lacking. We provide the first theoretical result proving that neural networks also benefit from sub-task decomposition, while earlier theoretical works in this space only prove that end-to-end learning is unfeasible in some compounded cases.

3. THE ANALYZED SEQUENCE-TO-SEQUENCE LEARNING ALGORITHM

A recent successful empirical approach for solving compounded natural language problems (Ling et al., 2017; Rajani et al., 2019; Piękos et al., 2021; Recchia, 2021; Cobbe et al., 2021; Nye et al., 2022; Wei et al., 2022; Zelikman et al., 2022) We analyze the classical Elman recurrent neural networks (Elman, 1990) with ReLU activations as our sequence-to-sequence model. Given an input sequence z of length T = d + len (S) as defined above, the architecture f RNN computes: ∀t ∈ [T ] h t (z) = ReLU (W h t-1 + Ae zt ) (1) ∀t ∈ [T ] f RNN t (z) = B T h t (z) (2) h 0 (z) = ReLU (M 0 ) where e 0 , e 1 ∈ Rfoot_1 are one-hot vectors, A ∈ R m×2 translates the input to the hidden dimension m, W ∈ R m×m is the learned hidden weights matrix, B ∈ R m is the output weights vector and M 0 ∈ R m is the initialization of the hidden state. We will use the binary cross-entropy loss over output locations for t ≥ d, i.e., our loss ignores the architecture's prediction of x and depends on its prediction of intermediate labels and final outcome: l (y, s) = 1 T -d T t=d log 1 + e -yt•st (4) Algorithm 1 below describes the analyzed training procedure of our sequence-to-sequence model. This algorithm describes a straightforward SGD training procedure where, for simplicity, we analyze a variant that updates only the hidden W weights while keeping A, B, M 0 frozen at initialization. This amounts to keeping the input, output and the t = 0 hidden state untrained, and training only the core recurrence operation to perform the task while complying with these frozen components.

4. COMPOUNDED SEQUENCE TO SEQUENCE LEARNABILITY

In this section, we present a hypothesis class and prove for it that the above described "teacher forcing" (Williams & Zipser, 1989) of intermediate supervision at training time with algorithm 1 provably leads to generalization in polynomial sample complexity and gradient updates. This guarantee will allow us to prove our positive results in the following sections, as we will show that interesting function families belong to this hypothesis class. Algorithm 1: Training f RNN with SGD Data: Data set D, learning rate η. Initialization: The entries of W (0) , A, M 0 are i.i.d. generated from N (0,foot_2 m ). The entries of B are i.i.d. generated from N (0, 1 m ). for i = 1, 2, 3...n do Randomly sample (x i , y i ) from the data set D. W (i) = W (i-1) -η∇ W (i-1) ℓ(y i , f RNN,W (i-1) (z i )) . end In order to analyze the teacher forcing technique, we begin with an important observation. Essentially, we show that when the zero-one loss of all the single-hop sub-tasks is low, then it implies that also at test time, when the model does not have the ground truth results of the previous sub-tasks and the errors might accumulate, the zero-one loss on the final answer is still low: (z test ) the iteratively predicted input at test time. Then, for any W the following holds: E x l 0-1 y, f RNN,W T (z test ) ≤ E x T t=d l 0-1 y t , f RNN,W t z train (5) Proof. Clearly, for any x when f RNN,W z train solves all of the sub-tasks correctly we have that z test = z train and therefore they have the same zero zero-one loss. So it is enough to upper bound the probability that f RNN,W z train is erroneous in any sub-task. Now by definition, for any t the zero-one loss at the t'th location is equal to the probability of wrong prediction at this location. Therefore, by the union bound, we get that the sum of the zero-one loss over all the locations is upper bounding the probability of f RNN,W z train make an error in any sub-task. See full details in section A at the appendix. As expected, due to a union bound, when the model does not have the ground truth results of the previous sub-task the error can increase by a factor of Td but this increase is relatively modest as long as T is polynomial in d. Lemma 1 above assures us that it is enough to find an hypothesis class for which algorithm 1 converges and generalizes when we do have the ground truth results of the previous sub-tasks, in order to prove that the teacher forcing technique works. As a candidate for such a hypothesis class, we consider tasks for which the output at each location d ≤ t ≤ T can be written as sign of composition of linear functions (represented by w below) of at most N < T input locations j 1 , . . . , j N ≤ t, with polynomials activations ψ t (x) = deg(ψt) i=0 a t,i x i : ∀d ≤ t ≤ T h t (z) = sign   ψt    w (t) w (t) ,    e zj 1 . . . e zj N          In order to prove convergence and generalization results, we will measure the complexity of functions in the above hypothesis class by a function ϕ (T, ψ, N ), described formally in appendix A. Importantly, ϕ (T, ψ, N ) is polynomial in both T and max t,i |a t,i |, while exponential in both max t deg (ψ t ) and N . We will denote by H ϕ(T,ψ,N ) the hypothesis class described in eq 6.

Now

, with this hypothesis class, we can combine lemma 1 with theorem 2 in Wang et al. (2021) . They study the learnability of RNNs for binary classification tasks 2 without intermediate supervision, and prove that algorithm 1 is capable of learning function where the final answer y have low complexity ϕ (T, ψ, N ). n , then there exists m ⋆ = poly n, δ -1 , T such that if m > m ⋆ then for any h ∈ H ϕ(T,ψ,N ) with probability at least 1δ over the randomness in Algorithm 1, the following holds: 1 n n i=1 E x l 0-1 y, f RNN,W (i) T z test < ϵ where W (i) denotes the output of Algorithm 1 at the i'th iteration and l 0-1 is the zero-one loss. Note that sections C,D of the appendix extends theorem 2 for both SGD and GD with finite precision. In the next sections, we will prove our positive results by showing the intermediate single-hop subtasks of the analyzed tasks belong to H ϕ(T,ψ,N ) with low complexity ϕ (T, ψ, N ).

5. LEARNING BIT-SUBSET PARITY WITH SEQUENCE TO SEQUENCE MODELS

As a concrete demonstration, in this section we show that unlike the end-to-end case, bit-subset parity can be learned with neural networks when intermediate supervision is provided. We begin by defining the challenging task of learning parities over unknown subsets of input bits. Specifically, for a d-bit string with a subset of d /2 randomly predefined unique indices i 1 , . . . , id /2 , our goal is to train a predictor mapping x ∈ {0, 1} d to y = (-1) d /2 j=1 xi j where x is uniformly distributed. In words, y indicates whether the number of 1's in the given subset of coordinates of x is odd or even. We analyze this task as a "multi-hop" task by decomposing it into natural intermediate sub-tasks: parities of subsets of the predefined input subset x i1 , ..., x i d /2 . Concretely, assuming for simplicity that d /2 is a power of 2, and beginning with only two adjacently indexed input bits at a time, we recursively treat the parity of every two adjacently indexed subgroups as an intermediate task. Figure 2 (a) illustrates this binary tree sub-task decomposition of our parity problem. The leaves of the tree of intermediate labels T are -1 xi 1 +xi 2 , . . . , -1 xi d /2-1 +xi d /2 and each node in the tree represents the sub-task of calculating the parity function over its descendants. In order to fit into the sequence-to-sequence setting of section 3, we translate our imaginary tree of intermediate labels T into a sequence of intermediate labels S by inverse-BFS like tree traversal, and then concatenate the sequence S after the input x. An exact formulation of the mapping from tree T to sequence of intermediate labels S is given in appendix B.

5.1. LEARNABILITY OF BIT-SUBSET PARITY WITH AND WITHOUT INTERMEDIATE SUPERVISION

In this section we show that the sub-tasks of learning bit-subset parity are simple enough to be covered by the result in theorem 2, i.e., we prove that our sequence-to-sequence formulation of the bit-subset parity target function, which includes the intermediate labels sequence S defined above, can be written as a multivariate function where each of its outputs is a simple low degree polynomial of at most O (1) inputs bits. We show that this is indeed the case, i.e., that all the parity target functions comprising the intermediate supervision to our problem belong to H ϕ(T,ψ,N ) (see section 4) with N , max t deg (ψ t ) , max t,i |a t,i | that do not grow with d. Importantly, when defining the input sequence to be only the original input, without the concatenated sub-task decomposition labels, then the function h T (x) clearly depends on d /2 bits, and therefore will require N = d /2, that leads to exponential complexity ϕ (T, ψ, N ) = Ω e d . Thus, no efficient learning is guaranteed for the original compounded task. We begin by showing that our single-hop tasks of parities over two bits (see illustration in figure 2(a)) are simple degree-2 polynomials: Lemma 2. There exists degree two polynomial ψ (x) = a 2 x 2 + a 1 x + a 0 with bounded coefficients ∀i |a i | < 10 as well as w ∈ R 4 such that: ∀z 1 , z 2 ∈ {0, 1} ψ w ∥w∥ , e z1 e z2 = 1 z 1 = z 2 -1 z 1 ̸ = z 2 (8) Proof. We will use w to sum the first coordinates of e z1 , e z2 , and use polynomial interpolation to find a degree two polynomial ψ that interpolates the z 1 = z 2 = 0 , z 1 ̸ = z 2 , z 1 = z 2 = 1 points, see full details in appendix B. The above lemma implies that all of the target functions in our defined intermediate supervision belong to H ϕ(T,ψ,N ) for ϕ (T, ψ, N ) = O (d). Therefore, together with theorem 2, it assures us that when intermediate supervision is available, Algorithm 1 can learn bit-subset parities with polynomial network size, sample complexity and number of gradient updates. Now, after we showed that when incorporating intermediate supervision bit-subset parities can be learned by a neural network, we will use the results of Shalev-Shwartz et al. (2017) to establish an exponential gap between the end-to-end and decomposition-based neural network learnabilityfoot_3 : Corollary 1. When learning bit-subset parities using neural networks, the following holds: • On one hand, when equipped with sub-task decomposition supervision, a simple sequenceto-sequence model can get arbitrarily low ϵ > 0 zero-one loss with number of gradient updates that is polynomial in d, ϵ -1 . • On the other hand, when supervision regarding sub-task is missing, then for any (constant) ϵ > 0 with high probability over the target parity, the zero-one loss will be higher than 1 /2ϵ unless the number of gradient updates is exponential in d. Proof. Follows directly by combining theorem 2 and lemma 2 with the the negative results in Shalev-Shwartz et al. (2017) . See full details in section F at the appendix.

5.2. BIT-SUBSET PARITY EXPERIMENTS

In section 5.1 we proved an exponential gap when using Elman RNNs (Elman, 1990) to learn bitsubset parity with and without sub-task decomposition. This section empirically demonstrates that the same gap exists with the commonly used Transformer (Vaswani et al., 2017) architecture. We trained a series of models while varying the input sizes from 8 bits to 256 bits. For each input size, we trained a BERT-base sized Transformer model for 100k iterationsfoot_4 with and without intermediate supervision. The intermediate supervision was introduced exactly as described in the previous subsection, see Figure 2 for an illustration. See full technical details of the training apparatus in appendix G. Figure 3 clearly shows that in a practical setting, using common Transformer networks, a very large gap quickly opens between the settings with and without intermediate supervision. The employed BERT base sized Transformer architecture is a strong network that pushed the envelope on very challenging NLP tasks, and is much stronger than the theoretically analyzed RNN. Still, learning even the 32 bit subset parity task without supervision proved to be too challenging even for this network (no learning after over 2M steps), while it easily learned the task in the presence of intermediate supervision. Overall this experiment, performed on the same task on which we prove our theoretical results, reinforces their relevance to common Transformer architectures.

6. UNIVERSALITY OF DECOMPOSITION BASED SEQUENCE-TO-SEQUENCE LEARNING

In this section, we prove our main result (outlined in the introductory theorem 1). On the one hand, we generalize the positive results of section 5.1 by showing that when sufficient intermediate supervision is available, a neural network can efficiently learn any function in the P time complexity class. On the other hand, we rely on existing works (Valiant, 1984; Goldreich et al., 1986; Daniely & Shalev-Shwartz, 2016) to show that under either standard cryptographic assumptions or computational complexity hardness assumptions, there exist functions in the P time complexity class that cannot be learned by any polynomial time learning algorithm without intermediate supervision. We begin by defining the decomposition of any function f in the P time complexity class into sub tasks. For that, we will use the fact that any such f has polynomial circuit complexity (see for example theorem 9.30 in Sipser ( 2013)), and therefore can be computed by a boolean circuit with polynomial size. We will denote by G = (V, E) the directed acyclic graph associated with such a circuit, and by l v the logic gates of each vertex v. Furthermore, since both the"AND" and "OR" logical gates can be decomposed into a boolean circuit with binary-tree like structure, we may assume that the input degree of each vertex is O (1). By learning we mean validation accuracy higher than 60%. While this definition is somehow arbitrary, in practice we observed a grokking phenomenon (Power et al., 2021) where very soon after the accuracy became higher than random level it also became almost perfect (accuracy > 95%). Now, in order to fit into the sequence-tosequence setting of section 3, we define the intermediate labels sequence S for any f . Basically, each non-leaf vertex v ∈ V will represent an intermediate task with its ground-truth label determined by l v , and we will use a topological sorting of G in order to translate G into a sequence of intermediate labels S with length T := |V | (see figure 2 for a concrete example of this abstract construction strategy). Importantly, as in the bit-subsets parity task, T is polynomial in d. In order to show our generalized positive result, theorem 2 motivates us to prove that our sequence-to-sequence formulation of any function f in the P time complexity class, which includes the intermediate labels sequence S defined above, can be written as a multivariate function where each of its outputs is a simple low degree polynomial of at most O (1) input bits. Lemma 3 below shows that this is indeed the case, i.e., that all the target functions comprising the intermediate supervision to our problem belong to H ϕ(T,ψ,N ) (see section 3) with N , max t deg (ψ t ) , max t,i |a t,i | that do not grow with d.

Lemma 3. For any logical gate l

v : {0, 1} N → {0, 1} with N = O (1), there exists O (1) degree polynomial ψ (x) = deg(ψ) i=0 a i x i with bounded coefficients max i |a i | = O (1) as well as w ∈ R 2N such that: ∀z 1 , . . . , z N ∈ {0, 1} ψ    w ∥w∥ ,    e z1 . . . e z N       = l v (z 1 , . . . , z N ) Proof. We will use w to uniquely represent each possible combination of z 0 , . . . , z N as an N bit real value number, and use polynomial interpolation to find a 2 N -degree polynomial ψ that interpolates the • On one hand, when equipped with sub-task decomposition supervision, a simple sequenceto-sequence model can get arbitrarily low ϵ > 0 zero-one loss with number of gradient updates that is polynomial in d, ϵ -1 . z 1 = • • • = z N = 0 , . . . , z 1 = • • • = z N = • On the other hand, when supervision regarding sub-task is missing, then for any polynomial time learning algorithm and (constant) ϵ > 0, the zero-one loss will be higher than 1 /2ϵ. Proof. Follows directly by combining theorem 2 and lemma 3 with either the negative results in Valiant (1984); Goldreich et al. (1986) or in Daniely & Shalev-Shwartz (2016) .

7. DISCUSSION

In this paper, we show for a broad family of functions an exponential gap between learning algorithms that rely on intermediate supervision and algorithms that do not rely on intermediate supervision. Across domains and architectures, there has been a wide range of proposed methods for introducing intermediate supervision. Some design specialized architectures, some add relevant loss terms, etc. The method that is taking over in the NLP domain is straightforward, and is particularly natural for this domain in which the core architectures are strong sequence-to-sequence Language Models: Concatenate the intermediate supervision to the input, and thus jointly train the model to maximize the likelihood of all the intermediate labels as well as the overall output. Our analysis is framed in this space, and motivates this intuitive incorporation of intermediate supervision in the framework of sequence-to-sequence models. We show that even with a simple sequence-to-sequence architecture it is feasible to expect such simultaneous compounded learning to be useful. In this regard, we view our work as providing timely theoretical feedback to the rapid empirical advances in this field. Limitations: We proved universal learnability results when sufficient intermediate supervision was provided. A fundamental question is what happens when we limit the amount of sub-task supervision. For the task of bit-subset parity, we demonstrated that supervision regarding O (d) sub-tasks can yield an exponential advantage. An interesting question that we leave open for future work is whether there exists a similar advantage with only O (1) sub-tasks. In addition, while our results show an exponential gain, it is still unclear which sub-tasks are solvable by end-to-end methods, and which tasks require decomposition? Interestingly, a recent study (Abbe et al., 2022) addressed exactly this question for one-layer hidden networks in the mean-field regime. However, our understanding of this question for practical architectures is still very limited.

REPRODUCIBILITY STATEMENT

A complete proof of all the theoretical claims was included in the appendix. We also provide the source code for the bit-subset parity experiment in https://github.com/HUJI-Deep/sub_task_decomposition.

A COMPOUNDED SEQUENCE TO SEQUENCE LEARNABILITY DETAILS

We start by formally defining our sequence-to-sequence functions complexity measure as: ϕ (T, ψ, N ) := Õ T 16+3N +maxt deg(ψt) C 2N max t deg (ψ t ) 3N max t,i |a t,i | 2 (10) Where C > 0 is some constant. Now we prove lemma 1 from the main text. Essentially this lemma applies the union bound to show that when the zero-one loss of all the single-hop sub-tasks is low, then also at test time -when the model does not have the ground truth results of the previous sub-task and errors may accumulatethe zero-one loss on the final answer is still low. Proof. Denote by ϵ the zero-one loss for z train , i.e., the right hand side in eq 5. Clearly, for any x when f RNN,W z train solves all the sub-tasks correctly we have that z test = z train and therefore l 0-1 y, f RNN,W T (z test ) = 0. So it is enough to upper bound the probability that f RNN,W z train makes an error in any sub-task by ϵ. But by the zero-one loss definition, for any t we have that: P x f RNN,W t z train ̸ = y t = E x l 0-1 y t , f RNN,W t z train (11) And therefore T t=d P x f RNN,W t z train ̸ = y t = ϵ. Finally, by the union bound we got that P x ∃d ≤ t ≤ T f RNN,W t z train ̸ = y t ≤ T t=d P x f RNN,W t z train ̸ = y t = ϵ B SUB TASKS LEARNABILITY PROOFS In this section we prove lemmas 2,3 from the main text, i.e., we prove that our intermediate steps are simple enough to be covered by theorem 2. We begin by formally describing the details of learning parities with sequence-to-sequence models. Since sequence-to-sequence models expect their inputs to be sequences, we will translate the tree described in section 5 into a sequence by inverse BFS like tree traversal, and concatenate the result sequence after x. Therefore, our inputs sequence includes all the d variables in x, together with all the sub-tasks decomposition nodes in the binary tree except the root (that represent y). So we will have an input sequence length T that is equal to: T := d + nodes in full binary tree with d 4 leaves -1 = 3 2 d -2 And the ground-truth sub-task results are recursively defined by: ∀t ≥ d y t = (-1) xi 2(t-d)+1 +xi 2(t-d)+2 t < 5 4 d (-1) y 2 ( t-5d 4 ) +d +y 2 ( t-5d 4 ) +d+1 else (14) For t > d, at training time, z t will be the ground-truth sub-task result y t-1 . At test time z t will be the model prediction at time t -1: ∀t ∈ [T ] z t =    x t t ≤ d 1 2 + 1 2 sign f RNN t-1 z 1 , • • • , z t-1 t > d ∧ test 1+yt-1 2 t > d ∧ training (15) Note that f RNN is causal model, i.e. f RNN t1 does not depend on z t2 for any t 2 > t 1 , and therefore eq 15 is well defined. Now, we prove lemma 2 from the main text. Essentially this lemma shows that the intermediate steps of length 2 parities belong to the hypothesis class define in eq 6.

Proof. Define w

:= 1 √ 2    1 0 1 0   , then w ∥w∥ , e z1 e z2 =      √ 2 z 1 = 0 ∧ z 2 = 0 1 √ 2 z 1 ̸ = z 2 0 z 1 = 1 ∧ z 2 = 1 (16) Therefore, it is enough to find ψ such that ψ (0) = ψ √ 2 = 1 ∧ ψ 1 √ 2 = -1 Finally, we will use Lagrange basis functions to find the required polynomial and get: ψ (z) = z -1 √ 2 0 -1 √ 2 z - √ 2 0 - √ 2 - z -0 1 √ 2 -0 z - √ 2 1 √ 2 - √ 2 + z -0 √ 2 -0 z -1 √ 2 √ 2 -1 √ 2 (18) = 1 • z - 1 √ 2 z - √ 2 + z 2 z - √ 2 + 3z z - 1 √ 2 (19) = z 2 - 3 √ 2 z + 1 + z 2 2 - 1 √ 2 z + 3z 2 - 3z √ 2 (20) = 9 2 z 2 - 7 √ 2 z + 1 Now we prove lemma 3 from the main text. Essentially this lemma shows that also our intermediate steps for any functions in the P time complexity class belongs to the hypothesis class define in eq 6. Proof. Denote α (z 1 , . . . , z N ) := N -1 i=0 2 i • 1 zi=0 the function that converts N bits to their binary string, and define w := 3 4 N -1           2 0 0 2 1 0 . . . 2 N -1 0           , then w is a unit vector that represents z 1 , . . . , z N as N bit numbers w ∥w∥ ,    e z1 . . . e z N    = α (z 1 , . . . , z N ) . Now, we can use the Lagrange basis functions to find the required polynomial: ψ (x) = 1 z1,...,z N =0   f v (z 1 , . . . , z N ) (z1,...,z N )̸ =(z1,...,z N ) x -α (z 1 , . . . , zN ) α (z 1 , . . . , z N ) -α (z 1 , . . . , zN )   (22) Finally the O (1) coefficients boundedness follows from taking the maximumfoot_6 over all possible f v functions.

C EXTENSION FOR SGD WITH FINITE PRECISION

In this section, we prove theorem 2 from the main text holds also for algorithm 2 which is a finiteprecision variant of SGDfoot_7 . . We will follow the proof in Wang et al. (2021) while taking into account the finite precision gradients. Algorithm 2: Training f RNN with finite precision SGD (an finite precisio variant of algorithm 1) Data: Data set D, learning rate η, finite precision σ. Initialization: The entries of W (0) , A, M 0 are generated i.i.d. from N (0, 2 m ). The entries of B are generated i.i.d. from N (0, 1 m ). for i = 1, 2, 3...n do Randomly sample (x i , y i ) from the data set D. Get arbitrary σ-approximation of the gradient: G (i) ∈ B ∞ 8 ∇ W (i-1) ℓ(y i , f RNN,W (i-1) (z i )) , σ . Update weights: i) . end W (i) = W (i-1) -ηG ( We begin by stating theorem 2foot_9 in Wang et al. (2021) with our notations: Theorem 3. Let δ > 0, and assume we run algorithm 1 for n iterations with learning rate η = 1 m √ n . Then there exists m ⋆ = poly n, δ -1 , T such that if m > m ⋆ then for any h ∈ H ϕ(T,ψ,N ) with probability at least 1δ over the randomness in algorithm 1, the following hold: E x 1 n (T -d) T t=d n i=1 l 0-1 y t , f RNN,W (i) t (z) < Õ ϕ (T, ψ, N ) √ n + O log 1 δ √ n where W [i] denote the output of algorithm 1 at the i'th iteration and l 0-1 is the zero-one loss. Clearly when using infinite precision SGD theorem 2 follows from theorem 3 and lemma 1 by simple algebraic manipulations. Therefore, for proving theorem 2 with finite precision gradient based optimization, it is enough to modify theorem's 3 proof to analyze algorithm 2 that uses finite precision gradients, instead of algorithm 1 that uses full precision gradients. While complicated, Wang et al. (2021) 's proof for theorem 3 can be divided into two high-level arguments. The first argument measure the complexity of the learned hypothesis class with respect to random initialization of f RNN (lemma 6). And the second argument is a generalization bound for algorithm 1 with networks that are overparameterized enough. Since the first argument is independent of the gradients, the proof of the first argument still holds and we only need to prove a generalization bound for algorithm 2. More specifically we only need to prove a lemma that is equivalent to lemma 14 in Wang et al. (2021) and the rest of the second argument (lemma 7 in Wang et al. ( 2021)) remain unchanged. Lemma 4. Let n ∈ N, and denote by L i (W ) := l y (i) , f RNN,W z (i) the training loss. Suppose there exists W ⋆ ∈ B 10 W (0) , R √ m such that R = O (poly (T )) = Ω T 16 and L i (W ⋆ ) ≤ 1+R 2 n . Then for any δ > 0 there exists m ⋆ = poly n, R, T, δ -1 such that if m > m ⋆ then with probability at least 1δ algorithm 2 with η = 1 m √ n and finite precision σ = O 1 m will output: E x 1 n (T -d) T t=d n i=1 l 0-1 y t , f RNN,W (i) t (z) < O R 2 + log 1 δ √ n Proof. We begin by showing that with high probability over the initialization of f RNN , at any iteration i ≤ n of algorithm 2, the distance of the learned hidden weights matrix W (i) from its initialization point W (0) is not too large. As a results we will get that the assumption of lemma 8 uphold, and therefore its upper bound of the deviation from linearization is valid. By the triangle inequality for any 0 ≤ i < n we have that: W (i+1) -W (0) F ≤ i k=0 W (k+1) -W (k) F ( ) Substituting algorithm 2 update rule for W (k+1) , we get that there exist ∥σ i ∥ ∞ < σ such that: W (k+1) -W (k) = -η ∇ W (k) ℓ(y k , f RNN,W (k) (z k )) + σ k (26) Now, explicitly writing ∇ W (k) ℓ(y k , f RNN,W (k) (z k )) with the chain rule we have that: ∇ W (k) ℓ(y k , f RNN,W (k) (z k )) = 1 T -d T t=d -y (k) t • e -y (k) t f RNN,W t (z k ) 1 + e -y (k) t f RNN,W t (z k ) • ∇ W (k) f RNN,W (k) t (z k ) (27) and since 0 ≤ x 1+x ≤ 1 for any x ≥ 0, we conclude that: ∇ W (k) ℓ(y k , f RNN,W (k) (z k )) F ≤ max d≤t≤T ∇ W (k) f RNN,W (k) t (z k ) F Now we will use an induction over i, to show that W (i+1) -W (0) F = O (i+1)T 8 m• √ n for any 0 ≤ i < n. By the induction hypothesis, lemma 10 assure us that for wide enough networks m ⋆ = Ω max T 6 log 3 n•T δ , √ nT 8 , with probability of at least 1δ over the initialization of f RNN , max d≤t≤T ∇ W (k) f RNN,W (k) t (z k ) F = O T 8 . Therefore, in this case W (k+1) -W (k) F = ηO T 8 + 1 m = O T 8 m • √ n + 1 m 2 √ n ( ) and hence W (i+1) -W (0) F = O (i+1)T 8 m• √ n as required. Now after we showed the assumptions of lemma 8 upholds, we can use it to obtain first order Taylor approximation of the training loss: L i W (i) -L i (W ⋆ ) ≤ ∇ W (i) ℓ(y i , f RNN,W (i) (z i )), W (i) -W ⋆ (30) + max d≤t≤T y t • e -ytf RNN,W (zi) 1 + e -ytf RNN,W (zi) ≤1 •O R √ m 1 3 T 10 √ m (log m) W (i) -W ⋆ F ( ) Where we assumed that m ⋆ > n and therefore W (i) -W (0) F < O R √ m . Using algorithm 2 update rule for W (i+1) again (see eq 26), we can use an inequality from in Shalev-Shwartz & Ben-David (2014)'s lemma 14.1 to get that: n i=1 ∇ W (i) ℓ(y i , f RNN,W (i) (z i )) + σ i , W (i) -W ⋆ (32) ≤ W (1) -W ⋆ 2 F 2η + η 2 n i=1 ∇ W (i) ℓ(y i , f RNN,W (i) (z i )) + σ i 2 F ( ) Now combing with Cauchy-Schwarz inequality we have that: n i=1 L i W (i) -L i (W ⋆ ) ≤ W (1) -W ⋆ 2 F 2η + η 2 n i=1 ∇ W (k) ℓ(y k , f RNN,W (k) (z k )) + σ k F (34) + O R √ m 1 3 T 10 √ m (log m) n i=1 W (i) -W ⋆ F (35) - n i=1 σ i , W (i) -W ⋆ (36) Substituting the upper bounds from eqs 28, 29 and using the assumption that R = Ω T 16 we get that: n i=1 L i W (i) -L i (W ⋆ ) ≤ O R 2 + T 16 2ηm + η • n 2 O T 8 + 1 2 (37) +O R √ m 1 3 T 10 n (log m) R + nT 8 + O m -3 2 n 2 T 8 + nR (38) ≤ O R 2 √ n + O R √ m 1 3 T 10 n (log m) R + nT 8 + O m -3 2 n 2 R (39) To ensure the left hand side is upper bounded by O R 2 √ n we will chose m ⋆ such that m ⋆ (log 3 m ⋆ ) > n 9 2 , note that, as required, m ⋆ is polynomial in n, T . Then for m > m ⋆ we have that n i=1 L i W (i) ≤ O R 2 √ n . Therefore, 1 n n i=1 L i W (i) ≤ O R 2 √ n Now, to prove generalization bound we will follow lemma 4.3 in Ji & Telgarsky (2020) and use a martingale Bernstein bound argument. We begin by showing that during the whole training process, our binary cross entropy loss is bounded. Indeed lemma 9 assure us that : max x∈{0,1} d max d<t≤T f RNN,W (i) t (z) = O T 14 • n m + T ≤ O T 14 (41) Therefore, their exist a constant C > 0 such that the binary cross entropy loss is bounded by log 1 + e O(T 14 ) ≤ C • T 14 . Now, we will define a bounded martingle. For any i ≥ 0, let s i denote (x i , y i ) and s 0,i denote (s 0 , . . . , s i ). Importantly, the quantity 1 C • T 14 t<i E x l y, f RNN,W (t) (z) -l y t , f RNN,W (t) (z t ) (42) is a martingal w.r.t the filration σ (s 0,i-1 ). This martingal difference sequence is given by 1 C • T 14 E x l y, f RNN,W (t) (z) -l y t , f RNN,W (t) (z t ) ≤ 1 (43) Moreover, we have E s0,t 1 C 2 • T 28 E x l y, f RNN,W (t) (z) -l y t , f RNN,W (t) (z t ) 2 |σ (s 0,i-1 ) (44) = 1 C 2 • T 28 E s0,t l y t , f RNN,W (t) (z t ) 2 |σ (s 0,i-1 ) -E x l y, f RNN,W (t) (z) 2 (45) ≤ E s0,t 1 C • T 14 l y t , f RNN,W (t) (z t ) |σ (s 0,i-1 ) (46) = 1 C • T 14 • E x l y, f RNN,W (t) (z) Therefore, by lemma C.2 in Ji & Telgarsky (2020) we have that with probability 1δ 1 C • T 14 n i=1 E x l y, f RNN,W (i) (z) -l y i , f RNN,W (i) (z i ) (48) ≤ 1 C • T 14 (e -2) • n i=1 E x l y, f RNN,W (i) (z) + ln 1 δ ( ) And hence n i=1 E x l y, f RNN,W (t) (z) ≤ 1 (3 -e) n i=1 l y i , f RNN,W (i) (z i ) + O T 14 ln 1 δ (50) = O R 2 √ n + O R ln 1 δ (51) Finally, for y t • f RNN,W t (z) < 0 we have that log 1 + e -yt•f RNN,W t (z) > log 2. In addition, clearly log 1 + e -yt•f RNN,W t (z) > 0. Therefore, we conclude that 1 T -d T t=d l o-1 y t , f RNN,W (i) t (z) < 1 log 2 l y, f RNN,W (i) (z) and thus eq 24 uphold.

D EXTENSION FOR GD WITH FINITE PRECISION

In this section, we prove theorem 2 from the main text still holds when using the gradient descent based algorithm 3, instead of the stochastic gradient descent based algorithm 2. . Establishing our positive results that the parities task is efficiently learnable with sub-task decomposition supervision in the exact same setting of the negative results that show that learning is impossible without intermediate supervision, presented in section F. We will follow the proof in section C while taking into account the full non-stochastic gradients. Algorithm 3: Training f RNN with finite precision GD Data: Data set D, learning rate η, finite precision σ. Initialization: The entries of W (0) , A, M 0 are generated i.i.d. from N (0, 2 m ). The entries of B are generated i.i.d. from N (0, 1 m ). for i = 1, 2, 3...n do Get arbitrary σ-approximation of the gradient: G (i) ∈ B ∞ 11 E x ∇ W (i-1) ℓ(y , f RNN,W (i-1) (z)) , σ . Update weights: W (i) = W (i-1) -ηG (i) . end We begin by sampling a fake training set denoted by x (1) , . . . , x (n) , y (1) , . . . , y (n) . Essentially, we will apply the same reasoning as in section C with this fake training points. As in the finite precision SGD case it is enough to prove a lemma that is equivalent to lemma 4 in section C. Lemma 5. Let n ∈ N, and denote by L i (W ) := l y (i) , f RNN,W z (i) the training loss. Suppose there exists W ⋆ ∈ B 12 W (0) , R √ m such that R = O (poly (T )) = Ω T 16 and L i (W ⋆ ) ≤ 1+R 2 n . Then for any δ > 0 there exists m ⋆ = poly n, R, T, δ -1 such that if m > m ⋆ then with probability at least 1δ algorithm 3 with η = 1 m √ n and finite precision σ = O 1 m will output: E x 1 n (T -d) T t=d n i=1 l 0-1 y t , f RNN,W (i) t (z) < O R 2 + log 1 δ √ n (52) F ≤ 1 T -d T t=d E x ∇ W (k) f RNN,W (k) t (z) F (57) Now we will use an induction over i, to show that W (i+1) -W (0) F = O (i+1)T 8 m• √ n for any 0 ≤ i < n. By the induction hypothesis, lemma 10 assure us that for wide enough networks m ⋆ = Ω max T 7 log 3 n•T δ , √ nT 8 , with probability of at least 1δ over the initialization of f RNN , E x ∇ W (k) f RNN,W (k) t (z) F = O T 8 for any d ≤ t ≤ T . Therefore, in this case W (k+1) -W (k) F = ηO T 8 + 1 = O T 8 m • √ n + 1 m 2 √ n and hence W (i+1) -W (0) F = O (i+1)T 8 m• √ n as required. Now after we showed the assumptions of lemma 11 upholds, we can use it to obtain first order Taylor approximation of the training loss for any x: l y, f RNN,W (i) (z) -l y, f RNN, W (z) ≤ ∇ W (i) ℓ(y , f RNN,W (i) (z)), W (i) -W + max d≤t≤T y t • e -ytf RNN,W (z) 1 + e -ytf RNN,W (z) ≤1 •O T 8 √ m 1 3 T 10 √ m (log m) W (i) -W F Where we assumed that m ⋆ > n and therefore W (i) -W (0) F < O T 8 √ m . Using algorithm 3 update rule for W (k+1) again (see eq 55), we can use an inequality from Shalev-Shwartz & Ben-David (2014)'s section 14.1.1 to get that n i=1 E x ∇ W (i) ℓ(y , f RNN,W (i) (z)) + σ i , W (i) -W W (1) -W 2 F 2η + η 2 n i=1 E x ∇ W (i) ℓ(y , f RNN,W (i) (z)) + σ i 2 F Now, we can take expectation over eq 59, and combine the Cauchy-Schwarz inequality with the above bound and get that: n i=1 E x l y, f RNN,W (i) (z) -l y, f RNN, W (z) ≤ W (1) -W 2 F 2η + η 2 n i=1 E x ∇ W (i) ℓ(y , f RNN,W (i) (z)) + σ i 2 F + O T 8 √ m 1 3 T 10 √ m (log m) n i=1 W (i) -W F Substituting the upper bounds from eqs 57, 58 we get that: n i=1 E x l y, f RNN,W (i) (z) -l y, f RNN,W ⋆ (z) ≤ O T 16 2ηm + η • n 2 O T 8 + 1 2 (65) +O T 8 √ m 1 3 T 10 n (log m) 2nT 8 + O m -3 2 n 2 T 8 ≤ O T 16 √ n + O m -1 6 T 62 3 n 2 (log m) + O m -3 2 n 2 T 8 To ensure the left hand side is upper bounded by O R 2 √ n we will chose m ⋆ such that m ⋆ (log 3 m ⋆ ) > n 2 , note that, as required, m ⋆ is polynomial in n, T . Then since eq 53 assure us that E x l y, f RNN, W (z) ≤ O R 2 √ n , we have that n i=1 E x l y, f RNN,W (i) (z) ≤ O R 2 √ n for m > m ⋆ . Therefore, 1 n n i=1 E x l y, f RNN,W (i) (z) ≤ O R 2 √ n Finally, for y t • f RNN,W t (z) < 0 we have that log 1 + e -yt•f RNN,W t (z) > log 2. In addition, clearly log 1 + e -yt•f RNN,W t (z) > 0. Therefore, we conclude that 1 T -d T t=d l o-1 y t , f RNN,W (i) t (z) < 1 log 2 l y, f RNN,W (i) (z) and thus eq 52 uphold.

E LEARNABILITY OF RNNS

In this section we state and extend several lemmas from Wang et al. (2021) and Allen-Zhu et al. (2019) . We will use this lemmas in sections C,D for extending theorem 2 in the main text for finite precision SGD and GD. Following Wang et al. (2021) 's notations, for any target function h ∈ H ϕ(T,ψ,N ) and n samples z (i) n i=1 we will denote: H t i,j := 1 m ∇ W (0) f RNN,W (0) t (z (i) ), ∇ W (0) f RNN,W (0) t (z (j) ) and ỹ(t) =    h t z (1) . . . h t z (n)    We start with theorem 4foot_13 from Wang et al. (2021) measuring the complexity of a learned hypothesis class with respect to random initialization of f RNN . Lemma 6. Let δ > 0 and n ∈ N, then there exist m ⋆ = poly n, δ -1 , T such that if m > m ⋆ then for any h ∈ H ϕ(T,ψ,N ) and z (i) n i=1 with probability at least 1δ over the random initialization of f RNN there exists matrices H d,∞ , H d+1,∞ , . . . , H T,∞ ∈ R n×n such that for any d ≤ t ≤ T the following holds: 1. There exist v (t) ∈ B F 0, 1 100 (ỹ (t) ) T (H t,∞ ) -1 ỹ(t) such that H t + v (t) T v (t) -H t,∞ is positive semi-definite matrix. 2. ỹ(t) T (H t,∞ ) -1 ỹ(t) = O ϕ (T, ψ, N ) . Now we state lemma 15 from Wang et al. (2021) . Essentially this lemma state that with high probability over the initialization of f RNN , there exists weights that are not far from the initialization and have low training loss. Lemma 7. Let z (i) , y (i) n i=1 be the training set and denote by L i (W ) := l y (i) , f RNN,W z (i) the training loss. Then for any δ > 0 there exists m ⋆ = poly n, R, T, δ -1 such that if m > m ⋆ then with probability at least 1δ, there exist W ⋆ ∈ B W (0) , R √ m such that L i (W ⋆ ) ≤ 1+R 2 n and R = Õ T T t=d ỹ(t) T (H t,∞ ) -1 ỹ(t) . Now we state theorem 13 from Wang et al. (2021) . Essentially this lemma bounds the finite width deviation of f RNN from its infinite width neural tangent kernel (Jacot et al., 2018) . Lemma 8. Let m, n ∈ N and r = O poly(n,T ) √ m , then with probability of at least 1 - O (n) exp -Ω m 1 3 foot_14 over the initialization of f RNN , for all z (i) n i=1 and W, W ∈ B W (0) , r for any d ≤ t ≤ T and i ∈ [n] the following hold: over the random initialization of f RNN the following hold: f RNN, W t z (i) -f RNN,W t z (i) -∇ W f RNN,W t (z (i) ), W -W (71) = O r 1 3 T 10 √ m (log m) W -W F f RNN,W t (z) = O T 6 W -W (0) F + t (73) Moreover, with probability of at least 1e O(T )-Ω m 1 3 T 2 also the following holds: max x∈{0,1} d max d<t≤T f RNN,W t (z) = O T 6 W -W (0) F + T (74) Proof. We begin by showing that it is enough to bound h t (z). This is indeed the case since it is well known that ∥B∥ 2 = O (1) with probability of at least 1e  W t (z) -h W (0) t (z) = O T 6 W -W (0) F . Finally a simple union bound on all the possible d bits configurations prove also equation 74. Now, we use lemma 9 proof together with lemmas B.11 and C.9 from Allen-Zhu et al. (2019) to prove a lemma that bounds the magnitudes of f RNN 's gradients. We will use this lemma for bounding the finite width deviation of f RNN from its infinite width neural tangent kernel (Jacot et al., 2018) . Beyond its application for proving our positive results, lemma 10 below also implies that f RNN , for which we proved our positive results when intermediate supervision exists, upholds the polynomially bounded gradients requirement in the parities negative results with high probability. Thus proving the actual factor enabling efficient learning is, therefore, intermediate supervision. 4Lemma 10. Let m ∈ N and W ∈ Bfoot_16 W (0) , poly(T ) √ m , then for a given x and d < t ≤ T with probability of at least 1e -Ω m 1 3 T 2 over the random initialization of f RNN the following hold: ∇ W f RNN,W t (z) F = O T 8 W -W (0) F + T 4 (75) Moreover, with probability of at least 1e O(T )-Ω m 1 3 T 2 it holds simultaneously for all x's and t's. Proof. Let D t ∈ R m×m be a diagonal matrix that its diagonal equals to 1 when h W t (z) > 0 and otherwise 0. Then we can write the gradients of f RNN as: ∇ W f RNN,W t (z) = t i=1 i+1 j=t W T D t BD i h W i (z) The proof of lemma 9 assure us that h W t (z) = O T 6 W -W F + T 3 for all i, j with probability of at least 1e -Ω m 1 3 T 2 . Now, it is well known that ∥B∥ 2 = O (1) with probability of at least 1e -Ω(m) (see for example Bandeira & Van Handel (2016) ), and clearly ∥D t ∥ 2 ≤ 1 for any t. Therefore, overall we got that equation 75 holds with probability of at least 1e -Ω m 1 3 T 2 . Finally, a simple union bound on all the possible d bits configurations and Td t's proves the bound holds simultaneously for all x's and t's. Finally, we rely on lemma 8 from Wang et al. (2021) to bounds the finite width deviation of f RNN from its infinite width neural tangent kernel (Jacot et al., 2018) . over the initialization of f RNN , for all W, W ∈ B W (0) , r the following hold: max d≤t≤T max z f RNN, W t (z) -f RNN,W t (z) -∇ W f RNN,W t (z), W -W (77) = O r 1 3 T 10 √ m (log m) W -W F Proof. Simple union bound on all the possible d bits configurations and Td t's proves the bound in lemma 8 holds simultaneously for all x's and t's.

F BIT-SUBSET PARITY END-TO-END NEGATIVE RESULT PROOFS

In this section, we present a theorem stating that without intermediate supervision, for any neural network with gradients that are polynomially bounded, one must use an exponential number of GD steps to achieve non negligible accuracy in the above-presented task of bit subset parity. This theorem follows directly from Shalev-Shwartz et al. (2017) and Shamir (2018) and from the fact that random guessing has high zero-one loss, see full proof below. Importantly, lemma 10 shows that f RNN , for which we proved our positive results when intermediate supervision exists, upholds the polynomially bounded gradients requirement in theorem 4 above with high probability. Moreover, section D shows that our positive results holds also for an finite-precision gradient descent optimization algorithm, for with theorem 4 proves the negative results. So overall both the positive and negative proof are on the exact same setup and hence corollary 1follows. We will use abuse of notations and sometimes identify a predictor h ∈ H with the corresponding vector in {0, 1} d that his i'th coordinate equals to 1 if i is one of the indices in the subset and 0 otherwise. We start by describing a measure from Shalev-Shwartz et al. (2017) that quantifies the amount of "signal" on the underlying target function contained in the gradient. Consider the stochastic optimization problem associated with learning a target function h. min w F h (w) where: F h (w) := E x [l (p w (x) , h (x))] where l is a loss function, x are the stochastic inputs, and p w is some predictor parametrized by a parameter vector w (e.g. a neural network of a certain architecture). We assume that F is differentiable. We measure the variance of the gradient of F with respect to h, when h is drawn uniformly at random from a collection of candidate target functions H: Var (H, F, w) := E h ∇F h (w) -E h ∇F h (w) 2 To measure meaningful learnability, we will define the expected loss of random guessing from α ∈ (0, 1) fraction of the hypothesis class H as: E H,α,l := min h∈H H⊆H,| H|≥α•|H| E h∼U ( H) E x l h (x) , h (x) As α approaches 1, this measure reflects the expected loss to be attained by randomly assigning a hypothesis from H. We will use this measure in the following lemma, which is a direct corollary of theorem 4 in Shamir (2018) . Essentially, this addresses any iterative algorithm (possibly randomized), which relies on an ϵ-approximate gradient oracle to optimize F h in eq 79. The lemma states that if the number of iterations is not larger than Var (H, F, w) -1 /3 then with high probability, the algorithm will return the same predictor independent of h. Lemma 12. Define ϵ = 3 sup w Var (H, F, w) and let A be any iterative gradient-basedfoot_19 optimization algorithm that runs for n iterations, and at each iteration receives Ω (ϵ)-approximation of ∇F h (w). Then the following holds: P h∼U (H) E x [l (h (x) , p A (x))] ≥ E H,1-nϵ,l ≥ 1 -nϵ Lemma 12 above assures us that in order to prove negative learning results regarding some task, it is enough to show it has both exponentially low variance of gradient and high random guessing error. Now we will show both of these for the task of learning bit subset parity presented in section 5. Following Shalev-Shwartz et al. (2017) , we start by showing that different parities functions are uncorrelated, and hence this task has exponentially low gradient variance: Lemma 13. For any (i j ) d 2 j=1 ̸ = ĩj d 2 j=1 we have that: E x (-1) d 2 j=1 xi j (-1) d 2 j=1 xĩ j = 0 Proof. Denote by h ̸ = h ∈ {0, 1} d the vectors that corresponding to the (i j ) d 2 j=1 , ĩj d 2 j=1 . Then E x (-1) d 2 j=1 xi j (-1) d 2 j=1 xĩ j = E x (-1) ⟨x,h⟩ (-1) ⟨x, h⟩ = E x (-1) ⟨x,h+ h⟩ (85) = E x   d j=1 (-1) xj (hj+ hj )   Since the coordinates are independent we can swap the order of the expectation and the product: Proof. We use theorem 1 from Shalev-Shwartz et al. (2017) . Essentially, this theorem shows that if the functions in the hypothesis class are uncorrelated, then the variance is upper bounded by = d j=1 E x (-1) xj (hj+ hj ) = d j=1 (-1) hj + hj + (-1) 0 2 E x ∂ ∂w p w (x) 2 times the inverse of the hypothesis class size. In our case, lemma 13 above shows that: (i j ) d 2 j=1 ̸ = ĩj d 2 j=1 =⇒ E x (-1) d 2 j=1 xi j (-1) d 2 j=1 xĩ j = 0 And therefore we have that Var (H, F, w) ≤ |H| -1 E x ∂ ∂w p w (x) 2 . . Having low gradient variance alone does not imply that learning is impossible, it only implies that all hypotheses perform similarly. To show the negative result for our parity problem, we must also show that the random guessing error is high. Lemma 15 establishes this, showing that the task of learning bit-subset parity has non-vanishing random guessing error of Θ (1) with respect to the dimension d, and linear with respect to the fraction of the hypothesis class from which the guessing occurs. Lemma 15. Let H denote the hypothesis class of d dimensional parities described in section 5, and l 0-1 denote the zero one loss. then the following holds: ∀α ∈ (0, 1) E H,α,l0-1 = 1 2 1 - 1 α |H| Proof. Let h ∈ H, we begin by writing the zero-one loss with elementary algebraic operation: E h∼U ( H) E x l 0-1 h (x) , h (x) = E h∼U ( H⊆{0,1} d ) E x 1 (-1) ⟨x,h⟩ ̸ =(-1) ⟨x, h⟩ = E h∼U ( H) E x 1 -(-1) ⟨x,h⟩ • (-1) ⟨x, h⟩ 2 (93) = 1 2 1 -E h∼U ( H) E x (-1) ⟨x,h+ h⟩ Since the coordinates are independent we can swap the order of the expectation and the product: = 1 2 1 -E h∼U ( H) d i=1 E xi (-1) xi(hi+ hi) (95) = 1 2 1 - 1 2 d E h∼U ( H) d i=1 (-1) hi+ hi + (-1) 0 Now, in cases where h ̸ = h there exists some j ∈ [d] such that h j ̸ = hj . Therefore, in such cases h j + hj = 1 which implies that d j=1 (-1) h j + hj +1 2 = 0 and we get that: = 1 2 1 - 1 2 d E h∼U ( H) 2 d • 1 h= h (97) = 1 2 1 -E h∼U ( H) 1 h= h (98) Now clearly E h∼U ( H) 1 h= h = 1 | H| h ∈ H 0 h / ∈ H And therefore E h∼U ( H) E x l 0-1 h (x) , h (x) = 1 2 - 1 2 • 1 | H| h ∈ H 0 h / ∈ H Finally, by combining lemmas 14,15 with lemma 12 we prove theorem 4. = Ω e d i.e. we have exponentially low variance. Therefore, lemma 12 assure us that with probability of at least 1n • O e -d /3 the loss of A will be higher than E H,1-n•O |H| -1 3 ,l0-1 . Finally, lemma 15 implies that: G BIT-SUBSET PARITY EXPERIMENTS FULL DETAILS Section 5.2 in the main text empirically demonstrates that there exists a large gap when using the commonly used Transformer architecture for learning bit subset parity with and without intermediate supervision. In this section, we present the full details of the experiments. E H,1-n•O |H| -1 3 ,l0-1 = 1 2   1 - 1 |H| -max O |H| 2 3 • n, |H| -1   (101) Following the bit-subset parity definition in section 5, we randomly sampled a subset of d /2 predefined unique indices and then we randomly sampled non-overlapping training, validation and test datasets. For reproducibility purposes, Table 2 reports the random seeds that used for this sampling. Then, we used the standard implementation of GPT-2 from the transformers framework (Wolf et al., 2020) and trained a BERT-Base size Transformer decoder model with a vocabulary size of 2. Even though the small vocabulary size of our model may indicate that it is too shallow for its size (Wies et al., 2021) , we still choose to use the standard Transformer-Base configuration. See full architecture and optimization details in Table 2 . While the evaluation of a model without intermediate supervision is straightforward, we can simply use argmax over the last token logits. This is no longer the case when intermediate supervision is used since at evaluation time the model must create its own intermediate steps. As a result, in this case, we need a decoding algorithm that iteratively predicates the intermediate steps. Since we found in preliminary experiments that greedy decoding tends to be more efficient for longer sequences than random sampling, we choose to use it. To complement Figure 3 in the main text, Table 1 shows the first iterationfoot_21 for which the model is better than random as well as the test accuracy after 100k training iterations. Though "better than random" could be defined in a variety of ways. Figure 4 illustrates a typical learning trajectory and demonstrate the insignificant differences in definitions due to the observed grokking phenomenon (Power et al., 2021) . For each task we repeated the task 3 times with different dataset seeds, and perform grid search over the hyper-parameters in Table 2 . Figure 4 : Illustration of a typical bit-subset parity learning trajectory. In these tasks we observed a grokking phenomenon (Power et al., 2021) where very soon after the validation accuracy became higher than random level it also became almost flawless (accuracy > 95%). Therefore, successful learning is not sensitive to the exact cutoff.



For clarity, we wrote yt = zt+1 although the inputs domain in our model was {0, 1} while the output domain was {-1, 1} and hence yt = • (zt+1 -1 /2) They also discussed the case where y is a sequence, however, in their case it was considered part of the task, unlike our case where we add intermediate steps as a method of supervision. Note that the negative result holds only for full gradient descent and does not hold for stochastic gradient descent, for whichAbbe & Sandon (2020);Abbe et al. (2021) show that parities are efficiently learnable when using complex non-random initialization. With the exception of the 16 and 18 bits tasks without intermediate supervision, for which we increase the number of iterations to 300K and 2M respectably in order to try and successfully learn the task. This maximum exists since there is a finite number of possible functions from {0, 1} N into {0, 1}. This maximum is exists since there is a finite number of possible function from {0, 1} N into {0, 1}. See section D for the extension of the proof to algorithm 3 that is based on GD. The ball is with respect to the distance that defined by the max matrix norm, i.e. the elementwise distances are at most σ. Combined with Remark H.1 inWang et al. (2021)'s supplementary materials. In addition, for simplicity, we state a 1 / √ n convergence rate as opposite to the linear convergence rate inWang et al. (2021). The ball is with respect to the distance that defined by the Frobenius matrix norm. The ball is with respect to the distance that defined by the max matrix norm, i.e. the elementwise distances are at most σ. The ball is with respect to the distance that defined by the Frobenius matrix norm. Combined with Remark H.1 in Wang et al. (2021)'s supplementary materials. Note that lemma 13 fromWang et al. (2021) ensure only weaker guaranties of probability 1 -O (n) exp (-Ω (log m)), but we confirm with the authors that theirs proof proves also our stronger guaranties. The ball is with respect to the distance that defined by the Frobenius matrix norm. The ball is with respect to the distance that defined by the Frobenius matrix norm. For any θ that is reachable by A with n iterations. See footnote 3. See footnote 3. For models that converged after less than 5K steps we disabled the learning rate warmup.21 For models that did not converged after 100K steps we continue the training until convergence. Note that even for thus models Table1reports the test accuracy at step 100k. Validation accuracy was evaluated at 1K-step intervals, except for models that converged in less than 4K steps for which we run the evaluation every 10 steps.



Denote by z train t := y t-1 the ground truth input at training time, and by z test t

Figure 2: Illustration of the proposed input and output for learning the d = 8 bit-subset parity problem with sequence-to-sequence models.

Figure3: The number of steps until a BERTbase sized Transformer learns bit-subset parities with and without intermediate supervision. By learning we mean validation accuracy higher than 60%. While this definition is somehow arbitrary, in practice we observed a grokking phenomenon(Power et al., 2021) where very soon after the accuracy became higher than random level it also became almost perfect (accuracy > 95%).

72)Now we state a simple corollary of lemmas B.3 and C.2.a inAllen-Zhu et al. (2019). Essentially this corollary bound the output of f RNN with high probability over its initialization.Lemma 9. Let m ∈ N and W ∈ B 15 W (0) , poly(T ) √ m, then for a given x and d < t ≤ T with probability of at least 1e

initialization of f RNN . In addition, lemma B.11 together with lemma C.9.d from Allen-Zhu et al. (2019) assure us that i+1 j=t W T D t 2 = O T 7 W -W (0)

Let m ∈ N and r = O poly(T ) √ m , then with probability of at least 1exp O (T log T ) -Ω m 1 3

Let f θ be any neural-network with E x ∂ ∂θ f θ (x) 2 = O (poly (d)) 17 (polynomially bounded gradients), and let A be any iterative gradient-based 18 optimization algorithm that runs for n = O (poly (d)) iterations, and at each iteration receives Ω e -d /3 -approximation of E x [∇l (y, f θ (x))]. Then, with probability of at least 1-O e -d /3 over the target parity function, the loss of A will be higher than 1 2 -O(e -d ).

h ̸ = h there exists some j ∈[d]  such that h j ̸ = hj and therefore h j + hj = 1 and Now followingShalev-Shwartz et al. (2017), we show that the zero correlation between different parities implies exponentially low gradient variance.Lemma 14. Assuming that E x ∂ ∂w p w (x) 2 = O (poly (d)) , then sup w Var (H, F, w) = O |H| -1where H denote the hypothesis class of d dimensional parities described in section 5.

By lemma 14 we know that sup w Var (H, F, w) = O |H| -1 . Now since n k > n k k for any k < n, we have that |H| = d d 2

is to concatenate intermediate supervision labels to the input. This way, the language model receives a sequence composed of the input followed by the labels of the intermediate tasks, before emitting the final compounded answer. For a compounded binary classification task which consists of a d-bit input string x, with S denoting the string of intermediate step results, we denote the combined input sequence as z = Concat{x; S}, and the combined output sequence as y, defined in a standard autoregressive fashion by 1 y t = z t+1 (see figure2for a d = 8 example). Training and testing follow conventional sequence-to-sequence model protocol: At training time, z t for t > d will be the ground-truth sub-task result y t-1 (a practice sometimes referred to as "teacher forcing" (Williams & Zipser, 1989)), and at test time, z t for t > d will be the model's prediction at time t -1.

1 points. Finally the O (1) coefficients boundedness follow from taking maximum 5 over all possible l v logical gates. see full details in appendix B.The above lemma implies that all of the target functions in our defined intermediate supervision belong to H ϕ(T,ψ,N ) for ϕ (T, ψ, N ) = O (d). Therefore, together with theorem 2, it assures us that when intermediate supervision is available, Algorithm 1 can learn any function in the P time complexity class with polynomial network size, sample complexity and number of gradient updates. Now, after we showed that when incorporating intermediate supervision any function in the P time complexity class can be learned by a neural network, our main results is a simple corollary of the above results: Corollary 2. Under either standard cryptographic assumptions or computational complexity hardness assumptions, there exists a binary classification problem parameterized by size d, such that the following holds:

-Ω(m) (see for example Bandeira & Van Handel (2016)). Now lemma B.3 from Allen-Zhu et al. (2019) assure us that h W (0) t (z) = O (t), and finally lemma C.2.a from Allen-Zhu et al. (2019) bounds the deviation from the initialization h

Learning Bit-Subset Parity with Transformers results. We say that a model is better than random when its validation accuracy are higher than 60%. In addition, for each task we also evaluated the test accuracy after 100k training iteration. The reported test accuracies are of the models with the best hyper-parameters according to the validation loss. Each value after the ± sign indicates the two standard deviations over three different random seeds.

Hyper-Parameters and the random seed that examined in the experiment of learning bitsubset parity with Transformers.Learning Rate 10 -6 , 10 -5 , 10 -4 Weight Decay 0, 10 -6 , 10 -4 , 10 -2

Table 1 reports the mean and standard deviation test accuracy of the best hyper-parameters according to the validation binary cross-entropy loss, where we break ties by choosing the model that converged faster. As excepted, without intermediate supervision, the Transformer models can not learn beyond random guessing tasks with more than 18 bits. In contrast, with intermediate supervision, the test accuracy is almost perfect even at 128 bits.

ACKNOWLEDGMENTS AND DISCLOSURE OF FUNDING

We thank Eran Malach and Shai Shalev-Shwartz for a helpful discussion on our stronger negative results, as well as Lifu Wang for clarifying Wang et al. (2021). This research was supported by the ERC (European Research Council) and the ISF (Israel Science Foundation). Yoav Levine was supported by the Israel Academy of Sciences Adams fellowship.

annex

Proof. We begin by showing that there exists W ∈ B W (0) , T 8 √ m such that:Indeed, since the minimum can not be larger than the mean, eq'50 in the proof of lemma 4 assure us that under the assumptions of lemma 5, for m > max n 2 , ln 4 1 δ , with probability of at least 1δ algorithm 2 will reach such W during the first max n, ln 2 1 δ SGD iteration. Now we shows that that with high probability over the initialization of f RNN , at any iteration i ≤ n of algorithm 3, the distance of the learned hidden weights matrix W (i) from its initialization point W (0) is not too large. As a results we will get that the assumption of lemma 8 uphold, and therefore its upper bound of the deviation from linearization is valid.By the triangle inequality for any 0 ≤ i < n we have that:Substituting algorithm 3 update rule for W (k+1) , we get that there exist ∥σ i ∥ ∞ < σ such that:with the chain rule we have that:and since 0 ≤ x 1+x ≤ 1 for any x ≥ 0, we conclude by Jensen's inequality that:

