TRANSFORMERS LEARN SHORTCUTS TO AUTOMATA

Abstract

Algorithmic reasoning requires capabilities which are most naturally understood through recurrent models of computation, like the Turing machine. However, Transformer models, while lacking recurrence, are able to perform such reasoning using far fewer layers than the number of reasoning steps. This raises the question: what solutions are these shallow and non-recurrent models finding? We investigate this question in the setting of learning automata, discrete dynamical systems naturally suited to recurrent modeling and expressing algorithmic tasks. Our theoretical results completely characterize shortcut solutions, whereby a shallow Transformer with only o(T ) layers can exactly replicate the computation of an automaton on an input sequence of length T . By representing automata using the algebraic structure of their underlying transformation semigroups, we obtain O(log T )-depth simulators for all automata and O(1)-depth simulators for all automata whose associated groups are solvable. Empirically, we perform synthetic experiments by training Transformers to simulate a wide variety of automata, and show that shortcut solutions can be learned via standard training. We further investigate the brittleness of these solutions and propose potential mitigations. * The majority of this work was completed while B. Liu was an intern at Microsoft Research NYC. † This work was completed while S. Goel was at Microsoft Research NYC. 1 Compared to the number of symbols it can process. For example, DistilBERT (Sanh et al., 2019) can handle thousands of tokens with 6 sequential layers.

1. INTRODUCTION

Modern deep learning pipelines demonstrate an increasing capability to perform combinatorial reasoning: pretrained on large, diverse distributions of natural language, math, and code, they are nascently solving tasks which seem to require a rigid "understanding" of syntax, entailment, and state inference. How do these neural networks represent the primitives of logic and the algorithms they execute internally? When considering this question, there is an immediate mismatch between classical sequential models of computation (e.g., Turing machines) and the Transformer architecture, which has delivered many of the recent breakthroughs in reasoning domains. If we are to think of an algorithm as a set of sequentially-executed computational rules, why would we use a shallow 1 non-recurrent network? We study this question through the lens of finite semiautomata, which compute state sequences q 1 , . . . , q T from inputs σ 1 , . . . , σ T by application of a transition function δ (and initial state q 0 ): q t = δ(q t-1 , σ t ). Semiautomata are the underlying structures governing the computations realizable by automata (such as regular expression parsers or finite-state transducers), which are simply semiautomata equipped with mappings from states to output. Thus, one natural motivation for studying them comes from the question of whether Transformers can subsume the structures found in classical NLP pipelines. Another motivation comes from the perspective of reinforcement learning and control, where Transformers are beginning to be used as world models: semiautomata specify deterministic discrete-state dynamical systems. We perform a theoretical and empirical investigation of whether (and how) non-recurrent Transformers learn semiautomata. We characterize and analyze how shallow Transformers find shortcut solutions, which correctly and efficiently simulate the transition dynamics of semiautomata with far fewer sequential computations than required for iteratively inferring each state q t . even odd 0 0 1 1 Q = {even, odd} Σ = {0, 1} ♦ ♣ ⊥, σ ♣ ⊥, σ ♦ σ ♣ σ ♦ Q = {♣, ♦} Σ = {σ ♣ , σ ♦ , ⊥} 1 2 3 4 Q = {1, 2, 3, 4} Σ = { , } parity counter memory unit 1D gridworld Q = {1 . . 3} × {1 . . 4} Σ = {← , → , ↑ , ↓} Q = {54 stickers} Σ = {6 face rotations} . σ 1 σ 6 (a) (b) (c) (d) (e) Our contributions. Our theoretical results provide structural guarantees for the representability of semiautomata by shallow, non-recurrent Transformers. In particular, we show that: • Shortcut solutions, with depth logarithmic in the sequence length, always exist (Theorem 1). • Constant-depth shortcuts exist for solvable semiautomata (Theorem 2). There do not exist constant-depth shortcuts for non-solvable semiautomata, unless TC 0 = NC 1 (Theorem 4). • For a natural class of semiautomata corresponding to path integration in a "gridworld" with boundaries, we show that there are even shorter shortcuts (Theorem 3), beyond those guaranteed by the general structure theorems above. We accompany these theoretical findings with an extensive set of experiments: • End-to-end learnability of shortcuts via SGD (Section 4). The theory shows that shortcut solutions exist; is the non-convexity of the optimization problem an obstruction to learning them in practice? For a variety of semiautomaton simulation problems, we find empirically that there is no such obstruction. Shallow non-recurrent Transformers are able to learn shortcuts which generalize near-perfectly in-distribution. • More challenging settings (Section 5). We compare non-recurrent and recurrent models in the presence of additional considerations: out-of-distribution generalization (including to unseen sequence lengths) and limited supervision. This reveals the brittleness of non-recurrent models, in line with prior "spurious representation" notions of shortcuts in deep learning. Toward mitigating these drawbacks and obtaining the best of both worlds, we show that with recency-biased scratchpad training, Transformers can be guided to learn the robust recurrent solutions.

1.1. RELATED WORK

Emergent reasoning in neural sequence models. Neural sequence models, both recurrent (Wu et al., 2016; Peters et al., 2018; Howard & Ruder, 2018) and non-recurrent (Vaswani et al., 2017; Devlin et al., 2018) , have ushered in an era of broadly-applicable and (with pretraining) sampleefficient natural language understanding. Building on this, large-scale non-recurrent Transformer models have demonstrated capabilities in program synthesis, mathematical reasoning, and in-context multi-task adaptation. A nascent frontier is to leverage neural dynamics models, again both recurrent (Hafner et al., 2019) and non-recurrent (Chen et al., 2021a; Janner et al., 2021) , for decision making. At the highest level, the present work seeks to idealize and understand the mechanisms behind which deep learning solves tasks requiring combinatorial and algorithmic reasoning. Computational models of neural networks. In light of the above, it is empirically evident that neural networks are successfully learning circuits which generalize on some combinatorial tasks. Many efforts in the theory and empirical science of deep learning are dedicated towards the rigorous analysis of this phenomenon. Various perspectives map self-attention to bounded-complexity circuits (Hahn, 2020; Elhage et al., 2021; Merrill et al., 2021; Edelman et al., 2022) , declarative programs (Weiss et al., 2021) , and Turing machines (Dehghani et al., 2019) . The research program of BERTology (Clark et al., 2019; Vig, 2019; Tenney et al., 2019) interprets trained models in terms of known linguistic and symbolic primitives. The most relevant theoretical work to ours is (Barrington & Thérien, 1988) , which acts as a "Rosetta Stone" between classical circuit complexity and semigroup theory. The core technical ideas for



Figure 1: Various examples of semiautomata. From left to right: a mod-2 counter, a 2-state memory unit, Grid 4 , a 2-dimensional gridworld constructible via a direct product Grid 3 × Grid 4 , and a Rubik's Cube, whose transformation semigroup is a very large non-abelian group.

