IMPROVED FULLY QUANTIZED TRAINING VIA RECTI-FYING BATCH NORMALIZATION

Abstract

Quantization-aware Training (QAT) is able to reduce the training cost by quantizing neural network weights and activations in the forward pass and improve the speed at the inference stage. QAT can be extended to Fully-Quantized Training (FQT), which further accelerates the training by quantizing gradients in the backward pass as backpropagation typically occupies half of the training time. Unfortunately, gradient quantization is challenging as Stochastic Gradient Descent (SGD) based training is sensitive to the precision of the gradient signal. Particularly, the noise introduced by gradient quantization accumulates during backward pass, which causes the exploding gradient problem and results in unstable training and significant accuracy drop. Though Batch Normalization (BatchNorm) is a de-facto resort to stabilize training in regular full-precision scenario, we observe that it fails to prevent the gradient explosion when gradient quantizers are injected in the backward pass. Surprisingly, our theory shows that BatchNorm could amplify the noise accumulation, which in turn hastens the explosion of gradients. A BatchNorm rectification method is derived from our theory to suppress the amplification effect and bridge the performance gap between full-precision training and FQT. Adding this simple rectification loss to baselines generates better results than most prior FQT algorithms on various neural network architectures and datasets, regardless of the gradient bit-widths used (8,4, and 2 bits).

1. INTRODUCTION

Quantization-aware Training (QAT) is a popular track of research that simulates the neural network quantization (weights and activations) during the course of training to curb the inference-time accuracy drop of low-bit models (e.g. INT8 quantization). On the other hand, theoretical calculations on the BitOps (Yang & Jin (2021) ; Guo et al. (2020) ) computation costs can easily conclude that backpropagation accounts for half of the computations during training. Empirical datafoot_0 shows backward pass sometimes even costs more in practice. Decreasing the gradient bit-widths will apparently reduce computation overheads of backpropagation Horowitz (2014). If variables in backward pass are also quantized, adding up the forward quantization in QAT, all the network variables required in training would be fully quantized and the whole training process could be accelerated on dedicated hardware, i. Yet gradient quantization under the FQT scheme is vastly underexplored, as it is notoriously more challenging than forward quantization in QAT. It is observed that network training is sensitive to the precision of gradients, and low-bit gradient quantization leads to unstable training and significant accuracy drop (see Fig. 1 ). More importantly, the accumulation of gradient quantization noise in backward pass (see Fig. 2 ) causes the exploding gradient problem during backpropagation, even resulting in training failure. In contrast to weight/activation quantization, gradient quantization noise produced during backpropagation cannot be automatically corrected by optimizing objective loss. Unlike prior works on optimizing gradient quantizers for quantization noise reduction Zhou et al. (2016); Choi et al. (2018); Zhu et al. (2020) , this paper reveals the negative effect of Batch Nor-malization (BatchNorm) on amplifying the gradient quantization noise accumulation, when training deep Convolutional Neural Networks (CNNs) with low bit gradients. We show that the noise amplification effect further explodes the gradients during the backward pass. We thus propose a BatchNorm rectification method to suppress the noise amplification effect and alleviate the gradient explosion problem, which in turn enables stabilized training and better accuracy at low-bit gradients. Our contributions are summarized as follows: • We discover that BatchNorm fails to prevent the exploded low-bit gradients in fullquantized training through theoretic analysis, and may even amplify the accumulated gradient quantization noise, which further aggravates the gradient explosion. • According to our theory, we propose a simple yet effective BatchNorm variance rectification algorithm without introducing noticeable overhead, to suppress the noise amplification effect, resulting in alleviated gradient explosion. • Extensive experiments on MNIST, CIFAR-10, and ImageNet show that our method achieves improved training and higher accuracy over state-of-the-arts with vanilla gradient quantizers, regardless of the gradient bit-widths used (8,4,2 bits). 



https://github.com/jcjohnson/cnn-benchmarks



e., Fully-Quantized Training (FQT), providing huge accessibility of large model training to users with limited computation capability. Recent work Zhu et al. (2020) has shown that INT8 FQT speeds up the forward pass and the backward pass by 1.63× and 1.94× respectively when training ResNet-50 on ImageNet with NVIDIA Pascal GPU.

-Aware Training (QAT). DoReFa-Net Zhou et al. (2016) proposed to optimize the clipping value and the scaling factor of the uniform quantizers for weights and activations separately. It was validated on image classification task under multiple bit-widths, but only with the rather simple AlexNet architecture. PACT Choi et al. (2018) proposed to quantize the activations with a learnable layer-wise clipping value, which not surprisingly achieved better accuracy than DoReFa-Net at 5-bit down to 2-bit. Most QAT works quantize weights and activations simultaneously, by optimizing the uniform quantization parameters Zhang et al. (2018a); Esser et al. (2020); Bhalgat et al. (2020), layer-wise or channel-wise mixed-precision quantization Jin et al. (2020); Lou et al. (2020), or leveraging non-uniform quantization such as Logarithmic quantizer Miyashita et al. (2016) and piece-wise linear quantizer Fang et al. (2020). Most recent QAT works Zhou et al. (2016); Choi et al. (2018); Zhang et al. (2018a); Esser et al. (2020); Bhalgat et al. (2020) used "Straight-Through Estimator" (STE) Bengio et al. (2013) to estimate the gradient of the non-differentiable quantization function, while other work Gong et al. (2019) softened the linear quantization operation in order to match the true gradient with STE. Fully-quantized Training (FQT). FQT aims to accelerate and quantize the backward pass of network training with low-bit error signals and gradients, agnostic to single machine or parallel training. Early attempt Zhou et al. (2016) adopted a primitive quantizer design based on uniform quantizer for gradients (without scaling and other optimization) and large performance drops are witnessed when training with low-bit gradients. SBM Banner et al. (2018) adopted fixed-point 8-bit gradient quantization, but only focused on improving the quantization schemes in the forward pass. WAGEUBN Yang et al. (2020) quantized gradients to 8-bit integers, but also showed a huge performance gap against its full-precision counterpart. NITI Wang et al. (2020) integrated gradient calculations with parameter update operations to reduce the gradient quantization noise with welldesigned quantizers. However, it can only support shallow CNN architectures and did not explore any deeper networks with BatchNorm. In Zhu et al. (2020), the authors considered the sharp and wide distribution of gradients, and proposed to clip the gradients according to the deviation of the gradient distribution before quantization, achieving on-par results with full-precision training. To compensate the quantization loss on gradient, AFP Zhang et al. (2020) and CPT Fu et al. (2021) used higher precision data to aid low-precision training. DAINT8 Zhao et al. (2021) adopted a bespoke 8-bit channel-wise gradient quantization to suppress the negative effect of quantization noise during training. Gradient quantization with less than INT8 representations remains largely unexplored. FP4 Sun et al. (2020) managed to train modern CNN architectures using 4-bit gradients without significant accuracy loss, but the gradients were represented as floating-point numbers. Batch Normalization in QAT. Most QAT approaches either left BatchNorm in between parameterized layer (Conv/FC) and activation layer (ReLU) without quantization Zhou et al. (2016); Choi et al. (2018) or with quantization Yang et al. (2020), or absorbed BatchNorm into Conv before weight quantization in the forward pass Jacob et al. (2018), or directly trained a BatchNorm-free shallow network architecture to achieve full 8-bit integer-only arithmetic Wang et al. (2020). To our

