TRANSFORMERS LEARN SHORTCUTS TO AUTOMATA

Abstract

Algorithmic reasoning requires capabilities which are most naturally understood through recurrent models of computation, like the Turing machine. However, Transformer models, while lacking recurrence, are able to perform such reasoning using far fewer layers than the number of reasoning steps. This raises the question: what solutions are these shallow and non-recurrent models finding? We investigate this question in the setting of learning automata, discrete dynamical systems naturally suited to recurrent modeling and expressing algorithmic tasks. Our theoretical results completely characterize shortcut solutions, whereby a shallow Transformer with only o(T ) layers can exactly replicate the computation of an automaton on an input sequence of length T . By representing automata using the algebraic structure of their underlying transformation semigroups, we obtain O(log T )-depth simulators for all automata and O(1)-depth simulators for all automata whose associated groups are solvable. Empirically, we perform synthetic experiments by training Transformers to simulate a wide variety of automata, and show that shortcut solutions can be learned via standard training. We further investigate the brittleness of these solutions and propose potential mitigations.

1. INTRODUCTION

Modern deep learning pipelines demonstrate an increasing capability to perform combinatorial reasoning: pretrained on large, diverse distributions of natural language, math, and code, they are nascently solving tasks which seem to require a rigid "understanding" of syntax, entailment, and state inference. How do these neural networks represent the primitives of logic and the algorithms they execute internally? When considering this question, there is an immediate mismatch between classical sequential models of computation (e.g., Turing machines) and the Transformer architecture, which has delivered many of the recent breakthroughs in reasoning domains. If we are to think of an algorithm as a set of sequentially-executed computational rules, why would we use a shallowfoot_0 non-recurrent network? We study this question through the lens of finite semiautomata, which compute state sequences q 1 , . . . , q T from inputs σ 1 , . . . , σ T by application of a transition function δ (and initial state q 0 ): q t = δ(q t-1 , σ t ). Semiautomata are the underlying structures governing the computations realizable by automata (such as regular expression parsers or finite-state transducers), which are simply semiautomata equipped with mappings from states to output. Thus, one natural motivation for studying them comes from the question of whether Transformers can subsume the structures found in classical NLP pipelines. Another motivation comes from the perspective of reinforcement learning and control, where Transformers are beginning to be used as world models: semiautomata specify deterministic discrete-state dynamical systems. We perform a theoretical and empirical investigation of whether (and how) non-recurrent Transformers learn semiautomata. We characterize and analyze how shallow Transformers find shortcut



Compared to the number of symbols it can process. For example,DistilBERT (Sanh et al., 2019) can handle thousands of tokens with 6 sequential layers.

