REVISITING BFLOAT16 TRAINING

Abstract

State-of-the-art generic low-precision training algorithms use a mix of 16-bit and 32-bit precision, creating the folklore that 16-bit precision alone is not enough to maximize model accuracy. As a result, deep learning accelerators are forced to support both 16-bit and 32-bit compute units which is more costly than only using 16-bit units for hardware design. We ask can we do pure 16-bit training which requires only 16-bit compute units, while still matching the model accuracy attained by 32-bit training. Towards this end, we study pure 16-bit training algorithms on the widely adopted BFloat16 compute unit. While these units conventionally use nearest rounding to cast output to 16-bit precision, we show that nearest rounding for model weight updates can often cancel small updates, which degrades the convergence and model accuracy. Motivated by this, we identify two simple existing techniques, stochastic rounding and Kahan summation, to remedy the model accuracy degradation in pure 16-bit training. We empirically show that these two techniques can enable up to 7% absolute validation accuracy gain in pure 16-bit training. This leads to 0.1% lower to 0.2% higher matching validation accuracy compared to 32-bit precision training across seven deep learning applications.

1. INTRODUCTION

Recently there has been an explosion in the compute resources required for training deep learning models (Shoeybi et al., 2019; Rajbhandari et al., 2019; Real et al., 2019) . As a result, there has been broad interest in leveraging low-precision (< 32-bit) training algorithms to reduce the required compute resources (De Sa et al., 2017; Hubara et al., 2017; Gupta et al., 2015) . Among these algorithms, mixed-precision training-in which model activations and gradients are stored using a 16-bit floating point format while model weights and optimizer states use 32-bit precision-is commonly used when training generic deep learning models (Micikevicius et al., 2017; Kalamkar et al., 2019) . While there is a wide body of literature showing that low-precision training can minimally impact accuracy on specific models (Wang et al., 2018b; De Sa et al., 2015; Zhang et al., 2017) , conventional wisdom suggests that at least some 32-bit computation is required as a fail-safe in generic deep learning training. As such, new accelerator architectures for deep learning are forced to support both 32-bit and 16-bit compute units. This is much more costly in terms of area, power, and speed when compared to hardware with only 16-bit compute units (Horowitz, 2014; Galal et al., 2013) . In this paper we question if 32-bit compute units are truly needed for new deep learning hardware accelerators. Namely, can we match the model accuracy of 32-bit-precision algorithms while leveraging only 16-bit compute units? To answer this question, we study pure 16-bit training algorithms, ones which use only 16-bit compute units and which store activations, gradients, model weights, and optimizer states all in a 16-bit precision. Specifically, we focus on training with the BFloat16 compute unit which is widely adopted in modern deep learning accelerators (Jouppi et al., 2017; Burgess et al., 2019) . Such units take 16-bit inputs, perform computation, and then round the results to a 16bit output. BFloat16 compute units can provide 3⇥ higher power efficiency, 1.5⇥ lower latency, and 1.5⇥ less chip area than 32-bit units (Horowitz, 2014; Galal et al., 2013) . In addition, pure 16-bit training algorithms can reduce the memory footprint and bandwidth consumption of model weights and optimizers by 2⇥ compared to mixed precision or 32-bit precision training, especially for large models with billions of weights (Shoeybi et al., 2019; Rajbhandari et al., 2019) . Developing reliable pure 16-bit training algorithms will enable hardware designers to realize these advantages. The simplest approach to pure 16-bit training is to take a 32-bit baseline and "make it low-precision" by replacing all the 32-bit numbers with 16-bit numbers and replacing each 32-bit floating-point op-

