CHARACTERIZING SIGNAL PROPAGATION TO CLOSE THE PERFORMANCE GAP IN UNNORMALIZED RESNETS

Abstract

Batch Normalization is a key component in almost all state-of-the-art image classifiers, but it also introduces practical challenges: it breaks the independence between training examples within a batch, can incur compute and memory overhead, and often results in unexpected bugs. Building on recent theoretical analyses of deep ResNets at initialization, we propose a simple set of analysis tools to characterize signal propagation on the forward pass, and leverage these tools to design highly performant ResNets without activation normalization layers. Crucial to our success is an adapted version of the recently proposed Weight Standardization. Our analysis tools show how this technique preserves the signal in networks with ReLU or Swish activation functions by ensuring that the per-channel activation means do not grow with depth. Across a range of FLOP budgets, our networks attain performance competitive with the state-of-the-art EfficientNets on ImageNet.

1. INTRODUCTION

BatchNorm has become a core computational primitive in deep learning (Ioffe & Szegedy, 2015) , and it is used in almost all state-of-the-art image classifiers (Tan & Le, 2019; Wei et al., 2020) . A number of different benefits of BatchNorm have been identified. It smoothens the loss landscape (Santurkar et al., 2018) , which allows training with larger learning rates (Bjorck et al., 2018) , and the noise arising from the minibatch estimates of the batch statistics introduces implicit regularization (Luo et al., 2019) . Crucially, recent theoretical work (Balduzzi et al., 2017; De & Smith, 2020) has demonstrated that BatchNorm ensures good signal propagation at initialization in deep residual networks with identity skip connections (He et al., 2016b; a) , and this benefit has enabled practitioners to train deep ResNets with hundreds or even thousands of layers (Zhang et al., 2019) . However, BatchNorm also has many disadvantages. Its behavior is strongly dependent on the batch size, performing poorly when the per device batch size is too small or too large (Hoffer et al., 2017) , and it introduces a discrepancy between the behaviour of the model during training and at inference time. BatchNorm also adds memory overhead (Rota Bulò et al., 2018) , and is a common source of implementation errors (Pham et al., 2019) . In addition, it is often difficult to replicate batch normalized models trained on different hardware. A number of alternative normalization layers have been proposed (Ba et al., 2016; Wu & He, 2018) , but typically these alternatives generalize poorly or introduce their own drawbacks, such as added compute costs at inference. Another line of work has sought to eliminate layers which normalize hidden activations entirely. A common trend is to initialize residual branches to output zeros (Goyal et al., 2017; Zhang et al., 2019; De & Smith, 2020; Bachlechner et al., 2020) , which ensures that the signal is dominated by the skip path early in training. However while this strategy enables us to train deep ResNets with thousands of layers, it still degrades generalization when compared to well-tuned baselines (De & Smith, 2020) . These simple initialization strategies are also not applicable to more complicated architectures like EfficientNets (Tan & Le, 2019), the current state of the art on ImageNet (Russakovsky et al., 2015) . This work seeks to establish a general recipe for training deep ResNets without normalization layers, which achieve test accuracy competitive with the state of the art. Our contributions are as follows: • We introduce Signal Propagation Plots (SPPs): a simple set of visualizations which help us inspect signal propagation at initialization on the forward pass in deep residual networks. Leveraging these SPPs, we show how to design unnormalized ResNets which are constrained to have signal propagation properties similar to batch-normalized ResNets. • We identify a key failure mode in unnormalized ResNets with ReLU or Swish activations and Gaussian weights. Because the mean output of these non-linearities is positive, the squared mean of the hidden activations on each channel grows rapidly as the network depth increases. To resolve this, we propose Scaled Weight Standardization, a minor modification of the recently proposed Weight Standardization (Qiao et al., 2019; Huang et al., 2017b) , which prevents the growth in the mean signal, leading to a substantial boost in performance. • We apply our normalization-free network structure in conjunction with Scaled Weight Standardization to ResNets on ImageNet, where we for the first time attain performance which is comparable or better than batch-normalized ResNets on networks as deep as 288 layers. • Finally, we apply our normalization-free approach to the RegNet architecture (Radosavovic et al., 2020) . By combining this architecture with the compound scaling strategy proposed by Tan & Le (2019), we develop a class of models without normalization layers which are competitive with the current ImageNet state of the art across a range of FLOP budgets.

2. BACKGROUND

Deep ResNets at initialization: The combination of BatchNorm (Ioffe & Szegedy, 2015) and skip connections (Srivastava et al., 2015; He et al., 2016a) has allowed practitioners to train deep ResNets with hundreds or thousands of layers. To understand this effect, a number of papers have analyzed signal propagation in normalized ResNets at initialization (Balduzzi et al., 2017; Yang et al., 2019) . In a recent work, De & Smith (2020) showed that in normalized ResNets with Gaussian initialization, the activations on the th residual branch are suppressed by factor of O( √ ), relative to the scale of the activations on the skip path. This biases the residual blocks in deep ResNets towards the identity function at initialization, ensuring well-behaved gradients. In unnormalized networks, one can preserve this benefit by introducing a learnable scalar at the end of each residual branch, initialized to zero (Zhang et al., 2019; De & Smith, 2020; Bachlechner et al., 2020) . This simple change is sufficient to train deep ResNets with thousands of layers without normalization. However, while this method is easy to implement and achieves excellent convergence on the training set, it still achieves lower test accuracies than normalized networks when compared to well-tuned baselines. These insights from studies of batch-normalized ResNets are also supported by theoretical analyses of unnormalized networks (Taki, 2017; Yang & Schoenholz, 2017; Hanin & Rolnick, 2018; Qi et al., 2020) . These works suggest that, in ResNets with identity skip connections, if the signal does not explode on the forward pass, the gradients will neither explode nor vanish on the backward pass. Hanin & Rolnick (2018) conclude that multiplying the hidden activations on the residual branch by a factor of O(1/d) or less, where d denotes the network depth, is sufficient to ensure trainability at initialization.

Alternate normalizers:

To counteract the limitations of BatchNorm in different situations, a range of alternative normalization schemes have been proposed, each operating on different components of the hidden activations. These include LayerNorm (Ba et al., 2016 ), InstanceNorm (Ulyanov et al., 2016 ), GroupNorm (Wu & He, 2018) , and many more (Huang et al., 2020) . While these alternatives remove the dependency on the batch size and typically work better than BatchNorm for very small batch sizes, they also introduce limitations of their own, such as introducing additional computational costs during inference time. Furthermore for image classification, these alternatives still tend to achieve lower test accuracies than well-tuned BatchNorm baselines. As one exception, we note that the combination of GroupNorm with Weight Standardization (Qiao et al., 2019) was recently identified as a promising alternative to BatchNorm in ResNet-50 (Kolesnikov et al., 2019) .

