ON THE STABILITY OF MULTI-BRANCH NETWORK Anonymous

Abstract

Multi-branch architectures are widely used in state-of-the-art neural networks. Their empirical success relies on some design wisdom, like adding normalization layers or/and scaling down the initialization. In this paper, we investigate the multi-branch architecture from the stability perspective. Specifically, we establish the forward/backward stability of multi-branch network, which leads to several new findings. Our analysis shows that only scaling down the initialization may not be enough for training multi-branch network successfully because of the uncontrollable backward process. We also unveil a new role of the normalization layer in terms of stabilizing the multi-branch architectures. More importantly, we propose a new design "STAM aggregation" that can guarantee to STAbilize the forward/backward process of Multi-branch networks irrespective of the number of branches. We demonstrate that with STAM aggregation, the same training strategy is applicable to models with different numbers of branches, which can reduce the hyper-parameter tuning burden. Our experiments verify our theoretical findings and also demonstrate that the STAM aggregation can improve the performance of multi-branch networks considerably.

1. INTRODUCTION

Multi-branch architecture is a building block in state-of-the-art neural network models for many tasks, e.g., the ResNeXt (Xie et al., 2017) for computer vision and the Transformer (Vaswani et al., 2017) for machine translation. It has been pointed out that the benefit of multi-branch architecture is the parameter efficiency (Xie et al., 2017) . The number of parameters grows linearly with the number of branches but quadratically with the width (the number of neurons in one layer). It has also been argued that the multiple branches can bring diversity if branches are composed of sub-networks with different filters and depths (Huang et al., 2017; Li et al., 2019) . To train multi-branch networks successfully, it usually requires careful designs and some hyperparameter tuning such as adding normalization layers, scaling down the initialization, and adjusting learning rate. As a verifying example, for a trainable single-branch network, simply adding branches multiple times and aggregating their outputs together often do not work as expected, e.g., the training instability of sum aggregation in Figure 4 . This demonstrates the difficulty of training multi-branch network and also motivates us to do this study. In this paper, we try to understand the behavior of training multi-branch network. Specifically, we study the forward and backward process of multi-branch networks, which is believed to govern whether the network is easy to optimize by gradient-based methods. We find out that the aggregation scheme, i.e., "the way of combining the multi-branch outputs" plays a central role in determining the behavior of training multi-branch network. We show that the sum aggregation would become unstable as the number of branches grows, which explains the bad performance of simply adding branches. Moreover, we characterize the condition on the aggregation scheme under which the forward and backward stability is guaranteed. Inspired by the theoretical analysis, we propose a "STAM" aggregation, that can STAbilize Multibranch network, which scales the sum of the branch outputs by a branch-aware factor α (see the later part of Section 3.1 for details). We argue the benefit of STAM aggregation over the sum and average aggregations by analyzing the Hessian of the multi-branch network. We show that STAM permits the same gradient-based optimizer works for different settings, which could reduce lots of tuning burden for training network with flexible number of branches. We further examine the usual design wisdom through the stability lens. As a result, we find that scaling down initialization may control the forward or backward stability but not necessarily the both, which is verified in experiment. We also unveil a new role of normalization layer that it can stabilize the forward and backward process of multi-branch network besides the many wanted and unwanted properties that have been argued before (Ioffe & Szegedy, 2015; Yang et al., 2018; Santurkar et al., 2018; Xiong et al., 2020) . Apart from the usual feedforward multi-branch architecture, we analyze the multi-head attention layer, a multi-branch architecture widely used in natural language processing. We give an upper bound on the multi-head representations when the softmax operator is replaced with max operation. The upper bound unveils the relation between the head dimension and the length of the sequence, which interprets empirical observation well. This relation cannot be discovered if assuming softmax outputs equal probability as in Xiong et al. (2020) . Overall, our contribution can be summarized as follows. • We analyze the forward/backward stability of multi-branch network, under which we can clearly interpret the benefit and potential problem of the practical wisdom, i.e., scaling down initialization and adding normalization layer. • We propose a theoretically inspired STAM aggregation design for multi-branch network, which can handle arbitrary number of branches with a same optimizer. • We also analyze the forward/backward process of multi-head attention layer and identify its special property that has not been characterized before.

1.1. RELATED WORK

Multi-branch architecture, also known as split-transform-merge architecture, has been widely used in computer vision task, namely Inceptions (Szegedy et al., 2017; Chollet, 2017 ), ResNeXt (Xie et al., 2017) , and many others (Abdi & Nahavandi, 2016; Ahmed & Torresani, 2017) . In fact, the models in natural language tasks have also leveraged the multi-branch architecture including the BiLSTM (Wu et al., 2016; Zhou et al., 2016) and the multi-head attention layer in Transformer (Vaswani et al., 2017; Anonymous, 2020) . Apart from the sum or average aggregation, recent works (Li et al., 2019; Zhang et al., 2020) integrate the attention mechanism with the aggregation scheme, i.e., the attentive aggregation, although only a small number (2 ∼ 3) of parallel branches are considered. Theoretically, Zhang et al. (2018) interpret the benefit of multi-branch architecture from reducing the duality gap or the degree of non-convexity. The theory of training general deep neural network has been widely studied, via the stability analysis (Arpit et al., 2019; Zhang et al., 2019a; c; Yang & Schoenholz, 2017; Zhang et al., 2019b; Yang, 2019; Lee et al., 2019) , neural tangent kernel (Jacot et al., 2018; Allen-Zhu et al., 2018; Du et al., 2018; Chizat & Bach, 2018; Zou et al., 2018; Zou & Gu, 2019; Arora et al., 2019; Oymak & Soltanolkotabi, 2019; Chen et al., 2019; Ji & Telgarsky, 2019) . In contrast, we focus on the multi-branch network, which has not been studied theoretically before.

2. MODEL DESCRIPTION AND NOTATIONS

In practice, the multi-branch architecture is often used as a building block in a whole network. In this paper, we describe a multi-branch architecture/network N (•) as follows (see Figure 1 ). • N (•) has C branches {B k } C k=1 , input h in ∈ R p and output h out ∈ R d ; • The aggregation is parameterized with a vector α = (α 1 , . . . , α C ) T : h out := N (h in ) := C k=1 α k • B k (h in ). (1) Each branch B k often consists of multiple layers with various structures and flexible configuration: depth, width, kernel size for convolution layer, activation functions and normalization layer. The aggregation weight is given by α. Such description covers popular multi-branch architectures in state-of-the-art models, e.g., Inception, ResNeXt and Transformer, if specifying B k and α properly.

