SYSTEMATIC GENERALIZATION AND EMERGENT STRUCTURES IN TRANSFORMERS TRAINED ON STRUCTURED TASKS

Abstract

Transformer networks have seen great success in natural language processing and machine vision, where task objectives such as next word prediction and image classification benefit from nuanced context sensitivity across high-dimensional inputs. However, there is an ongoing debate about how and when transformers can acquire highly structured behavior and achieve systematic generalization. Here, we explore how well a causal transformer can perform a set of algorithmic tasks, including copying, sorting, and hierarchical compositions of these operations. We demonstrate strong generalization to sequences longer than those used in training by replacing the standard positional encoding typically used in transformers with labels arbitrarily paired with items in the sequence. We search for the layer and head configuration sufficient to solve these tasks, then probe for signs of systematic processing in latent representations and attention patterns. We show that two-layer transformers learn generalizable solutions to multi-level problems, develop signs of systematic task decomposition, and exploit shared computation across related tasks. These results provide key insights into how stacks of attention layers support structured computation both within task and across tasks.

1. INTRODUCTION

Since their introduction (Vaswani et al., 2017) , transformer-based models have become the new norm of natural language modeling (Brown et al., 2020; Devlin et al., 2018) and are being leveraged for machine vision tasks as well as in reinforcement learning contexts (Chen et al., 2021; Dosovitskiy et al., 2020; Janner et al., 2021; Ramesh et al., 2021) . Transformers trained on large amounts of data under simple self-supervised, sequence modeling objectives are capable of subsequent generalization to a wide variety of tasks, making them an appealing option for building multi-modal, multi-task, generalist agents (Bommasani et al., 2021; Reed et al., 2022) . Central to this success is the ability to represent each part of the input in the context of other parts through the self-attention mechanism. This may be especially important for task objectives such as next word prediction and image classification at scale with naturalistic data, which benefit from nuanced context sensitivity across high-dimensional inputs. Interestingly, transformer-based language models seem to also acquire structured knowledge without being explicitly trained to do so and display few-shot learning capabilities (Brown et al., 2020; Linzen & Baroni, 2021; Manning et al., 2020) . These insights have led to ongoing work exploring these models' potential to develop more broad reasoning capabilities (Binz & Schulz, 2022; Dasgupta et al., 2022) . Despite success in learning large-scale, naturalistic data and signs of acquisition of structured knowledge or generalizable behavior, how transformer models support systematic generalization remains to be better understood. Recent work demonstrated that large language models struggle at longer problems and fail to robustly reason beyond the training data (Anil et al., 2022; Razeghi et al., 2022) . Different architectural variations have been proposed to improve length generalization in transformers, highlighting the role of variants of position-based encodings (Csordás et al., 2021a; b; Ontanón et al., 2021; Press et al., 2021) . Indeed, whether neural networks will ever be capable of systematic generalization without building in explicit symbolic components remains an open question (Fodor & Pylyshyn, 1988; Smolensky et al., 2022) .

