WHY SELF-ATTENTION IS NATURAL FOR SEQUENCE-TO-SEQUENCE PROBLEMS? A PERSPECTIVE FROM SYMMETRIES

Abstract

In this paper, we show that structures similar to self-attention are natural to learn many sequence-to-sequence problems from the perspective of symmetry. Inspired by language processing applications, we study the orthogonal equivariance of seq2seq functions with knowledge, which are functions taking two inputs-an input sequence and a "knowledge"-and outputting another sequence. The knowledge consists of a set of vectors in the same embedding space as the input sequence, containing the information of the language used to process the input sequence. We show that orthogonal equivariance in the embedding space is natural for seq2seq functions with knowledge, and under such equivariance the function must take the form close to the self-attention. This shows that network structures similar to self-attention are the right structures to represent the target function of many seq2seq problems. The representation can be further refined if a "finite information principle" is considered, or a permutation equivariance holds for the elements of the input sequence.

1. INTRODUCTION

Neural network models using self-attention, such as Transformers Vaswani et al. (2017) , have become the new benchmark in the fields such as natural language processing and protein folding. Though, the design of self-attention is largely heuristic, and theoretical understanding of its success is still lacking. In this paper, we provide a perspective for this problem from the symmetries of sequence-to-sequence (seq2seq) learning problems. By identifying and studying appropriate symmetries for seq2seq problems of practical interest, we demonstrate that structures like self-attention are natural for representing these problems. Symmetries in the learning problems can inspire the invention of simple and efficient neural network structures. This is because symmetries reduce the complexity of the problems, and a network with matching symmetries can learn the problems more efficiently. For instance, convolutional neural networks (CNNs) have seen great success on vision problems, with the translation invariance/equivariance of the problems being one of the main reasons. This is not only observed in practice, but also justified theoretically Li et al. (2020b) . Many other symmetries have been studied and exploited in the design of neural network models. Examples include permutation equivariance Zaheer et al. (2017) and rotational invariance Kim et al. (2020); Chidester et al. (2019) , with various applications in learning physical problems. See Section 2.1 for more related works. In this work, we start from studying the symmetry of seq2seq functions in the embedding space, the space in which each element of the input and output sequences lie. For a language processing problem, for example, words or tokens are usually vectorized by a one-hot embedding using a dictionary. In this process, the order of words in the dictionary should not influence the meaning of input and output sentences. Thus, if a permutation is applied on the dimensions of the embedding space, the input and output sequences should experience the same permutation, without other changes. This implies a permutation equivariance in the embedding space. In our analysis, we consider equivariance under orthogonal group, which is slightly larger than the permutation group. We show that if a function f is orthogonal equivariant in the embedding space, then its output can be expressed as linear combinations of the elements of the input sequence, with the coefficients only depending on the inner products of these elements. Concretely, let X ∈ R d×n denote an input sequence with length n in the embedding space R d . If f (QX) = Qf (X) holds for any orthogonal Q ∈ R d×d , then there exists a function g such that f (X) = Xg(X T X). However, the symmetry on the embedding space is actually more complicated than a simple orthogonal equivariance. In Section 3.2, we show that the target function for a simple seq2seq problem is not orthogonal equivariant, because the target function works in a fixed embedding. To accurately catch the symmetry in the embedding space, we propose to study seq2seq functions with knowledge, which are functions with two inputs, f (X, Z), where X ∈ R d×n is the input sequence and Z ∈ R d×k is another input representing our "knowledge" of the language. The knowledge lies in the same embedding space as X, and is used to extract information from X. With this additional input, the symmetry in the embedding space can be formulated as an orthogonal equivariance of f (X, Z), i.e. f (QX, QZ) = Qf (X, Z) for any inputs and orthogonal matrix Q. Intuitively understood, in a language application, as long as the knowledge is always in the same embedding as the input sequence, the meaning of the output sequence will not change with the embedding. Based on the earlier theoretical result for simple orthogonal equivariant functions, if a seq2seq function with knowledge is orthogonal equivariant, then it must have the form f (X, Z) = Xg 1 (X T X, Z T X, Z T Z) + Zg 2 (X T X, Z T X, Z T Z) If Z is understood as a parameter matrix to be learned, the following subset of this representation, f (X, Z) = Xg(X T Z), is close to a self-attention used in practice, with Z being the concatenation of query and key parameters. This reveals one possible reason behind the success of self-attention based models on language problems. Based on the results from orthogonal equivariance, we further study the permutation equivariance on the elements of the input sequence. Under this symmetry, we show that seq2seq functions with knowledge have a further reduced form which only involves four different nonlinear functions. Finally, discussions are made on the possible forms of g (or g 1 and g 2 ) in the formulations mentioned above. Based on the assumption that these functions are described by a finite amount of information (although their output sizes need to change with respect to the sequence length n), we reason that quadratic forms with a nonlinearity used in usual self-attentions is one of the simplest choice of g. We also discuss practical considerations that add the complexity of the models used in application compared with theoretical forms.

2.1. NEURAL NETWORKS AND SYMMETRIES

Implementing symmetries in neural networks can help the models learn certain problems more efficiently. A well-known example is the success of convolutional neural networks (CNNs) on image problems due to their (approximate) translation invariance LeCun et al. (1989) . Many types of symmetries have been explored in the design of neural networks, such as permutation equivariance and invariance Zaheer et al. ( 2017 



); Guttenberg et al. (2016); Rahme et al. (2021); Qi et al. (2017a;b), rotational equivariance and invariance Thomas et al. (2018); Shuaibi et al. (2021); Fuchs et al. (2020); Kim et al. (2020), and more Satorras et al. (2021); Wang et al. (2020b); Ling et al. (2016a); Ravanbakhsh et al. (2017). Some works deal with multiple symmetries. In Villar et al. (2021), the forms of functions with various symmetries are studied. These networks see many applications in physical problems, where symmetries are intrinsic in the problems to learn. Examples include fluid dynamics Wang et al. (2020a); Ling et al. (2016b); Li et al. (2020a); Mattheakis et al. (2019), molecular dynamics Anderson et al. (2019); Schütt et al. (2021); Zhang et al. (2018), quantum mechanics Luo et al. (2021a;b); Vieijra et al. (2020), etc. Theoretical studies have also been conducted to show the benefit of preserving symmetry during learning Bietti et al. (2021); Elesedy & Zaidi (2021); Li et al. (2020b); Mei et al. (2021).

2.2 SELF-ATTENTION self-attention Vaswani et al. (2017); Parikh et al. (2016); Paulus et al. (2017); Lin et al. (2017); Shaw et al. (2018) is a type of attention mechanism Bahdanau et al. (2014); Luong et al. (2015) that attends

