ON THE STABILITY OF MULTI-BRANCH NETWORK Anonymous

Abstract

Multi-branch architectures are widely used in state-of-the-art neural networks. Their empirical success relies on some design wisdom, like adding normalization layers or/and scaling down the initialization. In this paper, we investigate the multi-branch architecture from the stability perspective. Specifically, we establish the forward/backward stability of multi-branch network, which leads to several new findings. Our analysis shows that only scaling down the initialization may not be enough for training multi-branch network successfully because of the uncontrollable backward process. We also unveil a new role of the normalization layer in terms of stabilizing the multi-branch architectures. More importantly, we propose a new design "STAM aggregation" that can guarantee to STAbilize the forward/backward process of Multi-branch networks irrespective of the number of branches. We demonstrate that with STAM aggregation, the same training strategy is applicable to models with different numbers of branches, which can reduce the hyper-parameter tuning burden. Our experiments verify our theoretical findings and also demonstrate that the STAM aggregation can improve the performance of multi-branch networks considerably.

1. INTRODUCTION

Multi-branch architecture is a building block in state-of-the-art neural network models for many tasks, e.g., the ResNeXt (Xie et al., 2017) for computer vision and the Transformer (Vaswani et al., 2017) for machine translation. It has been pointed out that the benefit of multi-branch architecture is the parameter efficiency (Xie et al., 2017) . The number of parameters grows linearly with the number of branches but quadratically with the width (the number of neurons in one layer). It has also been argued that the multiple branches can bring diversity if branches are composed of sub-networks with different filters and depths (Huang et al., 2017; Li et al., 2019) . To train multi-branch networks successfully, it usually requires careful designs and some hyperparameter tuning such as adding normalization layers, scaling down the initialization, and adjusting learning rate. As a verifying example, for a trainable single-branch network, simply adding branches multiple times and aggregating their outputs together often do not work as expected, e.g., the training instability of sum aggregation in Figure 4 . This demonstrates the difficulty of training multi-branch network and also motivates us to do this study. In this paper, we try to understand the behavior of training multi-branch network. Specifically, we study the forward and backward process of multi-branch networks, which is believed to govern whether the network is easy to optimize by gradient-based methods. We find out that the aggregation scheme, i.e., "the way of combining the multi-branch outputs" plays a central role in determining the behavior of training multi-branch network. We show that the sum aggregation would become unstable as the number of branches grows, which explains the bad performance of simply adding branches. Moreover, we characterize the condition on the aggregation scheme under which the forward and backward stability is guaranteed. Inspired by the theoretical analysis, we propose a "STAM" aggregation, that can STAbilize Multibranch network, which scales the sum of the branch outputs by a branch-aware factor α (see the later part of Section 3.1 for details). We argue the benefit of STAM aggregation over the sum and average aggregations by analyzing the Hessian of the multi-branch network. We show that STAM permits the same gradient-based optimizer works for different settings, which could reduce lots of tuning burden for training network with flexible number of branches.

