CHARACTERIZING INTRINSIC COMPOSITIONALITY IN TRANSFORMERS WITH TREE PROJECTIONS

Abstract

When trained on language data, do transformers learn some arbitrary computation that utilizes the full capacity of the architecture or do they learn a simpler, treelike computation, hypothesized to underlie compositional meaning systems like human languages? There is an apparent tension between compositional accounts of human language understanding, which are based on a restricted bottom-up computational process, and the enormous success of neural models like transformers, which can route information arbitrarily between different parts of their input. One possibility is that these models, while extremely flexible in principle, in practice learn to interpret language hierarchically, ultimately building sentence representations close to those predictable by a bottom-up, tree-structured model. To evaluate this possibility, we describe an unsupervised and parameter-free method to functionally project the behavior of any transformer into the space of tree-structured networks. Given an input sentence, we produce a binary tree that approximates the transformer's representation-building process and a score that captures how "treelike" the transformer's behavior is on the input. While calculation of this score does not require training any additional models, it provably upper-bounds the fit between a transformer and any tree-structured approximation. Using this method, we show that transformers for three different tasks become more tree-like over the course of training, in some cases unsupervisedly recovering the same trees as supervised parsers. These trees, in turn, are predictive of model behavior, with more tree-like models generalizing better on tests of compositional generalization.

1. INTRODUCTION

Consider the sentence Jack has more apples than Saturn has rings, which you have almost certainly never encountered before. Such compositionally novel sentences consist of known words in unknown contexts, and can be reliably interpreted by humans. One leading hypothesis suggests that humans process language according to hierarchical tree-structured computation and that such a restricted computation is, in part, responsible for compositional generalization. Meanwhile, popular neural network models of language processing such as the transformer can in principle, learn an arbitrarily expressive computation over sentences, with the ability to route information between any two pieces of the sentence. In practice, when trained on language data, do transformers instead constrain their computation to look equivalent to a tree-structured bottom-up computation? While generalization tests on benchmarks (Lake & Baroni, 2018; Bahdanau et al., 2019; Hupkes et al., 2019; Kim & Linzen, 2020, among others) assess if a transformer's behavior is aligned with tree-like models, they do not measure if the transformer's computation is tree-structured, largely because model behavior on benchmarks could entirely be due to orthogonal properties of the dataset (Patel et al., 2022) . Thus, to understand if transformers implement tree-structured computations, the approach we take is based on directly approximating them with a separate, tree-structured computation. Prior methods based on this approach (Andreas, 2019; McCoy et al., 2019) require putatively gold syntax trees, which not only requires committing to a specific theory of syntax, but crucially, may not exist in some domains due to syntactic indeterminacy. Consequently, these methods will fail to recognize a model as tree-like if it is tree-structured according to a different notion of syntax. Moreover, all of these approaches involve an expensive training procedure for explicitly fitting a tree-structured model (Socher et al., 2013; Smolensky, 1990) to the neural network. Figure 1: (a) Given a transformer model f , our method finds the tree projection of f i.e., binary trees corresponding to the tree-structured neural network g ϕproj (in the space of all tree-structured models) that best approximates the outputs of f on a given set of strings. (b) (i) Given a string, we compute context-free representations (ṽ ij ) for all spans of the string via attention masking (Section 3). (ii) We use the distance between (average-pooled) context-free and contextual representations (v ij ) to populate a chart data structure. (iii) We decode a tree structure from chart entries. Instead, we present a method that is completely unsupervised (no gold syntax needed) and parameter-free (no neural network fitting needed). At a high level, our proposed method functionally projectsfoot_0 transformers into the space of all tree-structured models, via an implicit search over the joint space of tree structures and parameters of corresponding tree-structured models (Figure 1 ). The main intuition behind our approach is to appeal to the notion of representational invariance: bottom-up tree-structured computations over sentences build intermediate representations that are invariant to outside context, and so we can approximate transformers with a tree-structured computation by searching for a "bracketing" of the sentence where transformer representations of intermediate brackets are maximally invariant to their context. Concretely, the main workhorse of our approach is a subroutine that computes distances between contextual and context-free representations of all spans of a sentence. We use these distances to induce a tree projection of the transformer using classical chart parsing (Section 3), along with a score that estimates tree-structuredness. First, we prove that our approach can find the best tree-structured account of a transformer's computation under mild assumptions (Theorem 1). Empirically, we find transformer encoders of varying depths become more tree-like as they train on three sequence transduction datasets, with corresponding tree projections gradually aligning with gold syntax on two of three datasets (Section 5). Then, we use tree projections as a tool to predict behaviors associated with compositionality: induced trees reliably reflect contextual dependence structure implemented by encoders (Section 6.1) and both tree scores as well as parsing F1 of tree projections better correlate with compositional generalization to configurations unseen in training than in-domain accuracy on two of three datasets (Section 6.2).

2. BACKGROUND

How can we compute the meaning of red apples are delicious? Substantial evidence (Crain & Nakayama, 1987; Pallier et al., 2011; Hale et al., 2018) supports the hypothesis that semantic interpretation of sentences by humans involves a tree-structured, hierarchical computation, where smaller constituents (red, apples) recursively combine into larger constituents (red apples), until we reach the full sentence. Concretely, suppose we have a sentence S ≜ {w 1 , w 2 , . . . , w |S| }. Let T be a function that returns a binary tree for any sentence S, defined recursively as T (S) ≜ ⟨T (S 1,j ), T (S j+1,|S| )⟩ where T (S a,b ) refers to a subtree over the span S a,b ≜ {w a , w a+1 , . . . , w b }. We say that a span S a,b ∈ T (S) if the node T (S a,b ) exists as a subtree in T (S). For notational convenience, we sometimes use S l and S r as the left and right subtrees for T (S) i.e., T (S) = ⟨S l , S r ⟩. Compositionality in Meaning Representations. While theories of compositional meaning formation might differ on specifics of syntax, at a high-level, they propose that computing the meaning of S must involve a bottom-up procedure along some syntax tree T (S) of the sentence S. Formally, we say that a meaning representation system m is compositional if the meaning m(s) of some expression s is a homomorphic image of the syntax of s i.e., m(s) = ϕ(m(s l ), m(s r )) for some ϕ following Montague (1970) . Crucially, we note that such a ϕ exists only if m(s) can be fully determined by the contents of s, that is, if m(s) is contextually invariant. While there are several phenomena that necessarily require a non-compositional context-sensitive interpretation (indexicals, idioms, pronouns, lexical ambiguity among others), compositional interpretation remains a central component in explanations of the human ability to systematically interpret novel sentences. Compositionality in Neural Models. A class of neural networks that are obviously compositional are tree-structured models such as Socher et al. (2013) , that obtain vector representations of sentences by performing a bottom-up computation over syntax. Specifically, given S and a corresponding binary tree T (S), the output of the tree-structured network g ϕ is defined recursively-for any span p ∈ T (S), g ϕ (p, T (p)) ≜ h θ (g ϕ (p l , T (p l )), g ϕ (p r , T (p r )) where h θ : R d × R d → R d is some feedforward neural network. For leaf nodes w i , g ϕ (w i , T (w i )) ≜ η wi , where η w ∈ R d represents the word embedding for w. The parameters of the network are ϕ = {θ, η w1 , η w2 , . . .}.

3. OUR APPROACH

While tree-structured networks were built to reflect the compositional structure of natural language, they have been superseded by relatively unstructured transformers (Vaswani et al., 2017) . How can we measure if the computation implemented by a transformer is compositional and tree-like? We start by noting that in any bottom-up tree computation over a sentence, representation of an intermediate constituent depends only on the span it corresponds to, while being fully invariant to outside context. Thus, one way to assess tree-structuredness of a computation over some span is to measure contextual invariance of the resulting representation. Consequently, we construct a tree-structured approximation of a transformer's computation over a sentence by searching for a bracketing of the sentence where spans have maximal contextual invariance. Suppose f is a transformer model that produces contextual vectors of words in S as f (S) ≜ {v S w1 , v S w2 , . . . , v S w |S| } where v S w is a contextual vector representation of w. Given a span p, let v S p be the span representation of the contextual vectors of words in p, v S p = w∈p v S w . Similarly, let ṽp be a contextfree representation of the span p. For transformers, we obtain context-free representations through a simple attention masking scheme. In particular, to obtain ṽp , we apply a "T-shaped" attention mask and take the pooled representation of the words in p at the final layer (Figure 2 ). The mask ensures that attention heads do not attend to tokens outside of p after an optional threshold layerfoot_1 

3.1. SPAN CONTEXTUAL INVARIANCE

We define span contextual invariance (SCI) of a span p in the sentence S as SCI(S, p) ≜ d(v S p , ṽp ) for some distance function d. Similarly, we define the cumulative SCI score for a tree T to be: SCI(S, T ) ≜ s∈T d(v S p , ṽp ). (1)

3.2. COMPUTING TREE PROJECTIONS BY MINIMIZING SCI

Consider the collection of strings, D = {(S)}, and some function T that produces binary trees for any S ∈ D. The cumulative error from approximating outputs of the transformer f with outputs of a tree-structured network g ϕ structured according to T can be written as L(f, g ϕ , T ) ≜ S∈D p∈T (S) d(g ϕ (p, T (p)), v S p ). Suppose we are interested in finding the best tree-structured approximation to f over all possible trees i.e. a configuration of tree structures and corresponding model parameters that best approximate the transformer's behavior. We define this as the exact tree projection of f , ϕ proj , T proj ≜ arg min ϕ,T L(f, g ϕ , T ). (3) Theorem 1. min ϕ,T L(f, g ϕ , T ) ≤ S∈D min T (S) SCI(S, T (S)). In other words, the best tree structured approximation to f has an error upper bounded by cumulative SCI scores. In general, finding tree projections involves a joint search over all discrete tree structures T (S) as well as over continuous parameters ϕ, which is intractable. However, we substantially simplify this search using Theorem 1, since the upper bound depends only on parses T (S) and properties of the transformer, and can be exactly minimized for a given f in polynomial time, with efficient parsing algorithms. We minimize this upper bound itself to approximately recover the best tree-structured approximation to f , over all choices of trees and parameters. The output of this minimization is an approximate tree projection, T proj (S) = arg min T (S) SCI(S, T (S)) for every S ∈ D. Under a mild assumptionfoot_2 , SCI minimization leads to tree projections exactly. Assumption 1. Let S p denote the collection of sentences that contain the span p. Then, for every span p, we have min v S∈Sp d(v S p , v) = S∈Sp d(v S p , ṽp ). That is, context-free vectors minimize the cumulative distance to their contextual counterparts. Corollary 1.1. Under Assumption 1, min ϕ,T L(f, g ϕ , T ) = S∈D min T (S) SCI(S, T (S)). Moreover, T proj (S) = arg min T (S) SCI(S, T (S)) for any S ∈ D.

3.3. MEASURING INTRINSIC COMPOSITIONALITY

SCI minimization provides two natural ways to measure intrinsic compositionality of f on D. To measure tree-structuredness, we use t score ≜ S∈D E T SCI(S, T ) -SCI(S, T proj (S)) |D| , which computes the averaged SCI score of induced trees, normalized against the expected SCI score under a uniform distribution over trees. We find normalization to be necessary to prevent our method from spuriously assigning high tree-structuredness to entirely context-free encoders (that have high SCI scores for all trees). When gold syntax T g is available, we use t parseval ≜ PARSEVAL( T proj , T g , D), to measure bracketing F1 score (PARSEVAL; Black et al. (1991) ) score of T proj against T g on D.

4. EXPERIMENTAL SETUP

Our experimentsfoot_3 are organized as follows. First, we show that on 3 sequence transduction tasks, transformers of varying depths become more tree-like over the course of training, and sometimes learn tree projections that progressively evolve towards ground truth syntax. Then, we show how tree projections can be used to assess various model behaviors related to compositionality. Datasets. We consider three datasets (Table 1 ) commonly used for benchmarking compositional generalization-COGS (Kim & Linzen, 2020) , M-PCFGSET (Hupkes et al., 2019) and GeoQuery (Zelle & Mooney, 1996) . COGS consists of automatically generated sentences from a context-free grammar paired with logical forms, split into in-domain examples (for training) and a compositionally challenging evaluation set. M-PCFGSET is a slightly modified versionfoot_4 of PCFGSET (Hupkes et al., 2019) , where inputs are a nested sequence of expresssions that specify a unary or binary operation over lists. The objective is to execute the function specified by the input to obtain the final list. We focus on the "systematicity split" for measuring compositional generalization. Finally, GeoQuery consists of natural language queries about US geography paired with logical forms. To measure compositional generalization, we use the "query" split from Finegan-Dollak et al. (2018) . Implementation Details. We use greedy top down chart parsing to approximately minimize SCI. In particular, we use SCI scores for all O(|S| 2 ) spans of a string S to populate a chart data structure, which is used to induce a tree by minimizing SCI via a top down greedy procedure (see Algorithm 1 in Appendix), similar to Stern et al. (2017) . Our procedure outputs a tree and simultaneously returns normalized SCI score of the tree, computing a sampling estimate of expected SCI score (Equation 5).We train transformer encoder-decoder models with encoders of depths {2, 4, 6} and a fixed decoder of depth 2. We omit 6-layer transformer results for GeoQuery as this model rapidly overfit and failed to generalize, perhaps due to the small size of the dataset. We choose a shallow decoder to ensure that most of the sentence processing is performed on the encoder side. We train for 100k iterations on COGS, 300k iterations on M-PCFGSET and 50k iterations on GeoQuery. We collect checkpoints every 1000, 2000 and 500 gradient updates and use the encoder at these checkpoints to obtain parses as well as tree scores. In all experiments, d is cosine distance i.e., d(x, y) = 1 -x ⊤ y ∥x∥∥y∥ . All transformer layers have 8 attention heads and a hidden dimensionality of 512. We use a learning rate of 1e-4 (linearly warming up from 0 to 1e-4 over 5k steps) with the AdamW optimizer. All accuracies refer to exact match accuracy against the gold target sequence. For all seq2seq transformers, we tune the threshold layer based on t parseval .

Inputs Outputs

i. The ball was found ball(x 1 ) AND find.theme(x 3 , x 1 ) A cookie was blessed cookie(x 1 ) AND bless.theme(x 3 , x 1 ) ii. copy interleave second reverse shift H13 C19 H9 O20 H9 H13 O20 C19 repeat interleave second interleave first S1 E3 W3 N11 H4 Y3 L8 E1 R13 T12 E1 T12 L8 E1 R13 T12 E1 T12 iii. Which state has the lowest population density? (A, smallest(B, ( state(A), density(A, B)))) What is the population density of Wyoming? (A, ( density(B, A), const(B, stateid(wyoming)))) Table 1 : Example (x, y) pairs from COGS (i), M-PCFGSET (ii) and GeoQuery (iii). See Appendix B for more details on pre-processing as well as dataset statistics.

5. TRAINED TRANSFORMERS IMPLEMENT A TREE-LIKE COMPUTATION

How does intrinsic compositionality of a transformer encoder evolve during the course of training on sequence transduction tasks? To study this, we plot t score (how tree-like is a model?) and t parseval (how accurate is the tree projection of a model?) of encoder checkpoints throughout training. As a comparison, we track how well a supervised probe recovers syntax from encoders-that is, we train a 1 layer transformer decoder to autoregressively predict linearized gold parse trees of S from transformer outputs f (S) at various points of training, and measure the PARSEVAL score of probe outputs (p parseval ) on a test set. Results. We plot t parseval and t score over the course of training in Figure 3 . We observe that 7/8 encoders gradually become more tree-like i.e., increase t score over the course of training, with the 4 layer transformer on GeoQuery being the exception. Interestingly, we note that t parseval also increases over time for all encoders on COGS and M-PCFGSET suggesting that the tree projection of trained transformers progressively becomes more like ground-truth syntax. In other words, all encoders trained on COGS and M-PCFGSET learn a computation that is gradually more "syntax aware". Can supervised probing also reveal this gradual syntactic enrichment? We plot PARSEVAL score of (a) Normalized Tree Scores for COGS, M-PCFGSET and GeoQuery (↑ is better). (b) Parsing Accuracies for COGS, M-PCFGSET and GeoQuery (↑ is better). Figure 3 : We plot t score and t parseval by computing approximate tree projections at various checkpoints. 7/8 models become more tree-structured (increased t score ) and all models on COGS and M-PCFGSET learn tree projections that gradually align with ground truth syntax (increased t parseval ). parse trees predicted by the probe on held out sentences (p parseval ) in Figure 4 -while p parseval does improve over time on both COGS and M-PCFGSET, we observe that all checkpoints after some threshold have similar probing accuracies. We quantitatively compare gradual syntactic enrichment by computing the spearman correlation between t parseval (p parseval ) and training step and find that ρ pparseval is significantly smaller than ρ tparseval for both datasets. Interestingly, we also find that our unsupervised procedure is able to produce better trees than the supervised probe on M-PCFGSET as observed by comparing p parseval and t parseval . Overall, we conclude that supervised probing is unable to discover latent tree structures as effectively as our method. How does supervisory signal affect compositionality? Could a purely self-supervised objective (i.e., no output logical form supervision) also lead to similar emergent tree-like behavior? To test this, we experiment with training the transformer encoder with a masked language modeling objective, similar to Devlin et al. (2019) for COGS and GeoQuery. Concretely, for every S, we mask out 15% of input tokens and jointly train a transformer encoder and a 1 layer feedforward network, to produce contextual embeddings from which the feedforward network can decode word identities for masked out words. As before, we collect checkpoints during training and plot both t parseval and t score over time in Figure 5 . We find that t parseval does not improve over time for any of the models. Additionally, we find that t score increases for all models on GeoQuery, but only for the 2 layer model on COGS. Taken together, these results suggest that under the low data regime studied here, transformers trained with a self-supervised objective do not learn tree-structured computations.

6. TREE PROJECTIONS AND MODEL BEHAVIOR

Given S, and corresponding contextual vectors f (S), the contextual dependence structure captures the dependence between contextual vectors and words in S i.e., how much does v S wi change when w j is perturbed to a different word. Contextual dependence structure is important for assessing compositional behavior. For instance, consider the span p = red apples appearing in some sentences. If the contextual vectors for p has large dependence on outside context, we expect the model to have poor generalization to the span appearing in novel contexts i.e., poor compositional generalization. We plot p parseval and t parseval over time for the 4 layer transformer encoder on COGS and M-PCFGSET. We find that t parseval improves gradually over time suggesting that the model becomes more "syntax aware". Such gradual syntax enrichment is not uncovered well by the probe since all checkpoints after 4000 (for COGS) and 50000 (for M-PCFGSET) iterations have similar p parseval . (a) Parsing Accuracies (b) Normalized Tree Scores Figure 5 : We plot t parseval and t score at various checkpoints for models trained with a masked language modeling objective on COGS (first) and GeoQuery (second). Only 2/5 models become treestructured and none learn tree projections aligned with gold syntax, suggesting that self-supervision may fail to produce tree-like computation in a relatively low data regime. We first show that tree projections reflect the contextual dependence structure implemented by a transformer. Next, we show that both t score and t parseval are better predictors of compositional generalization than in-domain accuracy.

6.1. INDUCED TREES CORRESPOND TO CONTEXTUAL DEPENDENCE STRUCTURE

In-constituent perturbation : ware + ϵ apples delicious are red : wred + ϵ

Out-of-constituent perturbation

Figure 6 : For word w (apples) in constituent c, an in-constituent perturbation adds noise ϵ ∼ N (0, 0.01) to another word's vector within c (red) while an out-of-constituent perturbation adds noise to a word vector at same relative distance outside c (are). Intuitively, greedily decoding with a SCI populated chart makes split point decisions where resulting spans are maximally invariant with one other. Thus, for a given constituent c and a word w ∈ c, we expect v S w to depend more on words within the same constituent than words outside the constituent. Thus, we compare the change in v S w when another word inside c is perturbed (in-constituent perturbations) to the change when a word outside c is perturbed (out-of-constituent perturbations), where word perturbations are performed by adding gaussian noise to corresponding word vectors in layer 0 (see Figure 6 ). We ensure that both perturbations are made to words at the same relative distance from w. As a control, we also compute changes to v S w when perturbations are made with respect to constituents from random trees.

Setup and Results

. We sample 500 random inputs from each of COGS, M-PCFGSET and Geo-Query and consider encoders from all transformer models. We obtain the mean L 2 distance between the contextual vector of w in the original and perturbed sentence for in-constituent perturbations (∆ ic ) and out-of-constituent perturbations (∆ oc ) and plot the relative difference between the two in Figure 7 . For 6/8 models, in-constituent perturbations result in larger L 2 changes than outof-constituent perturbations (statistically significant according to a two-sided t-test, p < 10 -4 ). Meanwhile, when constituents are chosen according to random trees, changes resulting from both perturbations are similar. Overall, this suggests that induced trees reflect the contextual dependence structure learnt by a transformer. Figure 7 : We measure the mean L 2 distance in the contextual vector of words when in-constituent and out-of-constituent words are perturbed. We plot the relative difference between ∆ ic and ∆ oc when constituents are obtained from tree projections (in blue). As a control, we also compute ∆ ic and ∆ oc when constituents are chosen from random trees (in orange). For all models except those marked with ‡, in-constituent perturbations lead to significantly (as measured by a t-test, p < 10 -5 ) larger change to contextual vectors compared to out-of-constituent perturbations.

6.2. TREE-STRUCTUREDNESS CORRELATES BETTER WITH GENERALIZATION THAN IN-DOMAIN ACCURACY

We study the connection between compositionality and generalization for the 4 layer transformer encoder on COGS and GeoQueryfoot_5 . On each dataset, we train the model with 5 different random seeds and collect checkpoints every 1000/500 iterations. For each checkpoint, we measure accuracy on the in-domain validation set (IID acc) and accuracy on the out-of-domain compositional generalization set (CG acc). Additionally, we also compute t parseval and t score for the encoders at each of these checkpoints. To measure the relationship between compositionality and generalization, we compute the spearman correlation between t parseval (t score ) and CG acc and denote that as ρ CG tparseval (ρ CG tscore ). As a comparison, we also compute the correlation between IID acc and CG acc (ρ CG IID ). Results. We plot the relationship between various properties and generalization along with corresponding correlations in Figure 8 . In general, we expect both IID acc and CG acc to improve together over time, and so it is unsurprising to see that ρ CG IID > 0. Moreover, for COGS, both t parseval and t score increase over time, and so it is expected that both ρ CG tparseval and ρ CG tscore are positive. Crucially, however, we find that both ρ CG tparseval and ρ CG tscore are greater than ρ CG IID on both COGS and GeoQuery. Thus, tree-like behavior (t score ) as well as the right tree-like behavior (t parseval ) are better predictors of compositional generalization than in-domain accuracy. This result gives simple model selection criteria to maximize CG accuracy in the absence of a compostional generalization test set (true for most practical scenarios)-given a collection of checkpoints with similar in-domain accuracies, choose the checkpoint with highest t score or t parseval (if syntactic annotations are available) to get the model with best generalization behavior, in expectation.

7. RELATED WORK

Measuring Linguistic Structure. A common analysis tool for assessing a model's competence in a specific linguistic phenomenon is behavioral testing (Linzen et al., 2016; Marvin & Linzen, 2018; Ribeiro et al., 2020) , where the model's performance on a curated test set is used as the measure of competence. Widely used in prior work to assess compositionality of neural models (Lake & Baroni, 2018; Bahdanau et al., 2019; Yu & Ettinger, 2020) , behavioral tests are inherently extrinsic, since they are agnostic to whether the model implements an appropriately constrained, tree-like computation. While most prior approaches for assessing intrinsic compositionality (Andreas, 2019; McCoy et al., 2019) require putatively gold syntax trees, our proposed approach does not require any pre-determined ground truth syntax, since we search over the space of all possible trees to find the best tree structure that approximates a transformer's computation. Tree-structured Neural Networks. Inspired by the widely accepted belief that natural language is mostly tree-structured (Chomsky, 1957) , there have been several attempts to construct tree shaped  t parseval and CG acc. We find that both t parseval and t score correlate better with generalization than in-domain accuracy. All correlations are statistically significant (p-values < 10 -3 ) . neural networks for various NLP tasks, such as Recursive Neural Networks (Socher et al., 2013 ), Tree RNNs (Tai et al., 2015) , Recurrent Neural Network Grammars (Dyer et al., 2016) , Neural Module Networks (Andreas et al., 2016) , Ordered Neuron (Shen et al., 2019) among others. These approaches have largely been superseded by transformers (Vaswani et al., 2017) , often pre-trained on a large corpus of text (Devlin et al. (2019) , inter alia). We show that transformers, though not explicitly tree-structured, may still learn to become tree-like when trained on language data. Invariances and Generalization. The general problem of studying model performance under domain shifts has been widely studied under domain generalization (Blanchard et al., 2011) . When domain shift is a result of changing feature covariates only, an effective strategy for domain generalization is to learn domain invariant representations (Muandet et al., 2013; Ganin et al., 2016) . We apply the notion of domain invariance in the context of compositional generalization, and posit that models that produce span representations that are more contextually invariant can generalize better to inputs where the span appears in a novel context, which is precisely the motivation behind SCI.

8. CONCLUSION

When trained on language data, how can we know whether a transformer learns a compositional, tree structured computation hypothesized to underlie human language processing? While extrinsic behavioral tests only assess if the model is capable of the same generalization capabilities as those expected from tree-structured models, this work proposes an intrinsic approach that directly estimates how well a parametric tree-structured computation approximates the model's computation. Our method is unsupervised and parameter-free and provably upper bounds the representation building process of a transformer with any tree-structured neural network, effectively providing a functional projection of the transformer into the space of all tree structured models. The central conceptual notion in our method is span contextual invariance (SCI) that measures how much the contextual representation of a span depends on the context of the span vs. the content of the span. SCI scores of all spans are plugged into a standard top-down greedy parsing algorithm to induce a binary tree along with a corresponding tree score. From experiments, we show that tree projections uncover interesting training dynamics that a supervised probe is unable to discover-we find that on 3 sequence transduction tasks, transformer encoders tend to become more tree-like over the course of training, with tree projections that become progressively closer to true syntactic derivations on 2/3 datasets. We also find that tree-structuredness as well as parsing F1 of tree projections is a better predictor of generalization to a compositionally challenging test set than in-domain accuracy i.e., given a collection of models with similar in-domain accuracies, select the model that is most tree-like for best compositional generalization. Overall, our results suggest that making further progress on human-like compositional generalization might require inductive biases that encourage the emergence of latent tree-like structure. A PROOFS Lemma 1. L(f, g ϕ * , T ) ≤ S∈D SCI(S, T (S)) Proof. Let l(f, g ϕ , S, T ) ≜ s∈T (S) d(g ϕ (s, T (s)), v S s ) for any S ∈ D, where g is a tree-structured network indexed by ϕ ∈ R p . The overall error of g ϕ on D is L(f, g ϕ , T ) = S∈D l(f, g ϕ , S, T ). Let ϕ * ≜ arg min ϕ L(f, g ϕ , T ). Next, consider φ ∈ R p such that g φ(s, T (s)) = ṽs for all s ∈ D. Such a φ always exists for large enough p, since there exists a unique ṽs for any p given D and f . Clearly, l(f, g φ, S, T ) = s∈T (S) d(v S s , ṽs ). By definition, we have L(f, g ϕ * , T ) ≤ L(f, g φ, T ) (8) = S∈D s∈T (S) d(v S s , ṽs ) = S∈D SCI(S, T (S)). Theorem 1. min ϕ,T L(f, g ϕ , T ) ≤ S∈D min T (S) SCI(S, T (S)). In other words, the best tree structured approximation to f has an error upper bounded by cumulative SCI scores. Proof. We have min ϕ,T L(f, g ϕ , T ) = min T min ϕ L(f, g ϕ , T ) For any given T , we have min ϕ L(f, g ϕ , T ) ≤ S∈D SCI(S, T (S)). Thus minimizing both sides with respect to T , we have Proof. Let s T be the collection of all spans that occur as a constituent for some T (S) where S ∈ D. We have L(f, g ϕ , T ) = S∈D s∈T (S) d(g ϕ (s, T (s)), v S s ) (13) = s∈s T S∈Ss d(g ϕ (s, T (s)), v S s ). Now, using Assumption 1, we note that S∈Ss d(g ϕ (s, T (s)), v S s ) ≥ min v S∈Ss d(v, v S s ) = S∈Ss d(ṽ s , v S s ). Combining Equation 15and Lemma 1, we have 2 . COGS. We use the standard train, validation and test splits provided by Kim & Linzen (2020) , where we use the "gen" split as our test set. The validation set is drawn from the same distribution as the training data, while the test set consists of compositionally challenging input sentences. Figure 9 : We plot d(v * s , ṽs ) for randomly sampled spans at various points during training. As a control, we also plot d(v S sc , ṽs ) for a random span s c . We observe that for COGS and GeoQuery, the distance between the optimal v * s and ṽs eventually becomes less than 0.05. We conclude that the conditions of Assumption 1 approximately hold true for 2/3 datasets. GeoQuery. We use the pre-processed JSON files corresponding to the query split from (Finegan-Dollak et al., 2018) . We create an 80/20 split of the original training data, to create an IID validation set.

C FUNCTIONAL VS. TOPOLOGICAL TREE-STRUCTUREDNESS

We emphasize that our approach finds a functional tree approximation to a transformer, and not a topological one. That is, we fit a separate, tree structured neural network to vector representations from a transformer, instead of decoding a tree-structure from the attention patterns. As a result, our definition of tree-structuredness does not restrict the transformer's attention pattern to be necessarily tree structured (see Figure 10 for examples).

D ANALYZING INDUCED TREE STRUCTURES

We choose the checkpoint with best bracketing F1 score on the training split for all our datasets, and compute corresponding bracketing F1 scores on the IID validation set in Table 3 . As a baseline, we compare with standard constituency parsing baselines: LBranch (choosing a completely left branching tree), RBranch (choosing a completely right branching tree) and Random (choosing a SCI score: How well can a tree shaped computation be used to approximate a graph? ϕ (a) sci doesnt test for tree-likeness in the topological space but in the functional space. A model that is not a perfect binary branching tree, for example (i) could still be functionally approximated by a tree with nodes of varying expressivity, therefore while not a tree topology the graph is treelike functionally. However graphs in (ii) cannot be approximated by a tree type functional computation and would have a poor sci score as well as a normalized sci score, (iii) graphs where the linear order is all that matters and the simplest tree like computation (straight-through) would have a low-normalized sci score but a high SCI score! Further models could be sparse and yet not tree like as in the graph on the left in (ii) Figure 10 : We show 3 instances of computations implemented by a transformer on the input red apples are delicious along with tree projections our method outputs for each instance. We divide the space of possibilities into 4 quadrants. In quadrant-(i), we show an instance that is both topologically as well as functionally tree-like. quadrant-(ii) is empty, since no transformer can be topologically tree-like but not a good functional approximation to a tree. In quadrant-(iii) we show a transformer that is either topologically nor functionally tree-like. Finally, in quadrant-(iv), we show a transformer that is functionally tree-like but does not resemble a tree structure topologically. Table 3 : Parsing accuracies random binary tree). Interestingly, we find that the trees discovered by our approach on COGS beats RBranch, which is a competitive constituency parsing baseline for English. Algorithm 1 Tree Projections via greedy SCI minimization 



We provide a functional account of the transformer's computation and not a topological account, i.e., we are agnostic to whether the attention patterns of the transformer themselves look tree structured-see Appendix C for examples. This procedure outputs vectors that are entirely context-free only if the threshold is exactly 0, but we find that tuning the threshold layer often leads to significantly better induced parses. Figure9in the Appendix shows that this assumption approximately holds in practice. Code and data will be available here see Appendix B for details. IID acc perfectly predicts generalization for M-PCFGSET so we omit it in these experiments



Figure 2: We use a T-shaped attention mask with a threshold layer to obtain approximate context-free vectors for transformers.

Figure4: We plot p parseval and t parseval over time for the 4 layer transformer encoder on COGS and M-PCFGSET. We find that t parseval improves gradually over time suggesting that the model becomes more "syntax aware". Such gradual syntax enrichment is not uncovered well by the probe since all checkpoints after 4000 (for COGS) and 50000 (for M-PCFGSET) iterations have similar p parseval .

Figure8: We plot the spearman correlation between (a) IID acc and CG acc, (b) t score and CG acc, (c) t parseval and CG acc. We find that both t parseval and t score correlate better with generalization than in-domain accuracy. All correlations are statistically significant (p-values < 10 -3 ) . neural networks for various NLP tasks, such as Recursive Neural Networks(Socher et al., 2013),  Tree RNNs (Tai et al., 2015), Recurrent Neural Network Grammars(Dyer et al., 2016), Neural Module Networks(Andreas et al., 2016), Ordered Neuron(Shen et al., 2019) among others. These approaches have largely been superseded by transformers(Vaswani et al., 2017), often pre-trained on a large corpus of text(Devlin et al. (2019), inter alia). We show that transformers, though not explicitly tree-structured, may still learn to become tree-like when trained on language data.

, g ϕ , T ) ≤ min T S∈D SCI(S, T (S)) (11) = S∈D min T (S) SCI(S, T (S))(12)Under Assumption 1 and Theorem 1, we have the proof for Corollary 1.1 which we present next. Corollary 1.1. Under Assumption 1, min ϕ,T L(f, g ϕ , T ) = S∈D min T (S) SCI(S, T (S)). Moreover, T proj (S) = arg min T (S) SCI(S, T (S)) for any S ∈ D.

proj (S) = arg min T (S) SCI(S, T (S)) Next, we consider specific examples of distance metric d, and what Assumption 1 implies for context-free vectors ṽs . Example A.1. Suppose d is the euclidean L 2 distance i.e., d(x, y) = ∥x -y∥. Then, Assumption 1 requires that ṽs = 1 |Ss| S∈Ss v S s Proof Sketch. We have v * s = arg min v S∈Ss d(v S s , v) = arg min v S∈Ss ∥v -v S s ∥. Setting derivatives with respect to v to 0, we have v * s = 1 |Ss| S∈Ss v S s Example A.2. Let d be the cosine distance of x and y i.e., d(x, y) = 1 -x ⊤ y ∥x∥∥y∥ . Then, Assumption 1 requires that ṽs = 1

greedily select split point to minimize SCI of resulting constituents 10:k * ← arg min k∈[i,j) [SCI(S i,k ) + SCI(S k+1,j )]; 11: s k * ← SCI(S i,k * ) + SCI(S k * +1,j ); 12:▷ select a random split point for normalization 13:s b ← SCI(S i,k b ) + SCI(S k b +1,j ), k b ∼ U [i, j -1];14: Recursively call the function to get a tree structure and score for left span 15: S l , ts l ← TREEPROJECTIONRECURSE(S, f, i, k * ); 16: Recursively call the function to get a tree structure and score for the right span 17: S r , ts r ← TREEPROJECTIONRECURSE(S, f, k * + 1, j); 18: return ⟨S l , S r ⟩, s b -s k * + ts l + ts r 19: end if 20: end function

Dataset StatisticsM-PCFGSET. We make two modifications to the PCFGSET dataset. First, we remove commas from expressions so that the model is forced to implictly learn to correctly partition the input expression for a correct intrepretation. To ensure that a unique parse exists even without commas, we additionally ensure that all lists have exactly 2 elements. For instance, the expression append A B

9. ACKNOWLEDGEMENTS

SM was funded by a gift from Apple Inc. JA is supported by the MIT Quest for Intelligence through a grant from Liberty Mutual Insurance. CM is a fellow in the CIFAR Learning in Machines and Brains program. We thank Ekin Akyürek, Marco Tulio Ribeiro, John Hewitt, Alexis Ross and members of the Stanford NLP group for feedback on early drafts on the paper.

