Differentiate Everything with a Reversible Embeded Domain-Specific Language

Abstract

Reverse-mode automatic differentiation (AD) suffers from the issue of having too much space overhead to trace back intermediate computational states for back-propagation. The traditional method to trace back states is called checkpointing that stores intermediate states into a global stack and restore state through either stack pop or re-computing. The overhead of stack manipulations and re-computing makes the general purposed (not tensor-based) AD engines unable to meet many industrial needs. Instead of checkpointing, we propose to use reverse computing to trace back states by designing and implementing a reversible programming eDSL, where a program can be executed bi-directionally without implicit stack operations. The absence of implicit stack operations makes the program compatible with existing compiler features, including utilizing existing optimization passes and compiling the code as GPU kernels. We implement AD for sparse matrix operations and some machine learning applications to show that our framework has the state-of-the-art performance.

1. Introduction

Most of the popular automatic differentiation (AD) tools in the market, such as TensorFlow (Abadi et al., 2015) , Pytorch (Paszke et al., 2017), and Flux (Innes et al., 2018) implements reverse mode AD at the tensor level to meet the need in machine learning. Later, People in the scientific computing domain also realized the power of these AD tools, they use these tools to solve scientific problems such as seismic inversion (Zhu et al., 2020) , variational quantum circuits simulation (Bergholm et al., 2018; Luo et al., 2019) and variational tensor network simulation (Liao et al., 2019; Roberts et al., 2019) . To meet the diverse need in these applications, one sometimes has to define backward rules manually, for example 1. To differentiate sparse matrix operations used in Hamiltonian engineering (Hao Xie & Wang) , people defined backward rules for sparse matrix multiplication and dominant eigensolvers (Golub & Van Loan, 2012), 2. In tensor network algorithms to study the phase transition problem (Liao et al., 2019; Seeger et al., 2017; Wan & Zhang, 2019; Hubig, 2019) , people defined backward rules for singular value decomposition (SVD) function and QR decomposition (Golub & Van Loan, 2012) . Instead of defining backward rules manually, one can also use a general purposed AD (GP-AD) framework like Tapenade (Hascoet & Pascual, 2013 ), OpenAD (Utke et al., 2008) and Zygote (Innes, 2018; Innes et al., 2019) . Researchers have used these tools in practical applications such as bundle adjustment (Shen & Dai, 2018) and earth system simulation (Forget et al., 2015) , where differentiating scalar operations is important. However, the power of these tools are often limited by their relatively poor performance. In many practical applications, a program might do billions of computations. In each computational step, the AD engine might cache some data for backpropagation. (Griewank & Walther, 2008) Frequent caching of data slows down the program significantly, while the memory usage will become a bottleneck as well. Caching implicitly also make these frameworks incompatible with kernel functions. To avoid such issues, we need a new GP-AD framework that does not cache automatically for users. In this paper, we propose to implement the reverse mode AD on a reversible (domain-specific) programming language (Perumalla, 2013; Frank, 2017) , where intermediate states can be traced backward without accessing an implicit stack. Reversible programming allows people to utilize the reversibility to reverse a program. In machine learning, reversibility is proven to substantially decrease the memory usage in unitary recurrent neural networks (MacKay et al., 2018) , normalizing flow (Dinh et al., 2014) , hyper-parameter learning (Maclaurin et al., 2015) and residual neural networks (Gomez et al., 2017; Behrmann et al., 2018) . Reversible programming will make these happen naturally. The power of reversible programming is not limited to handling these reversible applications, any program can be written in a reversible style. Converting an irreversible program to the reversible form would cost overheads in time and space. Reversible programming provides a flexible time-space trade-off scheme that different with checkpointing (Griewank, 1992; Griewank & Walther, 2008; Chen et al., 2016) , reverse computing (Bennett, 1989; Levine & Sherman, 1990) , to let user handle these overheads explicitly. There have been many prototypes of reversible languages like Janus (Lutz, 1986) (Likharev, 1977; Semenov et al., 2003; Takeuchi et al., 2014; 2017) , and these reversible computing devices are orders more energy-efficient. Landauer proves that only when a device does not erase information (i.e. reversible), its energy efficiency can go beyond the thermal dynamic limit. (Landauer, 1961; Reeb & Wolf, 2014) However, these reversible programming languages can not be used directly in real scientific computing, since most of them do not have basic elements like floating point numbers, arrays, and complex numbers. This motivates us to build a new embedded domain-specific language (eDSL) in Julia (Bezanson et al., 2012; 2017) as a new playground of GP-AD. In this paper, we first compare the time-space trade-off in the optimal checkpointing and the optimal reverse computing in Sec. 2. Then we introduce the language design of NiLang in Sec. 3. In Sec. 4, we explain the implementation of automatic differentiation in NiLang. In Sec. 5, we benchmark the performance of NiLang's AD with other AD software and explain why it is fast. 2 Reverse computing as an Alternative to Checkpointing One can use either checkpointing or reverse computing to trace back intermediate states of a Tstep computational process s 1 = f 1 (s 0 ), s 2 = f 2 (s 1 ), . . . , s T = f T (s T -1 ) with a run-time memory S . In the checkpointing scheme, the program first takes snapshots of states at certain time steps S = {s a , s b , . . .}, 1 ≤ a < b < ... ≤ T by running a forward pass. When retrieving a state s k , if s k ∈ S , just return this state, otherwise, return max j s j<k ∈ S and re-compute s k from s j . In the reverse computing scheme, one first writes the program in a reversible style. Without prior knowledge, a regular program can be transpiled to the reversible style is by doing the transformation in Listing. 1. Listing 1: Transpiling a regular code to the reversible code without prior knowledge. s 1 += f 1 (s 0 ) s 2 += f 2 (s 1 ) . . . s T += f T (s T -1 ) Listing 2: The reverse of Listing. 1 s T -= f T (s T -1 ) . . . s 2 -= f 2 (s 1 ) s 1 -= f 1 (s 0 ) Then one can visit states in the reversed order by running the reversed program in Listing. 2, which erases the computed results from the tail. One may argue that easing through uncomputing is not necessary here. This is not true for a general reversible program, because the intermediate states might be mutable and used in other parts of the program. It is easy to see, both checkpointing and reverse computing can trace back states without time overhead, but both suffer from a space overhead that linear to time (Table 1 ). The checkpointing scheme snapshots the output in every step, and the reverse computing scheme allocates extra storage for storing outputs in every step. On the other side, only checkpointing can achieve a zero space overhead by recomputing everything from the beginning s 0 , with a time complexity O(T 2 ). The minimum space complexity in reverse

