CHARACTERIZING INTRINSIC COMPOSITIONALITY IN TRANSFORMERS WITH TREE PROJECTIONS

Abstract

When trained on language data, do transformers learn some arbitrary computation that utilizes the full capacity of the architecture or do they learn a simpler, treelike computation, hypothesized to underlie compositional meaning systems like human languages? There is an apparent tension between compositional accounts of human language understanding, which are based on a restricted bottom-up computational process, and the enormous success of neural models like transformers, which can route information arbitrarily between different parts of their input. One possibility is that these models, while extremely flexible in principle, in practice learn to interpret language hierarchically, ultimately building sentence representations close to those predictable by a bottom-up, tree-structured model. To evaluate this possibility, we describe an unsupervised and parameter-free method to functionally project the behavior of any transformer into the space of tree-structured networks. Given an input sentence, we produce a binary tree that approximates the transformer's representation-building process and a score that captures how "treelike" the transformer's behavior is on the input. While calculation of this score does not require training any additional models, it provably upper-bounds the fit between a transformer and any tree-structured approximation. Using this method, we show that transformers for three different tasks become more tree-like over the course of training, in some cases unsupervisedly recovering the same trees as supervised parsers. These trees, in turn, are predictive of model behavior, with more tree-like models generalizing better on tests of compositional generalization.

1. INTRODUCTION

Consider the sentence Jack has more apples than Saturn has rings, which you have almost certainly never encountered before. Such compositionally novel sentences consist of known words in unknown contexts, and can be reliably interpreted by humans. One leading hypothesis suggests that humans process language according to hierarchical tree-structured computation and that such a restricted computation is, in part, responsible for compositional generalization. Meanwhile, popular neural network models of language processing such as the transformer can in principle, learn an arbitrarily expressive computation over sentences, with the ability to route information between any two pieces of the sentence. In practice, when trained on language data, do transformers instead constrain their computation to look equivalent to a tree-structured bottom-up computation? While generalization tests on benchmarks (Lake & Baroni, 2018; Bahdanau et al., 2019; Hupkes et al., 2019; Kim & Linzen, 2020, among others) assess if a transformer's behavior is aligned with tree-like models, they do not measure if the transformer's computation is tree-structured, largely because model behavior on benchmarks could entirely be due to orthogonal properties of the dataset (Patel et al., 2022) . Thus, to understand if transformers implement tree-structured computations, the approach we take is based on directly approximating them with a separate, tree-structured computation. Prior methods based on this approach (Andreas, 2019; McCoy et al., 2019) require putatively gold syntax trees, which not only requires committing to a specific theory of syntax, but crucially, may not exist in some domains due to syntactic indeterminacy. Consequently, these methods will fail to recognize a model as tree-like if it is tree-structured according to a different notion of syntax. Moreover, all of these approaches involve an expensive training procedure for explicitly fitting a tree-structured model (Socher et al., 2013; Smolensky, 1990) to the neural network. Figure 1: (a) Given a transformer model f , our method finds the tree projection of f i.e., binary trees corresponding to the tree-structured neural network g ϕproj (in the space of all tree-structured models) that best approximates the outputs of f on a given set of strings. (b) (i) Given a string, we compute context-free representations (ṽ ij ) for all spans of the string via attention masking (Section 3). (ii) We use the distance between (average-pooled) context-free and contextual representations (v ij ) to populate a chart data structure. (iii) We decode a tree structure from chart entries. Instead, we present a method that is completely unsupervised (no gold syntax needed) and parameter-free (no neural network fitting needed). At a high level, our proposed method functionally projectsfoot_0 transformers into the space of all tree-structured models, via an implicit search over the joint space of tree structures and parameters of corresponding tree-structured models (Figure 1 ). The main intuition behind our approach is to appeal to the notion of representational invariance: bottom-up tree-structured computations over sentences build intermediate representations that are invariant to outside context, and so we can approximate transformers with a tree-structured computation by searching for a "bracketing" of the sentence where transformer representations of intermediate brackets are maximally invariant to their context. Concretely, the main workhorse of our approach is a subroutine that computes distances between contextual and context-free representations of all spans of a sentence. We use these distances to induce a tree projection of the transformer using classical chart parsing (Section 3), along with a score that estimates tree-structuredness. First, we prove that our approach can find the best tree-structured account of a transformer's computation under mild assumptions (Theorem 1). Empirically, we find transformer encoders of varying depths become more tree-like as they train on three sequence transduction datasets, with corresponding tree projections gradually aligning with gold syntax on two of three datasets (Section 5). Then, we use tree projections as a tool to predict behaviors associated with compositionality: induced trees reliably reflect contextual dependence structure implemented by encoders (Section 6.1) and both tree scores as well as parsing F1 of tree projections better correlate with compositional generalization to configurations unseen in training than in-domain accuracy on two of three datasets (Section 6.2).

2. BACKGROUND

How can we compute the meaning of red apples are delicious? Substantial evidence (Crain & Nakayama, 1987; Pallier et al., 2011; Hale et al., 2018) supports the hypothesis that semantic interpretation of sentences by humans involves a tree-structured, hierarchical computation, where smaller constituents (red, apples) recursively combine into larger constituents (red apples), until we reach the full sentence. Concretely, suppose we have a sentence S ≜ {w 1 , w 2 , . . . , w |S| }. Let T be a function that returns a binary tree for any sentence S, defined recursively as T (S) ≜ ⟨T (S 1,j ), T (S j+1,|S| )⟩ where T (S a,b ) refers to a subtree over the span S a,b ≜ {w a , w a+1 , . . . , w b }. We say that a span S a,b ∈ T (S) if the node T (S a,b ) exists as a subtree in T (S). For notational convenience, we sometimes use S l and S r as the left and right subtrees for T (S) i.e., T (S) = ⟨S l , S r ⟩.



We provide a functional account of the transformer's computation and not a topological account, i.e., we are agnostic to whether the attention patterns of the transformer themselves look tree structured-see Appendix C for examples.

