CLUSTER-FORMER: CLUSTERING-BASED SPARSE TRANSFORMER FOR QUESTION ANSWERING

Abstract

Transformer has become ubiquitous in the deep learning field. One of the key ingredients that destined its success is the self-attention mechanism, which allows fully-connected contextual encoding over input tokens. However, despite its effectiveness in modeling short sequences, self-attention suffers when handling inputs with extreme long-range dependencies, as its complexity grows quadratically w.r.t. the sequence length. Therefore, long sequences are often encoded by Transformer in chunks using a sliding window. In this paper, we propose Cluster-Former, a novel clustering-based sparse Transformer to perform attention across chunked sequences. The proposed framework is pivoted on two unique types of Transformer layer: Sliding-Window Layer and Cluster-Former Layer, which encode local sequence information and global context jointly and iteratively. This new design allows information integration beyond local windows, which is especially beneficial for question answering (QA) tasks that rely on long-range dependencies. Experiments show that Cluster-Former achieves state-of-the-art performance on several major QA benchmarks.

1. INTRODUCTION

Long-range contextual understanding has proven critical in many natural language processing (NLP) tasks. For example, the relevant context for correctly answering an open-domain question can arch over thousands of words. Encoding long sequences via deep neural networks, however, has remained an expensive and challenging task due to high demand on training time and GPU memory. Traditional sequence modeling methods (Hochreiter & Schmidhuber, 1997) encode long sequences in a chronological order, which suffers high latency. In the place of sequential encoding, recent models such as Transformer (Vaswani et al., 2017) use simultaneous self-attention over the entire input instead, which has been successfully adopted in many NLP tasks such as textual entailment (Devlin et al., 2019) , dependency parsing (Zhou & Zhao, 2019), and summarization (Lewis et al., 2019) . A caveat with Transformer though is that building full connections over long sequences translates to quadratic growth on memory demand and computational complexity w.r.t. sequence length. One way to efficiently encode long sequences is to first chunk a sequence into much shorter ones with a sliding window, then build connections between the shorter sequences (Figure 1 2019) makes use of the shared memory of chunked sequences to build connections between them. However, these methods cannot encode long-range dependencies with as much flexibility or accuracy as fully-connected selfattention, due to their dependency on hand-designed patterns. Recently, several studies (Kitaev et al., 2020; Tay et al., 2020b) propose to further improve the sparse attention mechanism by hashing or sorting the hidden states into different buckets (Figure 1(c) ). These works mainly explore tasks with relatively short sequences, such as sentence-level Machine Translation (MT), where the number of hashing vectors is relatively small (less than 16 in Kitaev et al. ( 2020)), allowing randomly initialized hashing vectors to hash hidden states into correct buckets. However, how to use hashing-based attention in the context of long sequences (e.g.,, up to thousands of words) is still an unexplored territory. Our proposed framework for efficient long sequence encoding, Cluster-Former, marries both sliding-window and hashing-based methods to achieve effective local and long-range dependency encoding. Cluster-Former consists of two types of encoding layer. The first one (noted as Sliding-Window Layer) focuses on extracting local information within a sliding window. It applies Transformer to the hidden states of each chunked sequence independently, as shown in Figure 1 (a). The other one (noted as Cluster-Former Layer) learns to encode global information beyond the initial chunked sequences. Specifically, we first apply clustering to the input hidden states so that similar hidden states are assigned to the same cluster, as shown in Figure 1(d) . The clustered and sorted input is then divided uniformly into chunks, each encoded by a Transformer layer. Note that to make model training more efficient, the cluster centroids are not computed online but updated periodically (every epoch or a few epochs). We accumulate the hidden states from the layer prior to the Cluster-Former layer in a memory bank, and apply the K-Means algorithm to form cluster centroids during each update cycle. Compared to previously discussed sparse attention based on pre-selected positions (Figure 1(b) ) or randomly-initialized hashing vectors (Figure 1(c )), experimental results show that our method can encode dependency across chunked sequences more effectively. Our contributions can be summarized as follows. (i) We propose Cluster-Former, a novel approach to capturing long-range dependencies more effectively than locality-sensitive hashing method. (ii) We propose a new Transformer-based framework to process long sequences by combining Sliding-Window and Cluster-Former layers to extract both local and global contextual information. (iii) Our model achieves the best performance on question answering datasets of Natural Questions (long answer), SearchQA, and Quasar-T.

2. RELATED WORK

Efficient Transformers With Transformer models growing larger and larger, how to handle longer sequences arises as a critical challenge. Many works have been proposed to improve the computational and memory efficiency of Transformers, including Sparse Transformer (Child et al., 2019 ), Routing Transformer (Roy et al., 2020 ), Reformer (Kitaev et al., 2020 ), Sinkhorn Transformer (Tay et al., 2020b ), Longformer (Beltagy et al., 2020 ), ETC (Ainslie et al., 2020 ), Synthesizer (Tay et al., 2020a ), Performer (Choromanski et al., 2020 ), Linformer (Wang et al., 2020 ), Linear Transformer (Katharopoulos et al., 2020 ), and BigBird (Zaheer et al., 2020) . Tay et al. (2020c) provided an excellent literature survey on this emerging topic. Our method falls into the setting of learnable sparse-attention patterns including Routing Transformer, Reformer and Sinkhorn Transformer. Our method is closer to Routing Transformer (Roy et al., 2020) which also uses cluster centroids to learn patterns, while we are targeting on quite different tasks (language modeling VS question answering) which leads to the significant difference on frameworks. Moreover, our cluster centroids are updated in very different ways (online exponentially moving centroids VS periodical centroids update by KMeans).



(a)). For example, Child et al. (2019), Beltagy et al. (2020) and Zaheer et al. (2020) apply sparse attention to chunked sequences in hand-designed patterns in order to gather information from the chunks (Figure 1(b)). Choi et al. (2017) and Wang et al. (2019) first use a simpler model to filter chunked sequences, then process selected sequences with fully-connected self-attention. Rae et al. (

Figure 1: Illustration of different methods for processing long sequences. Each square represents a hidden state. The black-dotted boxes are Transformer layers. (a) is the sliding-window-based method to chunk a long sequence into short ones with window size 3 and stride 2. (b) builds crosssequence attention based on sliding window over pre-selected positions (red-dotted boxes). (c) hashes the hidden states into different buckets by randomly-initialized vectors. (d) is our proposed approach to cluster the hidden states. Our final model is a combination of (a) and (d) that processes both local and global context.

