CLUSTER-FORMER: CLUSTERING-BASED SPARSE TRANSFORMER FOR QUESTION ANSWERING

Abstract

Transformer has become ubiquitous in the deep learning field. One of the key ingredients that destined its success is the self-attention mechanism, which allows fully-connected contextual encoding over input tokens. However, despite its effectiveness in modeling short sequences, self-attention suffers when handling inputs with extreme long-range dependencies, as its complexity grows quadratically w.r.t. the sequence length. Therefore, long sequences are often encoded by Transformer in chunks using a sliding window. In this paper, we propose Cluster-Former, a novel clustering-based sparse Transformer to perform attention across chunked sequences. The proposed framework is pivoted on two unique types of Transformer layer: Sliding-Window Layer and Cluster-Former Layer, which encode local sequence information and global context jointly and iteratively. This new design allows information integration beyond local windows, which is especially beneficial for question answering (QA) tasks that rely on long-range dependencies. Experiments show that Cluster-Former achieves state-of-the-art performance on several major QA benchmarks.

1. INTRODUCTION

Long-range contextual understanding has proven critical in many natural language processing (NLP) tasks. For example, the relevant context for correctly answering an open-domain question can arch over thousands of words. Encoding long sequences via deep neural networks, however, has remained an expensive and challenging task due to high demand on training time and GPU memory. Traditional sequence modeling methods (Hochreiter & Schmidhuber, 1997) encode long sequences in a chronological order, which suffers high latency. In the place of sequential encoding, recent models such as Transformer (Vaswani et al., 2017) use simultaneous self-attention over the entire input instead, which has been successfully adopted in many NLP tasks such as textual entailment (Devlin et al., 2019) , dependency parsing (Zhou & Zhao, 2019) , and summarization (Lewis et al., 2019) . A caveat with Transformer though is that building full connections over long sequences translates to quadratic growth on memory demand and computational complexity w.r.t. sequence length. One way to efficiently encode long sequences is to first chunk a sequence into much shorter ones with a sliding window, then build connections between the shorter sequences (Figure 1 (a)). For example, Child et al. (2019) , Beltagy et al. (2020) and Zaheer et al. (2020) apply sparse attention to chunked sequences in hand-designed patterns in order to gather information from the chunks (Figure 1(b) ). Choi et al. (2017) and Wang et al. (2019) first use a simpler model to filter chunked sequences, then process selected sequences with fully-connected self-attention. Rae et al. (2019) makes use of the shared memory of chunked sequences to build connections between them. However, these methods cannot encode long-range dependencies with as much flexibility or accuracy as fully-connected selfattention, due to their dependency on hand-designed patterns. Recently, several studies (Kitaev et al., 2020; Tay et al., 2020b) propose to further improve the sparse attention mechanism by hashing or sorting the hidden states into different buckets (Figure 1(c) ). These works mainly explore tasks with relatively short sequences, such as sentence-level Machine Translation (MT), where the number of hashing vectors is relatively small (less than 16 in Kitaev et al. (2020) ), allowing randomly initialized hashing vectors to hash hidden states into correct buckets. However, how to use hashing-based attention in the context of long sequences (e.g.,, up to thousands of words) is still an unexplored territory. 

2. RELATED WORK

Efficient Transformers With Transformer models growing larger and larger, how to handle longer sequences arises as a critical challenge. Many works have been proposed to improve the computational and memory efficiency of Transformers, including Sparse Transformer (Child et al., 2019) , Routing Transformer (Roy et al., 2020) , Reformer (Kitaev et al., 2020) , Sinkhorn Transformer (Tay et al., 2020b) , Longformer (Beltagy et al., 2020) , ETC (Ainslie et al., 2020) , Synthesizer (Tay et al., 2020a) , Performer (Choromanski et al., 2020) , Linformer (Wang et al., 2020) , Linear Transformer (Katharopoulos et al., 2020) , and BigBird (Zaheer et al., 2020) . Tay et al. (2020c) provided an excellent literature survey on this emerging topic. Our method falls into the setting of learnable sparse-attention patterns including Routing Transformer, Reformer and Sinkhorn Transformer. Our method is closer to Routing Transformer (Roy et al., 2020) which also uses cluster centroids to learn patterns, while we are targeting on quite different tasks (language modeling VS question answering) which leads to the significant difference on frameworks. Moreover, our cluster centroids are updated in very different ways (online exponentially moving centroids VS periodical centroids update by KMeans). Long Sequence in Question Answering For tasks such as open-domain question answering (Chen et al., 2017) , a large volume of documents or paragraphs is usually retrieved to infer the answer, yielding extremely long context content. Despite the fact that state-of-the-art NLP models are capable of extracting answers amid complex context, they still struggle with extremely long input sequences. Recent advances that advocate the use of large-scale pre-trained models (Lewis et al., 2019; Liu et al., 2019; Lan et al., 2020) for question answering make this problem more prominent, due to tremendous memory consumption. To process long sequences, the most widely-used method is to first use a lightweight model to filter out redundant text, then use sliding-window-based approaches to encode the remaining sequences with a more sophisticated model. Chen et al. (2017) integrated bi-gram features into Information Retrieval (IR) methods to retrieve related documents more accurately. Wang et al. (2018) trained a paragraph selector using as the reward whether the entire system can obtain the correct answer or not . Lin et al. (2018) proposed to use a paragraph ranking model to curate data that are required for training reading comprehension models. Wang et al. (2019) trained a ranker to merge paragraphs for multi-passage reasoning. Asai et al. (2020) trained a recurrent retriever to select paragraphs for multi-hop question answering. Besides the above methods, directly applying Efficient Transformers to process long sequences in question answering is another option. In this paper, we focus on this direction by directly training our Cluster-Former on the long context without using lightweight model for context filtering.

3. PROPOSED APPROACH

The proposed framework to handle long sequences is pivoted on two types of Transformer layer: (i) Sliding-Window Layer; and (ii) Cluster-Former Layer. The former focuses on encoding local sequence information, while the latter is on encoding global context and always built on top of the former layer. An overview of the two layers is illustrated in Figure 2 .

3.1. SLIDING-WINDOW LAYER

Despite that our focus is on capturing long-range dependencies for global context, local information also plays a critical role for knowledge propagation. Therefore, in the lower section of our network, we adopt the traditional sliding-window encoding mechanism. A sliding window segments a long sequence X into short, overlapping ones with window size l and stride m, as illustrated in Figure 2(a) . Note that in this paper, we focus on question answering tasks, for which we concatenate the question Q with each sequence chunked from the document: H 0 k = [Q; X [m × k : (m × k + l)]] , where Q ∈ R q×d denotes question embeddings given a QA task, q is the number of tokens in the question, and X ∈ R x×d is the embeddings for all context, x is the number of tokens in context. k is the ID of the chunked sequence, l is the window size, and m is the stride of the sliding window. [idx 1 : idx 2 ] indicates selecting rows between the index of idx 1 and idx 2 of the matrix. [•; •] means concatenating the matrices along the row. We first use Transformer to encode each sequence in sliding window as follows: H n+1 k = Transformer(H n k ), where H n+1 k ∈ R (q+l)×d is the output of Transformer on the k-th sequence in the n-th layer. While it is not the final output of the n-th layer. As we expect the neighbouring sequences to share useful information in hidden states as well, we always set m < l to allow overlapping between sequences. We use the mean values of the Transformer hidden states at the overlapped tokens between windows as final outputs. To merge the representations from (k -1)-th sequence: H n+1 k [q : q + l -m] + = H n+1 k-1 [q + m : end], H n+1 k [q : q + l -m] / = 2, and merge representations from (k + 1)-th sequence: H n+1 k [q + m : end] + = H n+1 k+1 [q : q + l -m], H n+1 k [q + m : end] / = 2, where + = is to add matrices in-place and / = is to divide a matrix by a scalar value in-place. The merged hidden states H n+1 k ∈ R (q+l)×d are the final outputs of the n-th layer. If the next layer is Cluster-Former, the output hidden states in this layer H n+1 k will be saved into memory bank for computing the cluster centroids. The input of the Cluster-Former layer comes from the hidden states of the prior layer (in our case a Sliding-Window layer). After merging the overlaps between sequence chunks, the input of this layer is defined as:

3.2. CLUSTER-FORMER LAYER

Hn = [H n 0 [0 : q + m]; ...; H n k [0 : q + m]] , ) where Hn ∈ R (q x/m +x)×d is the hidden states to cluster, x is the number of tokens in the context. As the hidden states with larger cosine similarity are more likely to have higher attention weights, we build sparse selfattention only on the hidden states in the same cluster. In this work, we use K-Means as the chosen clustering method for simplicity. More advanced clustering algorithms have the potential of yielding better performance. Since running K-Means on the fly in each training iteration is computationally expensive, we decide to re-compute the cluster centroids with low frequency (every epoch or a few epochs). In addition, to avoid dramatic changes in the cluster centroids due to limited hidden state inputs, we maintain a memory bank for the most recent hidden states. The entire procedure is depicted in Algorithm 1. Once we compute the cluster centroids, we can directly use them for hidden state clustering as follows: v n = argmax H n (C n ) T ||H n || 2 ||C n || 2 , ( ) where C n ∈ R p×d are the cluster centroids for layer n, and p is the pre-defined number of clusters. The function argmax(•) performs on the last dimension and assigns all the input hidden states into different clusters based on the max value of cosine similarity between the hidden states and cluster centroids. v n ∈ R (q x/m +x) is the assigned cluster IDs of all the input hidden states. Since the number of hidden states in different clusters can vary substantially, padding them to the maximum length for Transformer training will significantly increase the computational time. To make the extraction of global context more efficient, we greedily pick the cluster centroids based on the nearest neighbour (measured by cosine similarity) as shown in the function GETCENTROIDS in Algorithm 1. Thus, the hidden states with similar cluster IDs are also close to each other. Then, we can directly sort the cluster IDs of hidden states and uniformly chunk the hidden states (same window size and stride m): u n = argsort(v n ), a n k = u n [mk : m(k + 1)], E n k = H n [a n k ], where the function argsort(•) is to obtain the indexes of input values sorted in order (same values sorted by the corresponding position of hidden states). a n k ∈ R m is the chunked indexes of the hidden states. E n k ∈ R m×d is the k-th clustered hidden states, and we will run Transformer on top of it to build the connection beyond the words in the initial sliding window as follows: E n+1 k = Transformer(E n k ). After updating the hidden states, we map them back to the order before clustering: Hn+1 = [E n+1 0 ; E n+1 1 ; ...; E n+1 K ], ān = [a n 0 ; a n 1 ; ...; a n K ], (8) Hn+1 [ā n ] = clone( Hn+1 ), ( ) where Hn+1 is the final output hidden state of this layer and has the same word order as the input Hn . In experiments, we stack these two types of layer interchangeably to capture both global and local context efficiently. We evaluate our proposed approach on multiple question answering benchmarks. The statistics of all the datasets are summarized in Table 1 .

4. EXPERIMENTS

Quasar-Tfoot_0 (Dhingra et al., 2017) : The goal of this task is to answer open-domain questions from Trivia Challenge. All the passages harvested through information retrieval can be used to answer questions. The task requires the model to generate answers in phrases. The evaluation metric on this dataset is based on Exact Match and F1 score of the bag-of-words matching. Our evaluation toolfoot_1 comes from the SQuAD dataset. SearchQAfoot_2 (Dunn et al., 2017) : The setting of this dataset is the same as Quasar-T, except that the questions are sourced from Jeopardy! instead. Natural Questionsfoot_3 (Kwiatkowski et al., 2019) : This task aims to answer questions based on a given Wikipedia document, and has two settings. (i) Long answer: select a paragraph that can answer the question based on the Wikipedia document if any. (ii) Short answer: extract an answer phrase from the document if the document contains the answer. As the given document may not contain answer, we can either predict an answer or predict no answer. The evaluation metric on this dataset is the F1 score, where true positives are exactly correct answers, false positives are incorrect answer predictions, and false negatives are incorrect "no answer" predictions. As the test set is hidden, we split 5% of the training set for validation, and use the original validation set for testing. We use the official tool from the dataset to evaluate our models. We also submit our best model to the leaderboard.

4.2. IMPLEMENTATION DETAILS

All the models are trained on 8 Nvidia V100 GPUs. For clustering, we adopt "Yinyang kmeans " (Ding et al., 2015) foot_4 which takes less than 5 seconds for clustering in all of our experiment settings. We set the memory size for clustering M = 100, 000 in Algorithm 1. We use cluster centroids that perform the best on the validation set for test set experiments. We initialize our models with RoBERTa-large (Liu et al., 2019) . As the number of position embeddings of RoBERTa is limited to 512, we cannot assign different position embeddings to all tokens. Instead, we assign the same position embeddings to each chunked sequence. The majority of our model is made up of Sliding-Window Layers, as the local information is essential for QA tasks. We adopt the proposed Cluster-Former Layer in layers 15 and 20 to further capture long-range information. We set the sliding window size l to 256, stride m to 224, and change the number of clusters in {64, 256, 512} to analyze its impact on the final performance. We prepend a special token to the beginning of all the given/retrieved paragraphs and directly concatenate all the paragraphs as the final context sequence. Due to memory constraints, we set the max length to be 5000 during training and 10000 during inference. During dataset finetuning, we use Adam (Kingma & Ba, 2015) to optimize the model. We set warm-up updates to 2,220, maximal updates to 22,200, learning rate to 5 × 10 -5 , and batch size to 160. We tune dropout rate from {0.1, 0.15, 0.2} for all methonds including baselines and report the best results. The model converges in one day for all the QA datasets. For Quasar-T and SearchQA, we predict the start and end positions of the answer. For Natural Question, we first identify whether the question has short/long answers or not based on the mean values of the first hidden state of all the chunked sequences, 1 K K k=1 H N k [0] , where K is the number of chunks and N is the number of layers. If answerable, we rank all the candidates for long answer selection, and predict the start and end positions of short answers. Our model submitted to Natural Question Leaderboard ensembled 3 models with 512 clusters, and only these models are firstly trained on SQuAD2.0 and then finetuned on Natural Question dataset.

4.3. BASELINE

We compare our models with several strong baselines, including: R3 (Wang et al., 2018) proposes to use reinforcement learning to jointly train passage ranker and reader. DS-QA (Lin et al., 2018) proposes to first use paragraph selection to filter the noisy data and then trained model on denoised data. Multi-passage BERT (Wang et al., 2019) proposes to filter the passages and then merge multiple useful passages into one sequence, which can be encoded by BERT. DrQA (Chen et al., 2017) makes use of attention mechanism across the question and the document for answer phrase extraction. DecAtt and DocReader (Kwiatkowski et al., 2019) is based on a pipeline approach that first uses a simpler model to select long answers and then a reading comprehension model to extract short answers from the long answers. BERT joint (Alberti et al., 2019) jointly trains short and long answer extraction in a single model rather than using a pipeline approach. BERT wwm +SQuAD2 (Pan et al., 2019) makes use of multi-task learning to further boost performance. RikiNet-RoBERTa (Liu et al., 2020) proposes a dynamic paragraph dual-attention reader and a multi-level cascaded answer predictor. BigBird-ETC (Zaheer et al., 2020) makes use of a sparse attention mechanism to encode long sequences. Quasar-T SearchQA NQ(long) NQ(short) EM/F1 EM/F1 F1 F1 R3 (Wang et al., 2018) 35.3/41.7 49.0/55.3 --DECAPROP (Tay et al., 2018) 38.6/46.9 62.2/70.8 --DS-QA (Lin et al., 2018) 42.2/49.3 58.8/64.5 --Multi-passage BERT (Wang et al., 2019) 51.1/59.1 65.1/70.7 --DrQA (Chen et al., 2017) 37.7/44.5 41.9/48.7 We also re-implement several strong baselines which have not been applied to process long context in question answering tasks: • Sliding Window: The original method is fully made up of Sliding-Window Layers and can only attend to local information. To make a fair comparison among different methods on long-range information collection, we replace several layers of this sliding window baseline with Sparse Attention, Locality-Sensitive Hashing, and Cluster-Former. • Sparse Attention (Child et al., 2019) : This method replaces several layers in the previous baseline by training a Transformer layer across sequences on pre-selected positions. We run this sparse Transformer on all the hidden states in the same position across sequences, so that the output of sparse Transformer can merge the information from different sequences. • Locality-Sensitive Hashing (Kitaev et al., 2020) : This method hashes hidden states into different buckets determined by randomly-initialized hashing vectors. A Transformer layer is then applied across buckets to build Sparse Attention across the whole sequence. Note that this method cannot be directly used for question answering without adding Sliding-Window layer, as our QA model is initialized by RoBERTa that only has 512 position embeddings.

4.4. EXPERIMENTAL RESULTS

State-of-the-Art Results on QA Table 2 and 3 show that our proposed method outperforms several strong baselines, thanks to its ability to encode both local and global information. Cluster-Former with 512 clusters achieves new state-of-the-art results on Quasar-T, SearchQA and Natural Question (long answer). Effect of Cluster-Former We also test the ability of Cluster-Former on modeling long-range dependencies. Note that Sparse Attention (Child et al., 2019) and Locality-Sensitive Hashing (Kitaev et al., 2020) have never been tested on question answering tasks with long con- 49, 50, 51, 52, 53, 54, 55, 115, 116, 168, 273, 394, ..., 6022, 6040, 6042, 6060, 6094, 6095 

Effect of Number of Cluster Centroids

We also test the effect of different numbers of cluster centroids (C) on model performance. We observe that the model with 512 clusters works significantly better than the model with 64 clusters on most of the tasks. However, for Natural Questions Long Answer setting, the improvement is marginal. As we mainly rely on the hidden state of special tokens "<s>" for long answer selection, and the same tokens can be assigned into same chunk more easily even with a smaller number of clusters.

Selection of Cluster-Former Layers

We also have an analysis on which layers are better used for Cluster-Former layer. As shown in Table 4 , we conduct a hyper-parameter search. And find that it can get better performance with at least one Cluster-Former layers in the middle layer (8-16). The worst results come from only one Cluster-Former layer in the layer of 22 or 23. Language Modeling Although we focus on QA tasks in this paper, to demonstrate the versatility of Cluster-Former, we further conduct additional experiments on language modeling using the Wikitext-103 (Merity et al., 2017) and Enwik8 (Mahoney, 2011) benchmarks. Implementation details are provided in Appendix. As shown in Table 5 , Cluster-Former outperforms strong state-ofthe-art baselines.

4.5. QUALITATIVE ANALYSIS

We perform qualitative analysis on how the hidden states are clustered, by visualizing the corresponding words and positions of the hidden states in Table 6 . From the first row, we can see that the special tokens "<s>" tend to belong to the same cluster. Note that "<s>" is the start token of each long answer candidate, and its hidden state is used for final long answer selection. Therefore, Transformer on this cluster can compare across the candidates to make the final prediction. We further observe that the same types of token are more likely to appear in the same cluster. For example, words from the second row to the forth row cover the topics of time, stopwords, and organization & geopolitical entities. Finally, we randomly sample a cluster and list the positions of clustered hidden states in the last row of the table. We find that states in long distance, such as the 50-th and 6060-th states (over 6000 tokens apart), can be in one cluster, which demonstrates the ability of Cluster-Former in detecting long-range dependencies. Further, we observe that states tend to cluster in phrases. For example, we see consecutive positions such as "49, 50, 51, 52, 53, 54, 55", which likely results from the sliding-window encoding.

5. CONCLUSION

In this paper, we present Cluster-Former, a new method to encode global information for long sequences. We achieve new state of the art on three question answering datasets: Quasar-T, SearchQA, and Natural Questions. Further, we observe that a larger number of clusters in Cluster-Former can lead to better performance on question answering tasks. Cluster-Former is a generic approach, and we believe that it can benefit other NLP tasks that rely on long-range dependencies as well.



https://github.com/bdhingra/quasar https://rajpurkar.github.io/SQuAD-explorer/ https://github.com/nyu-dl/dl4ir-searchQA https://ai.google.com/research/NaturalQuestions https://github.com/src-d/kmcuda



Figure 1: Illustration of different methods for processing long sequences. Each square represents a hidden state. The black-dotted boxes are Transformer layers. (a) is the sliding-window-based method to chunk a long sequence into short ones with window size 3 and stride 2. (b) builds crosssequence attention based on sliding window over pre-selected positions (red-dotted boxes). (c) hashes the hidden states into different buckets by randomly-initialized vectors. (d) is our proposed approach to cluster the hidden states. Our final model is a combination of (a) and (d) that processes both local and global context.

Figure 2: An overview of proposed Transformer layer. (a) Sliding-Window layer over a sequence. (b) Cluster-Former layer over clustered hidden states from the output of (a). Cluster centroids are periodically updated based on the memory bank of the hidden states in the corresponding layer. Note that the sequence inputs in (a) and (b) usually come from two different samples.

We introduce a Cluster-Former layer to add global representational power to Transformer beyond sliding windows. An in-depth visualization of the layer is illustrated in Figure2(b).

Results on Quasar-T, SearchQA test sets and NQ dev set. #C: number of clusters.

Results on Natural Questions (NQ) leaderboard (test set). We show two published results here from over 40 submissions. And our model achieves No.1 for long answer and No.4 for short answer.

QuestionWhere did the underground railroad start and finish ? Context The Underground Railroad by artist Charles T. Webber , 1893 Date Late 1700s -1865 Location Northern United States with routes to Canada , Mexico ... Special token <s><s><s>Island island in the colonies city<s><s><s>With in the in . Time did start and finish 1893 Date 1700 1865 Location Participants Outcome Deaths 19 1763 Stopwords the the , the , , , , to , , , , the American runaway slaves of free states the , , , it to , a the the Entity Canada Mexico Canada is applied Florida Spanish Railroad Railroad Railroad Positions

An example from Natural Question dataset. The rows in the middle section show the corresponding words of the clustered hidden states, and the bottom row shows the positions of the clustered hidden states. "<s>" refers to start token of long answer candidate. text. For fair comparison, we set the layers 15 and 20 as either Sparse Attention, Locality-Sensitive Hashing or our Cluster-Former, and the left layers are Sliding Window layers.

Experiments on Quasar-T dev dataset.

