SCALING FORWARD GRADIENT WITH LOCAL LOSSES

Abstract

Forward gradient learning computes a noisy directional gradient and is a biologically plausible alternative to backprop for learning deep neural networks. However, the standard forward gradient algorithm, when applied naively, suffers from high variance when the number of parameters to be learned is large. In this paper, we propose a series of architectural and algorithmic modifications that together make forward gradient learning practical for standard deep learning benchmark tasks. We show that it is possible to substantially reduce the variance of the forward gradient estimator by applying perturbations to activations rather than weights. We further improve the scalability of forward gradient by introducing a large number of local greedy loss functions, each of which involves only a small number of learnable parameters, and a new MLPMixer-inspired architecture, LocalMixer, that is more suitable for local learning. Our approach matches backprop on MNIST and CIFAR-10 and significantly outperforms previously proposed backprop-free algorithms on Im-ageNet.

1. INTRODUCTION

Most deep neural networks today are trained using the backpropagation algorithm (a.k.a. backprop) (Werbos, 1974; LeCun, 1985; Rumelhart et al., 1986) , which efficiently computes the gradients of the weight parameters by propagating the error signal backwards from the loss function to each layer. Although artificial neural networks were originally inspired by biological neurons, backprop has always been considered as "biologically implausible" as the brain does not form symmetric backward connections or perform synchronized computations. From an engineering perspective, backprop is incompatible with a massive level of model parallelism, and restricts potential hardware designs. These concerns call for a drastically different learning algorithm for deep networks. In the past, there have been attempts to address the above weight transport problem by introducing random backward weights (Lillicrap et al., 2016; Nøkland, 2016) , but they have been found to scale poorly on larger datasets such as ImageNet (Bartunov et al., 2018) . Addressing the issue of global synchronization, several papers showed that greedy local loss functions can be almost as good as end-to-end learning (Belilovsky et al., 2019; Löwe et al., 2019; Xiong et al., 2020) . However, they still rely on backprop for learning a number of internal layers within each local module. Approaches based on weight perturbation, on the other hand, directly send the loss signal back to the weight connections and hence do not require any backward weights. In the forward pass, the network adds a slight perturbation to the synaptic connections and the weight update is then multiplied by the negative change in the loss. Weight perturbation was previously proposed as a biologically plausible alternative to backprop (Xie & Seung, 1999; Seung, 2003; Fiete & Seung, 2006) . Instead of directly perturbing the weights, it is also possible to use forward-mode automatic differentiation (AD) to compute a directional gradient of the final loss along the perturbation direction (Pearlmutter, 1994) . Algorithms based on forward-mode AD have recently received renewed interest in the context of deep learning (Baydin et al., 2022; Silver et al., 2022) . However, existing approaches suffer from the curse of dimensionality, and the variance of the estimated gradients is too high to effectively train large networks. In this paper, we revisit activity perturbation (Le Cun et al., 1988; Widrow & Lehr, 1990; Fiete & Seung, 2006) as an alternative to weight perturbation. As previous works focused on specific settings, we explore the general applicability to large networks trained on challenging vision tasks. Forward gradient and reinforcement learning. Our work leverages forward-mode automatic differentiation (AD), which was first proposed by Wengert (1964) . Later it was used to learn recurrent neural networks (Williams & Zipser, 1989 ) and to compute Hessian vector products (Pearlmutter, 1994) . Computing the true gradient using forward-mode AD requires the full Jacobian, which is often large and expensive to compute. Recently, Baydin et al. (2022) and Silver et al. (2022) proposed to update the weights based on the directional gradient along a random or learned perturbation direction. They found that this approach is sufficient for small-scale problems. This general family of algorithms is also related to reinforcement learning (RL) and evolution strategies (ES), since in each case the network receives a global reward. RL and ES have a long history of application in neural networks (Whitley, 1993; Stanley & Miikkulainen, 2002; Salimans et al., 2017) , and they are effective for certain continuous control and decision-making tasks. Clark et al. (2021) found global credit assignment can also work well in vector neural networks where weights are only present between vectorized groups of neurons. Greedy local learning. There have been numerous attempts to use local greedy learning objectives for training deep neural networks. Greedy layerwise pretraining (Bengio et al., 2006; Hinton et al., 2006; Vincent et al., 2010) trains individual layers or modules one at a time to greedily optimize an objective. Local losses are typically applied to different layers or residual stages, using common supervised and self-supervised loss formulations (Belilovsky et al., 2019; Nøkland & Eidnes, 2019; Löwe et al., 2019; Belilovsky et al., 2020) . Xiong et al. (2020) ; Gomez et al. (2020) proposed to use overlapped losses to reduce the impact of greedy learning. Patel et al. (2022) proposed to split a network into neuron groups. Laskin et al. (2020) applied greedy local learning on model parallelism training, and Wang et al. (2021) proposed to add a local reconstruction loss for preserving information. However, most local learning approaches proposed in the last decade rely on backprop to compute the weight updates within a local module. One exception is the work of Nøkland & Eidnes (2019) , which avoided backprop by using layerwise objectives coupled with a similarity loss or a feedback alignment mechanism. Gated linear networks and their variants (Veness et al., 2017; 2021; Sezener et al., 2021) ask every neuron to make a prediction, and have shown interesting results on avoiding catastrophic forgetting. From a theoretical perspective, Baldi & Sadowski (2016) provided insights and proofs on why local learning can be worse than global learning. Asymmetric feedback weights. Backprop relies on weight symmetry: the backward weights are the same as the forward weights. Past research has looked at whether this constraint is necessary. Lillicrap et al. (2016) proposed feedback alignment (FA) that uses random and fixed backward weights and found it can support error driven learning in neural networks. Direct FA (Nøkland, 2016) uses a single backward layer to wire the loss function back to each layer. There have also been methods that aim to explicitly update backward weights. Recirculation (Hinton & McClelland, 1987) and target propagation (TP) (Bengio, 2014; Lee et al., 2015; Bartunov et al., 2018) use local reconstruction objective to learn separate forward and backward weights as approximate inverses of each other. Ladder networks (Rasmus et al., 2015) found local reconstruction objectives and asymmetric weights can help achieve strong semi-supervised learning performance. However, Bartunov et al. (2018) reported both FA and TP algorithms do not scale to larger problems such as ImageNet, where their error rates are over 90%. Liao et al. (2016) ; Xiao et al. (2019) proposed sign symmetry (SS) where each backward connection weight share the same sign as the forward counterpart. Akrout et al. (2019) proposed weight mirroring and the modified Kolen-Pollack algorithm (Kolen & Pollack, 1994) to align forward and backward weights. Woo et al. (2021) proposed to update using activities from several layers below to avoid bidirectional connections. Compared to these works, we circumvent the issue of weight symmetry, and more generally network symmetry, by using only reward (and the change rate thereof), instead of backward weights. Biologically plausible perturbation learning. Forward gradient is related to perturbation learning in the biology context. Traditionally, neural plasticity learning rules focus on deriving weight updates as a function of the input and output activity of a neuron (Hebb, 1949; Widrow & Hoff, 1960; Oja, 1982; Bienenstock et al., 1982; Abbott & Nelson, 2000) . Weight perturbation learning (Jabri & Flower, 1992) , on the other hand, is much more general as it permits any form of global reward (Schultz et al., 1997) . It was developed in both rated-based and spiking-based formuations (Xie & Seung, 1999; Seung, 2003) . Activity (or node) perturbation was proposed in shallow networks (Le Cun et al., 1988; Widrow & Lehr, 1990) and later in a spike-based continuous time network (Fiete & Seung, 2006) , where it was interpreted as the perturbation of the conductance of neurons. Werfel et al. (2003) showed that backprop has a faster convergence rate than perturbation learning, and activity perturbation wins over weight perturbation by another factor. In our work, we show activity perturbation has lower gradient estimation variance compared to weight perturbation.

3. FORWARD GRADIENT LEARNING

In this section, we review and establish the technical background for our learning algorithm. We first review the technique of forward-mode automatic differentiation (AD). Second, we formulate two different types of perturbation in the weight space or activity space.

3.1. FORWARD-MODE AUTOMATIC DIFFERENTIATION (AD)

Let f : R m → R n . The Jacobian of f , J f , is a matrix of size n × m. Forward-mode AD computes the matrix-vector product J f v, where v ∈ R m . It is defined as the directional gradient along v evaluated at x: J f v := lim δ →0 f (x + δv) -f (x) δ . For comparison, backprop, also known as reverse-mode AD, computes the vector-Jacobian product vJ f , where v ∈ R n , which corresponds to the last term in the chain rule. In contrast to reverse-mode AD, forward-mode AD only requires one forward pass, which is augmented with the derivative information. To compute the Jacobian vector product of a node in a computation graph, first the input node will be augmented with v, which is the vector to be multiplied. Then for other nodes, we send in a tuple of (x, x ) as inputs and compute a tuple (y, y ) as outputs, where x and y are the intermediate derivatives at node x and node y, i.e. y = dy dx x , and dy dx is the Jacobian between y and x. In the JAX library (Bradbury et al., 2018) , forward-mode AD is implemented as jax.jvp.

3.2. WEIGHT-PERTURBED FORWARD GRADIENT

Weight perturbation to generate weight updates was originally explored in (Barto et al., 1983; Xie & Seung, 1999; Seung, 2003) . Baydin et al. (2022) uses the technique of forward-mode AD to implement weight perturbation, which is better than finite differences in terms of numerical stability. Let w ij be the weight connection between unit i and j, and f be the loss function. We can estimate the gradient by sampling a random matrix with iid elements v ij drawn from a zero-mean unit-variance Gaussian distribution. The estimator is g w (w ij ) = i j ∇w i j v i j v ij . Intuitively, this estimator samples a random perturbation direction v ij and tests how it aligns with the true gradient ∇w i j by using forward-mode to perform the dot product, and then multiplies the scalar alignment with the perturbation direction again. Baydin et al. (2022) referred this form of gradient estimation using forward-mode AD as "forward gradient". To distinguish with another form Unbiased? Avg. Variance (shared) Avg. Variance (independent) gw(•) Yes pq+2 N V + (pq + 1)S pq+2 N V + pq+1 N S ga(•) Yes q+2 N V + (q + 1)S q+2 N V + q+1 N S Table 1 : Comparing weight (g w ) and activity (g a ) perturbation. V =dimension-wise avg. gradient variance, S=dimension-wise avg. squared gradient norm; p=fan-in; q=fan-out; N =batch size. of perturbation we detail later, we refer this to as "weight-perturbed forward gradient", or simply as "weight perturbation".

3.3. ACTIVITY-PERTURBED FORWARD GRADIENT

An alternative to perturbing the weights is to instead perturb the activities, which can reduce the number of perturbation dimensions per example. Activity perturbation was originally explored in Le Cun et al. (1988) ; Widrow & Lehr (1990) under restrictive assumptions. Here, we introduce a general way to estimate gradients using activity perturbation. It is potentially biologically plausible, since it could correspond to perturbation of the conductance in each neuron (Fiete & Seung, 2006) . Here, we focus on a discrete-time rate-based formulation for simplicity. Let x i denote the activity of the i-th presynaptic neuron and z j denote that of the j-th post-synaptic neuron before the non-linear activation function, and u j be the perturbation of z j . The activity-perturbed forward gradient estimator is g a (w ij ) = x i j ∇z j u j u j , where the inner product between ∇z and u is again computed by using forward-mode AD.

3.4. THEORETICAL PROPERTIES

In this section we aim to analyze the expectation and variance properties of forward gradient estimators. We focus our analysis on the gradient of one weight matrix {w ij }, but the conclusion holds for a network with many weight matrices too. Table 1 summarizes the theoretical resultsfoot_0 . With a batch size of N , independent perturbation can achieve 1/N reduction of variance, whereas shared perturbation has a constant variance term dominated by the squared gradient norm. However, when performing independent weight perturbation, matrix multiplications cannot be batched because each example's activation vector is multiplied with a different weight matrix. By contrast, independent activity perturbation admits batched matrix multiplications. Moreover, activity perturbation enjoys a factor of fan-in (p) times smaller variance compared to weight perturbation since the number of perturbed elements is the number of output units instead of the size of the whole weight matrix. The only drawback of activity perturbation is the memory required for storage of intermediate activations, in exchange for a factor of N p reduction in variance. However, for both activity and weight perturbation, the variance still grows with larger networks. In Section 4 we will further reduce the variance by introducing local loss functions.

3.5. CONTINUOUS-TIME RATE-BASED MODELS

Forward-mode AD can be viewed as computing the first-order time derivative in a continuous-time physical system. Suppose the tuples passed between nodes of the computation graph are (x, ẋ), where ẋ is the change in x over time. The computation is then the same as forward-mode AD. For each node, ẏ = dy dx ẋ, where dy dx is the Jacobian between the output and the input. Note that in a physical system we don't have to explicitly perform the differentiation operation by running two forward passes. Instead the first-order derivative information is readily available in the analog signal, and we only need to plug the output signal into a differentiator circuit. The activity-perturbed learning rule for a continuous time system is thus ẇij ∝ x i ẏj ṙ, where x i is the pre-synaptic activity, and ẏj is the rate of change in the post-synaptic activity, which is the perturbation direction for a small period of time, and ṙ is the rate of change of reward (or the negative loss). The reward controls whether learning is Hebbian or anti-Hebbian. Both Hinton et al. (2007) and Bengio et al. (2017) propose to use a product of pre-synaptic activity and the rate of change of postsynaptic activity. However, they did not consider using the rate of change of reward as a modulator and instead relied on another set of feedback weights to communicate the error signal through inputs. In contrast, we show that by broadcasting the rate of change of reward, we can actually bypass the weight transport problem. 

3.6. ACTIVATION SPARSITY AND NORMALIZATION FUNCTIONS

In networks with ReLU activations, we can leverage ReLU sparsity to achieve further variance reduction, because the inactivated units will have zero gradient and therefore we should not perturb these units, and set the perturbation to be zero. Normalization layers are often added in deep neural networks after the linear layer. To compute the correct gradient in activity perturbation, we also need to account for normalization in the weight update rule. Since there is no backward weight connections, one option is to simply apply backprop on normalization layers. However, we also found that it is also fine to ignore the gradient of normalization layer when using layer normalization.

4. SCALING WITH LOCAL LOSSES

As we have explained in the previous section, perturbation learning can suffer from a curse of dimensionality: the variance grows with the number of perturbation dimensions, and in deep networks there are often millions of parameters changing at the same time. One way to limit the number of learnable dimensions is to divide the network into submodules, each with a separate loss function. In this section, we will explore several ways to increase the number of local losses to tame the variance. 1) Blockwise loss. First, we will divide the network into modules in depth. Each module consists of several layers. At the end of each module, we compute a loss function, and that loss is used to update the parameters in that module. This approach is equivalent of adding a "stop gradient" operator in between modules. Such local greedy losses were previously explored in Belilovsky et al. (2019) and Löwe et al. (2019) . 2) Patchwise loss. Sensory input signals such as images have spatial dimensions. We will apply a separate loss patchwise along these spatial dimensions. In the Vision Transformer architecture (Vaswani et al., 2017; Dosovitskiy et al., 2021) , each spatial token represents a patch in the image. In modern deep networks, parameters in each spatial location are often shared to improve data efficiency and reduce memory bandwidth utilization. Although naive weight sharing is not biologically plausible, we still consider shared weights in this work. It may be possible to mimic the effect of weight sharing by adding knowledge distillation (Hinton et al., 2015) losses in between patches. 3) Groupwise loss. Lastly, we turn to the channel dimension. To create multiple losses, we split the channels into a number of groups, and each group is attached to a loss function (Patel et al., 2022) . To prevent groups from communicating between each other, channels are only connected to other channels within the same group. A grouped linear layer is computed as z g,j = i w g,i,j x g,i , for individual group g. Whereas previous work used channel groups to improve computational efficiency (Krizhevsky et al., 2012; Ioannou et al., 2017; Xie et al., 2017) , in our work, adding groups contributes to the total number of losses and thus reduces variance. Feature aggregators. Naively applying losses separately to the spatial and channel dimensions leads to suboptimal performances, since each dimension contains only local information. For losses of Groups Stack + StopGradient Patches AvgPool + StopGradient Avg Avg Avg Avg N x HW x G x C/G N x HW x G x C N x HW x HW x G x C N x HW x G x C B. Replicated A. Conventional Patches Avg N x C N x HW x C Figure 3: Feature aggregator designs. A) In the conventional design, average pooling is performed to aggregate features from different spatial locations. B) We propose the replicated design, features are first concatenated across groups and then averaged across spatial locations. We create copies of the same feature with different stop gradient masks so that we obtain more local losses instead of a global one. The stop gradient mask makes sure that perturbation in one spatial group corresponds to its loss function. The numerical value of the loss function is the same as the conventional design. standard tasks such as classification, the model needs a global view of the inputs to make a decision. Standard architectures obtain this global view by performing global average pooling layer before the final classification layer. We therefore explore strategies for aggregating information from other groups and spatial patches before the local loss function. We would prefer to perform aggregation without reducing the total number of dimensions. We thus propose a replicated design for feature aggregation, shown in Figure 3 . First, channel groups are copied and communicated to one another, but every group except the active group itself is masked with stop gradient so that other groups do not affect the forward gradient computation: x p,g = [StopGrad(x p,1 ...x p,g-1 ), x p,g , StopGrad(x p,g+1 , ..., x p,G )], where p and g index the patches and groups respectively. Similarly, each spatial location is also copied, communicated, and masked, and then averaged locally: x p,g = 1 P x p,g + p =p StopGrad(x p ,g ) . (5) The output of feature aggregation is the same as that of the conventional global average pooling layer. The difference is that here the loss is replicated and different patch groups are activated in each loss. Learning objectives. We consider the supervised classification loss and the contrastive InfoNCE loss (van den Oord et al., 2018; Chen et al., 2020) , which are the two most commonly used losses in image representation learning. For supervised classification, we attach a shared linear layer (shared across p, g) on top of the aggregated features for a cross entropy loss: L s p,g = -k t k log softmax(W l x p,g ) k . The loss is of the same value across each group and patch location. For contrastive learning, the linear layer becomes a linear feature projector. Suppose x  L c p,g = - n log (W x (1) n,p,g ) StopGrad(W x (2) n ) m (W x (1) n,p,g ) StopGrad(W x (2) m ) . Note that we add a stop gradient operator on the second view. It is usually unnecessary to add this stop gradient in the InfoNCE loss; however, we found that perturbationbased methods require a stop gradient and otherwise the loss will not go down. This is likely because we share the perturbations on both views, and having the same perturbation will increase the dot product between the two views but is not desired from a representation learning perspective. Figure 4 shows a comparison of the loss curves. Non-shared perturbations also work but are worse than stop gradient. et al., 2021) , which consists of fully connected networks and residual blocks. We leverage the fully connected networks so that each spatial patch performs computations without interfering with other patches, which is more compatible with our local learning objective. An image is divided into non-overlapping patches (i.e. tokens), and each block consists of token and channel mixing layers. Figure 1 shows the high level architecture, and Figure 2 shows the detailed diagram for one residual block. We add a linear projector/classification layer to attach a loss function at the end of each block. The last layer always uses backprop to update weights. For token mixing layers, we use one linear fully connected layer instead of an MLP, since we would like to make each block as shallow as possible. Before the last channel mixing layer, features are reshaped into a number of groups, and the last layer is fully connected within each feature group. Table 2 shows architectural details for the different sizes of models we investigate. Normalization. There are many ways of performing normalization within a neural network across different tensor dimensions (Krizhevsky et al., 2012; Ioffe & Szegedy, 2015; Ba et al., 2016; Ren et al., 2017; Wu & He, 2018) . We opted for a local variant of layer normalization that normalizes within each local spatial patch of features (Ren et al., 2017) . For grouped linear layers, each group is normalized separately (Wu & He, 2018) . Empirically, we found such local normalization performs better on contrastive learning experiments and about the same as layer normalization on supervised experiments. Local normalization is also more biologically plausible as it does not perform global communication. Conventionally, normalization layers are placed after linear layers. In MLPMixer (Tolstikhin et al., 2021) , layer normalization is placed at the beginning of each residual block. We found it is the best to place normalization before and after each linear layer, as shown in Figure 2 . Empirically this design choice does not make much difference for backprop, but it allows forward gradient learning to learn much faster and achieve lower training errors. Efficient implementation of replicated losses. Due to the design of feature aggregation and replicated losses, a naïve implementation of groups can be very inefficient in terms of both memory consumption and compute. However, each spatial group actually computes the same aggregated feature and loss function. This means that it is possible to share most of the computation across loss functions when performing both backprop and forward gradient. We implemented our custom JAX JVP/VJP functions (Bradbury et al., 2018) and observed significant memory savings and compute speed-ups for replicated losses, which would otherwise not be feasible to run on modern hardware. The results are reported in Figure 5 . A code snippet is included in Appendix 12.

6. EXPERIMENTS

We compare our proposed algorithm to a set of alternatives: Backprop, Feedback Alignment and other global variants of Forward Gradient. Backprop is a biologically implausible oracle, since it computes true gradients whereas we compute noisy gradients. Feedback alignment computes approximate gradients by using a set of random backward weights. We explain each method below. 1) Backprop (BP). We include the standard backprop algorithm as well as its local variants. Local Backprop (L-BP) adds local losses as proposed, but still permits gradient to flow in an end-to-end fashion. Local Greedy Backprop (LG-BP) in addition adds stop gradient operators in between blocks. This is to provide a comparison to our methods by computing true local gradients. LG-BP is similar in spirit to recent local learning algorithms (Belilovsky et al., 2019; Löwe et al., 2019) . 3) Forward Gradient (FG). This family of methods comprises our proposed algorithm and related approaches. Weight-perturbed forward gradient (FG-W) was proposed by Baydin et al. (2022) . In this paper, we propose the activity perturbation variant (FG-A). We further add local objective functions, producing LG-FG-W and LG-FG-A, which stand for Local Greedy Forward Gradient Weight/Activity-Perturbed. For local perturbation to work, we have to add a stop gradient in between blocks so each perturbation has a single corresponding loss. We expect LG-FG-A to achieve the best performance among other variants because it can leverage the variance reduction benefit from both activity perturbation and local losses. Datasets. We use standard image classification datasets to benchmark the learning algorithms. MNIST (LeCun, 1998) et al. (2020) . Because forward gradient suffers from variance, we apply weaker augmentations for contrastive learning experiments, increasing the area lower bound for random crops from 0.08 to 0.3-0.5. We find that this change has relatively little effect on the performance of backprop. Main results. Our main results are shown in Table 3 and Table 4 . In supervised experiments, there is almost no cost of introducing local greedy losses, and our local forward gradient method can match the test error of backprop on MNIST and CIFAR. Note that LG-FG-A fails to overfit the training set to 0% error when trained without data augmentation. This suggests that variance could still be an issue. For CIFAR-10 contrastive learning, our method obtains an error rate approaching that obtained by backprop (26.81% vs. 17.53%), and most of the gap is due to greedy learning vs. gradient estimation (6.09% vs. 3.19%). On ImageNet, we achieve reasonable performance compared to backprop (58.37% vs. 36.82% for supervised and 73.24% vs. 55.66% for contrastive). However, we find that the error due to greediness grows as the problem gets more complex and requires more layers to cooperate. We significantly outperform the FA family on ImageNet (by 25% for supervised and 10% for contrastive). Interestingly, local greedy FA also performs better than global feedback alignment, which suggests that the benefit of local learning transfers to other types of gradient approximation. TP-based methods were evaluated in Bartunov et al. (2018) and were found to be worse than FA on ImageNet. In sum, although there is still some noticeable gap between our method and backprop, we have made a large stride forward compared to backprop-free algorithms. More results are included in the Appendix 14. Effect of local losses. In Figure 6 we ablate the benefit of placing local losses at different locations: blockwise, patchwise and groupwise. A combination of all three is the strongest. Global perturbation learning fails to learn as the accuracy is similar to initializing with random weights. Effect of groups. In Figure 7 we investigate the effect of different number of groups by showing the training curves. Adding more groups bring significant improvement to local perturbation learning in terms of lowering both training and test errors, but the effect vanishes around 8 channels / group.

7. CONCLUSION

It is often believed that perturbation-based learning cannot scale to large and deep networks. We show that this is to some extent true because the gradient estimation variance grows with the number of hidden dimensions for activity perturbation, and is even worse for shared weight perturbation. But more optimistically, we show that a huge number of local greedy losses can help forward gradient learning scale much better. We explored blockwise, patchwise, and groupwise local losses, and a combination of all three, with a total of a quarter of a million losses in one of the larger networks, performs the best. Local activity-perturbed forward gradient performs better than previous backpropfree algorithms on larger networks. The idea of local losses opens up opportunities for different loss designs and sheds light on the search for biologically plausible learning algorithms in the brain and alternative computing devices.

8. PROOFS OF UNBIASEDNESS

In this section, we show the unbiasedness of g w (w ij ) and g a (w ij ). The first proof was given by Baydin et al. ( 2022). Proposition 1. g w (w ij ) is an unbiased gradient estimator if {v ij } are independent zero-mean uni-variance random variables (Baydin et al., 2022) . Proof. We can rewrite the weight perturbation estimator as g w (w ij ) =   i j ∇w i j v i j   v ij = ∇w ij v 2 ij + i j =ij ∇w i j v ij v i j . Note that since each dimension of v is an independent zero-mean uni-variance random variable, E[v ij ] = 0, E[v 2 ij ] = Var[v ij ] + E[v ij ] 2 = 1 + 0 = 1, and E[v ij v i j ] = 0 if ij = i j . E[g w (w ij )] = E ∇w ij v 2 ij + E   i j =ij ∇w i j v ij v i j   (8) = ∇w ij E v 2 ij + i j =ij ∇w i j E [v ij v i j ] (9) = ∇w ij • 1 + i j =ij ∇w i j • 0 (10) = ∇w ij . Proposition 2. g a (w ij ) is an unbiased gradient estimator if {u j } are independent zero-mean univariance random variables. Proof. The true gradient to the weights ∇w ij is the product between x j and ∇z k . Therefore, we can rewrite the weight perturbation estimator as g a (w ij ) = x i   j ∇z j u j   u j = x j ∇z j u 2 j + x i   j =j ∇z j u j   u j (12) = x i ∇z j u 2 j +   j =j x i ∇z j u j   u j (13) = ∇w ij u 2 j +   j =j ∇w ij u j   u j . Since each dimension of u is an independent zero-mean uni-variance random variable, E[u j ] = 0, E[u 2 j ] = Var[u j ] + E[u j ] 2 = 1 + 0 = 1, and E[u j u j ] = 0 if j = j . E[g a (w ij )] = E   ∇w ij u 2 j +   j =j ∇w ij u j   u j   (15) = ∇w ij E u 2 j + j =j ∇w ij E [u j u j ] (16) = ∇w ij • 1 + j =j ∇w ij • 0 (17) = ∇w ij . ( ) 9 PROOFS OF VARIANCES We followed Wen et al. (2018) and show that the variance of the gradient estimators can be decomposed. Lemma 1. The variance of the gradient estimator can be decomposed into three parts: Var (g(w ij )|x) = Z 1 + Z 2 + Z 3 , where Z 1 = 1 N V 1 Var x (∇w ij | x ), Z 2 = 1 N E x [Var v ( g(w ij )| x)], Z 3 = 1 N 2 E B x (n) ∈B x (m) ∈B\{x (n) } Cov v ( g(w ij )| x (n) , g(w ij )| x (m) ) . Proof. By the law of total variance, Var (g(w ij )) = Var B E v [g(w ij )|B] + E B Var v (g(w ij )|B) . ( ) The first term comes from the gradient variance from data sampling, and it vanishes as batch size grows: Var B E v [g(w ij )|B] (20) = Var B   E v   1 N x (n) ∈B g(w ij )|x (n)     (21) = 1 N 2 Var B   E v   x (n) ∈B g(w ij )|x (n)     (22) = 1 N 2 Var B   x (n) ∈B E v g(w ij )|x (n)   (23) = 1 N 2 Var B   x (n) ∈B ∇w ij | x (n)   (24) = 1 N 2 n Var x (∇w ij | x ) = 1 N Var x (∇w ij | x ) = Z 1 . ( ) The second term comes from the gradient estimation variance: E B Var v (g(w ij )|B) (26) = E B   Var v   1 N x (n) ∈B g(w ij ) x (n)     (27) = E B   1 N 2 Var v   x (n) ∈B g(w ij ) x (n)     (28) = E B   1 N 2 x (n) ∈B Var v g(w ij )| x (n) + x (n) ∈B x (m) ∈B\{x (n) } Cov v ( g(w ij )| x (n) , g(w ij )| x (m) )   (29) = 1 N E x Var v ( g(w ij )| x) + 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } Cov v ( g(w ij )| x (n) , g(w ij )| x (m) )   (30) =Z 2 + Z 3 . ( ) Remark. Z 2 is the variance of the gradient estimator in the deterministic case, and Z 3 measures the correlation between different gradient estimation within the batch. The Z 3 is zero if the perturbations are independent, and non-zero if the perturbations are shared within the mini-batch. Proposition 3. Let p × q be the size of the weight matrix, the element-wise average variance of the weight perturbed gradient estimator with a batch size N is pq+2 N V + (pq + 1)S if the perturbations are shared across the batch, and pq+2 N V + pq+1 N S if they are independent, where V is the element-wise average variance of the true gradient, and S is the element-wise average squared gradient. Proof. We first derive Z 2 . Z 2 = 1 N E x Var v ( g w (w ij )| x) (32) = 1 N E x   Var v     i j ∇w i j v i j   v ij     (33) = 1 N E x   Var v   ∇w ij v 2 ij + i j =ij ∇w i j v ij v i j     (34) = 1 N E x   Var v ∇w ij v 2 ij + Var v   i j =ij ∇w i j v ij v i j   + 2 Cov v   ∇w ij v 2 ij , i j =ij ∇w ij v ij v i j     (35) = 1 N E x   Var ∇w ij v 2 ij + Var v   i j =ij ∇w i j v ij v i j   + (36) 2 E v   i j =ij ∇w ij ∇w i j v 3 ij v i j   -2 E v ∇w ij v 2 ij E v   i j =ij ∇w i j v ij v i j     (37) = 1 N E x   ∇w 2 ij Var v v 2 ij + Var v   i j =ij ∇w i j v ij v i j   + (38) 2 i j =ij ∇w ij ∇w i j E v v 3 ij v i j -2∇w ij E v v 2 ij   i j =ij ∇w i j E v [v ij v i j ]     (39) = 1 N E x   ∇w 2 ij Var v v 2 ij + Var v   i j =ij ∇w i j v ij v i j   + (40) 2 i j =ij ∇w ij ∇w i j • 0 -2∇w ij • 1   i j =ij ∇w i j • 0     (41) = 1 N E x   ∇w 2 ij Var v v 2 ij + Var v   i j =ij ∇w i j v ij v i j     (42) = 1 N E x   ∇w 2 ij • (E[v 4 ij ] -E v [v 2 ij ] 2 ) + i j =ij Var v (∇w i j v ij v i j )   (43) = 1 N E x   ∇w 2 ij (3 Var v [v ij ] 2 -E v [v 2 ij ] 2 ) + i j =ij ∇w 2 i j Var v (v ij v i j )   (44) = 1 N E x 2∇w 2 ij + 1 N E x   i j =ij ∇w 2 i j (Var v [v ij ] + E v [v i j ] 2 )(Var v [v i j ] + E v [v i j ] 2 ) -E v [v ij ] 2 E v [v i j ] 2   (45) = 2 N ∇w 2 ij + 2 N Var x (∇w ij |x) + 1 N E x   i j =ij ∇w 2 i j Var v (v ij ) Var v (v i j )   (46) = 2 N ∇w 2 ij + 2 N Var x (∇w ij |x) + 1 N i j =ij E x ∇w 2 i j (47) = 1 N   ∇w 2 ij + Var x (∇w ij |x) + i j ∇w 2 i j + Var x (∇w i j |x)   . ( ) Z 3 is nonzero if the perturbations are shared within a batch. Assuming that the perturbations are shared, Z 3 = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } Cov v ( g w (w ij )| x (n) , g w (w ij )| x (m) )   (49) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E v g w (w ij )| x (n) g w (w ij )| x (m) -E v g w (w ij )| x (n) E v g w (w ij )| x (m)   (50) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E v g w (w ij )| x (n) g w (w ij )| x (m) -∇w ij |x (n) ∇w ij |x (m)   (51) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E v     i j ∇w i j |x (n) v i j   v ij   i j ∇w i j |x (m) v i j   v ij   - ( ) ∇w ij |x (n) ∇w ij |x (m) (53) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E v     ∇w ij |x (n) v 2 ij + i j =ij ∇w i j |x (n) v i j v ij   (54)   ∇w ij |x (m) v 2 ij + i j =ij ∇w i j |x (m) v i j v ij       - (55) 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } ∇w ij |x (n) ∇w ij |x (m)   (56) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E v ∇w ij |x (n) ∇w ij |x (m) v 4 ij + (57) ∇w ij |x (n) v 2 ij i j =ij ∇w i j |x (m) v i j v ij + ∇w ij |x (m) v 2 ij i j =ij ∇w i j |x (n) v i j v ij + (58) i j =ij ∇w i j |x (n) v i j v ij i j =ij ∇w i j |x (m) v i j v ij     - (59) 1 N 2   E x (n) E x (m)   n m =n ∇w ij |x (n) ∇w ij |x (m)     (60) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E v ∇w ij |x (n) ∇w ij |x (m) v 4 ij + (61)   i j =ij ∇w i j |x (n) v i j     i j =ij ∇w i j |x (m) v i j   v 2 ij     - 1 N 2   n m =n ∇w 2 ij   (62) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E v ∇w ij |x (n) ∇w ij |x (m) v 4 ij + (63) i j =ij ∇w i j |x (n) ∇w i j |x (m) v 2 i j v 2 ij + i j =ij i j =ij,i j ∇w i j |x (n) ∇w i j |x (m) v i j v i j v 2 ij     - (64) 1 N 2   n m =n ∇w 2 ij   (65) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E v   ∇w ij |x (n) ∇w ij |x (m) v 4 ij + i j =ij ∇w i j |x (n) ∇w i j |x (m) v 2 i j v 2 ij     - (66) 1 N 2   n m =n ∇w 2 ij   (67) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } ∇w ij |x (n) ∇w ij |x (m) E v v 4 ij + (68) i j =ij ∇w i j |x (n) ∇w i j |x (m) E v v 2 i j E v v 2 ij   - 1 N 2   n m =n ∇w 2 ij   (69) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } 3∇w ij |x (n) ∇w ij |x (m) + i j =ij ∇w i j |x (n) ∇w i j |x (m)   - (70) 1 N 2   n m =n ∇w 2 ij   (71) = 1 N 2   n m =n   3 E x [∇w ij |x] 2 + i j =ij E x [∇w i j |x] 2     - 1 N 2   n m =n ∇w 2 ij   (72) = 1 N 2   n m =n   3∇w 2 ij + i j =ij ∇w 2 i j     - 1 N 2   n m =n ∇w 2 ij   (73) = 1 N 2 n m =n   2∇w 2 ij + i j =ij ∇w 2 i j   = N (N -1) N 2   ∇w 2 ij + i j ∇w 2 i j   . ( ) Lastly, we average the variance across all weight dimensions: mVar(g w (w ij )) = 1 pq ij Var(g w (w ij )) (75) = 1 pq ij {Z 1 + Z 2 + Z 3 } (76) = 1 pq ij 1 N Var x (∇w ij | x ) + (77) 1 N   ∇w 2 ij + Var x (∇w ij |x) + i j ∇w 2 i j + Var x (∇w i j |x)   + (78) N (N -1) N 2   ∇w 2 ij + i j ∇w 2 i j      (79) = 1 pq ij 1 N Var x (∇w ij | x ) + (80) 1 N   Var x (∇w ij |x) + i j Var x (∇w i j |x)   +   ∇w 2 ij + i j ∇w 2 i j      (81) = 2 N mVar (∇w) + pq N mVar (∇w) + (pq + 1) mSqNorm(∇w) (82) = pq + 2 N V + (pq + 1)S. If the perturbations are independent, we show that Z 3 is 0. Z 3 = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } Cov v ( g w (w ij )| x (n) , g w (w ij )| x (m) )   (84) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E v g w (w ij )| x (n) g w (w ij )| x (m) -E v g w (w ij )| x (n) E v g w (w ij )| x (m)   (85) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E v g w (w ij )| x (n) g w (w ij )| x (m) -∇w ij |x (n) ∇w ij |x (m)   (86) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E v     j ∇w ij |x (n) v (n) i j   v (n) ij   j ∇w ij |x (m) v (m) i j   v (m) ij   (87) -∇w ij |x (n) ∇w ij |x (m) (88) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E v   j j ∇w ij |x (n) ∇w ij |x (m) v (n) i j v (m) i j v (n) ij v (m) ij     - (89) 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } ∇w ij |x (n) ∇w ij |x (m)   (90) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E v ∇w ij |x (n) ∇w ij |x (m) v (n)2 ij v (m)2 ij + (91) ∇w ij |x (n) v (n)2 ij v (m) ij i j =j ∇w ij |x (m) v (m) i j + ∇w ij |x (m) v (m)2 ij v (n) ij i j =j ∇w ij |x (n) v (n) i j + (92) i j =ij ∇w ij |x (m) ∇w ij |x (n) v (m) i j v (n) i j v (m) ij v (n) ij + (93) i j =ij i j / ∈{ij,j j } ∇w i j |x (n) v i j ∇w ij |x (m) v (n) i j v (m) i j v (n) ij v (m) ij     - (94) 1 N 2   E x (n) E x (m)   n m =n ∇w ij |x (n) ∇w ij |x (m)     (95) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E v ∇w ij |x (n) ∇w ij |x (m) v (n)2 ij v (m)2 ij + (96) i j =ij ∇w i j |x (m) ∇w i j |x (n) v (m) i j v (n) i j v (m) ij v (n) ij     - (97) 1 N 2   E x (n) E x (m)   n m =n ∇w ij |x (n) ∇w ij |x (m)     (98) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E v ∇w ij |x (n) ∇w ij |x (m) v (n)2 ij v (m)2 ij   - (99) 1 N 2   n m =n ∇w 2 ij   (100) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } ∇w ij |x (n) ∇w ij |x (m) E v v (n)2 ij E v v (m)2 ij   - (101) 1 N 2   n m =n ∇w 2 ij   (102) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } ∇w ij |x (n) ∇w ij |x (m)   - 1 N 2   n m =n ∇w 2 ij   (103) = 1 N 2   n m =n E x [∇w ij |x] 2   - 1 N 2   n m =n ∇w 2 ij   (104) =0. Then the average variance becomes: mVar(g w (w ij )) = 1 pq ij Var(g w (w ij )) (106) = 1 pq ij {Z 1 + Z 2 + Z 3 } (107) = 1 pq ij 1 N Var x (∇w ij | x ) + (108) 1 N   ∇w 2 ij + Var x (∇w ij |x) + i j ∇w 2 i j + Var x (∇w i j |x)   (109) = pq + 2 N mVar (∇w) + pq + 1 N mSqNorm(∇w) = pq + 2 N V + pq + 1 N S. Proposition 4. Let p × q be the size of the weight matrix, the element-wise average variance of the activity perturbed gradient estimator with a batch size N is q+2 N V + (q + 1)S if the perturbations are shared across the batch, and q+2 N V + q+1 N S if they are independent, where V is the element-wise average variance of the true gradient, and S is the element-wise average squared gradient. Proof. Z 2 = 1 N E x Var u ( g a (w ij )| x) (112) = 1 N E x   Var u     j ∇w ij u j   u j     (113) = 1 N E x   Var u   ∇w ij u 2 j + j =j ∇w j u j u j     (114) = 1 N E x   Var u ∇w ij u 2 j + Var u   j =j ∇w ij u j u j   + (115) 2 Cov u   ∇w ij u 2 j , j =j ∇w ij u j u j     (116) = 1 N E x   Var u ∇w ij u 2 j + Var u   i j =ij ∇w ij u j u j   + (117) 2 E u   j =j ∇w ij ∇w ij u 3 j u j   -2 E u ∇w ij u 2 j E u   j =j ∇w ij u j u j     (118) = 1 N E x   ∇w 2 ij Var u u 2 j + Var u   j =j ∇w ij u j u j   + (119) 2 j =j ∇w ij ∇w ij E u u 3 j u j -2∇w ij E u u 2 j   j =j ∇w ij E u [u j u j ]     (120) = 1 N E x   ∇w 2 ij Var u u 2 j + Var u   j =j ∇w ij u j u j   + (121) 2 j =j ∇w ij ∇w ij • 0 -2∇w ij • 1   j =j ∇w i j • 0     (122) = 1 N E x   ∇w 2 ij Var u u 2 j + Var u   j =j ∇w ij u j u j     (123) = 1 N E x   ∇w 2 ij • (E u [u 4 j ] -E u [u 2 j ] 2 ) + j =j Var u (∇w ij u j u j )   (124) = 1 N E x   ∇w 2 ij (3 Var u (u j ) 2 -E u [u 2 j ] 2 ) + j =j ∇w 2 j Var u (u j u j )   (125) = 1 N E x 2∇w 2 ij + (126) 1 N E x   j =j ∇w 2 ij (Var u [u j ] + E u [u j ] 2 )(Var u [u j ] + E u [u j ] 2 ) -E u [u j ] 2 E u [u j ] 2   (127) = 2 N ∇w 2 ij + 2 N Var x (∇w ij |x) + 1 N E x   j =j ∇w 2 ij Var u (u j ) Var u (u j )   (128) = 2 N ∇w 2 ij + 2 N Var x (∇w ij |x) + 1 N j =j E x ∇w 2 j (129) = 1 N   ∇w 2 ij + Var x (∇w ij |x) + j ∇w 2 ij + Var x (∇w ij |x)   . ( ) Z 3 is nonzero if the perturbations are shared within a batch. Assuming that the perturbations are shared, Z 3 = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } Cov u ( g a (w ij )| x (n) , g a (w ij )| x (m) )   (131) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E u g a (w ij )| x (n) g a (w ij )| x (m) -E u g a (w ij )| x (n) E u g a (w ij )| x (m)   (132) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E u g a (w ij )| x (n) g a (w ij )| x (m) -∇w ij |x (n) ∇w ij |x (m)   (133) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E u     j ∇w ij |x (n) u j   u j   j ∇w ij |x (m) u j   u j   - ∇w ij |x (n) ∇w ij |x (m) (135) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E u   j j ∇w ij |x (n) ∇w ij |x (m) u j u j u 2 j     - (136) 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } ∇w ij |x (n) ∇w ij |x (m)   (137) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E u ∇w ij |x (n) ∇w ij |x (m) v 4 ij + (138) ∇w ij |x (n) u 3 j j =j ∇w ij |x (m) u j + ∇w ij |x (m) u 3 j j =j ∇w ij |x (n) u j + (139) j =j ∇w ij |x (m) ∇w ij |x (n) u 2 j u 2 j + (140) j =j j / ∈{j,j } ∇w ij |x (n) u j ∇w ij |x (m) u j u j u 2 j     - (141) 1 N 2   E x (n) E x (m)   n m =n ∇w ij |x (n) ∇w ij |x (m)     (142) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E u   ∇w ij |x (n) ∇w ij |x (m) v 4 ij + j =j ∇w ij |x (m) ∇w ij |x (n) u 2 j u 2 j     - (143) 1 N 2   E x (n) E x (m)   n m =n ∇w ij |x (n) ∇w ij |x (m)     (144) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E u   ∇w ij |x (n) ∇w ij |x (m) v 4 ij + j =j ∇w ij |x (m) ∇w ij |x (n) u 2 j u 2 j     - (145) 1 N 2   n m =n ∇w 2 ij   (146) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } ∇w ij |x (n) ∇w ij |x (m) E v v 4 ij + j =j ∇w ij |x (n) ∇w ij |x (m) E u u 2 j E u u 2 j   - (147) 1 N 2   n m =n ∇w 2 ij   (148) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } 3∇w ij |x (n) ∇w ij |x (m) + j =j ∇w ij |x (n) ∇w ij |x (m)   -(149) 1 N 2   n m =n ∇w 2 ij   (150) = 1 N 2   n m =n   3 E x [∇w ij |x] 2 + j =j E x [∇w ij |x] 2     - 1 N 2   n m =n ∇w 2 ij   (151) = 1 N 2   n m =n   2∇w 2 ij + j =j ∇w 2 ij     - 1 N 2   n m =n ∇w 2 ij   (152) = 1 N 2   n m =n   ∇w 2 ij + j =j ∇w 2 ij     (153) = N (N -1) N 2   ∇w 2 ij + j =j ∇w 2 ij   . ( ) Then we compute the average variance across all weight dimensions (for shared perturbation): mVar(g a (w ij )) = 1 pq ij Var(g a (w ij )) (155) = 1 pq ij {Z 1 + Z 2 + Z 3 } (156) = 1 pq ij 1 N Var x (∇w ij | x ) + (157) 1 N   ∇w 2 ij + Var x (∇w ij |x) + j ∇w 2 ij + Var x (∇w ij |x)   + (158) N (N -1) N 2   ∇w 2 ij + j =j ∇w 2 ij   (159) = 1 pq ij 1 N Var x (∇w ij | x ) + (160) 1 N   Var x (∇w ij |x) + j Var x (∇w ij |x)   +   ∇w 2 ij + j ∇w 2 ij      (161) = 2 N mVar (∇w) + q N mVar (∇w) + (q + 1) mSqNorm(∇w) (162) = q + 2 N V + (q + 1)S. If the perturbations are independent, we show that Z 3 is 0. Z 3 = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } Cov u ( g a (w ij )| x (n) , g a (w ij )| x (m) )   (164) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E u g a (w ij )| x (n) g a (w ij )| x (m) -E u g a (w ij )| x (n) E u g a (w ij )| x (m)   (165) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E u g a (w ij )| x (n) g a (w ij )| x (m) -∇w ij |x (n) ∇w ij |x (m)   (166) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E u     j ∇w ij |x (n) u (n) j   u (n) j   j ∇w ij |x (m) u (m) j   u (m) j   - ∇w ij |x (n) ∇w ij |x (m) (168) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E u   j j ∇w ij |x (n) ∇w ij |x (m) u (n) j u (m) j u (n) j u (m) j     -(169) 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } ∇w ij |x (n) ∇w ij |x (m)   (170) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E u ∇w ij |x (n) ∇w ij |x (m) u (n)2 j u (m)2 j + ( ) ∇w ij |x (n) u (n)2 j u (m) j j =j ∇w ij |x (m) u (m) j + ∇w ij |x (m) u (m)2 j u (n) j j =j ∇w ij |x (n) u (n) j + (172) j =j ∇w ij |x (m) ∇w ij |x (n) u (m) j u (n) j u (m) j u (n) j + j =j j / ∈{j,j } ∇w ij |x (n) u j ∇w ij |x (m) u (n) j u (m) j u (n) j u (m) j     - (174) 1 N 2   E x (n) E x (m)   n m =n ∇w ij |x (n) ∇w ij |x (m)     (175) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E u ∇w ij |x (n) ∇w ij |x (m) u (n)2 j u (m)2 j + (176) j =j ∇w ij |x (m) ∇w ij |x (n) u (m) j u (n) j u (m) j u (n) j     - (177) 1 N 2   E x (n) E x (m)   n m =n ∇w ij |x (n) ∇w ij |x (m)     (178) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } E u ∇w ij |x (n) ∇w ij |x (m) u (n)2 j u (m)2 j   - (179) 1 N 2   n m =n ∇w 2 ij   (180) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } ∇w ij |x (n) ∇w ij |x (m) E u u (n)2 j E u u (m)2 j   - (181) 1 N 2   n m =n ∇w 2 ij   (182) = 1 N 2 E B   x (n) ∈B x (m) ∈B\{x (n) } ∇w ij |x (n) ∇w ij |x (m)   - 1 N 2   n m =n ∇w 2 ij   (183) = 1 N 2   n m =n E x [∇w ij |x] 2   - 1 N 2   n m =n ∇w 2 ij   (184) =0. ( ) Then the average variance becomes: In Figure 8 , we ran numerical simulation experiments to verify our analytical variance properties. We used a multi-layer network with 4 input units, 4 hidden units, 1 output unit, a tanh activation function, and the mean squared error loss. We varied the batch size (N ) between 1 and 4096. We tested the gradient estimator of the first layer weights using 5000 random samples. We also calculated the theoretical variance by applying the gradient norm and gradient variance constants found by backprop, from 5000 mini-batch true gradients. We then fixed the batch size to be 4 and vary the number of input units (p, fan in) and the number of hidden units (q, fan out) between 1 and 256. The theoretical variance for backprop was only computed for the batch size experiment since it is an inverse relationship ( 1 N ), but for fan in and fan out, we do not aim to analyze the theoretical variances here. "wt perturb" stands for weight perturbation with shared noise; "ind wt perturb" stands for weight perturbation with independent noise; and "act perturb" stands for activity perturbation with independent noise. Note that indepedent weight perturbation is much more costly to compute in neural networks. As shown in the figure, the empirical variances match very well with our theoretical predictions. mVar(g a (w ij )) = 1 pq ij Var(g a (w ij )) (186) = 1 pq ij {Z 1 + Z 2 + Z 3 } (187) = 1 pq ij 1 N Var x (∇w ij | x ) + (188) 1 N   ∇w 2 ij + Var x (∇w ij |x) + j ∇w 2 ij + Var x (∇w ij |x)      (189) = q + 2 N mVar (∇w) + q + 1 N mSqNorm(∇w) (190) = q + 2 N V + q + 1 N S.

11. TRAINING DETAILS

Here we provide more training details. MNIST. We use a batch size of 128, and the SGD optimizer with learning rate 0.01 and momentum 0.9 for a total of 1000 epochs with no data augmentation and a linear learning rate decay schedule. CIFAR-10. For the supervised experiments, we use a batch size of 128 and the SGD optimizer with learning rate 0.01 and momentum 0.9 for a total of 200 epochs with no data augmentation and a linear learning rate decay schedule. For the contrastive M/8 experiments, we use a batch size of 512 and the SGD optimizer with learning rate 1.0 and momentum 0.9 for a total of 1000 epochs with BYOL data augmentation using area crop lower bound to be 0.5 and a cosine decay schedule with a warm-up period of 10 epochs. For the contrastive L/8 experiments, we use a batch size of 2048 and the SGD optimizer with learning rate 4.0 and momentum 0.9 for a total of 1000 epochs with BYOL data augmentation (Grill et al., 2020) using area crop lower bound to be 0.3 and a cosine decay schedule with a warm-up period of 10 epochs. ImageNet. For the supervised experiments, we use a batch size of 256 and the SGD optimizer with learning rate 0.05 and momentum 0.9 for a total of 120 epochs with BYOL data augmentation (Grill et al., 2020) using area crop lower bound to be 0.3 and a cosine learning rate decay schedule with a warm-up period of 10 epochs. For the contrastive experiments, we use a batch size of 2048 and the LARS optimizer with learning rate 0.1 and momentum 0.9 for a total of 800 epochs with BYOL data augmentation (Grill et al., 2020) using area crop lower bound to be 0.08 and a cosine learning rate decay schedule with a warm-up period of 10 epochs.

12. FUSED JVP/VJP DETAILS

In Algorithm 1, we provide a JAX code snippet implementing fused operators for the supervised cross entropy loss. "Fused" here means that we package several operations into one function. In the supervised cross entropy loss, we combine average pooling, channel concatenation, a linear classifier layer, and cross entropy all together. Key steps and expected tensor shapes are annotated in the comments. The fused InfoNCE loss implementation will be included in our full code release.

13. LOCALMIXER ARCHITECTURE

In Algorithm 2, we provide code in JAX style that implements our proposed LocalMixer architecture.

14. ADDITIONAL RESULTS

In this section we provide additional experimental results. Normalization scheme. Table 5 compares different normalization schemes. Layer normalization (LN) is often better than batch normalization (BN) on our mixer architecture. Local LN is better on contrastive learning experiments and achieves lower error rates using forward gradient learning. Although in our main paper, backprop were used in normalization layers, backprop is not necessary for Local LN, c.f . "NG" (No Gradient) columns in Table 5 . Place of normalization. We investigate the places where we add normalization layers. Traditionally, normalization is added after linear layers. In MLPMixer, LN is added at the beginning of each block. With our forward gradient learning, it is now a question of which location is the optimal design. Adding it after the linear layer has the advantage of shaping the activations to be more well behaved, which can make perturbation learning more effective. Adding it before the linear layer can also help reduce the variance since the inputs always get multiplied with the gradient of the output activity. The results are reported in Table 6 . Adding normalization both before and after the linear layer helps forward gradient to achieve lower training errors. While this could result in some overfitting on supervised learning, it is good for contrastive learning which needs more model capacity. This is reasonable as forward gradient introduce a lot of variances, and more normalization layers help achieve better training performance. Algorithm 1 Naïve and fused local cross entropy, with custom JVP and VJP operators. 



All proofs can be found in Appendix 8 and 9. Numerical simulation results can be found in Appendix 10.



Figure 1: A LocalMixer network consists of several mixer blocks. A=Activation function (ReLU).

Figure 2: A LocalMixer residual block with local losses. Token mixing consists of a linear layer and channels are grouped in the channel mixing layers. Layer norm is applied before and after every linear layer. LN=Layer Norm; FC=Fully Connected layer; A=Activation function (ReLU); T=Transpose.

Figure 4: Importance of StopGradient in the InfoNCE loss, using M/8 on CIFAR-10 with 256 channels 1 group.

Figure 5: Memory and compute usage of naïve and fused implementation of replicated losses.

Figure 8: Numerical verification of the theoretical variance properties

# N: batch size; P: num patches; G: num grps; C: num channels; D: channels / grp; K: num cls # x: encoder features [N,P,G,C/G] # w: classifier weights [C,K]; b: classifier bias [K] # labels: class labels [N,K] import jax import jax.numpy as jnp from jax.scipy.special import logsumexp def naive_avg_group_linear_xent(x, w, b, labels): N, P, G, _ = x.shape # Average pooling, with stop gradients. [N,P,G,C/G] -> [N,1,G,C/G] avg_pool_p = jnp.mean(x, axis=1, keepdims=True) x_div_p = x / float(P) # [N,P,G,C/G] x = x_div_p + jax.lax.stop_gradient(avg_pool_p -x_div_p) # Concatenate everything, with stop gradients. [N,P,G,C] -> [N,P,G,G,C/G] x = jnp.tile(jnp.reshape(x, [N, P, 1, G, -1]), [1, 1, G, 1, 1]) mask = jnp.eye(G)[None, None, :, :, None] x = mask * x + jax.lax.stop_gradient((1.0 -mask) * x) # [N,P,G,G,C/G] -> [N,P,G,C] x = jnp.reshape(x, [N, P, G, -1]) logits = jnp.einsum('npgc,cd->npgd', x, w) + b logits = logits -logsumexp(logits, axis=-1, keepdims=True) loss = -jnp.sum(logits * labels[:, None, None, :], axis=-1) return loss def fused_avg_group_linear_xent(x, w, b, labels): # This is for forward pass. The numerical value of each local loss should be the same. # So we compute one and replicate it many times. N, P, G, _ = x.shape # [N,P,G,C/G] -> [N,G,C/G] x_avg = jnp.mean(x, axis=1) # [N,G,C/G] -> [N,C] x_grp = jnp.reshape(x_avg, [x_avg.shape[0], -1]) # [N,C] -> [N,K] logits = jnp.einsum('nc,ck->nk', x_grp, w) + b logits = logits -logsumexp(logits, axis=-1, keepdims=True) loss = -jnp.sum(logits * labels, axis=-1) # Key step: after computing the loss, replicate it for PxG times. [N] -> [N,P,G] return jnp.tile(jnp.reshape(loss, [N, 1, 1]), [1, P, G]) def fused_avg_group_linear_xent_jvp(primals, tangents): # This JVP operator performs both regular forward pass and the forward autodiff. x, w, b, labels = primals dx, dw, db, dlabels = tangents N, P, G, D = x.shape dx_avg = dx / float(P) # Reshape the classifier weights, since only one group passes gradient at a time. w_ = jnp.reshape(w, [G, D, -1]) b = jnp.reshape(b, [-1]) # Regular forward pass # [N,P,G,C/G] -> [N,G,C/G] x_avg = jnp.mean(x, axis=1) # [N,G,C/G] -> [N,C] x_grp = jnp.reshape(x_avg, [x_avg.shape[0], -1]) # [N,C] -> [N,K] logits = jnp.einsum('nd,dk->nk', x_grp, w) + b logits = logits -logsumexp(logits, axis=-1, keepdims=True) loss = -jnp.sum(logits * labels, axis=-1) # We can compute the gradient through cross entropy first. dlogits_bwd = jax.nn.softmax(logits, axis=-1) -labels # [N,K] # Key step: dloss = dx * w * dloss/dlogit + (x * dw + db) * dloss/dlogit # Do the einsum together to avoid replicating outputs. dloss = jnp.einsum('npgd,gdk,nk->npg', dx_avg, w_, dlogits_bwd) + jnp.einsum('nk,nk->n', (jnp.einsum('nc,ck->nk', x_grp, dw) + db), dlogits_bwd)[:, None, None] # [N,P,G] # Return loss and loss gradients [N,P,G]. return jnp.tile(jnp.reshape(loss, [N, 1, 1]), [1, P, G]), dloss def fused_avg_group_linear_xent_vjp(res, g): # This is a fused backprop (VJP) operator. x, w, logits, labels = res N, P, G, D = x.shape x_avg = jnp.mean(x, axis=1) x_grp = jnp.reshape(x_avg, [x_avg.shape[0], -1]) # Key step: only the first patch/group gradients since everything is the same. g_ = g[:, 0:1, 0] dlogits = g_ * (jax.nn.softmax(logits, axis=-1) -labels) # [N,K] # Remember to multiply gradients by PG times due to weight sharing. db = jnp.reshape(jnp.sum(dlogits, axis=[0]), [-1]) * float(P * G) dw = jnp.reshape(jnp.einsum('nc,nk->ck', x_grp, dlogits), [G * D, -1]) * float(P * G) # Key step: use grouped weights to perform backprop. dx = jnp.einsum('nd,gcd->ngc', dlogits, jnp.reshape(w, [G, C, -1])) / float(P) # Broadcast gradients across patches. dx = jnp.tile(dx[:, None, :, :], [1, P, 1, 1]) return dx, dw, db, None Algorithm 2 A LocalMixer architecture implemented with JAX style code. import jax import jax.numpy as jnp def linear(x, w, b): """Linear layer.""" return jnp.einsum('npc,cd->npd', x, w) + b def group_linear(x, w, b): """Linear layer with groups.""" return jnp.einsum('npgc,gcd->npgd', x, w) + b def normalize(x, axis=-1, eps=1e-5): """Normalization layer.""" mean = jnp.mean(x, axis=axis, keepdims=True) mean_of_squares = jnp.mean(jnp.square(x), axis=axis, keepdims=True) var = mean_of_squares -jnp.square(mean) inv = jax.lax.rsqrt(var + eps) y = (x -mean) * inv return y def block0(x, params): """Initial block with only channel mixing.""" N, P, _ = x.shape G = num_groups x = normalize(x) x = linear(x, params[0][0], params[0][1]) x = normalize(x) x = jax.nn.relu(x) x = jnp.reshape(x, [N, P, G, -1]) x = normalize(x) x = group_linear(x, params[1][0], params[1][1]) x = normalize(x) x = jax.nn.relu(x) return x def mlp_block(x, params): """Regular MLP block with token & channel mixing.(x, params[1][0], params[1][1]) x = normalize_layer(x) x = jax.nn.relu(x) x = jnp.reshape(x, [N, P, G, -1]) x = normalize(x) x = group_linear(x, params[2][0], params[2][1]) x = normalize(x) x = x + inputs x = jax.nn.relu(x) return x def local_mixer(x, params): """LocalMixer.""" x = preprocess(x, image_mean, image_std, num_patches) pred_local = [] # Local predictions. # Build network blocks. for blk in range(num_blocks): if blk == 0: x = block0(x, params[f'block_{blk}']) else: x = mlp_block(x, params[f'block_{blk}']) # Projector connects to local losses. x_proj = normalize(x) pred_local.append(linear(x_proj, params[f'proj_{blk}'][0], params[f'proj_{blk}'][1])) # Disconnect gradients. x = jax.lax.stop_gradient(x) x = jnp.reshape(x, [x.shape[0], x.shape[1], -1]) x = jnp.mean(x, axis=1) # [N,C] x = normalize(x) pred = linear(x, params['classifier'][0], params['classifier'][1]) return pred, pred_local

LocalMixer Architecture Details 5 IMPLEMENTATION Network architecture. We propose the LocalMixer architecture that is more suitable for local learning. It takes inspiration from MLPMixer (Tolstikhin



funding

tree/master/local_forward_gradient.

ACKNOWLEDGMENT

We thank Timothy Lillicrap for his helpful feedback on our earlier draft. LG-BP TrainLG-FG-A Train

BP Test

LG-BP TestLG-FG-A Test LG-BP TrainLG-FG-A Train LG-BP TestLG-FG-A Test Effect of groups. We provide additional results summarizing the training and test performance of adding more groups in Figure 9 . Backprop and local greedy backprop always achieve zero training error with increasing number of groups on CIFAR-10 supervised, but adding groups has a significant benefit lowering training errors for forward gradient. This suggests that the main opponent here is still the gradient estimation variance, and lowering training errors can generally make test errors lower too; on the other hand adding groups have negligible effect on backprop. For contrastive learning, here the task requires higher model capacity, and adding groups effectively reduce the model capacity by introducing sparsity in the weight matrix. As a result, we observe a slight drop of less than 5% performance on both backprop and local greedy backprop. By contrast, forward gradient gains over 10% of performance by adding 16 groups.

