NORMSOFTMAX: NORMALIZE THE INPUT OF SOFT-MAX TO ACCELERATE AND STABILIZE TRAINING

Abstract

Softmax is a basic function that normalizes a vector to a probability distribution and is widely used in machine learning, most notably in cross-entropy loss function and dot product attention operations. However, optimization of softmaxbased models is sensitive to the input statistics change. We observe that the input of softmax changes significantly during the initial training stage, causing slow and unstable convergence when training the model from scratch. To remedy the optimization difficulty of softmax, we propose a simple yet effective substitution, named NormSoftmax, where the input vector is first normalized to unit variance and then fed to the standard softmax function. Similar to other existing normalization layers in machine learning models, NormSoftmax can stabilize and accelerate the training process, and also increase the robustness of the training procedure against hyperparameters. Experiments on Transformer-based models and convolutional neural networks validate that our proposed NormSoftmax is an effective plug-and-play module to stabilize and speed up the optimization of neural networks with cross-entropy loss or dot-product attention operations.

1. INTRODUCTION

Softmax is a critical and widely used function in machine learning algorithms, which takes a vector as input and generate a standard simplex. It is usually used to generate a categorical probability distribution. The most notable applications of softmax are cross-entropy loss function for classification tasks and attention map generation in dot product attention operations. By importing the temperature in softmax, we can control the information entropy and sharpness of its output. However, gradient-based optimization of softmax-based models often suffers from slow and unstable convergence and is sensitive to optimization hyperparameters. Transformer-based models (Vaswani et al., 2017) are known to be hard to optimize. A lot of efforts have been devoted to solving this optimization difficulty (Liu et al., 2020) . For instance, Bolya et al. (2022) reports that softmax attention may crash with too many heads and proposes new attention functions. Chen et al. (2021) show that the Vision Transformer's (Dosovitskiy et al., 2021) loss landscape is very sharp, and it requires advanced optimizers to facilitate its training (Foret et al., 2020) . Huang et al. (2020) propose a better initialization to improve the Transformer optimization. Xiong et al. (2020) show that the location of layer normalization (LN) has a remarkable impact on the gradients and claim that the Pre-LN Transformer has better training stability. Among comprehensive reasons for the optimization difficulty of Transformers, cascaded softmax functions are one of them that leads to the training instability. However, limited prior work has discussed the impacts of softmax on optimization. Based on our experimentation, we find that the training difficulty can be attributed to the rapid change in the variance of the softmax inputs and the information entropy of its outputs. In dot-product attention, where the softmax is used to generate weight distribution for key-value pairs, we observe significant statistical fluctuation in softmax inputs. The rapid and extensive variance change in the initial learning stage can lead to unstable training. Moreover, for the softmax used in cross-entropy loss for classification problems, the input of the softmax usually has a lower variance at the initial training stage since the model has less knowledge of the problem (Wei et al., 2022) . The model is likely to stay in the lowconfidence zone, implying that it is difficult to train (Pearce et al., 2021) . We need a specially designed mechanism to push the model out of this low-confidence zone for stable and fast learning.

Converge fast Robust

Attention/ Linear Linear Cls.

Softmax(Scale( ))

Replace Softmax Attention/ Linear Linear Cls. Figure 1 : The standard softmax and proposed NormSoftmax. In the two cases above, the significant change of the softmax input variance is one of the reasons for optimization difficulty. In this paper, we propose NormSoftmax to stabilize and accelerate training by simply re-scaling the softmax inputs, especially in the early stage optimization. With NormSoftmax, we dynamically calculate vector-specific factors to scale the inputs before being fed to the standard softmax. Specifically, when the input variance is too small, Softmax will generate small gradients that hinder the learning process. In contrast, our proposed NormSoftmax can help re-scale the input distribution such that the information entropy of the output becomes stable without fluctuation during the training process, which boosts and stabilizes the early stage training. NormSoftmax shares similar properties with the existing normalization techniques in machine learning. We summarize its advantages below. • NormSoftmax can re-scale gradients to stabilize the training process, making the training robust to model architectures and optimization recipes (such as optimizers and weight decay schedules). • NormSoftmax can accelerate the early training stage without hurting the model representability. • NormSoftmax is an easy-to-use and low-cost module to replace standard softmax. The induced computation and memory cost overhead is negligible. • NormSoftmax has a regularization effect since the re-scaling can slightly restrict the representation space of the input vectors. In this paper, we focus on two applications of the softmax functions: (1) the activation function in dot-product attention, and (2) cross-entropy loss of the classification problem. ViT-B with our NormSoftmax shows significantly higher robustness to different head settings, showing an average of +4.63% higher test accuracy on CIFAR-10 than its softmax-based counterpart. When training for 100 epochs on ImageNet-1K, ViT with our NormSoftmax can achieve +0.91% higher test accuracy over its softmax baseline.

2. BACKGROUND

We briefly introduce the softmax function and normalization in machine learning. Then we discuss the two cases we focus on in this paper: softmax in dot product attention and cross entropy loss. Throughout this paper, we use µ(a), σ(a) to represent the mean and standard deviation (square root of the variance) of a vector a.

2.1. SOFTMAX

The standard softmax function z = softmax(x), where x, z ∈ R n is defined by Equation 1. z i = e xi n j=1 e xj , for i = 1, 2, ..., n The output of softmax can be seen as a categorical probability distribution since 0 < z i < 1 and i z i = 1. Instead of e, we can also use a different base in softmax. A temperature parameter T > 0 is imported to adjust the base. softmax T (x) = softmax x T Given the same input vector x, the higher temperature smooths the difference of the input vector and generates a probability distribution with high information entropy H(z) =i z i log(z i ). On the contrary, the lower temperature sharpens the output distribution with low entropy. (Agarwala et al., 2020) claim that the temperature has a crucial impact on the initial learning process.

