NEURAL ATTENTION MEMORY

Abstract

Scaled dot-product attention has become the essence of state-of-the-art deep neural networks for various machine learning tasks. Though its ubiquitous accomplishments, it is inefficient for long sequence tasks and problematic for tasks requiring memory states such as compositional generalization. We propose a novel perspective of the attention mechanism by reinventing it as a memory architecture for neural networks, namely Neural Attention Memory (NAM). NAM follows the same query-key-value structure by constructing a memory matrix while reducing its computational complexity from quadratic to linear to the sequence length. NAM writes a memory matrix by adding outer products of value and unit key vectors, and reads it by multiplying the matrix with a unit query vector. We define read and write primitives of NAM and mathematically prove their functionalities. One benefit of NAM is that it can be a basis for efficient linear attention, namely normalized outer-product attention. We evaluate a NAM-based Transformer on long-range tasks and demonstrate NAM's efficiency and efficacy. Most importantly, NAM provides building blocks for memory-augmented neural networks. We propose two NAM-augmented neural networks, namely Long Short-Term Attention Memory (LSAM) and NAM Turing Machine (NAM-TM), and test their compositional generalization capabilities using four different tasks. LSAM replaces LSTM's long-term cell state with NAM memory matrix and NAM-TM implements a Turing tape structure using NAM read/write primitives. The experiments show that they have better computational power than Transformer and LSTM, as well as DNC. NAM opens up possibilities in diverse research problems, including hierarchical data modeling, efficient edge inference, and few-shot learning.

1. INTRODUCTION

Scaled dot-product attention (Vaswani et al., 2017) has become a core mechanism of state-of-the-art deep learning models for variety of machine learning tasks, including natural language processing (Devlin et al., 2018) , multi-modal task (Li et al., 2019) , and graph data processing (Hamilton et al., 2017) . Specifically, the Transformers using the self-attention method have replaced recurrent neural networks (RNN) by outperforming them in most of the tasks. Despite its success, there exist limitations to the mechanism. First, it needs the information of the entire sequence to compute one attention so that its computational complexity becomes quadratic to the length of the sequence. Hence, it is inefficient for long sequence tasks (Tay et al., 2020) or edge inference environments (Tambe et al., 2020) . Also, its stateless design enables efficient parallelism but makes it impossible to solve tasks that require memory states. Hence, Transformers fail to generalize the rules that require inductive bias (Dehghani et al., 2018) or compositional generalization (Lake & Baroni, 2018). There have been studies designing neural networks with external memory to solve algorithmic tasks where Transformers fail. These memory-augmented neural networks (MANN) design differentiable read/write functions that can be trained by backpropagation. Some of them implement basic data structures like stack (Joulin & Mikolov, 2015) and queue (Grefenstette et al., 2015) while some implement complex memory structures using attention mechanisms (Graves et al., 2014; 2016) . They outperform generic neural networks in synthetic algorithmic tasks but are considered impractical due to their complexities and inefficiencies. In this work, we re-invent the attention mechanism as a memory architecture for neural networks, namely neural attention memory (NAM). NAM's design objective is to build simple, efficient, yet powerful external memory which also incorporates the attention mechanism. Following the same query-key-value structure of attention, NAM stores key-value pairs to a memory matrix via additively writing their outer-products. Reading the memory matrix is simply done by multiplying the matrix with a unit query vector. We provide mathematical formulation for the read/write primitives, and make theoretical analyses showing that these read and write primitives can replace attention. One big benefit of NAM is that it can perform attention in a more efficient way. By sacrificing the erasure capability of the NAM write operation, we can design an efficient and parallel attention mechanism, namely normalized outer-product attention. This special variant of NAM is almost equivalent to linear attention Katharopoulos et al. (2020) , enjoying the same linear computational complexity to the sequence length. We evaluate NAM-based efficient Transformer in long-range arena (Tay et al., 2020) tasks. Its efficacy is on par with the base Transformer and Linear Transformer, implying that NAM can be an efficient alternative to the scaled dot-product attention. The bigger value of NAM is that its read and write primitives can be building blocks for augmenting memory structures to deep neural networks. Using NAM read/write primitives, we design two memory-augmented neural networks (MANN), namely Long Short-term Attention Memory (LSAM) and NAM Turing Machine (NAM-TM). LSAM is a generic RNN architecture that replaces LSTM's long-term cell state with a memory matrix. Instead of additively writing a vector cell state, LSAM reads and writes the memory matrix using NAM primitives. The design combines strengths of attention and RNN while maintaining the same computational complexity as LSTM. NAM-TM is a MANN for algorithmic tasks, leveraging a Turing tape structure. A tape has read and write heads accessing the memory with NAM read/write primitives. They can move along the tape with four actions: NO-OP, LEFT, RIGHT, and JUMP. The actions are implemented as differentiable functions to enable end-to-end training with backpropagation. We compare LSAM and NAM-TM to others in compositional generalization tasks of number sequence prediction (Nam et al., 2019) , sequence reduction, and SCAN (Lake & Baroni, 2018). Specifically, we test their zero-shot generalization capability in length by training the models with sequences of limited length and validating them with longer sequences unobservable during training. The evaluation results show that their computational powers are superior to other baselines, including Universal Transformer (Dehghani et al., 2018) and DNC (Graves et al., 2016) . While the generic LSAM model consistently outperforms the others, NAM-TM shows even better results at algorithmic tasks. The results indicate that NAM is a powerful method to implement memory in neural networks. The efficient, simple, and flexible structure of NAM opens up new possibilities in multiple machine learning research fields. One straightforward application is leveraging NAM's efficiency for edge inference environment. Another possibility is using NAM for hierarchical data modeling by generalizing NAM with tensor products. Moreover, memorization of input-output mapping using NAM can be a solution for one-shot and few-shot learning. The main contributions of this work are as follows: • We re-invent the attention mechanism as a memory architecture for neural networks, namely neural attention memory (NAM). • We present mathematical basis for NAM read/write primitives, and give theoretical proofs that NAM is equivalent to attention in certain conditions. • We show that NAM can construct an efficient Transformer for long-range sequence tasks. • We propose two memory-augmented neural network designs of LSAM and NAM-TM and show their capabilities in compositional generalization tasks.

2. BACKGROUND 2.1 SCALED DOT-PRODUCT ATTENTION

Attention mechanisms of deep neural networks (Bahdanau et al., 2014; Luong et al., 2015) provide differentiable methods of selectively attending items from a variable-length sequence. While there are multiple variations of attention mechanism, most of them share the same high-level structure: 1) compute the attention scores of the items, and 2) return the weighted sum of their vector representations using the scores. Among the variations, scaled dot-product attention (Vaswani et al., 2017) has been the most successful. For each token, there are a key vector and a value vector associated to it.

