WAVEFORMER: LINEAR-TIME ATTENTION WITH FOR-WARD AND BACKWARD WAVELET TRANSFORM

Abstract

We propose Waveformer that learns attention mechanism in the wavelet coefficient space, requires only linear time complexity, and enjoys universal approximating power. Specifically, we first apply forward wavelet transform to project the input sequences to multi-resolution orthogonal wavelet bases, then conduct nonlinear transformations (in this case, a random feature kernel) in the wavelet coefficient space, and finally reconstruct the representation in input space via backward wavelet transform. We note that other non-linear transformations may be used, hence we name the learning paradigm Wavelet transformatIon for Sequence lEarning (WISE). We emphasize the importance of backward reconstruction in the WISE paradigm -without it, one would be mixing information from both the input space and coefficient space through skip-connections, which shall not be considered as mathematically sound. Compared with Fourier transform in recent works, wavelet transform is more efficient in time complexity and better captures local and positional information; we further support this through our ablation studies. Extensive experiments on seven long-range understanding datasets from the Long Range Arena benchmark and code understanding tasks demonstrate that (1) Waveformer achieves competitive and even better accuracy than a number of state-of-the-art Transformer variants and (2) WISE can boost accuracies of various attention approximation methods without increasing the time complexity. These together showcase the superiority of learning attention in a wavelet coefficient space over the input space.

1. INTRODUCTION

Transformer (Vaswani et al., 2017) has become one of the most influential models in natural language processing (Devlin et al., 2018; Brown et al., 2020 ), computer vision (Dosovitskiy et al., 2020) , speech processing (Baevski et al., 2020) , code understanding (Chen et al., 2021a) and many other applications. It is composed of the attention layer and the feed-forward layer with layer norms and skip-connections added in between. The original design of the attention layer scales quadratically to the sequence length, becoming a scalability bottleneck of Transformers as texts, images, speech, and codes can be of vast lengths. State-of-the-art attention approximation methods have enabled Transformers to scale sub-quadratic or even linearly to the input sequence length. Typical approaches to computing a cheaper pseudoattention include sparse attention patterns (Parmar et al., 2018; Wang et al., 2019; Beltagy et al., 2020; Zaheer et al., 2020 ), low-rank approximation (Wang et al., 2020; Chen et al., 2021b) , and kernel approximation (Katharopoulos et al., 2020; Choromanski et al., 2020; Peng et al., 2020) , where most of these methods are linear in time complexity. For a comprehensive review, please refer to Section 4. Recent works on improving the effectiveness and efficiency of long-range capabilities of Transformers start to explore attention learning in a transformed space. For example, conducting lowcost token-mixing with forward Fourier transform leads to remarkable accuracy improvement with a quasi-linear time complexity (Lee-Thorp et al., 2021) . Token-mixing ideas (You et al., 2020; Lee-Thorp et al., 2021) are simple and effective, however, they lose Transformer's universal approximating power by replacing attention with hard averaging (Yun et al., 2019) . Moreover, without backward transform the model will mix information from both the input and transformed spaces, We propose Waveformer that facilitates the attention mechanism learning in a wavelet coefficient space, as shown in Figure 1 (a). It requires only linear time complexity and enjoys universal approximating power. Specifically, we first apply forward wavelet transform to project the input sequence to multi-resolution orthogonal wavelet bases, then conduct non-linearity (e.g., random feature kernel (Rahimi & Recht, 2007)) in the wavelet coefficient space, and finally, reconstruct the representation in input space via backward wavelet transform. We name this general learning paradigm WISE, as shown in Figure 1 (b), it can be suited with attention approximation methods to boost their longrange understanding capabilities. We implement wavelet transform using Fast Wavelet Transform (FWT) (Mallat, 1989) so both transform steps are linear in time. Intuitively, WISE operates on a local to global, coarse to fine-grained cascading structure. Compared with Fourier transform, wavelet transform is more efficient in time complexity and better captures local and positional information since the wavelet basis is localized in space with ranging granularity. For the non-linear transformation in the wavelet coefficient space, one can apply any non-linearities, while we suggest using attention approximation methods. A reason behind this is that since wavelet transformation is invertible and exact, WISE will be universal approximating when coupled with universal approximators as its non-linearity. We conduct extensive experiments on the Long Range Arena (LRA) benchmark and common code understanding tasks to empirically ablate and justify this method. Compared with a number of widely-used Transformer variants, Waveformer with a linear time complexity can achieve competitive and even better performance. When combined with various representative attention approximation functions, WISE can boost their performance without incurring extra time complexities. This shows that learning in a wavelet coefficient space provides better long-range understanding capability over direct learning in the input space. Our ablation studies also support the use of the forward-backward schema and the superiority of wavelet transform over Fourier transform. In summary, our major contributions are as follows. • We propose WISE to facilitate learning in the wavelet coefficient space following a forwardbackward paradigm which can be suited with attention approximation methods while boosting their long-range understanding capabilities. • Based on WISE, we develop Waveformer that requires only linear time complexity and enjoys universal approximating power for sequence-to-sequence functions. • Extensive experiments on the Long-Range Arena benchmark and code understanding tasks have demonstrated the effectiveness and also justified the design of Waveformer. Reproducibility. We will release our code on GitHub.

2. WAVEFORMER

As shown in Figure 1 (a), the only difference between a Transformer block and a Waveformer block is the attention computation. In this section, we introduce the details that replace the attention



Figure 1: An overview of our proposed Waveformer and WISE. (a) The only difference between a Transformer block and a Waveformer block is the attention computation. (b) The general flow of computation in WISE with forward and backward wavelet transform.which is not mathematically sound. Since multiplication in the Fourier coefficient space, after projected back to the input space, is equivalent to directly calculating convolutions in the input space, people have also utilized the forward and backward Fourier transform to learn large global filters with linear weights(Rao et al., 2021) and non-linearities (Guibas et al., 2021).

