Reintroducing Straight-Through Estimators as Principled Methods for Stochastic Binary Networks

Abstract

Training neural networks with binary weights and activations is a challenging problem due to the lack of gradients and difficulty of optimization over discrete weights. Many successful experimental results have been achieved with empirical straight-through (ST) approaches, proposing a variety of ad-hoc rules for propagating gradients through non-differentiable activations and updating discrete weights. At the same time, ST methods can be truly derived as estimators in the stochastic binary network (SBN) model with Bernoulli weights. We advance these derivations to a more complete and systematic study. We analyze properties, estimation accuracy, obtain different forms of correct ST estimators for activations and weights, explain existing empirical approaches and their shortcomings, explain how latent weights arise from the mirror descent method when optimizing over probabilities. This allows to reintroduce ST methods, long known empirically, as sound approximations, apply them with clarity and develop further improvements.

1. Introduction

Neural networks with binary weights and activations have much lower computation costs and memory consumption than their real-valued counterparts [18, 26, 45] . They are therefore very attractive for applications in mobile devices, robotics and other resource-limited settings, in particular for solving vision and speech recognition problems [8, 56] . The seminal works that showed feasibility of training networks with binary weights [15] and binary weights and activations [27] used the empirical straightthrough gradient estimation approach. In this approach the derivative of a step function like sign, which is zero, is substituted with the derivative of some other function, hereafter called a proxy function, on the backward pass. One possible choice is to use identity proxy, i.e., to completely bypass sign on the backward pass, hence the name straight-through [5] . This ad-hoc solution appears to work We gratefully acknowledge support by Czech OP VVV project "Research Center for Informatics (CZ.02.1.01/0.0/0.0/16019/0000765)" [34] in (e). The scaling (standard deviation) of the noise in each case is chosen so that 2F (0) = 1. The identity ST form in (b) we recover as latent weight updates with mirror descent. surprisingly well and the later mainstream research on binary neural networks heavily relies on it [2, 6, 9, 11, 18, 34, 36, 45, 52, 60] . -2 -1 0 1 2 -1 1 (a) -2 -1 0 1 2 -2 2 (b) -2 -1 0 1 2 -1 1 (c) -2 -1 0 1 2 -1 1 (d) -2 -1 0 1 2 -1 1 (e) -2 -1 0 1 2 1 -2 -1 0 1 2 1 -2 -1 0 1 2 1 -2 -1 0 1 2 1 -2 -1 The de-facto standard straight-through approach in the above mentioned works is to use deterministic binarization and the clipped identity proxy as proposed by Hubara et al. [27] . However, other proxy functions were experimentally tried, including tanh and piece-wise quadratic ApproxSign [18, 34] , illustrated in Fig. 1 . This gives rise to a diversity of empirical ST methods, where various choices are studied purely experimentally [2, 6, 52] . Since binary weights can be also represented as a sign mapping of some real-valued latent weights, the same type of methods is applied to weights. However, often a different proxy is used for the weights, producing additional unclear choices. The dynamics and interpretation of latent weights are also studied purely empirically [51] . With such obscurity of latent weights, Helwegen et al. [24] argues that "latent weights do not exist" meaning that discrete optimization over binary weights needs to be considered. The existing partial justifications of deterministic straight-through approaches are limited to one-layer networks with Gaussian data [58] or binarization of weights only [1] and do not lead to practical recommendations. In contrast to the deterministic variant used by the mainstream SOTA, straight-through methods were originally proposed (also empirically) for stochastic autoencoders [25] and studied in models with stochastic binary neurons [5, 44] . In the stochastic binary network (SBN) model which we consider, all hidden units and/or weights are Bernoulli random variables. The expected loss is a truly differentiable function of parameters (i.e., weight probabilities) and its gradient can be estimated. This framework allows to pose questions such as: "What is the true expected gradient?" and "How far from it is the estimate computed by



Fig.1: The sign function and different proxy functions for derivatives used in empirical ST estimators. Variants (c-e) can be obtained by choosing the noise distribution in our framework. Specifically for a real-valued noise z with cdf F , in the upper plots we show E z [sign(az)] = 2F -1 and, respectively, twice the density, 2F in the lower plots. Choosing uniform distribution for z gives the density p(z) = 1 2 1l z∈[-1,1] and recovers the common Htanh proxy in (c). The logistic noise has cdf F (z) = σ(2z), which recovers tanh proxy in (d). The triangular noise has density p(z) = max(0, |(2x)/4|), which recovers a scaled version ofApproxSign [34]  in (e). The scaling (standard deviation) of the noise in each case is chosen so that 2F (0) = 1. The identity ST form in (b) we recover as latent weight updates with mirror descent.

