IMPROVED FULLY QUANTIZED TRAINING VIA RECTI-FYING BATCH NORMALIZATION

Abstract

Quantization-aware Training (QAT) is able to reduce the training cost by quantizing neural network weights and activations in the forward pass and improve the speed at the inference stage. QAT can be extended to Fully-Quantized Training (FQT), which further accelerates the training by quantizing gradients in the backward pass as backpropagation typically occupies half of the training time. Unfortunately, gradient quantization is challenging as Stochastic Gradient Descent (SGD) based training is sensitive to the precision of the gradient signal. Particularly, the noise introduced by gradient quantization accumulates during backward pass, which causes the exploding gradient problem and results in unstable training and significant accuracy drop. Though Batch Normalization (BatchNorm) is a de-facto resort to stabilize training in regular full-precision scenario, we observe that it fails to prevent the gradient explosion when gradient quantizers are injected in the backward pass. Surprisingly, our theory shows that BatchNorm could amplify the noise accumulation, which in turn hastens the explosion of gradients. A BatchNorm rectification method is derived from our theory to suppress the amplification effect and bridge the performance gap between full-precision training and FQT. Adding this simple rectification loss to baselines generates better results than most prior FQT algorithms on various neural network architectures and datasets, regardless of the gradient bit-widths used (8,4, and 2 bits).

1. INTRODUCTION

Quantization-aware Training (QAT) is a popular track of research that simulates the neural network quantization (weights and activations) during the course of training to curb the inference-time accuracy drop of low-bit models (e.g. INT8 quantization). On the other hand, theoretical calculations on the BitOps (Yang & Jin (2021) ; Guo et al. (2020) ) computation costs can easily conclude that backpropagation accounts for half of the computations during training. Empirical datafoot_0 shows backward pass sometimes even costs more in practice. Decreasing the gradient bit-widths will apparently reduce computation overheads of backpropagation Horowitz (2014) . If variables in backward pass are also quantized, adding up the forward quantization in QAT, all the network variables required in training would be fully quantized and the whole training process could be accelerated on dedicated hardware, i.e., Fully-Quantized Training (FQT), providing huge accessibility of large model training to users with limited computation capability. Recent work Zhu et al. (2020) has shown that INT8 FQT speeds up the forward pass and the backward pass by 1.63× and 1.94× respectively when training ResNet-50 on ImageNet with NVIDIA Pascal GPU. Yet gradient quantization under the FQT scheme is vastly underexplored, as it is notoriously more challenging than forward quantization in QAT. It is observed that network training is sensitive to the precision of gradients, and low-bit gradient quantization leads to unstable training and significant accuracy drop (see Fig. 1 ). More importantly, the accumulation of gradient quantization noise in backward pass (see Fig. 2 ) causes the exploding gradient problem during backpropagation, even resulting in training failure. In contrast to weight/activation quantization, gradient quantization noise produced during backpropagation cannot be automatically corrected by optimizing objective loss. Unlike prior works on optimizing gradient quantizers for quantization noise reduction Zhou et al. (2016) ; Choi et al. (2018) ; Zhu et al. (2020) , this paper reveals the negative effect of Batch Nor-malization (BatchNorm) on amplifying the gradient quantization noise accumulation, when training deep Convolutional Neural Networks (CNNs) with low bit gradients. We show that the noise amplification effect further explodes the gradients during the backward pass. We thus propose a BatchNorm rectification method to suppress the noise amplification effect and alleviate the gradient explosion problem, which in turn enables stabilized training and better accuracy at low-bit gradients. Our contributions are summarized as follows: • We discover that BatchNorm fails to prevent the exploded low-bit gradients in fullquantized training through theoretic analysis, and may even amplify the accumulated gradient quantization noise, which further aggravates the gradient explosion. • According to our theory, we propose a simple yet effective BatchNorm variance rectification algorithm without introducing noticeable overhead, to suppress the noise amplification effect, resulting in alleviated gradient explosion. • Extensive experiments on MNIST, CIFAR-10, and ImageNet show that our method achieves improved training and higher accuracy over state-of-the-arts with vanilla gradient quantizers, regardless of the gradient bit-widths used (8,4,2 bits).

2. RELATED WORK

Quantization-Aware Training (QAT). DoReFa-Net Zhou et al. (2016) proposed to optimize the clipping value and the scaling factor of the uniform quantizers for weights and activations separately. It was validated on image classification task under multiple bit-widths, but only with the rather simple AlexNet architecture. PACT Choi et al. (2018) proposed to quantize the activations with a learnable layer-wise clipping value, which not surprisingly achieved better accuracy than DoReFa-Net Norm-free Networks. There is another interesting line of works that gets rid of normalization layers from CNN architectures for good Zhang et al. (2018b) while still manages to train the full-precision networks stably. However, there is still a lack of attempts to adapt norm-free networks to QAT or even FQT settings at the first place, where potential problems need to be addressed when low-bit quantizers are introduced. Therefore, we decide currently it is not mature enough to discuss this track in this paper which targets at FQT. Prior efforts reducing variances in backpropagation. Rectifying gradient variances during backpropagation has been sporadically discussed for full-precision training, e.g. in Kaiming Initialization He et al. (2015) where it leverages weight distributions in Conv and FC layers and one can opt for backward variance rectification if the gradients are observed to be chaotic. However, optimal forward and backward rectification still cannot be satisfied simultaneously, especially when backward signals contains significant amount of quantization noises. In our attempts, using "fan-out" mode in kaiming initialization alone still cannot avoid training crashes in worse cases (e.g. MobileNet-V2 under W4A4G4). To our knowledge, there is no work studying the impact of normalization layers in CNNs for gradient rectification under the FQT setting.

3. PRELIMINARIES

Key Notations. We denote the variance of a probabilistic variable as D(•). We denote gradient w.r.t. weights as g w and error signal as g x . We use the plural term "gradients" to generally refer to all backward pass variables including g w and g x . In such contexts, subscripts of g are omitted. Similar to the additive quantization noise for weight and activation quantization in Meller et al. (2019) , the quantized gradients g at each layer can be decomposed into three parts: g = g + e(g) + δ q (g), where g and e(g) denote the original gradients and the gradient quantization noise at current layer respectively, and δ q (g) represents the accumulated gradient quantization noise propagated from all its succeeding layers during backward. E.g., for the (l -1) th layer, δ q (g a l-1 ) represents the quantization noise accumulated from layer l to the last layer (see Fig. 3 ). In this manuscript, we format the bit-widths of weights (W), activations (A), backward errors (dx) and gradients of weights (dW ) used in experiments as W/A/dx/dW if error dx and gradients dW are assigned different bit-widths, or simplified as W/A/G if dx and dW have the same bit-width. Batch Normalization. Batch Normalization Ioffe & Szegedy (2015) is widely adopted technique to stabilize the training of deep full-precision networks. The forward pass in a BatchNorm layer consists of operations calculating the mean and variance of each channel over a mini-batch with N input samples {x n } N n=1 . Each channel of the input x n is first normalized to xc = (x c -µ c )/σ c . The normalized input x is finally linearly scaled and shifted as y = γ ⊺ x + β. The relationship between backward error on BatchNorm's input x n and that on xn is: (we omit channel notations here for simplicity) 𝑙 -1-th layer (Conv/FC) ෪ 𝑾 𝑙-1 𝒂 𝑙-2 𝑀𝐴𝐶 𝒂 𝑙-1 𝑄 𝐴 𝑄 𝑊 𝑾 𝑙-1 ෪ 𝑾 𝑙 𝒂 𝑙 𝑀𝐴𝐶 𝑄 𝑊 𝑾 𝑙 𝑙-th layer (Conv/FC) Forward pass BN- ReLU 𝒂 𝑙 𝑄 𝐴 𝒂 𝑙

Backward pass

𝒈 𝑾𝑙-1 𝑀𝐴𝐶′ 𝒈 𝒂 𝑙-1 𝑄 𝐺 𝑙 -1-th layer (Conv/FC) 𝒈 𝒂 𝑙-1 , 𝑒(𝒈 𝒂𝑙-1 )

BN-ReLU

𝒈 𝒂 𝑙-2 𝒈 𝑾𝑙 𝑀𝐴𝐶′ 𝒈 𝒂 𝑙 𝑄 𝐺 𝑙-th layer (Conv/FC) 𝒈 𝒂 𝑙 , 𝛿 𝑞 (𝒈 𝒂𝑙 ) , 𝛿 𝑞 𝒈 𝒂𝑙 , 𝑒(𝒈 𝒂𝑙 ) Figure 3 : Illustration of weight/activation quantization in forward pass and gradient quantization in backward pass. MAC denotes the multiply-accumulate operations. Q (.) denotes the quantizer for weights (W), activations (A), or Gradients (G). In backward pass, the gradient quantization noise δ q (g a l-1 ) accumulated from the l th layer to the last layer is propagated to the (l -1) th layer and added to the (l -1) th layer's gradient quantization noise e(g a l-1 ) induced by itself. g xn = 1 N σ N g xn - N n=0 g xn -xn N n=0 g xn xn . (2) 4 PROBLEM IDENTIFICATION: ACCUMULATION OF GRADIENT

QUANTIZATION NOISE EXPLODES GRADIENTS

Gradients play a crucial role in backpropagation based optimization and make a huge impact on training stability and convergence speed. Intuitively, as we inject quantizers in between backward pass, the error signal become more and more noisy each time it passes through a quantizer. As the bit-width decreases, the quantization noise injected in error signal at each single layer increases exponentially. From a perspective of variance, since quantization introduces additive noise e(g) to the original signal g, the variance of quantization noise D(e(g)) is also added to the error signal D(g), which will be reflected in the propagated errors and will increase over the course of backpropagation. As shown in Fig. 2 , under various bit-width of gradients, we observe the variance of quantized error signals expanded during backpropagation. As the bit-width decreases, the variance also inflated much more drastically on shallow layers than late layers, implying that quantization noise is the culprit of the variance accumulation and explosion. As a result, the quantization impact on error signal and the weight update are severely affected, especially in those early layers. Eventually in worse cases when the accumulated quantization noise is overwhelming, the training goes off nature course and crashes. (e.g. under extremely low bit-width Fig. 7 ) Fig. 3 provides a glance at the accumulation mechanism of the gradient quantization noise in the backward pass. During backpropagation, the quantization of the error signal on the l th layer introduces the quantization noise denoted as e(g a l ). e(g a l ) is propagated to its predecessor -the (l-1) th layer, together with the gradient quantization noise δ q (g a l ) accumulated from the (l + 1) th layer to the last layer L. Similarly, e(g a l-1 ) and δ q (g a l-1 ) are propagated to the (l -2) th layer, and so on. As shown in Fig. 3 , both forward and backward pass have similar accumulation phenomenon, but why backward pass suffers more from the quantization during training? This is because quantization noise introduced in the forward pass are reflected in the computation graph w.r.t. the objective loss, therefore their impact can be partially offset by quantization-aware training, while quantization noise introduced during backward does not contribute to the objective loss. In view of some cases under the setting of distributed gradient compression Alistarh et al. (2017) , which only quantizes full-precision gradients after backpropagation is done, can stably train CNNs with as low as 4-bit gradients, we attribute the exploding gradient problem in FQT setting mainly to the accumulation of gradient quantization noise introduced by low-bit gradient quantizers.

5. BATCHNORM AMPLIFIES THE ACCUMULATED NOISE

In this section, we develop a theoretical framework to understand the role of BatchNorm in gradient quantization, explaining why BatchNorm may worsen the gradient explosion problem in FQT. One can refer to Appendix A.2 for the proofs of the theorems. Quantifying the impact of BatchNorm on Noise Accumulation. Through theoretical studies, we find that BatchNorm may amplify the accumulation of gradient quantization noise. This finding might be counter-intuitive as BatchNorm has been expected to regularize the "variance" and prevent the gradient explosion problem, by scaling the activations in forward pass. However, not only vanilla BatchNorm mainly focuses on rectifying variances in forward pass, but also does not count in the situation where the gradient signals are noisy. When error signal passes through such scaling layer inside BatchNorm, the error signal is also scaled by the reciprocal of the corresponding scaling factor (σ) when calculating its derivative. When training process is in full-precision, such scaling on error signal is manageable. But when the error signal contains accumulated noise from previous layers, the noise is scaled at the same time, causing unpredictable behavior to the backpropagation. Hence our theoretical focus is fundamentally different from previous "variance rectification and reduction" studies. The following theorem quantifies how specifically BatchNorm affects the accumulation effect. Assumption 1. δ q (g xi ) and xi are i.d.d. and are both zero-mean Zhao et al. (2021) . Theorem 1. Given Assumption 1, for a BatchNorm layer in a quantized network, the relationship between the gradient quantization noise w.r.t. the BatchNorm's input x i and that of the normalized input xi depends on the σ of BatchNorm with batch size N , in the form of η = D(δ q (g xi )) D(δ q (g xi )) = 1 N 2 σ 2 (N 2 + 2N ). (3) Remark 1.1. In Eq. (3), we define the amplification factor of the accumulated noise as η, which is the ratio of statistical variances between the scaled error signal δ q (g xi ) and one before scaling δ q (g xi ) (see Fig. 4 ). Corollary 1.1. To prevent BatchNorm from introducing more gradient quantization noise (i.e. the statistical variance D(δ q (g xi ))) when propagating the error signal to preceding layers, a desirable η * should not be greater than 1. Thus, a desirable σ of the BatchNorm should be σ As the accumulated quantization noise during backward cannot be automatically amortized by the training objective in FQT, ones are left with only options to either (1) minimize the primary source of noise by improving gradient quantizer design, or (2) minimize the accumulation of such noise. With the first choice being obvious, in this paper, we instead aim to raise people's awareness about the second choice and the importance of properly scaling the noisy error signal to alleviate noise amplification problem in backward pass and the eventual training stability. ≥ σ * = 1 + 2 N . Scaling 𝒈 ෝ 𝒙 𝑖 𝑀𝐴𝐶′ 𝒈 𝒚 𝑖 𝒈 𝒙 𝑖 𝒈 𝜎 2 𝒈 𝜇 𝒈 𝜸 𝑖 𝒈 𝜷 𝑖

6. OUR METHOD: RECTIFYING BATCHNORM FOR GRADIENT QUANTIZATION

Inspired by our theory in Sec. 5, we develop a solution to suppress the noise amplification effect in a principled way, which in turn reduces the gradient explosion for improved training of quantized networks with low bit gradients. Based on Theorem 1 and Corollary 1.1, we expect the σ to be larger than the ideal value σ * , so that the noise amplification factor η = D(δq(gx i )) D(δq(g xi )) between the output and the input of BatchNorm is minimized. Therefore, we propose a method to stabilize the training with low bit gradients, by rectifying the variance of BatchNorm computed in forward pass: min. f (w) s.t. L σ = 1 L L l=1 MSE min σ l σ * , 1 , 1 = 0, where f : S → R is the regular objective loss (e.g., cross-entropy loss for classification) and w ∈ S denotes neural network weights. σ L = {σ l } L l=1 is the set of σ from all BatchNorm layers in the network with depth L. We use Mean Square Error (MSE) between σ l σ * and 1 to enforce that σ l at each BatchNorm layer approaches to σ * . The Lagrangian dual approximation form of Eq. ( 4) is: f ′ (w) = f (w) + λL σ , ( ) where λ is an adjustable parameter to balance f (w) and the proposed rectification loss L σ . Gradient Computation and Computation Overhead. The overhead introduced by our proposed rectification term only has linear time complexity. During backpropagation, the error signal of rectification term L σ when propagated to BatchNorm's input x on channel i at layer l is: ∂L σ ∂a i l = 2 σ * N ( 1 σ * -1 σ i l )(a i l -µ i l ), σ i l < σ * ; 0, otherwise , where σ * , N and channel-wise BatchNorm parameters µ i l , σ i l are all constant scalars during backpropagation, showing the added rectification is very cheap and negligible when training on devices capable of vectorized computation optimization including GPUs. More detailed analysis can be found in Appendix A.4.

7. EXPERIMENTS

To evaluate our method, we conduct extensive experiments on various neural network architectures and popular datasets for image classification with low-bit gradients. Experimental Setup. To highlight the impact of BatchNorm, we only use two vanilla quantizers for gradients in all experiments without any optimizations: uniform quantizer and logarithmic quantizer (see Appendix A.1). Our method introduces only one hyper-parameter λ in Eq. ( 5), which is manually initialized then can be ramped down by the cosine rule during training or stay the same. For each parametrized layer in backpropagation, the gradient g w and error signal g x are quantized separately, and they can be quantized to different bit-widths. To evaluate our rectification method, we ensure all the used CNN architectures have BatchNorm layers, including ShallowNet and AlexNet-BN, the architecture details of which are listed in Appendix A.3. More details are listed in Appendix A.4 7.1 MAIN RESULTS

INT8 comparisons.

We compare our method (training with L σ ) to state-of-the-art gradient quantization approaches reporting results with 8-bit gradients: UI8 Zhu et al. (2020) , FP8 Wang et al. (2018) , AFP Zhang et al. (2020) , SBM Banner et al. (2018) , DAINT8 Zhao et al. (2021) . Using 8-bit Logarithmic quantizer for gradients and the proposed rectifier L σ , Tab. 1 shows that our method outperforms the state-of-the-arts in almost all cases, despite that we simple use vanilla quantizers on the gradient, while the quantizer designs of the counterparts are heavily engineered, e.g. DAINT8 Zhao et al. (2021) adopts vector quantization to process error signal in channel-wise manner. We found that MobileNet-V2 on ImageNet is harder to train with 8-bit gradients with vanilla quantizer designs even training with our L σ , ending up with around 1% accuracy drop than SOTA Zhao et al. (2021) . As an ablation, for ResNet-20 on CIFAR10, our method boosts the accuracy by 1% from baseline training (w/o L σ ), also for ResNet-18 on ImageNet, L σ achieves almost 2% improvement. Comparisons on 4-bits gradients. Since there are very few works on quantized neural networks with less than 8-bit gradients, we compare our method to a 4-bit floating-point quantization method named FP4 Sun et al. (2020) which actually adopt 4 bit floating-point representations, and the baseline which is without the proposed L σ . As shown in Tab. 2, we observe that in most cases our method with INT4 gradient quantization reports higher accuracy than FP4 Sun et al. (2020) , despite FP4 Sun et al. (2020) adopts floating-point quantization with customized radix and scaling selections. We also observe that our method (w/ L σ ) performs better than the baseline, further verifies the effectiveness of the proposed L σ . On the other hand, we observe MobileNet-V2 with INT4 gradients is still unstable and hard to converge even with our rectification method deployed (it explained why most FQT works did not report results for INT4 gradients on MobileNet-V2). Thus, we'd like to leave the FQT of MobileNet-V2 for future work. 

7.2. EXTREMELY LOW BIT-WIDTHS

Despite scarce exploitation in the wild and the challenges, we further attempt to quantize gradients to even lower bit-widths, with different backbone network architectures, bit-widths combinations, and gradient quantizers, to further evaluate the theoretical capability of the proposed rectifier L σ . When gradients are quantized to very low bit-widths, we expect the training becoming extremely unstable as the quantization noise and eventually the accumulation effect becoming much severer. Considering the increased training instability, we perform three independent trials for each experiments and report the comprehensive scores as (mean±std). Simple network architectures. As shown in Tab. 4, we first study the training of a two-layer quantized neural network ShallowNet on MNIST. We set the bit-width for weights, activations, and gradients as 2-bit (W2A2G2) and use Logarithmic quantizer for gradients. We observed that baseline training without L σ crashed twice out of 3 repetitive runs, while the training with the proposed L σ is stable throughout. In other words, our method outperforms the baseline by a large margin, improving the average accuracy by +54.4% (from 38.8% to 93.2%). On larger dataset ImageNet, we also have the similar observation for AlexNet-Bn, where baseline method fails to train completely while our method with L σ can train stably throughout three trials. More complex network architectures. We further test the effectiveness of our method in FQT training of networks with more complex structures. We push the boundary of the lowest bit-widths settings we can achieve, as lowest as e.g. W2A2G2 on ResNet-18 in Tab. 5. We observed that the FQT on VGG-16 is slightly more sensitive to the quantization than ResNet-18, thus we have to set higher bit-widths for VGG-16. We also notice that on VGG-16, error (dx) requires more bit-width than gradients w.r.t. weights (dW ). In all, our method performs consistently better than the baseline (w/o L σ ) on all settings, in particular, the performance gain on VGG-16 is significant. As an additional remark, we notice that compared to other models, ResNets are more robust against more quantization noise throughout our experiments as suggested in Tab. 2 and Tab. 5. We conjec- for Shal-lowNet@MNIST. (e) Test Acc. for VGG-16@CIFAR-10. (f) Test Acc. for ResNet-18@CIFAR-10. Figure 5 : Illustration of stabilized training of quantized networks with lower than 4-bit gradients w/ L σ , compared to the baseline w/o L σ . Our method (w/ L σ ) shows higher averaged accuracy and lower variance across repeated runs. X-axis in all sub-figures represents epochs. ture that it is due to the full-precision shortcuts within, making them naturally more robust against the accumulation effect during backward pass. Theoretical investigation towards this phenomenon would be an interesting future research topic.

Bit-widths

L σ Runs Avg Top-1 (%) Diff. (%) (W/A/dx/dW ) #1 #2 #3 ShallowNet on MNIST 2/2/2

7.3. OTHER DISCUSSIONS

Can L σ stabilize training? Fig. 5 illustrates the improved training of quantized networks with less than 4-bit gradients, thanks to the rectification loss L σ . Fig. 5a and Fig. 5d are ShallowNet trained on MNIST for bit-width configuration W2A2G2 with Logarithmic gradient quantizer. Fig. 5b and Fig. 5e are VGG-16 trained on CIFAR-10 for W4A4dx4dW 2 with Logarithmic gradient quantizer. Fig. 5c and Fig. 5f are ResNet-18 trained on CIFAR-10 for W4A4G4 with uniform gradient quantizer. We plot out the average loss/accuracy (AVG) and standard deviation of loss/accuracy (STD) of 3 trails separately. From Fig. 5a , Fig. 5b and Fig. 5c , we observe that the rectification loss L σ dominates the training loss for the first few epochs, forcing the optimization adapted to the low-bit gradients. Afterwards, the training process becomes much more stable with training loss decreased and test accuracy increased gradually (see Fig. 5d , Fig. 5e and Fig. 5f ). On the contrary, training without L σ is not able to suppress the negative effect of gradient quantization noise, resulting in training instability or even crash. Can L σ suppress the noise amplification effect? To verify such stabilizing effect indeed comes from the proposed rectification, we further studies its impacts on gradient distribution layer by layer. Fig. 6 illustrates the distribution of variances of gradients w.r. As a result, it help prevent crashing in training with low bit gradients (see NaN at the 2.4k th iteration for training without L σ ). Effect of λ. As shown in Tab. 3, we study the effect of the hyper-parameter λ when training quantized ResNet-18 on CIFAR-10 with bit-width W2A2G2. The optimal values of λ for different experiments could be different and we heuristically tune them separately. Can BatchNorm layer be discarded when training with low-bit gradients? Since BatchNorm amplifies the accumulated quantization noise during backpropagation, one may argue that a straightforward way to prevent the noise amplification effect is removing BatchNorm layer from the network architecture. To verify the point, we conducted experiments on training AlexNet with or without BatchNorm under W4A4dx4dW 2 using logarithmic quantizer. As shown in Fig. 7 , training AlexNet without BatchNorm quickly collapsed in the early stage in training. This implies that at the moment BatchNorm is still an essential building block to deep CNNs training for its rectifying benefits mainly in forward pass, while our rectification method complementarily stabilizes the backward pass, at least in the case when quantization is applied especially in FQT.

8. CONCLUSIONS

In this paper, we study an under-explored factor causing the gradient explosion problem when training deep CNNs with low-bit gradients, from a theoretical perspective. Our theory sheds light on the negative effect of BatchNorm in amplifying the accumulated gradient quantization noise during backpropagation, which leads to unstable training or even crash. The theory inspires a simple yet effective method to stabilize FQT with low-bit gradients, which consistently brings performance gain on a wide range of CNNs and datasets compared to state-of-the-art FQT algorithms. Zhou et al. (2016) ; Zhu et al. (2020) ; Yang et al. (2020) ; Wang et al. (2018) . Uniform quantizer is the default and common choice gradient quantization, where given the gradients g ∈ {dx, dW }, the asymmetric uniform quantizer with bit-width B and quantization levels ranging from [a, b] can be formulated as: g q = Quant u (g ′ , a, b, B) = round clip(g ′ , a, b) • 2 B-1 -1 b -a , where g ′ = g + δ q (g) (see Sec. 3 for details) and clip(x, a, b) = min(max(x, a), b). The quantized gradients are de-quantized as g = g q • b-a 2 B-1 -1 . Similarly, the symmetric uniform quantizer can be defined as Quant s = Quant u (g, -c, c, B) where the c is the clipping value. We also explored logarithmic quantizer Miyashita et al. (2016) in this paper for gradient quantization given by: g q = Quant log (g ′ , c, B) = sign(g ′ ) • 2 Quant u (log 2 |g ′ |,log 2 (c)-2 B-1 ,log 2 (c),B) , g ′ ̸ = 0; 0, g ′ = 0 A.2 QUANTIFYING THE AMPLIFICATION EFFECT IN GQNA Theorem 1. For a BatchNorm layer within a quantized network, the relationship between gradient quantization error of BatchNorm's input x i and that of the normalized input xi only depends on the batch size N and the σ of BatchNorm, in the form of D(δ q (g xi )) D(δ q (g xi )) = 1 N 2 σ 2 (N 2 + 2N ). Proof. First we extend the Eq.2 to its quantized counterpart:  gxi = 1 N σ N gx i - N i=0 gx i -xi N i=0 gx i xi Expand Eq. (10) using Eq. (1): g xi + δ q (g xi ) + e(g xi ) = 1 N σ N (g xi + δ q (g xi ) + e(g xi )) - N i=0 (g xi + δ q (g xi ) + e(g xi )) -xi N i=0 (g xi + δ q (g xi ) + e(g xi )) xi . Eliminating terms in Eq. ( 12) using Eq. ( 1) and Eq. ( 11), we have: δ q (g xi ) = 1 N σ N δ q (g xi ) - N i=0 δ q (g xi ) -xi N i=0 δ q (g xi ) xi . ( ) Calculate the variance D(•) of LHS and RHS of the above we have: (assume δ q (g xi ) and xi are i.d.d. and are both zero-mean) 2019) for weight quantization. We choose the SGD optimizer and train all quantized network models from scratch. The learning rate is adjusted with a cosine scheduler (Loshchilov & Hutter (2017) ). For MNIST, we set the learning rate as 0.1, weight decay as 0.0001, and train for 20 epochs. For CIFAR-10, we set the learning rate as 0.1 for all architectures, weight decay 0.0001 for ResNet-18 and VGG-16, and weight decay 0.0004 for MobileNet-V2. For ImageNet, we set the learning rate as 0. We specify the choices of λ and its ramping down strategy as below. D(δ q (g xi )) = 1 N 2 σ 2 D N δ q (g xi ) - N i=0 δ q (g xi ) -xi N i=0 δ q (g xi ) xi (14) = 1 N 2 σ 2 N 2 D(δ q (g xi )) + D N i=0 δ q (g xi ) + D xi N i=0 δ q (g xi ) xi (15) = 1 N 2 σ 2 N 2 D(δ q (g xi )) + N i=0 D(δ q (g xi )) + D xi N i=0 δ q (g xi ) xi Comparing 

A.5 DETAILED ANALYSIS OF COMPUTATION OVERHEAD

To more directly derive in the linear complexity conclusion, one can simplify Eq. ( 6) into ∂Lσ ∂a l = C 1 ⊙ a l + C 2 , where ⊙ denotes the element-wise multiplication, C 1 , C 2 are both constant tensors with the same shape as a l . The partial derivative ∂Lσ ∂a l is further aggregated with the error signal from objective function ∂f (w) ∂a l to compute the gradient w.r.t. weights ∂f ′ (w) ∂w l . Here L σ introduces no extra computation cost. Tab. 9 gives an empirical verification of the actual training overhead of our method compared to the baseline.

Dataset

Arch To comprehensively evaluate the influence of the choices of hyper-parameter λ, we further tested the model performance on higher bit-width. Tab. 10 shows the results under different λ values of ResNet-18 on CIFAR-10. The quantizer choice for gradients is log quantizer. 



https://github.com/jcjohnson/cnn-benchmarks



at 5-bit down to 2-bit. Most QAT works quantize weights and activations simultaneously, by optimizing the uniform quantization parameters Zhang et al. (2018a); Esser et al. (2020); Bhalgat et al. (2020), layer-wise or channel-wise mixed-precision quantization Jin et al. (2020); Lou et al. (2020), or leveraging non-uniform quantization such as Logarithmic quantizer Miyashita et al. (2016) and piece-wise linear quantizer Fang et al. (2020). Most recent QAT works Zhou et al. (2016); Choi et al. (2018); Zhang et al. (2018a); Esser et al. (2020); Bhalgat et al. (2020) used "Straight-Through Estimator" (STE) Bengio et al. (2013) to estimate the gradient of the non-differentiable quantization function, while other work Gong et al. (2019) softened the linear quantization operation in order to match the true gradient with STE.

Figure 4: Illustration of gradient propagation inside BatchNorm.

Figure 6: Distribution of variances of gradients w.r.t.activations when training VGG-16 on CIFAR-10 with bit-width W4A4dx4dW 2. Logarithmic quantizer is used for gradients.

Figure 7: Train loss curve of AlexNet on Im-ageNet with W4A4dx4dW 2 w/ or w/o Batch-Norm. (dx denotes error signal, dW is gradients of weights, same below).

xi ) xi .

0512 and weight decay as 0.0001 for ResNet-18, learning rate 0.01 and weight decay 0.0005 for AlexNet, and learning rate 0.1 and weight decay 0.00004 for MobileNet-V2. All models on CIFAR-10 and ImageNet are trained for 120 epochs, except for MobileNet-V2 that is trained for 150 epochs. For ImageNet, training images are randomly cropped to 224×224 and then randomly flipped horizontally. For experiments on MNIST and CIFAR-10, we repeat the training of each model for 3 runs by varying the random seed and report the average accuracy with standard deviation. Same as other INT8 works Zhu et al. (2020), we left the first and the last weighted layers as well as activations in full-precision for all INT8 experiments. For INT4 experiments, we followed the setting in Choi et al. (2019) that only left shortcut layers in ResNets in full-precision and quantized all other layers.

Ablation study of λ under bit-widths W4A4G4.

Effect of λ for training quantized ResNet-18 on CIFAR-10 with bit-width W2A2G2.

Ablations of L σ under extremely low bit-widths. Scores in red denotes failed train.

t. layer output activations when training quantized VGG-16 on CIFAR-10 with bit-width W4A4dx4dW 2. By injecting the rectification loss L σ defined on BatchNorm in training objective function, one can see the noise amplification effect is largely suppressed, and thus the gradients are less exploded (measured by the variances of gradients). Ablations of L σ and quantizers under extremely low bit-widths.

SOTA with 8-bit gradients Comparing state-of-the-art with 8-bit gradients.

Comparing our 4-bit fixed-point gradient quantization to the 4-bit floating-point gradient quantiationFP4 Sun et al. (2020).

Less than 4-bit gradients.

Comparison of training time (hours).A.6 DETAILED ABLATIONS ON HYPER-PARAMETER (ON HIGHER BIT-WIDTHS)

annex

Therefore, 

