σREPARAM: STABLE TRANSFORMER TRAINING WITH SPECTRAL REPARAMETRIZATION

Abstract

Training stability is of great importance to Transformers. In this work, we investigate the training dynamics of Transformers by examining the evolution of the attention layers. In particular, we track the "attention entropy" for each attention head during the course of training, which is a proxy of the attention's sharpness. We observe a common, non monotonic evolution of attention entropy across different settings: the attention entropy first quickly decreases in the initial phase of training, followed by quickly increasing, and finally entering a long stable phase. While the exact shape can be affected by hyperparameters such as warmup, initialization, learning rate etc., we found that there is a close correlation between the minima of attention entropy and the model's training stability. To this end, we propose a simple and efficient solution dubbed σReparam, where we reparametrize all linear layers with Spectral Normalization and an additional learned scalar. We provide a lower bound on the attention entropy as a function of the spectral norms of the query and key projections, which suggests that small attention entropy can be obtained with large spectral norms. σReparam decouples the growth rate of a weight matrix's spectral norm from its dimensionality, which we verify empirically. We conduct experiments with σReparam on image classification, image self supervised learning, automatic speech recognition and language modeling tasks. We show that σReparam provides great stability and robustness with respect to the choice of hyperparameters.

1. INTRODUCTION

Transformers (Vaswani et al., 2017) are state-of-the-art models in many application domains. However, despite their empirical success and wide adoption, great care often needs to be taken in order to achieve good training stability and convergence. In the original paper (Vaswani et al., 2017) , residual connections and Layer Normalizations (LNs) (Ba et al., 2016) are extensively used for each Attention and MLP block (specifically, in the "Post Norm" fashion). There has since been various works attempting to promote better training stability and robustness. For example, the "Pre Norm" (Radford et al., 2019) In this work, we study the training instability of Transformers from the lens of training dynamics. We start by monitoring the average entropy of the attention heads (by treating each attention head as a multinomial distribution) over all query positions and examples. Interestingly, the average attention entropy often evolves in a pattern consisting of three phases. In the beginning, attention entropy starts high (corresponding to uniform attention scores) and quickly drops to a small value; This is then followed by a second stage where it quickly increases to a relatively high entropy regime; Lastly the attention entropy curve stabilizes and smoothly evolves to convergence. See the top left plot of Figure 1 for an illustration, which is a Vision Transformer (Touvron et al., 2021) (ViT) trained on ImageNet classification, using well optimized hyper parameters. Empirically, we have found that the attention entropy is directly correlated with the model's stability and convergence. In particular, small attention entropy reached in the initial phase often causes slow convergence, fluctuations in training loss and, in the worst case, divergence. This is shown in Figure 1 where we vary the learning rate and warmup epochs of the baseline ViT model. We see that both decreased the learning rate and increased warmup epochs provide smoothing effects to the attention entropy curves, which in turn yield lower training losses. On the other hand, increasing learning rate brings a detrimental impact on training where the attention entropy collapses to near zero and training diverges. We denote the rapid dip of attention entropy to a near zero value and its resulting pathological optimization dynamics as "entropy collapse". The remaining questions are: 1) How do we get rid of entropy collapse? 2) Can we improve training stability by doing so? We answer them by showing that attention entropy is closely related to the spectral norms of the query and key projections. In particular, we show a lower bound of the attention entropy, which suggests that large spectral norms of the projections can more easily lead to entropy collapse. We then provide a simple fix, dubbed σReparam, which reparameterizes all weight matrices by sequentially applying Spectral Normalization (Miyato et al., 2018) and a learned multiplicative scalar. Intuitively, σReparam decouples the update of the spectral norms of weights from their dimensionality, which allows them to update smoothly in a controlled way. Also note that σReparam does not change the model space, which allows one to learn an arbitrarily expressive model. We validate σReparam on 4 tasks: image classification, image self supervised learning, automatic speech recognition and language modelling. We show that σReparam effectively slows down the growth of each layer's spectral norms, and as a result, their attention entropy curves are greatly smoothed. This allows us to achieve great robustness with respect to the choice of hyper parameters. In certain cases, we are able to remove Layer Norms and still achieve competitive results.

2. RELATED WORKS

Transformers have relied heavily on LNs to achieve training stability. Besides the popular Post Norm and Pre Norm configurations, other variants have been proposed (Wang et al., 2022; Shleifer et al., 



scheme has gained wide popularity, where one moves the placement of LNs to the beginning of each residual block. Others have argued that it is important to properly condition the residual connections. Bachlechner et al. (2021) proposes to initialize the residual connections to zero to promoter better signal propagation. Zhang et al. (2018); Huang et al. (2020) remove LNs with carefully designed initialization schemes.

Figure 1: The training loss curves of ViT-B on ImageNet, together with the attention entropy for three layers. From top left to bottom right: baseline with default hyper parameters from Touvron et al. (2021); 0.2× learning rate; 2× warmup epochs; 2× learning rate. We see a close correlation between the dip of the attention entropy and the convergence and stability of the training loss.

