MEMORY OPTIMIZATION FOR DEEP NETWORKS

Abstract

Deep learning is slowly, but steadily, hitting a memory bottleneck. While the tensor computation in top-of-the-line GPUs increased by 32× over the last five years, the total available memory only grew by 2.5×. This prevents researchers from exploring larger architectures, as training large networks requires more memory for storing intermediate outputs. In this paper, we present MONET, an automatic framework that minimizes both the memory footprint and computational overhead of deep networks. MONET jointly optimizes the checkpointing schedule and the implementation of various operators. MONET is able to outperform all prior handtuned operations as well as automated checkpointing. MONET reduces the overall memory requirement by 3× for various PyTorch models, with a 9-16% overhead in computation. For the same computation cost, MONET requires 1.2-1.8× less memory than current state-of-the-art automated checkpointing frameworks. Our code is available at https://github.com/utsaslab/MONeT.

1. INTRODUCTION

Deep networks are widely used in domains ranging from image classification (Krizhevsky et al., 2012; Simonyan & Zisserman, 2015; He et al., 2016) to video recognition (Wu et al., 2019; Feichtenhofer et al., 2019) or natural language processing (Devlin et al., 2019; Yang et al., 2019) . However, training deep networks is resource-intensive. In particular, the amount of GPU memory bottlenecks training many deep networks (Dong et al., 2016; Kim et al., 2016; Chen et al., 2018; Child et al., 2019) . This bottleneck requires either modifying the network architecture or scaling training to multiple nodes, incurring significant overheads. We presents MONET, an automatic framework to minimize memory footprint for deep networks. MONET jointly optimizes global compute-graph-level techniques (such as checkpointing) and local techniques (such as memory-efficient implementations of individual operator). At the heart of MONET is a theoretical analysis that enables joint optimization and provides tight bounds on memory consumption. We analyze the memory consumption and computational cost of a general forward and backward pass under changing local operator implementations and a global checkpointing schedule. Specifically, we are able to tightly bound the peak memory consumption for network forward, backward, and recomputation stages. MONET uses these constraints to optimize for the most efficient forward and backward implementation both locally and globally under a fixed memory budget. We linearize all memory bounds, and express both implementation selection and checkpointing as a 0-1 integer program, which we solve using standard solvers. We conduct extensive experiments, demonstrating that MONET significantly outperforms existing automatic frameworks that use local or global techniques. On multiple architectures (ResNet (He et al., 2016) , VGG (Simonyan & Zisserman, 2015) , UNet (Ronneberger et al., 2015) , GoogleNet (Szegedy et al., 2015) , MobileNet-V2 (Sandler et al., 2018) ), memory budgets (5-10 GB), and network configurations (multiple resolutions), MONET consistently achieves lower memory footprints at equivalent or lower computational overhead. MONET reduces the overall memory requirement by 3× for various models, with a 9-16% overhead in computation. For the same computation cost, MONET requires 1.2-1.8× less memory than the current state-of-the-art automated checkpointing framework. The results achieved by MONET demonstrate the power of jointly optimizing global checkpointing schedules and local operator implementations. 

2. RELATED WORK

There are two broad families of approaches to reduce the memory footprint of a deep network during training: operator-level implementation changes, and global, graph-level optimizations. The novel aspect of MONET is that it is able to combine both approaches and find the optimal mix of local and global techniques for a given network. Operator-Specific Optimizations. Researchers have found creative ways to implement individual operators or groups of operators in a more memory-efficient manner. Standard deep learning frameworks (Jia et al., 2014; Collobert et al., 2011; Paszke et al., 2019; Abadi et al., 2016) provide different implementations of certain operators that trade computation for intermediate memory use. These implementation are chosen according to local search heuristics, and are not globally optimal. Gist (Jain et al., 2018) proposes several hand-crafted optimizations such as storing only ReLU signs. RevNets (Gomez et al., 2017) redesigns a ResNet (He et al., 2016) architecture making each network block reversible, thereby eliminating the need to store intermediate activations for backpropagation. Memory-efficient DenseNets (Pleiss et al., 2017) reduce memory utilized for feature maps by recomputing all intermediate feature maps during the backward pass with a small compute overhead. In-place activated batchnorm (Bulò et al., 2018) or ReLU layers use output activations to compute their gradients, thus reusing a single memory buffer for the gradient computation in consecutive layers. Mixed precision training (Micikevicius et al., 2018) uses half precision (FP16) instead of single precision (FP32) for all tensors and arithmetic during training, reducing the memory by nearly half. While training at precision lower than FP16 results in loss of training quality (Banner et al., 2018) , prior work like backpropagation with approximate activations (Chakrabarti & Moseley, 2019) carefully quantize certain intermediate outputs (activations) to 4 bits, resulting in significant memory savings. Although these hand-crafted techniques independently result in memory savings, there is no one-size-fits-all recipe, and different implementations perform best on different architectures. In contrast, MONET automatically finds the best implementation for each forward and backward operator given a memory budget. Checkpointing. Chen et al. (2016) proposed dividing a network into different segments, dropping all intermediate outputs within each segment, and recomputing them later. Chen et al. use √ n equal segments, trading memory savings for the cost of an extra forward pass. Checkmate (Jain et al., 2019) solves the problem in a more general setting, using an mixed-integer linear program solver to decide which layers to recompute for a given network. Like Checkmate, our work optimizes a checkpointing schedule, but on a different computation graph. Our computation graph allows for the optimization of an entire execution plan jointly finding a checkpointing schedule and the best implementation of each forward and backward operator. In Checkmate, changes in operator implementation induce a different computation graph, and could thus not directly be optimized. Appendix F highlights some of the difficulties of adding operator optimizations into Checkmate. In summary, while much work has been done on local optimizations (operator implementations) and global compute-graph-level techniques (automated checkpointing), MONET is the first system to jointly optimize a given architecture using both local and global techniques. 

3. PRELIMINARIES

Let the forward pass of a CNN with parameters Θ be expressed as a directed-acyclic graph (DAG), where each node i ∈ {1, . . . , N } corresponds to an operator forward i , and edges (i, j) ∈ E specify the data-flow dependencies, i.e., the output of operator i is used as input in operator j. Without loss of generality, computational dependency (i, j) ∈ E implies i < j. Let N j = {i : (i, j) ∈ E} be the set of all incoming edges of an operation j. We will first discuss the forward pass through a network and the basic form of a backward pass using checkpointing. The backward pass reverses all computational dependency expressed in our DAG, and induces certain dependencies on forward activations. We call these checkpoint dependencies D k . They are either saved or recomputed depending on a schedule (s, r). Checkpointing creates a trade-off between computation and memory consumption. To highlight this trade-off, we formally compute the amount of memory consumed in both forward and backward passes, which allows us to optimize for the ideal execution plan in Sec. 4. We provide a reference to the notations introduced in this section and the next along with their explanations in Appendix A. The Forward Pass. Alg. 1 shows a general overview of the forward pass in a deep network, as implemented in standard deep learning frameworks (Jia et al., 2014; Collobert et al., 2011; Paszke et al., 2019; Abadi et al., 2016) . The algorithm proceeds in increasing order of index i. Each operator forward i (•) depends on a set of tensors L stored in local memory. These tensors include model parameters Θ, computational dependencies N i , and tensors stored for later forward operators, i.e. skip or residual activations (He et al., 2016) . At each iteration, we add any output tensors of forward i to the local memory L. Early deep learning frameworks (Jia et al., 2014; Collobert et al., 2011) strictly grew the set of local tensors L leading to an unnecessarily high memory consumption. Modern graph-based frameworks (Paszke et al., 2019; Abadi et al., 2016) reduce the memory footprint by aggressively pruning local memory L and freeing any tensor that is no longer used in later computations. Some output activations x i are used in the backward pass, and have to be saved for later. We use a checkpointing schedule s N to determine which. Formally, s N i ∈ {0, 1} indicates whether the activation of node i is stored during the forward pass. An activation which is not stored will be recomputed if it is needed during the backward pass. Analyzing peak memory consumption of the forward pass. Only the forward i operator (Alg. 1 L. 4) allocates memory. All other operators perform mere bookkeeping on existing tensor. It is thus sufficient to study the peak memory consumption m N i in forward i for each node i. Let L i , S N i be the set of local tensors L and saved tensors S while calling forward i respectively. L i includes all parameters and computational dependencies for this and later forward passes L i = Θ ∪ {x j : j ∈ N t for any t ≥ i and j < i}. L i is constant and computed ahead of time. The schedule s N determines the set of saved tensors S N i = {x j : s N j = 1 for j < i}. In addition, each forward operator uses a certain amount of workspace memory c i to store intermediate results. The total memory consumption of a forward operator is thus m i = c i + |x i | + |S N i ∪ L i | = c i + |x i | + xj ∈Li |x j | + j<i:xj / ∈Li |x j |s N j , where | • | refers to the memory consumed by a tensor or set of tensors. Most of the memory consumption is constant and does not depend on the schedule. The Backward Pass. The backward pass proceeds in a reverse order, as summarized in Alg. 2. backward k (•) of each node k depends on a set of gradient tensors L and forward tensors {x i : i ∈ D k }. Any gradients required by the current and later backward passes are stored in local memory L. Dependencies D k may either be stored in S k or need to be recomputed from checkpoints in S k . Recomputation involves forward computation of one or more nodes, which increases computational overhead, and allows for a new set of tensors S k-1 to be saved. After recomputation, all dependencies D k are kept in memory. The backward operation produces a gradient for each input tensor of the original forward operation, which is added to L if required for a later backward computation. We aggressively remove tensors in L that are not required. Analyzing the peak memory consumption of the backward pass. Peak memory consumption mk again only depends on the forward i (Alg. 2 L. 7) and backward k (Alg. 2 L. 12) operations. For the backward k operation, let ĉk be the workspace memory, Lk be the set of gradient tensors stored, D k = {x i : i ∈ D k } be the forward tensors used, and S k-1 be the set of newly saved tensors. Here Lk and D k can be pre-computed. The total memory consumption for the backward k call is mk = ĉk + |y k | + |S k-1 ∪ Lk ∪ D k | = ĉk + |y k | + y l ∈ Lk |y l | + xi∈D k |x i | + xi / ∈D k s k-1 i |x i |. (2) Here again, only the last term depends on the checkpointing schedule, while the rest is a constant. Analyzing the peak memory consumption of the recomputation. Finally, the peak memory mk i for the forward i call (Alg. 2 L. 7) depends on the set of local tensors L, checkpoint dependencies D, saved tensors S, and gradient tensors L, named L k i , D k , S k-1 i , Lk respectively. Following the forward pass: mk i = c i + |x i | + | Lk | + |S k-1 i ∪ L k i ∪ D k | = c i + |x i | + | Lk | + j<i:xj / ∈L k i ∪D k s k-1 j |x j | + j<i:xj ∈L k i ∪D k |x j | + j>i s k j |x j |. Unlike the forward pass, L k i is no longer constant, but depends on past saved tensors and future recomputations in (s, r): L k i = Θ ∪ {x j : j ∈ N t for any t ≥ i with r k t = 1 and j < i}. In the next section, we show how to take this formalization of the forward and backward pass, and find an optimal execution plan including checkpointing schedule (s, r), forward i implementations, and backward k implementations, under a fixed memory budget.

4. METHOD

Our goal is to find a global checkpointing schedule (s, r) and local forward i /backward k implementations that jointly minimize the computation cost τ within a memory budget M . We show how to express this optimization in a 0-1 integer program and efficiently solve it. To this end, we linearize any peak memory consumption constraints, ensure that the checkpointing schedule is valid, and solve to minimize a computation cost objective. We keep track of the three contributors to memory and computational cost -forward pass, backward pass, and recomputation of forward operators. Memory Constraints. Consider the case of basic checkpointing using only a single implementation for forward i and backward k . The memory consumption of the forward 1 and backward 2 pass are linear in s, and thus efficiently expressed in an integer program. However, recomputation depends both on s k-1 and r k in a non-linear manner through the local memory L k i . This joint dependence on optimization variables gives rise to quadratic constraints, which cannot directly be incorporated into an integer program. For simplicity in this derivation, we bound the set of local tensors from above, assuming every future tensor is recomputed. We give more information about this in Appendix B. The upper bound Lk i is constant, yielding a linear upper bound mk i of the recomputation memory mk i analogous to Eq. 3. The set of memory constraints is thus m i ≤ M ∀ i and mk ≤ M ∀ k and mk i ≤ M ∀ k,i To enable operator optimization, we use a bit-vector δ to indicate the selection of an operator implementation. We add δ to the constraints which allows us to jointly optimize checkpointing (s, r) and operator implementations δ. Forward Operator Optimization. Let each forward operator forward i have multiple different implementations I i = {a, b, c, . . .}. For examples, convolution may be implemented using matrix multiplication, the Winograd algorithm (Winograd, 1980) , a Fourier transform, etc. (Chetlur et al., 2014) . All implementations follow the same DAG structure, and thus use the same dependencies N i . However, each implementation trades workspace memory {c a i , c b i , . . .} for computational efficiency {τ a i , τ b i , . . .} in a different manner. Our experiments show that this trade-off is often complex. Our goal is to represent the peak memory when using multiple forward i implementations in the forward pass and recomputation. Let δ i,a ∈ {0, 1} indicate that implementation a ∈ I i is used for forward i in the forward pass. Each forward operator should use exactly one implementation l δ i,l = 1. The choice of implementation determines the operator's computational cost l τ l i δ i,l and workspace memory c i = l c l i δ i,l . Analogously, each recomputation of forward i during backward k chooses between implementations δ k i,a ∈ {0, 1} when needed l δ k i,l = r k i , with equivalent cost estimates l τ l i δ k i,l and workspace memory use c k i = l c l i δ k i,l . In this formulation, all additional memory requirements remain linear and are directly integrated into the linear memory constraints or their linear relaxations (equation 4). We again aim to represent memory in terms of implementations for each backward k operator. Let δk,a ∈ {0, 1} indicate that implementation a ∈ Îk is used at node k in the backward pass. Each backward operator should use exactly one implementation l δk,l = 1, with a computational cost l τ l k δk,l and workspace memory ĉk = l ĉl k δk,l . The workspace memory adds a linear constraint to the memory consumption mk equation 2.

Backward Operator

The biggest changes to the optimization problem, comes from the changing dependency structure. D k is no longer constant. Instead, the implementation of a backward operator changes the set of computational dependencies D k obtained from D l k . To deal with this changing dependency structure, we use the indicator vector δk to select memory contribution of dependencies from the chosen implementation. This changes the backward memory consumption to mk = l ĉl k δk,l ĉk +|y k | + | Lk | + l δk,l .|D l k ∪ S k-1 |, and the corresponding peak recomputation memory mk i to mk i = c i + |x i | + | Lk | + l δk,l .|S k-1 i ∪ Lk i ∪ D l k |. Note, the last term of equation 5 and equation 6 are quadratic in the original optimization variables s k-1 i , which determines S k-1 , and δk,l . However, for binary variables, it can be linearized using an auxiliary variable (see Appendix C.4). We show the full equation expansion in Appendix C.1. Checkpointing Constraints. The computational dependencies of forward and backward operators impose strict constraints on the checkpointing schedule. Any schedule violating these constraints cannot be executed. Recomputation r k i requires saved s k-1 j or recomputed r k j dependencies j ∈ N i , and only previously stored or recomputed tensors can be saved: r k i ≤ s k-1 j + r k j ∀ i,k,j∈Ni and s k-2 i ≤ s k-1 i + r k i ∀ i,k . Furthermore, all forward tensors D l k required by backward k need to be stored or computed s k-1 i + r k i ≥ δk,l ∀ k,l,i∈D l k . Objective. Our goal is to minimize the amount of computation required for the forward and backward pass. This is represented as the sum of computational costs of all operators: i l τ l i δ i,l forward pass + k l δk,l τ l k backward pass + k l τ l i δ k i,l recomputation . ( ) Objective equation 9 with constraints equation 4, equation 7, equation 8, and definitions equation 1, equation 5, equation 6 form our final optimization objective. It jointly solves for the optimal implementation of each forward and backward operator, as well as an efficient checkpointing schedule.

5. EXPERIMENTS

Implementation Details. We develop MONET in PyTorch v1.5.1 and solve the joint optimization problem using the Gurobi (2014) solver. Appendix D provides more implementation details and a full list of optimized operators. The UNet experiments use 608×416 inputs following prior work (Jain et al., 2019) . All other experiments use 224×224 inputs following conventions (Krizhevsky et al., 2012; Simonyan & Zisserman, 2015; He et al., 2016) . Batch size for the experiments is fixed to be the maximum at which the model can be trained using baseline PyTorch on a 16 GB GPU. Since Checkmate's (Jain et al., 2019) execution engine is built for TensorFlow, and an official Gist (Jain et al., 2018) implementation is not available, we reimplement them in PyTorch for our comparisons. Our Checkmate implementation is competitive, it uses the original Checkmate solver and has the same network structure as MONET. Checkmate does not optimize for operator implementations like convolutions, so we show its runtime using the default convolution algorithm (Checkmate-D). For a stronger comparison, we also show the runtime of a Checkmate schedule that is post-optimized to greedily run the fastest convolution algorithm (Checkmate-O). Wherever not explicitly specified, we compare with Checkmate-O. All checkpointing schedules are run using the same software implementations and costs are profiled on the same hardware (NVIDIA P100 GPUs). In order to compare against operator-specific optimizations, we reimplement all Gist techniques in PyTorch and run them on our execution engine. See Appendix E for more details about our baseline implementations. Detailed Comparison to Baselines. (a) Checkpointing: MONET and Checkmate schedules. The memory budgets range from 5 GB to 10 GB, or equivalently, 0.33× to 0.70× PyTorch memory consumption. Batch size for these models is mentioned in paranthesis. For all models, MONET reduces memory usage by 3× (0.33 memory ratio) as compared to baseline PyTorch with 9 -16% compute overhead. For the same memory budget, MONET schedules are up-to 34% faster than Checkmate schedules. Note that we measure the empirical performance of the schedules running on GPUs instead of just providing a simulation of runtime and memory using the solver values; this is important since Checkmate does not consider workspace cost and overestimates its savings. For networks with individual memory-intensive layers, like VGG-16, operator optimization becomes even more important for reducing memory; Checkmate can reduce memory for VGG-16 only up to 7 GB, whereas MONET with its optimizations is able to run VGG-16 with 5.5 GB memory. The small runtime improvement of MONET schedules over PyTorch for VGG-16 and UNet at higher memory budgets is mainly because of choosing faster convolution algorithms. MobileNet-V2 uses depthwise convolutions, and hence does not significantly benefit from joint convolutionoptimization. As a result, the performance of MONET and Checkmate is closer for MobileNet-V2. We provide additional results for MONET on a memory-intensive model, 3D-UNet (C ¸ic ¸ek et al., 2016) , in Appendix J, for which we observe a consistent memory reduction to 0.54× of PyTorch memory with an overhead of 8.86%. For our evaluation, we cap the solver time to 24 hours for both MONET and Checkmate, and run the schedule thus obtained on our execution framework. At tighter memory budgets for non-linear models like ResNet-50 and GoogleNet, Checkmate is unable to find a feasible solution within a couple of hours. In contrast to Checkmate, MONET finds the execution plans efficiently. For all the models and memory limits that we evaluate, MONET reaches a 5% close-to-optimal solution within a few hours or sometimes even minutes. Table 2 shows the time it takes for the solver to reach 5% close to the optimal solution, for Checkmate, MONET-NOOP (MONET with checkpointing enabled but operator-optimization disabled), and MONET. MONET-NOOP converges to a close-to-optimal solution 1.6×-117.4× faster than Checkmate. For larger models, MONET's solver converges to a close-to-optimal solution up to 27× faster than Checkmate. Note that running a solver is a one-time cost for a model -once a MONET schedule has been solved for, it can be used by everyone to train the model for different purposes with different batch sizes. The cost (typically seconds to hours) is tiny compared to the efforts and costs to develop a model for distribution in most cases. See Appendix H for more discussion regarding solver times, problem statistics, and full Table 2 data. (b) Operator optimizations: Table 3 shows the comparison of MONET with Gist. While MONET can determine a range of memory-runtime tradeoffs, purely operator-optimization-based schemes like Gist only provide a single memory-runtime data point. For MONET, we show the memoryruntime data point with the most memory saving. MONET uses 1.4×-2.1× less memory than Gist for multiple architectures while maintaining full-precision. Overall, Gist provides impressive memory savings, but incurs a high computation cost to achieve the savings. While we get similar memory saving results for reimplemented-Gist as Jain et al. (2018) for VGG-16, our compute overhead results are higher. This could be because of evaluations on different frameworks (PyTorch v/s CNTK) and different GPU models (Nvidia P100 v/s Nvidia Maxwell GTX Titan X). Gist uses dense to sparse conversion using cusparseSdense2csr in one of its techniques. For the first ReLU-Conv layer in VGG-16 (shape (2207744,256)), this function takes 144ms, which itself is 10% of the VGG-16 execution time. We see similar results for other networks. To ensure a fair comparison, we focus on the maximum memory savings obtained by MONET with Gist, while reporting the compute overhead for completeness. Ablation Experiments. Fig. 4 shows additional ablation experiments. We show the % compute overhead over PyTorch on ResNet-50, GoogleNet, and VGG-16 for different types of MONET checkpointing schedules with a memory budget of 8 GB -with no operator optimizations enabled, with only one type of operator optimization enabled (conv-optimized, output-activated optimized, intermediate-activated optimized), and with all optimizations enabled. Schedules which do not jointly optimize convolution algorithms are run with greedily post-optimized convolution algorithm. Plots for other models look similar to that of ResNet-50 and GoogleNet. The only difference between 'none' and 'conv' is that convolution algorithms are jointly optimized in the latter. However, this fact leads to significant improvement in compute time for all cases. In fact, convolution algorithms have complex workspace memory -compute characteristics, reserving slightly more memory for convolution workspace while checkpointing can allow for a much faster convolution (see Appendix I). This makes it important to jointly optimize conv algorithms with checkpointing. Similarly, output-activated optimization also provides significant benefits over vanilla checkpointing, since it effectively reduces the number of recomputations required. For memory-intensive net- works, intermediate-activated optimization becomes more important. Jointly optimizing all strategies together gives the least computational overhead. See Appendix G for detailed ablation plots. Detailed Case Study. The top graph of Fig. 5 shows memory usage while executing PyTorch, MONET without operator optimization, and MONET for ResNet-50 at batch size 184. As the training progresses along network layers represented on X-axis, PyTorch and both MONET schedules store forward-pass outputs, leading to an increasing memory footprint. MONET reaches peak memory of 8 GB, whereas PyTorch requires 14.7 GB. Stored forward outputs are freed up one after other as backward pass proceeds, leading to reduced usage of memory. According to the checkpointing schedule, MONET saves only a subset of the outputs stored by PyTorch, resulting in the memory saving shown in the middle graph for layer outputs that are not stored. The bottom graph shows the per-layer compute overhead of recomputation of MONET over PyTorch. For MONET, later layers which are backward operators result in a recomputation of the forward, and have higher overhead.

6. CONCLUSION

We present MONET, a system to automatically reduce memory requirements for training deep networks. MONET jointly optimizes local (operator-level) and global (graph-level) optimizations to yield a compute-and memory-efficient checkpointing schedule. MONET reduces memory usage by 3× over PyTorch, with a 9 -16% compute overhead. It uses 1.2-1.8× less memory than the state-of-the-art automated checkpointing framework for the same computational cost. Our experimental results show that MONET leads to better memory-computation trade-offs compared to the state-of-the-art. A NOTATIONS Table 4 : Notations used in paper with explanations. Notations with only i in subscript/superscript generally relate to the forward pass, with only k relate to the backward pass, and with both i and k relate to the recomputation phase.

B BOUNDS ON LOCAL MEMORY

In Section 3, we mentioned that local memory L k i is dependent on solver variable r k t . L k i = Θ ∪ {x j : j ∈ N t for any t ≥ i with r k t = 1 and j < i}. In order to remove this dependence, we can get an upper bound Lk i on L k i by assuming that all future tensors after i will always be recomputed, that is, r k t = 1, ∀t > i, and L k i ⊆ Lk i = Θ ∪ {x j : j ∈ N t for any t ≥ i and j < i}. Our experiments also use this upper bound. It is possible to tighten the upper bound by noting that r k t may be 1 only in the case when t ≤ k. That is, forward node t will not be recomputed before computing backward of node k if node t lies after node k. Thus, a tighter bound to L k i follows L k i ⊆ Lk i = Θ ∪ {x j : j ∈ N t for any t ≥ i and t ≤ k and j < i} ⊆ Lk i .

C DETAILED CONSTRAINTS C.1 EXPANDED BACKWARD PASS MEMORY CONSTRAINTS

Sec. 4 formulates backward peak memory mk and recomputation peak memory mk i as sum of memory of a set of tensors. We expand the memory formulation and represent it in the terms of optimization varaible here: mk = l ĉl k δk,l + |y k | + | Lk | + l δk,l .|D l k ∪ S k-1 | = l ĉl k δk,l + |y k | + y l ∈ Lk |y l | + l xi∈D l k δk,l |x i | + l xi / ∈D l k δk,l s k-1 i σ k,l,s |x i |, mk i = c i + |x i | + | Lk | + l δk,l .|S k-1 i ∪ Lk i ∪ D l k | = c i + |x i | + | Lk | + l j<i: xj / ∈ Lk i ∪D l k δk,l s k-1 j |x j | + l j<i: xj ∈ Lk i ∪D l k δk,l |x j | + j>i s k j |x j |.

C.2 COMPLETE MEMORY CONSTRAINTS

In this section, we present the complete memory constraints which we use for MONET optimization. These constraints include the recomputation variable r k i , which was excluded from the main text to make understanding simpler. As discussed in Sec. 3, the peak memory of a forward i recomputation before computing backward k is denoted by mk i . This represents the recomputation memory (renamed to m k Ri ) when forward i is actually recomputed, that is, r k i = 1. When this is not true, the peak memory ( mk Si ) only depends on stored checkpoints S k-1 i , checkpoint dependencies for D k , and gradient tensors Lk . Thus, mk Ri = c i + |x i | + | Lk | + |S k-1 i ∪ L k i ∪ D k | = r k i c i + r k i |x i | + | Lk | + j<i:xj / ∈L k i ∪D k s k-1 j |x j | + j<i:xj ∈L k i r k i |x j | + j<i:xj ∈D k -L k i |x j | + j>i s k j |x j |. ( ) mk Si = | Lk | + |S k-1 i ∪ D k | = | Lk | + j≤i:xj / ∈D k s k-1 j |x j | + j≤i:xj ∈D k |x j | + j>i s k j |x j |. Local memory L k can be bounded by Lk , which gives us mk Ri . To add forward operator optimizations to mk Ri , we recall the trade-off between workspace memory and compute time. We replace the workspace memory contributor r k i c i in equation 12 with l δ k i,l c l i . The complete memory constraints are: m i ≤ M ∀ i and mk ≤ M ∀ k and mk Ri ≤ M ∀ k,i and mk Si ≤ M ∀ k,i

C.3 IN-PLACE CONSTRAINTS

We show how to represent the decision of computing an operator using an in-place or out-of-place implementation. If an operator like ReLU uses an in-place implementation, its input tensor is overwritten with its output. In this case, its input tensor cannot be stored or used as input to a computation in this stage. This needs to be reflected in our constraints. We introduce two new binary variables to model in-place computations: q k i represents if forward i is recomputed in-place when computing backward k . p k i represents that the output of forward i has been computed and will not be overwritten by any other forward node recomputations in this stage. If q k i is true, then p k j will be false else p k j will be the same as r k j , where j ∈ N i . Further, s k-1 j will also be false if q k i is true. This can be written in the form of boolean constraints as follows: p k j ≥ r k j -2q k i and p k j ≤ 2 -2q k i and s k-1 k ≤ 2 -2q k i . The checkpointing constraint 7 changes, with p k j replacing r k j on the RHS. Further, q k i (or p k j ) can only be true if forward i (or forward j ) is actually recomputed prior to computing backward node k. Thus, p k j ≤ r k j and q k i ≤ r k i . C.4 CONSTRAINT LINEARIZATION The memory constraints we introduce in Section 4 contain quadratic terms in the form of x i • x j , with x i , x j ∈ {0, 1}. The quadratic terms cannot directly be incorporated into an integer program. However, we can linearize these terms by replacing each quadratic term x i • x j by an auxiliary variable α i,j ∈ {0, 1} and introducing additional linear constraints α i,j ≥ x i + x j -1, α i,j ≤ x i , and α i,j ≤ x j . After this substitution for all quadratic terms, all constraints in MONET are linear.

D IMPLEMENTATION

We develop MONET in the PyTorch (v1.5.1) framework. We use PyTorch's default Autograd package for backward implementation of elementary functions when the autograd implementation is stateless. In all other cases, we implement custom forward and backward functions leveraging Py-Torch ATen library functions to flexibly support multiple operators and execution schedules. Each backward operator implementation is annotated with its computational dependencies, which is generally the input or the output of its corresponding forward operator. Certain backward operators implementations may have dependencies on intermediate activations generated in the forward pass. For example, an intermediate-activated ReLU backward uses an encoded bitmask representing the sign of forward operator's input. We annotate this as an intermediate storage node and add it to our optimization problem, with a strict recomputation dependency of the interemediate storage node on its creator node. Our operator optimizations select from different backward operator implementations, convolution algorithms, in-place operators etc. We split the convolution backward operator into two -a parameter-gradient operator followed by an input-gradient operator. Since the inputgradient operator does not have any computational dependency on the forward pass, we can agressively free the forward input tensor right after the parameter-gradient is computed. We also reuse BatchNorm statistics in case of their recomputation. For our experiments, we limit ourselves to full precision training as quantization or lower precision computations introduce additional noise into SGD and change its convergence properties. We solve the joint optimization problem using the CVXPY (Diamond & Boyd, 2016; Agrawal et al., 2018) solver with Gurobi (2014) backend. MONET workflow. We obtain the forward pass dependencies in MONET by JIT tracing a model to obtain its graph. We profile each layer for workspace memory and compute cost, and obtain memory usage of the tensors from their shape and type. Note that the workspace memory for many convolution operators in VGG-16 is greater than 2GB, making it an important factor to model. Unlike prior approaches like Checkmate, we account for this workspace memory in our optimization problem, bringing the memory model very close to actual memory allocation. We phrase a boolean integer programming problem using the generated graph and the profiled compute cost and workspace memory and solve it using the CVXPY (Diamond & Boyd, 2016; Agrawal et al., 2018) modeling language and GUROBI (Gurobi, 2014) solver. The solution is used to generate a schedule that can be run by the MONET scheduler. Operator optimizations. We divide operator optimizations according to the different type of implementations they select from. (1) Output-activated: Backward calculation of operators like ReLU and BatchNorm can have computational dependency either on on their forward node's inputs or outputs. (2) Intermediate-activated: Backward of ReLU has computational dependency on a 1-bit encoding of the sign of its forward node's input. Backward of MaxPool is calculated using an intermediate 8bit output-shaped tensor which contains the kernel-index of the maximum element. (3) Convolution algorithms: We choose from 8 forward and 6 backward cuDNN convolution algorithms. (4) Inplace operations: The solver can choose to do inplace computation for operators like ReLU forward. We discuss constraints for in-place operator selection in C.3. All MONET experiments enable in-place operation selection.

E BASELINE IMPLEMENTATIONS

Checkmate implementation. We reimplement Checkmate (Jain et al., 2019) in PyTorch for our comparisons. We use the Checkmate solver as-is to obtain Checkmate schedules. Since Checkmate does not provide an execution engine for PyTorch, we run the generated Checkmate schedules on our own execution framework. Our inference engine uses the same operator implementations for Checkmate and MONeT. We have released our Checkmate implementation with the MONET code. Gist implementation. Gist (Jain et al., 2018) In their paper, Jain et al. (2018) use the most memory-efficient convolution algorithms in Gist and compare its memory saving against a baseline which also chooses the most memory-efficient convolution algorithm. Using memory-efficient convolution algorithms, our Gist reimplementation can train VGG-16 with 0.55× of the PyTorch-required memory (1.81× memory footprint), which is close to the data presented by Jain et al. (2018) . However, it is 59% slower than when convolution selection is enabled, in which case it can train using 0.76× of the PyTorch-required memory. Since implementing Gist using memory-efficient convolutions is not optimal in terms of compute time, we implement Gist to use PyTorch's convolution selection algorithm. For all models other than VGG-16 and UNet, we see similar memory savings for Gist with memory-efficient convolutions and with convolution-selection enabled. We have released our Gist implementation with the MONET code.

F ON OPERATOR SELECTION FOR CHECKMATE

In this section, we briefly explain the difficulties of including operator selection directly into checkmate (Jain et al., 2019) . We will refer directly to notation and equations in the checkmate paper (arxiv v3; 14 May 2020). The most direct way to incorporate operator selection into checkmate is to introduce an auxiliary variable R v t,i ∈ {0, 1} that refers to re-computing layer i at time t using implementation v. Most constraints in equation 1 could stay the same, given R t,i = v R v t,i , and loss (1a) t i v R v t,i C v i . Some of our operators produce a different kind of checkpoint (e.g. binary activated ReLUs), which could be handled in check-mate by splitting S v t,i . The main issues in checkmate arise in the memory modeling and its relaxations (eq 4,5,7). The memory consumed by a specific checkpoint may depend on the operator implementation: DEPS[k] and USERS[i] both depend on the operator implementation (output activated, input activated, ...). In short, the checkmate computation graph is dynamic and depends on operator implementations. The most direct way to address this is to mem freed t (v k ) = v R v t,i mem freed t (v k ) in a implementation dependent way mem freed v t (v k ), and select the right version dependent on the operator used. Likewise, we need to extend FREE v i,t,k to account for different operator implementations in R v t,k . Likewise the product in equation ( 5) will now go over all implementations R v i,j using different USERS sets. This leads to a linear blowup in the number of constraints, and number of auxiliary variables, leading to an at least quadratic expansion on computational costs. Furthermore, mem freed t (v k ) = v R v t,i mem freed t (v k ) is a quadratic constrain that further needs to be resolved using additional auxiliary variables. Given that Checkmate already pushes the limits of current solvers, it is unlikely able to handle this explosion in constraints and variables, without significant modifications. MONET in the other hand represents the compute-graph more compactly and efficiently integrates different operator implementations. G DETAILED ABLATIONS Fig. 6 shows a detailed plot of our ablation experiments comparing the compute overhead of variants of MONET across a range of memory limits. Y-axis shows the compute overhead over PyTorch and X-axis shows the memory ratio to a PyTorch model. All variants which are not conv-optimized are greedily post-optimized to use the fastest convolution. We see that MONET with no operator optimization (NoOp) is generally slower than the other variants for all models and memory limits. Convolution and output-activated optimizations are both important in reducing compute overhead. MobileNet-V2 uses depthwise separable convolutions, and hence does not significantly benefit from convolution-optimization. Further, MobileNet-V2 has hardtanh operators instead of ReLU operators, for which we have not implemented intermediate-activated backward optimization. Interemediate-activated optimizations provide memory savings in memory-intensive models, allowing models like VGG-16 to reach memory savings which are not attainable by other optimizations. All optimizations together result in the least compute overhead for any model or memory limit.

H SOLVER TIME AND ILP STATISTICS

Solver runtime. MONET's solver runtimes vary for different models and different memory limits. We evaluate schedules obtained using solver times set to a maximum of 24 hours. For moderate memory limits, both MONeT and Checkmate achieve an optimal solution before 24 hours. For tighter memory limits, the solution obtained by MONeT and Checkmate may not be most optimal. For multiple models and memory limits, Table 2 in Sec. 5 shows the time it takes for the solver to reach 5% close to the optimal solution for Checkmate, MONET-NOOP (MONET with checkpointing enabled but operator-optimization disabled), and MONET. We add the data for MobileNet-V2 and UNet in Table 5 which also follow a similar pattern. We also provide Table 6 which shows the time taken by the solver to reach 2% close to the optimal solution. We note that it has similar behavior as the time taken the solver to reach 5% close to the optimal solution. MONET-NOOP converges to 2% close-to-optimal solution 1.3×-139× faster than Checkmate. For larger models, MONeT's solver converges to a 2% close-to-optimal solution up to 16× faster than Checkmate. At tighter memory limits for MobileNet-V2, the Checkmate solver reaches 2% close-to-optimal solution faster than MONET, but is still much slower than MONET-NOOP. ILP statistics. For different models, Table 7 shows the solver statistics after presolving for the problem formulated by Checkmate, MONET-NOOP and MONET for a 10 GB memory limit. It shows the number of forward operators in the model and the number of constraints and variables for each solver. MONET-NOOP, which is MONET with only checkpointing enabled and without using operator optimization, has on average 50% fewer constraints and 67% fewer variables than Checkmate. Jointly-optimized MONET has a slightly larger number of constraints, and on average I CONVOLUTION ALGORITHMS Fig. 7 shows the complex workspace memory-compute trade-off for different convolution algorithms. The memory used is not always inversely proportional to the compute requirement. Jointly optimizing convolution algorithms enables MONET to make the best decisions for which convolution algorithm to select.

J APPLICABILITY TO MEMORY-INTENSIVE MODELS

To further show MONET's applicability to memory-intensive models, we evaluate it on 3D-UNet (C ¸ic ¸ek et al., 2016) , a fully-convolutional model for volumetric images. Fig. 8 presents the runtime-memory trade-off for MONET on 3D-UNet. We used a commonly used 3D-UNet implementation (Wolny, 2019; Wolny et al., 2020) with training configuration similar to 3DUnet confocal boundary provided in the repository and a batch size of 22, which just fits on a 16 GB P100 GPU. MONET reduces memory usage to 0.54× of PyTorch, while incurring 8.86% overhead in compute time. At a memory ratio of 0.81, MONET incurs almost no computational overhead, because it makes use of operator optimizations and is able to bring down the recomputation cost to zero. 



Figure 1: Memory Optimized Network Training (MONeT), an automatic framework that minimizes the memory footprint of deep networks by jointly optimizing global and local techniques.

Figure2: Schematic overview of the forward and backward passes. The algorithms include aggressive memory savings by greedily freeing unused tensors, and allow for a general checkpointing schedule (s, r) to be executed.

Optimization. Let each backward operator backward k have a set of different implementations Îk = {a, b, c, . . .}. Each implementation again trades workspace memory {ĉ a k , ĉb k , . . .} for computational cost {τ a k , τ b k , . . .}. While gradient tensors follow the fixed DAG structure, different implementations may depend on different forward activations {D a k , D b k , . . .}. For example, in-place activated operators (Bulò et al., 2018) depend on their output activation, while regular operators use the input activation. This change in the dependency structure makes optimizing for backward-operator implementations challenging.

Figure 3: Comparing MONeT with PyTorch and Checkmate. MONeT reduces memory by 3× compared to PyTorch, with 9-16% compute overhead. It achieves a better memory-compute tradeoff than default Checkmate-D and conv-optimized Checkmate-O.

Figure 5: Detailed case study on ResNet-50. Top : memory usage along execution (forward and backward). Middle: memory saving of MONeT over PyTorch for each layer. Bottom: compute overhead of MONeT over PyTorch. MONeT saves memory in early layers to reduce peak memory. Most compute overhead happens at recomputation during backward (right-hand-side of the figure).

Peak memory of forwardi when it is recomputed before backward k . mk Peak memory of backward k .Operator costs c lWorkspace memory of operator forwardi executed using implementation l ∈ Ii. ĉl k Workspace memory of operator backward k executed using implementation l ∈ Îk . τ l i Compute cost of operator forwardi executed using implementation l ∈ Ii. τ l k Compute cost of operator backward k executed using implementation l ∈ Îk .

Figure 6: Ablation results on ResNet-50, GoogleNet, UNet, VGG-16, MobileNet-V2.

Figure 7: Memory vs. compute for 7 convolution algorithms with 256×64×56×56 input, 3×3 kernel, 64 output channels.

Table1compares the memory savings obtained by MONET and Checkmate for five different models when computational overhead over PyTorch is fixed to be 10%. MONET schedules use 2-3× less memory than PyTorch. For the same computational overhead, MONET uses 1.2-1.8× less memory than Checkmate. Fig.3shows more detailed runtime-memory trade-offs of MONET to PyTorch and Checkmate for different models. We plot the average iteration time of training as % overhead over PyTorch for Memory

Solver time (in hours) to reach 5% close to optimal solution. MONeT-NoOp reaches a 5% close-to-optimal solution 1.6×-117× faster than Checkmate. MONeT gets close to 5% of the optimal solution only in a few hours, and up-to 16× faster than Checkmate for larger models.

Memory ratio and overhead (%) over PyTorch for Gist and MONeT. MONET obtains 1.4×-2.1× higher memory savings over Gist across models. Number in parenthesis after model name shows the batch size. Ablation results for memory ratio 0.53. Lowest compute overhead across models is seen only when all optimizations are jointly optimized.

gives a list of notations used in the paper along with explanations.Set of stored tensors after forwardi in forward pass. (N = num backward operators)LiSet of all parameters and forward tensors created till forward node i, required as computational dependencies for forwardi and later forward passes.D kSet of forward pass tensors required as computational dependencies for backward k . S k-1Set of stored forward pass tensors right before calling backward k .LkSet of gradient tensors created before backward node k, and required as computational dependencies for backward k and later backward passes.Set of all parameters and forward tensors created till forward node i, required as computational dependencies for forwardi and later forward recomputations to be done before backward k .Indicate if output of forwardi is stored in memory when computing backward k .Indicate if forwardi is recomputed before computing backward k . δ i,lIndicate if forwardi uses implementation l ∈ Ii in the forward pass.Indicate if forwardi uses implementation l ∈ Ii when recomputed before backward k .δk,lIndicate if backward k uses implementation l ∈ Îk .

is an operator-based memory-efficient scheme for training DNNs. It encodes stashed forward tensors into smaller tensors which require less memory.Jain et al. (2018) evaluate Gist using CNTK on an Nvidia Maxwell GTX Titan X GPU. Since we implement MONET in PyTorch and have access to an Nvidia P100 GPU, a direct comparison with the numbers in the Gist paper is not possible. As an official Gist implementation is not available, we reimplement it on PyTorch and evaluate its execution using MONET 's execution framework.We implement all Gist optimizations -Binarize (intermediate encodings for ReLU-Pool layers), Sparse Storage Dense Compute (compress and store sparse convolution inputs in ReLU-Conv layers as sparse storage), Delayed Precision Reduction (storing stashed non-compressed tensors in FP-16, but computing in FP-32), and Inplace (performing ReLU operators in-place wherever possible) over MONET's execution framework. In Gist, the Sparse Storage Dense Compute (SSDC) technique creates a sparse storage tensor in the Compressed Sparse Row (CSR) representation using the Nvidia cuSPARSE library. The dense storage is reshaped into a 256-sized column tensor before storing it in a sparse format, allowing the column index of CSR representation to be saved in 8 bits instead of using 32 bits (termed Narrow Value Optimization in the paper). We also implement SSDC using Nvidia's cuSPARSE library (function cusparseSdense2csr) with CUDA Toolkit version 10.1 using PyTorch's C++ extensions.

ILP statistics for Checkmate, MONeT-NoOp, and MONeT. MONeT-NoOp has on average 50% fewer constraints and 67% fewer variables than Checkmate. MONeT has slightly larger number of constraints, on average 40% fewer variables than Checkmate.

ACKNOWLEDGMENTS

We would like to thank the anonymous reviewers for their feedback. Aashaka Shah and Vijay Chidambaram were partially supported by donations from VMware and Google. Chao-Yuan Wu was partially supported by a Facebook Fellowship. Jayashree Mohan was supported by a Microsoft Research Fellowship. The results presented in this paper were obtained using the Chameleon testbed supported by the National Science Foundation

