SYSTEMATIC GENERALIZATION AND EMERGENT STRUCTURES IN TRANSFORMERS TRAINED ON STRUCTURED TASKS

Abstract

Transformer networks have seen great success in natural language processing and machine vision, where task objectives such as next word prediction and image classification benefit from nuanced context sensitivity across high-dimensional inputs. However, there is an ongoing debate about how and when transformers can acquire highly structured behavior and achieve systematic generalization. Here, we explore how well a causal transformer can perform a set of algorithmic tasks, including copying, sorting, and hierarchical compositions of these operations. We demonstrate strong generalization to sequences longer than those used in training by replacing the standard positional encoding typically used in transformers with labels arbitrarily paired with items in the sequence. We search for the layer and head configuration sufficient to solve these tasks, then probe for signs of systematic processing in latent representations and attention patterns. We show that two-layer transformers learn generalizable solutions to multi-level problems, develop signs of systematic task decomposition, and exploit shared computation across related tasks. These results provide key insights into how stacks of attention layers support structured computation both within task and across tasks.

1. INTRODUCTION

Since their introduction (Vaswani et al., 2017) , transformer-based models have become the new norm of natural language modeling (Brown et al., 2020; Devlin et al., 2018) and are being leveraged for machine vision tasks as well as in reinforcement learning contexts (Chen et al., 2021; Dosovitskiy et al., 2020; Janner et al., 2021; Ramesh et al., 2021) . Transformers trained on large amounts of data under simple self-supervised, sequence modeling objectives are capable of subsequent generalization to a wide variety of tasks, making them an appealing option for building multi-modal, multi-task, generalist agents (Bommasani et al., 2021; Reed et al., 2022) . Central to this success is the ability to represent each part of the input in the context of other parts through the self-attention mechanism. This may be especially important for task objectives such as next word prediction and image classification at scale with naturalistic data, which benefit from nuanced context sensitivity across high-dimensional inputs. Interestingly, transformer-based language models seem to also acquire structured knowledge without being explicitly trained to do so and display few-shot learning capabilities (Brown et al., 2020; Linzen & Baroni, 2021; Manning et al., 2020) . These insights have led to ongoing work exploring these models' potential to develop more broad reasoning capabilities (Binz & Schulz, 2022; Dasgupta et al., 2022) . Despite success in learning large-scale, naturalistic data and signs of acquisition of structured knowledge or generalizable behavior, how transformer models support systematic generalization remains to be better understood. Recent work demonstrated that large language models struggle at longer problems and fail to robustly reason beyond the training data (Anil et al., 2022; Razeghi et al., 2022) . Different architectural variations have been proposed to improve length generalization in transformers, highlighting the role of variants of position-based encodings (Csordás et al., 2021a; b; Ontanón et al., 2021; Press et al., 2021) . Indeed, whether neural networks will ever be capable of systematic generalization without building in explicit symbolic components remains an open question (Fodor & Pylyshyn, 1988; Smolensky et al., 2022) . Here, we approach this question by training a causal transformer model to perform a set of algorithmic operations, including copy, reverse, and hierarchical group or sort tasks. We explicitly sought the minimal transformer that would reliably solve these simple tasks and thoroughly analyze such minimal solution through attention ablation and representation analysis. Exploring how a transformer with no predefined task-aligned structure could adapt to structures in these algorithmic tasks provides a starting point for understanding how self-attention can tune to structures in more complex problems, e.g., those with the kinds of exceptions and partial regularities of natural datasets, where the exploitation of task structures may occur in a more approximate, graded manner. Our main contributions are: 1. We present a set of two-layer causal transformers that are capable of learning multiple algorithmic operations, and show that putting more attention heads at deeper layers has advantages for learning multi-level tasks. 2. We show that the attention layers in these models reveal signs of systematic decomposition within tasks and exploitation of shared structures across tasks. 3. We highlight a simple label-based order encoding method in place of the positional encoding methods typically used in transformers, and show that it helps to achieve strong length generalization performance. Dataset. We created an item pool covering all combinations of 5 shapes, 5 colors, and 5 textures, and generated a sequence dataset by sampling 100k sequences of 5-50 items randomly selected from the item pool. The tasks we used to train the models are shown in Fig 1A . Each task corresponds to one of the following rules, which relies on item feature and/or item order information to rearrange an input sequence (grouping or sorting items by a particular feature is with respect to a pre-defined feature sort order, e.g., circles < squares < pentagons, or red < purple < blue):

2. METHOD

COPY (C): copy the input sequence. REVERSE (R): reverse the input sequence. We instantiated the token vocabularies as onehot or multihot vectors. The task tokens were onehot vectors with the corresponding task category set to one, with one additional task dimension corresponding to the end-of-sequence (EOS) token. The item tokens were multihot vectors whose units indicated its value in each feature dimension (equivalent to concatenated onehot feature vectors). As such, the model receives disentangled feature information in the input, though in principle it can learn to disentangle feature information given onehot encodings for each unique item. Label-based order encoding. Using position-based order encodings (whether absolute or relative), models trained with sequences up to length L encounter an out-of-distribution problem when tested



Figure 1: Task and model design.

GROUP[SHAPE]  (G[S]): group the items by shape, preserve the input order within each shape group.

