DEEP TRANSFORMERS WITHOUT SHORTCUTS: MODIFYING SELF-ATTENTION FOR FAITHFUL SIGNAL PROPAGATION

Abstract

Skip connections and normalisation layers form two standard architectural components that are ubiquitous for the training of Deep Neural Networks (DNNs), but whose precise roles are poorly understood. Recent approaches such as Deep Kernel Shaping have made progress towards reducing our reliance on them, using insights from wide NN kernel theory to improve signal propagation in vanilla DNNs (which we define as networks without skips or normalisation layers). However, these approaches are incompatible with the self-attention layers present in transformers, whose kernels are intrinsically more complicated to analyse and control. And so the question remains: is it possible to train deep vanilla transformers? We answer this question in the affirmative by designing several approaches that use combinations of parameter initialisations, bias matrices and location-dependent rescaling to achieve faithful signal propagation in vanilla transformers. Our methods address several intricacies specific to signal propagation in transformers, including the interaction with positional encoding and causal masking. In experiments on WikiText-103 and C4, our approaches enable deep transformers without normalisation to train at speeds matching their standard counterparts, and deep vanilla transformers to reach the same performance as standard ones after about 5 times more iterations.

1. INTRODUCTION

Despite numerous impressive successes, the practice of training deep neural networks (DNNs) has progressed to a large extent independently of theoretical justification. Most successful modern DNN architectures rely on particular arrangements of skip connections and normalisation layers, but a general principle for how to use these components in new architectures (assuming they are even applicable) remains unknown, and their roles in existing ones are still not completely understood. The residual architecture, arguably the most popular and successful of these, was first developed in the context of convolutional networks (CNNs) (He et al., 2016) , and later in self-attention networks yielding the ubiquitous transformer architecture (Vaswani et al., 2017) . One proposed explanation for the success of residual architectures is that they have superior signal propagation compared to vanilla DNNs (e.g. Balduzzi et al., 2017; Xiao et al., 2018; Hayou et al., 2019; De & Smith, 2020; Martens et al., 2021) , where signal propagation refers to the transmission of geometric information through the layers of a DNN, as represented by a kernel function (Daniely et al., 2016; Poole et al., 2016; Schoenholz et al., 2017) . Recently, using signal propagation principles to train DNNs at high depths, without the skip connections and/or normalisation layers found in residual architectures, has become an area of interest in the community. The reasons are two-fold. First, it would validate the signal propagation hypothesis for the effectiveness of residual architectures, thus clarifying our understanding of DNN trainability. And second, it could lead to general principles and techniques for achieving trainability in DNNs beyond the residual paradigm, with the potential for improved or more efficient architectures. For CNNs, Xiao et al. (2018) showed that improved signal propagation from better initialisation enables very deep vanilla networks to be effectively trained, although at significantly reduced speeds  Σ l ) -1 2 •Σ l •diag(Σ l ) -1 2 (which are like kernel matrices except with cosine similarities instead of inner-products) at various depths for standard attention-only vanilla transformers and two of our proposed alternatives (Section 3). Standard attention-only vanilla transformers (top) quickly suffer from rank collapse where all entries of the normalised kernel converge to 1, whereas our approaches, U-SPA and E-SPA, maintain controlled signal propagation even at large depths. Moreover, our main method E-SPA (bottom) exhibits a recency bias, where cosine similarities corresponding to nearby pairs of locations are larger, akin to positional encoding. Equivalent plots for attention-only transformers with skips and normalisation can be found in The key quantity that is analysed in signal propagation is the DNN's initialisation-time kernel, or more precisely, the approximate kernel given by the infinite width limit (Neal, 2012; Matthews et al., 2018; Lee et al., 2018; Yang, 2019) . For MLPs, and for CNNs that use a Delta-initialisation (Balduzzi et al., 2017; Xiao et al., 2018) , this kernel can be written as a simple recursion over layers that involves only 2D functions, facilitating a straightforward analysis. Unfortunately, the evolution of the kernel across layers of a transformer is more complicated, and as a result, existing approaches like DKS are not applicable to transformers (or indeed any architecture that contains self-attention layers). More concretely, if X l ∈ R T ×d denotes a length-T sequence of activations at layer l of a transformer, then the kernel matrix Σ l = X l X ⊤ l /d ∈ R T ×T for layer l (or more precisely its limit as d → ∞), can be written as a function of the kernel matrix Σ l-1 of the previous layer (Hron et al., 2020) . In the case of self-attention layers, the relationship of Σ l on Σ l-1 cannot be simplified or decomposed into lower dimensional functions, leading to a recursion that is intrinsically high dimensional and harder to analyse or control. Analogously to the case of MLPs, where signal propagation is judged by looking at the behavior of the (one-dimensional) kernel, signal propagation in transformers can be judged by looking at the evolution of these (high-dimensional) kernel matrices through the layers of the network. One situation we must avoid is where the diagonal entries rapidly grow or shrink with depth, which corresponds to uncontrolled activation norms and can lead to saturated losses or numerical issues. A more subtle form of signal degradation can occur where Σ l converges to a rank-1 matrix, which is known as rank collapse (Dong et al., 2021) . Dong et al. (2021) showed that skip connections are essential to avoid the collapsed state: skipless transformers quickly converge to rank collapse at large depths, which we corroborate in Fig. 1 (top). Moreover, Noci et al. (2022) showed that rank collapse may lead to zero gradients for certain parameters in attention layers, hindering the trainablility of deep transformers. Thus, avoiding rank collapse is necessary for deep transformers to be trainable, and the question of whether one can train deep skipless transformers remains open. In the present work we address this question, demonstrating for the first time that it is possible to successfully train deep transformers without skip connections or normalisation layers. To do so, we study the problem of signal propagation and rank collapse in deep skipless transformers, and derive three approaches to prevent it in Section 3. Our methods use combinations of: 1) parameter initialisations, 2) bias matrices, and 3) location-dependent rescaling, and highlight several intricacies



Figure 1: Normalised kernel matrices diag(Σ l ) -1 2•Σ l •diag(Σ l ) -1 2 (which are like kernel matrices except with cosine similarities instead of inner-products) at various depths for standard attention-only vanilla transformers and two of our proposed alternatives (Section 3). Standard attention-only vanilla transformers (top) quickly suffer from rank collapse where all entries of the normalised kernel converge to 1, whereas our approaches, U-SPA and E-SPA, maintain controlled signal propagation even at large depths. Moreover, our main method E-SPA (bottom) exhibits a recency bias, where cosine similarities corresponding to nearby pairs of locations are larger, akin to positional encoding. Equivalent plots for attention-only transformers with skips and normalisation can be found in Fig.7.

Figure 1: Normalised kernel matrices diag(Σ l ) -1 2•Σ l •diag(Σ l ) -1 2 (which are like kernel matrices except with cosine similarities instead of inner-products) at various depths for standard attention-only vanilla transformers and two of our proposed alternatives (Section 3). Standard attention-only vanilla transformers (top) quickly suffer from rank collapse where all entries of the normalised kernel converge to 1, whereas our approaches, U-SPA and E-SPA, maintain controlled signal propagation even at large depths. Moreover, our main method E-SPA (bottom) exhibits a recency bias, where cosine similarities corresponding to nearby pairs of locations are larger, akin to positional encoding. Equivalent plots for attention-only transformers with skips and normalisation can be found in Fig.7.

