BEAM TREE RECURSIVE CELLS

Abstract

Recursive Neural Networks (RvNNs) generalize Recurrent Neural Networks (RNNs) by allowing sequential composition in a more flexible order, typically, based on some tree structure. While initially user-annotated tree structures were used, in due time, several approaches were proposed to automatically induce treestructures from raw text to guide the recursive compositions in RvNNs. In this paper, we present an approach called Beam Tree Recursive Cell (or BT-Cell) based on a simple yet overlooked backpropagation-friendly framework. BT-Cell adapts beam search easy-first parsing for simulating RvNNs with automatic structureinduction. Our results show that BT-Cell achieves near-perfect performance on several aspects of challenging structure-sensitive synthetic tasks like ListOps and also comparable performance in realistic data to other RvNN-based models. We further introduce and analyze several extensions of BT-Cell based on relaxations of the hard top-k operators in beam search. We evaluate the models in different out of distribution splits in both synthetic and realistic data. Additionally, we identify a previously unknown failure case for neural models in generalization to unseen number of arguments in ListOps. Code is in the supplementary.

1. INTRODUCTION

In the space of sequence encoders, Recursive Neural Networks (RvNNs) can be said to lie somewhere in-between Recurrent Neural Networks (RNNs) and Transformers in terms of flexibility. While vanilla Transformers show phenomenal performance and efficient scalability on a variety of tasks, it can often struggle in length generalization and systematicity in syntax-sensitive tasks (Tran et al., 2018; Shen et al., 2019a; Lakretz et al., 2021; Csordás et al., 2022) . RvNN-based models, on the other hand, can often excel on some of the latter kind of tasks (Shen et al., 2019a; Chowdhury & Caragea, 2021; Liu et al., 2021; Bogin et al., 2021) making them worthy of further study although they may suffer from limited scalability in their current formulations. Given an input text, RvNNs (Pollack, 1990; Socher et al., 2010) are designed to build up the representation of the whole text by recursively building up the representations of its constituents starting from the most elementary representations (tokens) in a bottom-up fashion. As such, RvNNs can model the hierarchical part-whole structures underlying texts. However, originally RvNNs required access to pre-defined hierarchical constituency-tree structures. Several works (Socher et al., 2011; Havrylov et al., 2019; Choi et al., 2018; Maillard et al., 2019; Chowdhury & Caragea, 2021) introduced latent-tree RvNNs that sought to move beyond this limitation by making RvNNs able to learn to automatically determine the structure of composition from any arbitrary downstream task objective, given just the raw input text. Among these approaches, Gumbel-Tree models (Choi et al., 2018) are particularly attractive for its simplicity. It often serves as a standard baseline for latent-tree models. However, Gumbel-Tree models not only suffer from biased gradients (due to use of Straight-Through Estimation (STE)), but they also perform poorly on synthetic tasks like ListOps (Nangia & Bowman, 2018) that were specifically designed to diagnose the capacity of neural models for automatically inducing underlying hierarchical structures. In this paper, we tackle these issues by introducing the Beam Tree Cell (BT-Cell) framework that applies beam-search as a simple modification over Gumbel-Tree models. Instead of greedily selecting the highest scored sub-tree representations like Gumbel-Tree models, BT-Cell chooses and maintains top-k highest scored sub-tree representations. We show that this simple modification increases the performance of Gumbel-Tree models in challenging structure sensitive tasks by several folds. For example, in ListOps, when testing for samples of length 900-1000, a BT-Cell based model increases the performance of a comparable Gumbel-Tree model from 37.9% to 86.7% (see: Table 1 ). We further explore several variants of BT-Cell. Particularly, we explore ways to replace the non-differentiable top-k operators involved in beam search with different alternatives such as top-k gumbel softmax with STE and a novel strategy of maintaining a convex combination of bottom scoring paths. Our best extension achieves a new state-of-the-art in length generalization and depth-generalization in structure-sensitive synthetic tasks like ListOps and performs comparably in realistic data against other latent-tree models. A few recently proposed latent-tree models simulating RvNNs like LSTM-RL (Havrylov et al., 2019) , Ordered Memory (OM) (Shen et al., 2019a) or CRvNN (Chowdhury & Caragea, 2021) are also strong contenders to BT-Cell and its extensions on synthetic data. However, unlike BT-Cell, LSTM-RL relies on expensive reinforcement learning and several sophisticated techniques to stabilize training. Moreover, compared to OM and CRvNN, one distinct advantage of BT-Cell is that it not just provides the final sequence encoding (representing the whole input text) but also the intermediate constituent representations at different levels of hierarchy (representations of all nodes of the underlying induced trees). Such tree-structured node representations can be useful as inputs to further downstream modules like a Transformer (Vaswani et al., 2017) or GNN (Scarselli et al., 2009) in a full end-to-end settingfoot_0 . While CYK-based RvNNs (Maillard et al., 2019) are also promising and similarly can provide multiple span representations they tend to be much more expensive than BT-Cell. All these architectural trade-offs among different latent-tree models are discussed in more details in Appendix E.6. Besides proposal and evaluation of BT-Cell variants, our paper also serves as a survey of how well prior proposed latent-tree RvNNs work in structure-sensitive synthetic tasks and out-of-distributionsplits in natural language tasks, particularly when combined with more powerful recursive cells. Additionally, as a further contribution, we identify a previously unknown failure case for even the best performing neural models when it comes to argument generalization in ListOps (Nangia & Bowman, 2018) -opening up a new challenge for future research.

2. PRELIMINARIES

Problem Formulation: Similar to Choi et al. (2018) , throughout this paper, we explore the use of RvNNs as a sentence encoder. Formally, given a sequence of token embeddings X = (e 1 , e 2 , . . . , e n ) (where X ∈ IR n×de , e i ∈ IR de , and d e is the embedding size), the task of a sentence encoding function E : IR n×de → IR d h is to encode the whole sequence of vectors into a single vector o = E(X ) (where o ∈ IR d h and d h is the size of the encoded vector). We can use a sentence encoder for sentence-pair comparison tasks like logical inference or for text classification.

2.1. RECCURENT NEURAL NETWORKS AND RECURSIVE NEURAL NETWORKS

A core component of both RNNs and RvNNs is a recursive cell. In our contexts, the cell function takes as arguments two vectors (a 1 ∈ IR da 1 and a 2 ∈ IR da 2 ) and returns a single vector v = cell(a 1 , a 2 ) (where v ∈ IR dv ). cell : IR da 1 × IR da 2 → IR dv . In our settings, we generally set d a1 = d a2 = d v = d h . Given a sequence X , both RNNs and RvNNs sequentially process it through recursive application of the cell function. For a concrete example, consider a sequence of token embeddings such as (2 + 4 × 4 + 3) (Assume the symbols 2, 4, + etc. represent transformations of corresponding embedding vectors ∈ d h ). Given any such sequence, RNNs can only follow a fixed left-to-right order of composition. For the particular aforementioned sequence, an RNN-like application of the cell function can be expressed as: (cell(cell(cell(cell(cell(cell(h0, 2) , +), 4), ×), 4), +), 3) o = cell (1) Here h0 is some input-independent initial state ("initial hidden state") set in the model. In contrast to RNNs, RvNNs can compose the sequence in more flexible orders. For example, one way (among



There are several works that have used intermediate span representations for better compositional generalization in generalization tasks(Liu et al., 2020; Herzig & Berant, 2021; Bogin et al., 2021; Liu et al., 2021; Mao et al., 2021). We keep it as a future task to explore whether the span representations returned by BT-Cell can be used in relevant ways.

