ON THE STABILITY OF MULTI-BRANCH NETWORK Anonymous

Abstract

Multi-branch architectures are widely used in state-of-the-art neural networks. Their empirical success relies on some design wisdom, like adding normalization layers or/and scaling down the initialization. In this paper, we investigate the multi-branch architecture from the stability perspective. Specifically, we establish the forward/backward stability of multi-branch network, which leads to several new findings. Our analysis shows that only scaling down the initialization may not be enough for training multi-branch network successfully because of the uncontrollable backward process. We also unveil a new role of the normalization layer in terms of stabilizing the multi-branch architectures. More importantly, we propose a new design "STAM aggregation" that can guarantee to STAbilize the forward/backward process of Multi-branch networks irrespective of the number of branches. We demonstrate that with STAM aggregation, the same training strategy is applicable to models with different numbers of branches, which can reduce the hyper-parameter tuning burden. Our experiments verify our theoretical findings and also demonstrate that the STAM aggregation can improve the performance of multi-branch networks considerably.

1. INTRODUCTION

Multi-branch architecture is a building block in state-of-the-art neural network models for many tasks, e.g., the ResNeXt (Xie et al., 2017) for computer vision and the Transformer (Vaswani et al., 2017) for machine translation. It has been pointed out that the benefit of multi-branch architecture is the parameter efficiency (Xie et al., 2017) . The number of parameters grows linearly with the number of branches but quadratically with the width (the number of neurons in one layer). It has also been argued that the multiple branches can bring diversity if branches are composed of sub-networks with different filters and depths (Huang et al., 2017; Li et al., 2019) . To train multi-branch networks successfully, it usually requires careful designs and some hyperparameter tuning such as adding normalization layers, scaling down the initialization, and adjusting learning rate. As a verifying example, for a trainable single-branch network, simply adding branches multiple times and aggregating their outputs together often do not work as expected, e.g., the training instability of sum aggregation in Figure 4 . This demonstrates the difficulty of training multi-branch network and also motivates us to do this study. In this paper, we try to understand the behavior of training multi-branch network. Specifically, we study the forward and backward process of multi-branch networks, which is believed to govern whether the network is easy to optimize by gradient-based methods. We find out that the aggregation scheme, i.e., "the way of combining the multi-branch outputs" plays a central role in determining the behavior of training multi-branch network. We show that the sum aggregation would become unstable as the number of branches grows, which explains the bad performance of simply adding branches. Moreover, we characterize the condition on the aggregation scheme under which the forward and backward stability is guaranteed. Inspired by the theoretical analysis, we propose a "STAM" aggregation, that can STAbilize Multibranch network, which scales the sum of the branch outputs by a branch-aware factor α (see the later part of Section 3.1 for details). We argue the benefit of STAM aggregation over the sum and average aggregations by analyzing the Hessian of the multi-branch network. We show that STAM permits the same gradient-based optimizer works for different settings, which could reduce lots of tuning burden for training network with flexible number of branches. We further examine the usual design wisdom through the stability lens. As a result, we find that scaling down initialization may control the forward or backward stability but not necessarily the both, which is verified in experiment. We also unveil a new role of normalization layer that it can stabilize the forward and backward process of multi-branch network besides the many wanted and unwanted properties that have been argued before (Ioffe & Szegedy, 2015; Yang et al., 2018; Santurkar et al., 2018; Xiong et al., 2020) . Apart from the usual feedforward multi-branch architecture, we analyze the multi-head attention layer, a multi-branch architecture widely used in natural language processing. We give an upper bound on the multi-head representations when the softmax operator is replaced with max operation. The upper bound unveils the relation between the head dimension and the length of the sequence, which interprets empirical observation well. This relation cannot be discovered if assuming softmax outputs equal probability as in Xiong et al. (2020) . Overall, our contribution can be summarized as follows. • We analyze the forward/backward stability of multi-branch network, under which we can clearly interpret the benefit and potential problem of the practical wisdom, i.e., scaling down initialization and adding normalization layer. • We propose a theoretically inspired STAM aggregation design for multi-branch network, which can handle arbitrary number of branches with a same optimizer. • We also analyze the forward/backward process of multi-head attention layer and identify its special property that has not been characterized before.

1.1. RELATED WORK

Multi-branch architecture, also known as split-transform-merge architecture, has been widely used in computer vision task, namely Inceptions (Szegedy et al., 2017; Chollet, 2017) , ResNeXt (Xie et al., 2017) , and many others (Abdi & Nahavandi, 2016; Ahmed & Torresani, 2017) . In fact, the models in natural language tasks have also leveraged the multi-branch architecture including the BiLSTM (Wu et al., 2016; Zhou et al., 2016) and the multi-head attention layer in Transformer (Vaswani et al., 2017; Anonymous, 2020) . Apart from the sum or average aggregation, recent works (Li et al., 2019; Zhang et al., 2020) integrate the attention mechanism with the aggregation scheme, i.e., the attentive aggregation, although only a small number (2 ∼ 3) of parallel branches are considered. Theoretically, Zhang et al. (2018) interpret the benefit of multi-branch architecture from reducing the duality gap or the degree of non-convexity. The theory of training general deep neural network has been widely studied, via the stability analysis (Arpit et al., 2019; Zhang et al., 2019a; c; Yang & Schoenholz, 2017; Zhang et al., 2019b; Yang, 2019; Lee et al., 2019) , neural tangent kernel (Jacot et al., 2018; Allen-Zhu et al., 2018; Du et al., 2018; Chizat & Bach, 2018; Zou et al., 2018; Zou & Gu, 2019; Arora et al., 2019; Oymak & Soltanolkotabi, 2019; Chen et al., 2019; Ji & Telgarsky, 2019) . In contrast, we focus on the multi-branch network, which has not been studied theoretically before.

2. MODEL DESCRIPTION AND NOTATIONS

In practice, the multi-branch architecture is often used as a building block in a whole network. In this paper, we describe a multi-branch architecture/network N (•) as follows (see Figure 1 ). • N (•) has C branches {B k } C k=1 , input h in ∈ R p and output h out ∈ R d ; • The aggregation is parameterized with a vector α = (α 1 , . . . , α C ) T : h out := N (h in ) := C k=1 α k • B k (h in ). (1) Each branch B k often consists of multiple layers with various structures and flexible configuration: depth, width, kernel size for convolution layer, activation functions and normalization layer. The aggregation weight is given by α. Such description covers popular multi-branch architectures in state-of-the-art models, e.g., Inception, ResNeXt and Transformer, if specifying B k and α properly. Throughout the paper, we use • to denote the l 2 norm of a vector. We further use • and • F to denote the spectral norm and the Frobenius norm of a matrix, respectively. We denote a set of naturals with [n] := {1, . . . , n} and [c : d] = {c, c + 1, . . . , d}. We use bold small case letters, e.g., v, to denote vectors and bold capital case letters, e.g., M , to denote matrices. Moreover, I d×d is the d × d identity matrix, 1 d is d dimensional vector with all 1's, and Vec(M ) stacks the row vectors of matrix M as a long vector.

3. STABILITY OF MULTI-BRANCH FEEDFORWARD NETWORK

To theoretically study the multi-branch network, we introduce a simplified multi-branch feedforward network. Specifically, we assume that each branch is a b-layer fully connected network with ReLU activation, and each layer has the same width m. Branches share the same structure and they differ from each other by the random initialization. One branch is given by B k (h in ) = W k b φ(W k b-1 • • • φ(W k 1 h in )), where W k 1 ∈ R m×p , W k b ∈ R d×m and φ(•) is the ReLU activation φ(•) := max{0, •}. We further introduce -→ W k := (W k 1 , . . . , W k b ) that collects parameters of B k , and -→ W := ( -→ W 1 , . . . , -→ W C ). Next we analyze the forward and backward propagation of the multi-branch network given by (Equation 1) and (Equation 2), and characterize a condition that guarantees forward/backward stability. Based on the theoretical analysis, we propose the STAM aggregation that can stabilize the forward/backward process of multi-branch networks. We further argue why the practical wisdom of scaling down initialization and adding normalization layer works and when it could fail.

3.1. FORWARD AND BACKWARD PROCESS

We assume that the feedforward multi-branch network given by (Equation 1) and (Equation 2) adopts the Kaiming's initialization (He et al., 2016) : entries of W a for a ∈ [1 : b -1] are independently sampled from N (0, 2 m ), and entries of W b are independently sampled from N (0, 1 d ). Then the forward norm is well concentrated around its mean value as follows. Theorem 1. Suppose the multi-branch network N (•) is given by Equation 1 and Equation 2 with Kaiming's initialization. For an input h in , the following holds with probability at least 1 -O(bC) • e -Ω(m 2 /b) over the initialization randomness of -→ W N (h in ) ∈ (1 ± ) k∈[C] α 2 k h in , ( ) where C is the number of branches, and m, b are the width and depth of each branch, respectively. Proof. The proof is based on the Gaussianness of B k (h in ) and a concentration property, whose full version is deferred to Appendix A.2. Theorem 1 is presented for one input sample. If we want such a result holds for all the training samples, the probability loses an n factor by the union bound and becomes 1 -O(nbC) • e -Ω(m 2 /b) . Remark 1. With the same assumption as Theorem 1, we have E N (h in ) = k∈[C] α 2 k h in , ( ) where C is the number of branches and the expectation is over the random initialization. These results are based on the bound of the forward propagation of one feed-forward branch given by (Equation 2), which is studied in Allen-Zhu et al. (2018) and restated in Appendix A. Furthermore, if the weight matrices follows the Gaussian distribution, then B k (h in ) is also roughly Gaussian and jointly Gaussian for different input samples, as the width m goes to infinity. Then the aggregation (Equation 1) can be viewed as a sum of weighted Gaussian vectors. Hence at the initialization, we can characterize the aggregation of multiple branches as above. We next analyze the backward propagation of multi-branch feedforward network. We abuse "gradient" to refer to the values computed through back-propagation even for non-smooth function. We assume the loss function (•, •) is quadratic, i.e., (h out , y * ) = 1 2 h out -y * 2 2 . Hence, the objective function is L( -→ W ) := 1 n n i=1 L i ( -→ W ), where L i ( -→ W ) := (N (x i ), y * i ). We next show the backward process is bounded for each individual sample for the multi-branch network N (•). Theorem 2. With probability at least 1 -(nb) • exp(-Ω(m)) over the randomness of -→ W , it satisfies for every a ∈ [b] and k ∈ [C], every i ∈ [n], ∇ W k a Li( -→ W ) 2 F ≤ O Li( -→ W ) α 2 k × m d , ∇-→ W Li( -→ W ) 2 F ≤ O Li( -→ W ) × mb d k∈[C] α 2 k . Proof. We can compute the gradient with respect to intermediate layer outputs via the backward propagation procedure. The gradient upper bound is guaranteed if the intermediate layer outputs and their gradients are bounded across layers. The full proof is relegated to Appendix A.3. If further assuming the gradient independence condition: weights in backward process can be assumed to be independent from weights in the forward pass (Yang, 2019) , we can estimate the expectation of the gradient norm as follows. Remark 2. Assuming the gradient independence, we have for every a ∈ [b] and k ∈ [C], every i ∈ [n], E ∇ W k a Li( -→ W ) 2 F = Li( -→ W ) α 2 k × m d , E ∇-→ W Li( -→ W ) 2 F = Li( -→ W ) × mb d k∈[C] α 2 k . With Theorem 1 and 2 and two remarks, we can discuss the property of the forward and backward process of the multi-branch network. We can see that both the output of multi-branch network and the gradient are under control if α 2 k ≤ O(1). Specifically, for the sum aggregation, we have α 2 k = C which grows unbounded with the number of branches C. For the average aggregation, we have α 2 k = 1/C which diminishes with the number of branches C. There exists a better choice of α k : α k = 1/ √ C for k ∈ [C] that keeps α 2 k = 1 constant as the number of branches varies. We call it "STAM" aggregation, abbreviating STAble Multibranch aggregation. We plot the output norm of the first residual block in multi-branch ResNets at initialization in Figure 2 . Multi-branch ResNets are generated by varying the number of branches in the residual block with batch normalization removed. We can see that the forward norm of STAM aggregation roughly remains the same, while that of the sum aggregation explodes and that of the average aggregation diminishes, as the number of branches grows. We also analyze the Hessian of different aggregation schemes. We find that the spectral norm of the Hessian, which determines the smoothness of the objective, proportionally scales with the square root of the number of branches for the sum aggregation, while reciprocally scales with the square root of the number of branches for the average aggregation. In contrast, the Hessian for the STAM aggregation keeps unchanged as the number of branches varies. Hence with STAM, the same learning rate works for network with different number of branches. We present the details in Appendix B.

3.2. UNDERSTANDING THE PRACTICAL WISDOM

In practice people have design wisdom to stabilize the multi-branch network. We next analyze them through the stability lens. One can scale down the initialization for multi-branch network. For example, one can initialize each W k l with a scaling-down factor C -1 2b for all l ∈ [b] and k ∈ [C] so that the forward norm of N (h in ) has a bound around h in , irrelevant with C. However, with this initialization, the norm of the output update induced by one gradient descent step, scales with bC 1 b . Alternatively, we suggest initializing each W k l with a scaling-down factor (bC) -1 2(b-1) for all l ∈ [b] and k ∈ [C] to stabilize the backward process such that the expected update on the output is irrelevant with b and C. This indicates that a constant learning rate can train multi-branch network with any b and C, although the forward process has a diminishing output norm scaling with b -1 2 (bC) -1 2(b-1) . More detailed discussion is presented in Appendix A.4. We empirically verify this in Section 5.1. Another widely used technique is adding normalization layers. It is easy to see that normalization layer can stabilize the forward process as it normalizes the output to be with mean 0 and variance 1. Moreover, the normalization layer can also stabilize the backward process as the error signal is divided by the standard deviation which is proportional to

√

C when propagating backward. Therefore, we show a new role of the normalization layer that it automatically stabilizes the forward and backward process of the multi-branch network or in general stabilizes the architecture, besides the previous understanding of increasing the smoothness (Santurkar et al., 2018) or handling the covariance shift (Ioffe & Szegedy, 2015) or implicit structure bias (De & Smith, 2020) . We next examine concrete structures. One popular aggregation scheme for multi-branch outputs is concatenation followed by a linear transformation. This is equivalent to a sum aggregation. The default Xavier initialization (Glorot & Bengio, 2010) would scale down the initialization of the linear transformation by roughly 1/ C to stabilize the forward process but the training could still be unstable because of the unbounded backward process as discussed above and verified in Section 5.1. A well-known multi-branch structure is ResNeXt for image classification task. It uses batch normalization on the output of the residual branch to stabilize the forward and backward processes. However, if we use the pre-act form, i.e., remove the batch normalization on the output of the residual branch and add a batch normalization on the input of the residual branch, the training becomes unstable as the number of branches increases (see Section 5.2). With the pre-act residual block, one needs a normalization layer on the final output of the network, which has been studied in (Xiong et al., 2020) . It could still be unstable for the deep half-precision ResNet, as the output may overflow before the final normalization layer. This is interesting but beyond the scope of this paper. Transformer also uses multi-branch structure in the attention layer and fully-connected layer. It does not use normalization layer on the output of the residual block because of the practical performance concerns. Hence it is observed unstable training or deteriorating performance when adding the number of heads in attention layer or adding fully-connected branches (see Section 5.3). where Q := X Qc , K := X Kc , V := X Ṽ c and Qc ∈ R d×d k , Kc ∈ R d×d k , Ṽ c ∈ R d×dv , Õc ∈ R dv×d are projection (trainable) matrices and C is the number of heads, d v , d k are dimensions. The multi-head attention computes each head representation (with dimension d v ) independently, and then concatenates these heads (C in total) together to form a large representation (Cd v dimensional vector), finally uses a fully-connected layer Õ ∈ R d× (Cdv) projecting back to R d (d is the model embedding dimension). This indicates that the original multi-head attention uses the sum aggregation. It is convenient to first consider the case of one head with dimension d v for studying the multi-head attention layer. We give an upper bound on the forward process when the softmax behaves more like a "max" operator. The softmax output could be rather spreadout even after training which is usually the case after initial training. Proposition 1. Suppose a self-attention layer as described above and assume x i = √ d for i ∈ [n]. If the parameters Q, K, Ṽ are randomly initialized with N (0, 1 d ), then for symbol x i and its head representation h i ∈ R dv , we have E h i 2 d v + (log n -1 + 2(2d v -1) log n), where " " means that the bound holds asymptotically and the expectation is over the initialization. Proof. The proof is based on estimating the extreme Chi-square values, deferred to Appendix C. We can see that besides d v , the norm also scales with log n where n is the number of symbols. In the Transformer literature, one usually fixes C • d v = d, and then the upper bound of head norm grows with C because of the second term in the right hand side of Equation 9. This term is ignored in previous analysis, which can dominate when d v < Ω(log n). We plot the norms of self attention layer (in the first decoder block) in Transformers with varying number of heads on IWSLT de-en task in Figure 3 to illustrate this point. We can see at initialization the norm does not change with number of heads because of the expected behavior at initialization (Xiong et al., 2020) . However, this is not true after training (see the curve at Epoch 10). The assumption that softmax outputs uniform simplex does not hold. With a large C and a small d v , the second term in Equation 9 dominates, which may hurts the forward/backward stability and the training process, matching the observation in practice. For the backward process, we can estimate the gradient at initialization with the gradient independence assumption (Yang, 2019) . Suppose the backward signal is {e 1 , e 2 , ..., e n }, the gradient on Õ is ∇ Õ = i∈[n] e i h T i and E ∇ Õ 2 F = i∈[n] (E h i 2 ) e i 2 , which scales with the head norm. Large forward process results in large gradients, which may lead to unstable training procedure.

5. EXPERIMENTS

In this section, we first verify the theoretical claims in Section 3. We then apply the STAM aggregation on the popular ResNeXt (Xie et al., 2017) model for image classification and on Transformer (Vaswani et al., 2017) to demonstrate its efficacy.

5.1. THEORETICAL VERIFICATION

We conduct some experiments to verify the theoretical claims in Section 3. Specifically, we use a Fixup ResNet20 (Zhang et al., 2019a ) as a backbone network, which uses Fixup initialization without all batch normalization (BN) layers and can be trained to a decent accuracy. The multi-branch ResNets are generated by varying the number of branches in each residual block (Figure 1 ). We train these networks on the CIFAR10/CIFAR100 classification task (Krizhevsky & Hinton, 2009) and record the top 1 validation error in Figure 4 with a standard training procedure, i.e., 128 batch size, 200 epoch, initial learning rate 0.1 and step-wise learning rate decay by 1/10 at epoch 100 and 150 respectively. We compare several approaches: the STAM aggregation with α = 1/ √ C; the sum aggregation with α = 1; the average aggregation with α = 1/C; the backward scaling-down initialization γ = 1/ √ C, referred to as "Back-scaled init" in Figure 4 ; the concatenation then aggregation by a linear layer with scaling-down initialization γ = 1/ √ C, referred to as "linear layer aggregation" in Figure 4 . As shown in Figure 4 , the performance of model trained with STAM aggregation consistently improves as the number of branches C grows. In contrast, the models with sum aggregation fail to converge for C > 4, as predicted in Theorem 1 and 2. The performance of average aggregation degrades as the number of branches is beyond a certain value. In addition, the backward scaling-down initialization works well in terms of stability as we argued in Section 3.2. However, the scheme of concatenation then aggregation with a linear transformation becomes unstable as the number of branches increases even with a scaled-down initialization.

5.2. APPLY STAM AGGREGATION TO TRAIN RESNEXT

We next investigate a well-known multi-branch network, the ResNeXt (Xie et al., 2017) model, which divides the bottleneck residual block into multiple branches. It uses 1x1 convolution to aggregate the branch outputs, which is equivalent to sum aggregation if taking the 1x1 convolution as part of each branch (see Figure 3 in Xie et al. (2017) ). ResNext uses a BN layer to normalize the aggregated result. We have argued how normalization layer can stabilize the forward and backward process in Section 3. It can be observed that ResNeXt fails to converge for C > 2 without such BN layer. To verify the efficacy of STAM aggregation on ResNeXt, we use separate BN layer for each branch, equivalent to viewing the BN layer a part of a branch, and then apply the STAM aggregation on these outputs. The hyperparameters and the data augmentation are the same as in Xie et al. (2017) . From Table 1 , we see that our STAM-aggregation is applicable to the cases with many branches (more than the original paper) and improves over the strong baselines on both CIFAR10/CIFAR100. More details about the experiments can be found in Appendix D.

5.3. APPLY STAM AGGREGATION TO TRAIN TRANSFORMER WITH MANY BRANCHES

In this section, we investigate the aggregation scheme on Transformer models (Vaswani et al., 2017) with machine translation tasks. A Transformer model consists of several encoder and decoder layers, with each layer stacking one or two multi-head attention blocks and one fully connected (FC) block. We conduct two sets of experiments: IWSLT14 German-English (de-en) task and the WMT16 English-German (en-de) task. For baseline models, we choose the default Transformer and the big Transformer for IWSLT14 de-en and WMT16 en-de, respectively. For multi-branch versions of Transformer, we first construct a mini model, and then generate multi-branch Transformers by copying mini-models multiple times. This can explore many branches while controlling model size. (256, 64, 8, 1024), (256, 64, 12, 1536) , ... . These models are trained with the Adam (Kingma & Ba, 2014) optimizer with initial learning rate 1e-3, β 1 = 0.9, β 2 = 0.98 and the inverse sqrt learning rate scheduler, which are standard hyperparemeter choices following Ott et al. (2018) . For the STAM aggregation, we set α = 1/ √ C as C is the equivalent number of mini models. We evaluate the translation quality by BLEU scores by using multi-bleu.perl. We train every model on a single Tesla P40 GPU for 200 epochs. We test on a single model that achieves best BLEU score on the validation set. Other detailed configuration can be found in the code. For the WMT16 English-German (en-de) task, the baseline big Transformer model has configuration d = 1024, d h = 64, d f c = 4096, h = 16. We choose the mini model for WMT16 en-de task with configuration d = 512, d h = 64, h = 8, d f c = 2048, and the generated multi-branch Transformers are with configurations (d, d h , h, d f c , ) = (512, 64, 16, 4096), (512, 64, 32, 8192) , ... . For the STAM aggregation, we set α = 1/ √ C where C is the equivalent number of mini models. All models are trained for 150 epochs on a single node with 8 Tesla V100 GPUs. We average the model parameters of last 5 checkpoints for evaluation. We report the BLEU scores of all the models on test set in Table 2 and 3 . We can see that the performance of original architecture degrades as the number of branches increases, while with STAM aggregation the performance gradually improves as the number of branches increases.

6. CONCLUSION

In this paper, we study the training process of multi-branch network, especially the forward and backward stability. The theory tells that the sum aggregation is not stable as the number of branches increases. Motivated by the theoretical analysis of the multi-branch network, we propose the STAM aggregation that can not only guarantee the forward/backward stability but also allows a same training strategy works for networks with different number of branches. We show that with STAM aggregation, the models can consistently benefit from increasing the number of branches. We believe that our analysis and the proposed STAM aggregation gives practitioners a new direction to design new multi-branch models.

A PROOFS IN SECTION 3

For feed-forward branches given by equation 2, we adopt the Kaiming's initialization He et al. (2016) : entries of W a for a ∈ [1 : b -1] are independently sampled from N (0, 2 m ), and entries of W b are independently sampled from N (0, 1 d ). A.1 ONE BRANCH RESULT Lemma 1. [Lemma 7.1 in Allen-Zhu et al. ( 2018)] For one branch B(h in ) given by equation 2 with Kaiming's initialization, there exists some small constant such that with probability at least 1 -O(b) • e -Ω(m 2 /b) over the initialization randomness of -→ W ∈ (R m×m ) b the following holds ∀a ∈ [b -1] : h a ∈ (1 ± ) h in , where h a := φ(W a h a-1 ), for a = 1, ..., b -1 with h 0 := h in . We note that h a ∈ (1 ± ) h in means (1 -) h in ≤ h a ≤ (1 + ) h in . The proof of this lemma is based on the randomness of -→ W k and the concentration property. A.2 PROOF OF THEOREM 1 Theorem 1. Suppose the multi-branch network N (•) is given by Equation 1and Equation 2 with Kaiming's initialization. For an input h in , the following holds with probability at least 1 -O(bC) • e -Ω(m 2 /b) over the initialization randomness of -→ W N (h in ) ∈ (1 ± ) k∈[C] α 2 k h in , where C is the number of branches, and m, b are the width and depth of each branch, respectively. Proof. In the following derivation, we fix a specific sample i ∈ [n] and the input is h i,in . For notation simplicity, we ignore the i index. We define an event E := { h k b-1 ∈ (1 ± ) h in , ∀k ∈ [C]}. By Lemma 1, we have that E holds with probability 1 -O(bC)e -Ω(m 2 /b) . On the event E, using only the randomness of W k b ∈ R d×m , h k b = W k b h k b-1 are independent Gaussian across k and E[h k b ] = 0, Var[h k b ] = h k b-1 2 I d×d for k ∈ [C]. Hence, we have N (h in ) = k∈[C] B k (h in ) is Gaussian with mean 0 and covariance matrix k∈[C] α 2 k h k b-1 2 I d×d . By law of large number, we have N (h in ) 2 ∈ (1 ± ) 2 k α 2 k h in 2 . ( ) The claim is proved.

Proof of Remark 1

The proof of Remark 1 is straightforward. In expectation we always have E[h k l |h k l-1 ] = 0, E h k l 2 h k l-1 = h k l-1 2 for k ∈ [C] and l ∈ [b]. Furthermore as h k b 's are independent across k, then E N (h in ) 2 = k α 2 k h in 2 . A.3 PROOF OF THEOREM 2 Given the branch equation 2, we further introduce notations to denote intermediate representation:  g k a := W k a h k a-1 , ∈ [C], every i ∈ [n], ∇ W k a Li( -→ W ) 2 F ≤ O Li( -→ W ) α 2 k × m d , ∇-→ W Li( -→ W ) 2 F ≤ O Li( -→ W ) × mb d k∈[C] α 2 k . Deep neural networks are often trained with gradient based optimization algorithms. The gradient with respect to the parameter is computed through back-propagation, e.g., ∂W l = ∂h l • g T l-1 , where ∂• represents the gradient of the objective with respect to •. Therefore, the gradient upper bound is guaranteed if g l-1 and ∂h l are bounded across layers and iterations. We next show the backward process is bounded for each individual sample for the multi-branch network N (•). Proof. The gradient is computed as follows via back-propagation. We ignore the sample index i for notational simplicity. h out = - → g b α, ∂h out ∂g k b = α k I d×d , ∂g k b = ∂h out ∂g k b T • ∂h out = α k ∂h out , ∂h k b-1 = (W k b ) T ∂g k b , ∂W k b = ∂g k b • (h k b-1 ) T (14) ∂g k b-1 = D k b-1 ∂h b-1 , • • • (15) ∂h k 0 = (W k 1 ) T ∂g k 1 , ∂W k 1 = ∂g k 1 • (h k 0 ) T . For quadratic loss and sample i ∈ [n], we have that ∂h i,out 2 = 2L i ( -→ W ). We have ∇ W k a L i ( -→ W ) F = ∂g k i,a • h k i,a-1 F = D k a (W k a+1 ) T • • • D k b-1 (W k b ) T ∂h i,out • h k i,a-1 ≤ O( m/d) • α k L i ( -→ W ) • O(1) x i , where the last inequality is due to Theorem 1 and the (Allen-Zhu et al., 2018, Lemma7.4b ). Thus, with high probability 1 -O(nbC) • exp(-Ω(m)) for all i ∈ [n], a ∈ [b] and k ∈ [C], we have ∇ W k a L i ( -→ W ) 2 F ≤ O Li( -→ W ) d α 2 k × m under the assumption x i = 1.

Proof of Remark 2

To establish the expectation value on the backward process, it requires the gradient independence assumption (Yang, 2019) , which can be verified and argued for certain cases. With this assumption, we can take the expectation on the forward pass and the backward pass independently in (Equation 17). It is easy to obtain the expectation estimation in Remark 2.

A.4 MORE DISCUSSION ON PRACTICAL WISDOM

We argue the choices of scaling down initialization in detail. Without loss of generality, we first initialize the parameters following the Kaiming's scheme, the same as in Section 3. Then we multiply the parameters in multi-branch block by a same scaling down factor γ. We use the sum aggregation α k = 1 for k ∈ [C]. For forward stability, we need E N (h in ) 2 ≤ O(1), which requires k∈[C] (γ b ) 2 ≤ 1. Therefore we obtain γ ≤ C -1 2b . For the backward process, we focus on the output update induced by one gradient descent step. That is N (h in ; -→ W + η∇-→ W ) -N (h in ; -→ W ). Suppose the backward error signal on h out is e with e = 1. Then following the proof of Theorem 2, we have E ∇ W k a F = γ b-1 • m/d • e • h in = Ω(γ b-1 ), E B k (h in ; W k l + η∇ W k l ) -B k (h in ; W k l ) = Ω(ηγ 2(b-1) ), where the second equality is due to Taylor expansion and the forward step has a factor γ b-1 . Then the forward output update of branch k is B k h in ; -→ W k + η∇-→ W k -B k h in ; -→ W k ≈ l∈[b] B k h in ; W k l + η∇ W k l -B k h in ; W k l , where the "≈" ignores second order perturbation. Hence E B k (h in ; -→ W k + η∇-→ W k ) - B k (h in ; -→ W k ) ≈ Ω(ηbγ 2(b-1) ). Summing over C branches, we have E N (h in ; -→ W + η∇-→ W ) -N (h in ; -→ W ) ≈ Ω(ηbCγ 2(b-1) ). Hence, to make the expected output update irrelevant with C and b, it requires γ = (bC) -1 2(b-1) . With γ = (bC) -1 2(b-1) , the forward process will not explode but obtain a diminishing output norm Ω(b -1 2 (bC) -1 2(b-1) ).

B HESSIAN ANALYSIS

Next to obtain a refined understanding on the property of different aggregation schemes, we analyze the Hessian of a multi-branch network. Specifically for simplicity, we assume that B k has only one linear layer: B k (h in ) = W k h in . We can compute the Hessian of the objective with respect to the input h in and the learnable parameter -→ W , H hin = k α k W k T k α k W k , H-→ W = αα T ⊗ I d×d ⊗ Eh in h T in , where H-→ W := H Vec( -→ W ) and Vec( -→ W ) := (Vec(W 1 ) T , Vec(W 2 ) T , . . . , Vec(W k ) T ) T is the long vector stacking -→ W , ⊗ is the Kronecker product, and E is average over the training samples. Fact 1. For B k (h in ) = W k h in and α = α1 C , we have the spectral norm of Hessian, E H hin = Ω(α 2 C), and H-→ W = α 2 C Eh in h T in . ( ) where the first expectation is over the randomness at initialization and the second expectation is over the training samples. Proof. The Hessians are written as H hin = k α k W k T k α k W k α k =α = α 2 k W k T k W k , H-→ W = αα T ⊗ I d×d ⊗ Eh in h T in α k =α = α 2 (1 C×C ) ⊗ I b×b ⊗ Eh in h T in , where ⊗ represents the Kronecker product and H-→ W has dimension Cbp × Cbp. Because entries of W k are initialized with Gaussian distribution N (0, 1/d), k W k follows Gaussian distribution N (0, C/d). The singular value of Gaussian matrix is given by the following Bai-Yin's law (BAI & YIN, 1993) . Lemma 2. Let A ∈ R N ×n , and entries of A are independent standard Gaussian random variables. Suppose that dimensions N and n grow to infinity while the aspect ratio n/N converges to a constant in [0, 1], one has s max (A) = √ N + √ n + o( √ n) almost surely. ( ) where s max (A) are the largest singular value of A. Hence E H hin = α 2 C(1 + m/d) 2 = Ω(α 2 C) if viewing m/d some fixed hyper-parameter. For the spectral norm of H-→ W , we use the property of Kronecker product λ max (A ⊗ B) = λ max (A) • λ max (B), where λ max (•) is the largest eigenvalue of the positively semi-definite matrix •. Hence H-→ W = α 2 λ max (1 C×C ) • 1 • λ max (Eh in h T in ) = α 2 C Eh in h T in , where the second equality is due to the fact λ max (1 C×C ) = C and λ max (I C×C ) = 1. The spectral norm of the Hessian corresponds to the smoothness or the Lipschitz coefficient of the gradient. It determines the largest allowable learning rate for gradient descent algorithm. For the sum aggregation, the spectral norm H hin = Ω(C) and H-→ W = C Eh in h T in . Hence they both are C times larger than those in the case with only one branch. This indicates the learning rate for the gradient descent needs properly scaling down with the number of branches to guarantee convergence. For the average aggregation, at initialization, the spectral norm H hin = Ω(1/C) and H-→ W = 1 C Eh in h T in . They both are shrunk by 1/C compared to the case with only one branch. This indicates the learning rate or the gradient needs properly scaling up with the number of branches. Thus one cannot apply the same training strategy to train network with varying number of branches and the sum or average aggregation. In contrast, for the STAM aggregation, the spectral norm H hin = Ω(1) and H-→ W = Eh in h T in . They both remain the same when the number of branches varies. This indicates that we can apply the same training strategy for varying number of branches.

Proof of Proposition 1

Without loss of generality, we assume that there are only one head and ignore the head subscript for simplicity. For a specific symbol representation x, the query is q = x T Q, the keys are k 1 = x T 1 K, ..., k n = x T n K (i.e., K = X K ), the values are v 1 = x T 1 Ṽ , ..., v n = x T n Ṽ (i.e., V = X Ṽ . We note that q, k j , v j are row vectors. Let (p 1 , ..., p n ) = softmax( qk T 1 √ d k , qk T 2 √ d k , ..., qk T n √ d k ). We note that for symbol x, h = (p 1 , . . . , p n )V and then we have an upper bound on the head representation h = (p 1 , . . . , p n )X Ṽ ≤ max i∈ [n] x i Ṽ . Suppose x i = √ d for all i ∈ [n], which is reasonable assumption because of the layer normalization. Then x i Ṽ 2 is a Chi-square χ 2 dv variable. Next we estimate the maximum of n independent Chi-square variables by using the following asymptotic relation. For a random variable U ∼ χ 2 ν and large ν, we have √ 2U -√ 2ν -1 ∼ N (0, 1). We use the estimation of extreme value of Gaussian variables. For Z = max i∈[n] V i and V i ∼ N (0, 1), we have E(Z) ≤ √ 2 log n. Based on the above estimation, we have asymptotically max i∈ [n] x i Ṽ 2 d v -1 2 + log n + 2(2d v -1) log n. We can also estimate the maximum of a sequence of χ 2 variables by using the relation that the extreme value of χ 2 distribution (a case of Gamma distribution) asymptotically converges to Gumbel distribution and get a similar result. Hence, we have h 2 d v + (log n -1 + 2(2d v -1) log n).

C.2 BACKWARD PROCESS

We focus on the case with one head and omit the subscript c. We assume d v = d k = d q =: d h . We rewrite the forward process as follows. Given the parameters Q ∈ R d×d h , K ∈ R d×d h , Ṽ ∈



FORWARD AND BACKWARD PROCESS OF MULTI-HEAD ATTENTION LAYERIn this section, we analyze the multi-head attention layer, a multi-branch structure used in Transformer(Vaswani et al., 2017) with each head viewed as one branch. It is worthy to note that the multi-head attention layer behaves differently from the feed-forward network and at the same time it is in general very hard to analyze, because of the softmax operator and the inter-symbol dependence. Previous work assumes that the softmax outputs equal weightsXiong et al. (2020) and ignores the inter-symbol dependence, which does not fully reflect the attention behavior.Suppose that the input is a sequence of symbols X = (x 1 , . . . , x n ) T and each row of X is a symbol representation x i ∈ R d . The multi-head self-attention layer, is given by MultiHead(Q, K, V ) = Concat(h 1 , ..., h C ) Õ,(7)h c = softmax QK T √ d k V , for c ∈ [C],(8) https://github.com/pytorch/fairseq/blob/v0.9.0/examples/translation/README.md https://github.com/AnonymousAKES/STAM-aggregation



Figure 1: A multi-branch network.

Figure 3: Head representation norm of the self attention layer in the first decoder block. We fix hd v = d and vary h. When d v is small, the representation norm grows with h.

Figure 4: Top1 validation error of multibranch ResNets on CIFAR10. Models trained with STAM aggregation consistently benefit from increasing branches.

IWSLT14 de-en task, the baseline Transformer configuration is d = 512, d h = 128, d f c = 1024, h = 4. We choose the mini model for IWSLT14 de-en with configuration d = 256, d h = 64, h = 4, d f c = 512, and the generated multi-branch Transformers are with configurations (d, d h , h, d f c ) =

and h k a := φ(g k a ), for a = 1, ..., b. with h k 0 := h in and h k b := g k b for all k ∈ [C]. Let D k a be a diagonal matrix representing the activation state with (D k a ) jj = 1, if (g k a ) j > 0; 0, otherwise. Let -→ g b := (g 1 b , g 2 b , . . . , g k b ) ∈ R d×b . Theorem 2. With probability at least 1 -(nb) • exp(-Ω(m)) over the randomness of -→ W , it satisfies for every a ∈ [b] and k

Top1 validation accuracy (in %) of ResNeXt models. The numbers with * are reported in(Xie et al., 2017) and other numbers are our implementation (averaged over 3 repeats).

BLEU scores on IWSLT14 de-en test sets. The higher, the better. * is baseline setting.

BLEU scores on newstest2014 for WMT En-De. The higher, the better. The number with * is reported in(Ott et al., 2018). There are several hyperparameters about the model setting: model/embedding dimension d, head dimension d h , number of heads h, intermediate FC dimension d f c . More detailed introduction on Transformer can be found in Appendix C.

acknowledgement

We use head-level dropout which drop heads representation (before the output projection of attention block). The head-level dropout probability is p. We set p = 0.3 for IWSLT14 De-En task and p = 0.1 for WMT16 En-De task, respectively. Our source code is available at anonymous GitHub page 2 .

annex

R d×d h , Õ ∈ R d h ×d . X = (x 1 , ...., x n ) T , with each row being a symbol embedding, ∈ R n×d (25)P = softmax( P ), with each row being a probability simplex, ∈ R n×n (30)O = H Õ, with each row being the new representation of a symbol, ∈ R n×d (32)For the backward process, it can be written as follows. Suppose the error signals on the output O is E = (e 1 , ..., e n ) T , with each row being an error signal on a symbol representation.We can further compute the expected norm of the gradient as in the proof of Proposition 1 under the gradient independence assumption (Yang, 2019) .

C.3 MORE ABOUT TRANSFORMER

Apart from the multi-head attention block, the fully-connected (FC) block is also a multi-branch architecture with sum aggregation. We use d f c to denote the intermediate dimension of fully connected block. The fully connected block consists of two FC layers with weightsIn practice, d f c is several times larger than d. For example, the "big" Transformer in Vaswani et al. (2017) uses d f c = 4d. These sum aggregation operations restrict practitioners to explore Transformer models with large h and d f c .

D MORE DETAILS ABOUT EXPERIMENTS

The training procedure of ResNeXt model in Section 5 is the same as Xie et al. (2017) . Specifically, we train the models on 8 GPUs with batch size 128, total number of epochs 300 and weight decay 0.0005. We use Nesterov momentum with coefficient 0.9. The learning rate is multiplied by 0.1 at the 150-th and 225-th epoch. For data augmentation, we take a random crop with size 32x32 from a zero-padded 40x40 image or its horizontal flipping. Now we introduce more details of Transformers on translation tasks in Section 5. The IWSLT14 De-En dataset is collected from Fairseq official site 1 , which contains 156K/7K/7K sentence pairs for training/validation/test, respectively. For WMT English-German task, we use the same data setup as Ott et al. (2018) . Specifically, we train the model using the training data of WMT16 , which contains 4.5M sentence pairs. The validation and test sets are newstest13 and newstest14, respectively.

