ADDING RECURRENCE TO PRETRAINED TRANSFORM-ERS

Abstract

Fine-tuning a pretrained transformer for a downstream task has become a standard method in NLP in the last few years. While the results from these models are impressive, applying them can be extremely computationally expensive, as is pretraining new models with the latest architectures. We present a novel method for applying pretrained transformer language models which lowers their memory requirement both at training and inference time. An additional benefit is that our method removes the fixed context size constraint that most transformer models have, allowing for more flexible use. When applied to the GPT-2 language model, we find that our method attains better perplexity than an unmodified GPT-2 model on the PG-19 and WikiText-103 corpora, for a given amount of computation or memory.

1. INTRODUCTION

Recent progress in NLP has been dominated by large pretrained transformer neural networks (Vaswani et al., 2017) , such as BERT (Devlin et al., 2019) , and GPT-2 (Radford et al., 2019) . However, these models have a memory footprint that is quadratic in input sequence length. Although architectural innovations such as those of Kitaev et al. (2019) and Rae et al. (2019) mitigate this and the issue of a predetermined maximum context size, large pretrained models applying these techniques are not available at this time. Even if large pretrained models of this kind are released in the future, they will likely not cover the wide range of domains that BERT-family models have been published for. For example, there have been BERT-based models trained for other languages such as French (Le et al., 2020; Martin et al., 2020 ), Italian (Polignano et al., 2019) , and many other languages (see Nozza et al. (2020) for an overview) as well as specific domains such as scientific papers (Beltagy et al., 2019 ), biomedical papers (Lee et al., 2020 ), and health records (Rasmy et al., 2020) . Individuals working with these models may not have the resources to train new models from scratch using the latest tricks, as the computation requirements for pretraining are extremely high. As such, identifying ways that already existing models can be improved could be widely impactful. Another drawback of this family of models is that they have an a priori fixed maximum context size (typically 512 or 1024 tokens for the currently available pretrained models). A typical application of pretrained language models is producing contextual embeddings for a document. If the document is simply chunked into disjoint segments of 512 tokens, tokens at the boundary of a window will have less contextual information than tokens in the center of a window. This can be mitigated by striding the evaluation of the model, and only keeping the embedding for a token which has the largest context-but this adds quite a bit of wasted computation. In this paper, we propose a method for augmenting and fine-tuning pretrained transformer language models to use context without directly attending to it. Our method simultaneously allows for increasing the context size a transformer processes, while allowing a controllable trade-off between computation and perplexity. We accomplish this by adding a small recurrence module that computes a fixed size representation from the transformer hidden states in a window of text. Then, the representation for that window is used during processing of the next window. Shrinking the window size is then a way to reduce the memory footprint of the model, with less loss of performance than would occur with a standard transformer. Our experiments add recurrence GPT-2 language models, and fine-tune them on the PG-19 (Rae et al., 2019) and WikiText-103 corpora (Merity et al., 2016) , and require only the same amount of memory used for standard fine-tuning of a pretrained language model. We demonstrate improvements in perplexity compared to a baseline model using the same amount of computation. Qualitative analysis shows that our recurrent module propagates certain information from previous windows of text, which can facilitate handling of long-distance dependencies with fixed-size input windows.

2. RELATED WORK

Many methods have been proposed to lower the memory footprint or computation time of transformer language models, or allow them to be used on larger contexts. The Transformer-XL (Dai et al., 2019) allows a position within an attention window to attend to tokens from the previous windows by introducing relative position embeddings. While that mechanism, like ours, allows information to flow between windows, existing BERT and GPT-2 models do not use relative position embeddings, so training from scratch would be necessary to take advantage of this architecture. Additionally, each layer in the Transformer-XL attends to the previous layer in the previous window, so the maximum attention horizon is finite. Our recurrent method could theoretically pass information across an arbitrary distance, although one would not expect it to exceed the Transformer-XL's horizon without a much larger scale of data than we experiment with. We list here some other modifications of the transformer architecture, somewhat imprecisely grouping them for brevity. For a more detailed discussion, see Tay et al. (2020b) . Child et al. ( 2019 2019) replace dynamically computed self-attention with cheaper alternatives. While the above methods all allow for a reduction in computation, they also all require training from scratch. Our goal is to allow more efficient and powerful use of the wide array of existing pre-trained models that cover many domains. Cao et al. (2020) propose the DeFormer, which also modifies the execution of a pretrained transformer. However, unlike our method, they decompose a single window into multiple windows by removing the attention interactions between these windows. This is largely orthogonal to our method, as one could both decompose windows of text, and additionally use our method to allow information to be passed between neighboring windows. Similarly, distilled versions of pre-trained models such as DistilBERT (Sanh et al., 2019) provide more computational efficiency, but could be combined with our method to apply them to longer contexts, or reduce the quadratic cost of self-attention. Hao et al. (2019) apply pre-trained transformers recurrently for machine translation, but do so by using an attention network to embed the document, applying a recurrent encoder to those embeddings, and using the recurrent encoder alongside a typical transformer encoder. This differs from our method as we are fine-tuning language models, which are transformer decoders, and directly modifying the transformer's computation with a recurrent connection, rather than running an RNN on top of embeddings produced by a transformer.

3. METHOD

The main idea of our method is to take a transformer that was pretrained in a fixed context size setting and add recurrence at the level of T -token windows of text. For example, instead of executing the model on one 1000 token window of text, we could instead execute our model with 10 windows of 100 tokens. The first window is processed by the transformer model as normal, but for subsequent windows we add a supplementary embedding, which is generated using the hidden states from the preceding window (see Figure 1 ). The recurrence module is extremely small compared to the size of transformer language model, so the additional computation required is negligible.

3.1. ADDING RECURRENCE TO PRETRAINED TRANSFORMERS

Starting by defining terms, we will consider a pretrained transformer with L layers, a hidden state size of k, and a maximum context size of T tokens. Let h ( ) i ∈ R k be the output of the -th layer of the pretrained model, at position i. To produce a fixed-size representation of tokens t 1 , t 2 , . . . , t T ,



), Qiu et al. (2019), Kitaev et al. (2019), Sukhbaatar et al. (2019), and Roy et al. (2020) introduce sparsity to self-attention in various forms, reducing its memory cost. Rae et al. (2019) and Beltagy et al. (2020)-dynamically and statically respectively-add extra tokens to attend to which allow for global passing of information. Tay et al. (2020a) and Wu et al. (

