THINKING LIKE TRANSFORMERS

Abstract

What is the computational model behind a transformer? Where recurrent neural networks have direct parallels in finite state machines, allowing clear discussion and thought around architecture variants or trained models, transformers have no such familiar parallel. In this paper we aim to change that, proposing a computational model for the transformer-encoder in the form of a programming language. We map the basic components of a transformer-encoder -attention and feed-forward computation -into the simple primitives of select, aggregate and zipmap, around which we form a programming language: the Restricted Access Sequence Processing Language (RASP). We show how RASP can be used to program solutions to tasks that could conceivably be learned by a transformer, augmenting it with tools we discover in our work. In particular, we provide RASP programs for histograms, sorting, and even logical inference similar to that of Clark et al. (2020). We further use our model to relate their difficulty in terms of the number of required layers and attention heads. Finally, we see how insights gained from our abstraction might be used to explain phenomena seen in recent works.

1. INTRODUCTION

While Yun et al. (2019) show that sufficiently large transformers can approximate any constantlength sequence-to-sequence function, and Hahn (2019) provides theoretical limitations on their ability to compute functions on unbounded input length, neither of these provide insight on how a transformer may achieve a specific task. Orthogonally, Bhattamishra et al. (2020) provide transformer constructions for several counting languages, but this also does not direct us towards a general model. This is in stark contrast to other neural network architectures, which do have clear computational models. For example, convolution networks are seen as as a sequence of filters (Zhang et al., 2018) , and finite-state automata and their variants have been extensively used both for extraction from and theoretical analysis of recurrent neural networks (RNNs) (Omlin & Giles, 1996; Weiss et al., 2018; Rabusseau et al., 2018; Merrill et al., 2020) , even inspiring new RNN variants (Joulin & Mikolov, 2015) . In this work we propose a computational model for the transformer-encoder, in the form of a simple sequence-processing language which we dub RASP(Restricted Access Sequence Processing Language). Much like how automata describe the token-by-token processing behavior of an RNN, our language captures the unique information flow constraints under which a transformer (Vaswani et al., 2017) operates as it processes input sequences. Considering computation problems and their implementation in the RASP language allows us to "think like a transformer" while abstracting away the technical details of a neural network in favor of symbolic programs. A RASP program operates on sequences of values from uniform atomic types, and transforms them by composing a restricted set of sequence processors. One pair of processors is used to select inputs for aggregation, and then aggregate the selected items. Another processor performs arbitrary but local computation over its (localized) input. However, access to the complete sequence is available only through aggregate operations that reduce a stream of numbers to a scalar. The key to performing complex global computations under this model is to compose the aggregations such that they gather the correct information, that can then be locally processed for a final output. Given a RASP program, we can analyze it to infer the minimal number of layers and maximum number of heads that is required to implement it as a transformer. We show several examples of expressive programs written in the RASP language, showing how complex operations can be 

2. THE RESTRICTED ACCESS SEQUENCE PROCESSING LANGUAGE

In this section, we present the the Restricted Access Sequence Processing Language (RASP). RASP assumes a machine composed of several Turing-complete processors, each of which can only run functions taking and returning a fixed number of primitive arguments, and a simple memory accessor that is controlled by these processors. The select, aggregate, and zipmap operations which we present will define and constrain how the processors work together to process an input sequence. We will focus here only on the language itself, leaving the discussion of its exact relation to transformers to Section 3. Overview A RASP program works by manipulating sequences, occasionally with the help of selectors. Sequences contain values of uniform atomic type, such as booleans, integers, floats, or strings. They are functions used for selecting elements from sequences, and are used (together with the appropriate operations) only in the process of creating new sequences. All sequences in RASP are lazily evaluated, meaning that their length and contents are not populated until passed an input. The Base Sequences Every program in RASP begins from the same set of base sequences, and then creates new ones using a small number of core operations. These base sequences are indices, length, and tokens, evaluated on input x 1 , x 2 , ..., x n as their names suggest: (0, 1, ..., n -1), (n, n, ..., n) (of length n), and (x 1 , x 2 , ..., x n ), respectively. Combining Sequences Sequences can be combined in an 'elementwise' manner, such that the value of the resulting sequence at each position i is a function of the values in the combined sequences at position i (similar to a map operation), or have positions 'mixed' in more complicated ways using selectors, which are functions f : N × N → {True, False} whose sole purpose is to guide the combination of existing sequences into new ones. We present the basic ingredients of RASP using an example. Figure 1 shows a simple RASP function for sorting a sequence of values according to a sequence of keys. It accepts an input sequence vals and uses the base sequence indices, that is available to any RASP program, to compute its output in three operations as follows: 1. count_conditioned of Line 2 creates a new sequence that counts for each element of keys the number of "previous items" it has in keys, where the "previous items" are defined to be all items that have a lesser value, or equal value and lower index. Thus, num_prevs creates a sequence of numbers, representing the target sorted position of each item. 2. select of line 7 creates a new selector which will focus each position i on the corresponding position j for which indices[i] is equal to num_prevs[j]. Effectively, it will direct the elements in each position j towards their target location i.



Figure 1: RASP program taking two sequences vals,keys and returning a sequence y sorting the elements of vals according to keys, e.g.: if vals(x)=[a,b,c] and keys(x)=[0,4,2], then y(x)=[a,c,b].

