TENSOR REMATERIALIZATION

Abstract

Checkpointing enables the training of deep learning models under restricted memory budgets by freeing intermediate activations from memory and recomputing them on demand. Current checkpointing techniques statically plan these recomputations offline and assume static computation graphs. We demonstrate that a simple online algorithm can achieve comparable performance by introducing Dynamic Tensor Rematerialization (DTR), a greedy online algorithm for checkpointing that is extensible and general, is parameterized by eviction policy, and supports dynamic models. We prove that DTR can train an N -layer linear feedforward network on an Ω( √ N ) memory budget with only O(N ) tensor operations. DTR closely matches the performance of optimal static checkpointing in simulated experiments. We incorporate a DTR prototype into PyTorch merely by interposing on tensor allocations and operator calls and collecting lightweight metadata on tensors.

1. INTRODUCTION

As state-of-the-art deep learning (DL) models continue to grow, training them within the constraints of on-device memory becomes increasingly challenging. The memory demands of emerging models prevent their training on memory-limited devices (such as specialized accelerators, low-powered embedded devices, or older GPUs) and limit researchers' ability to explore memory-intensive architectures and training techniques. Checkpointing is one technique that enables training with models and batches that exceed on-device memory without modifying the model's design. It is achieved by freeing some activations from memory and recomputing them on demand. Adapted from techniques in automatic differentiation (Baydin et al., 2015; Griewank & Walther, 2000; Siskind & Pearlmutter, 2018) , checkpointing in the DL context exploits the fact that intermediate activations for backpropagation dominate memory usage during training (Sohoni et al., 2019) but can be easily recomputed by replaying parts of the forward pass. Current DL checkpointing techniques (Chen et al., 2016; Jain et al., 2020; Kumar et al., 2019; Gruslys et al., 2016) statically plan which activations to recompute offline, requiring an initial stage of model analysis. In this paper, we demonstrate that static planning is unnecessary for DL checkpointing. We present Dynamic Tensor Rematerialization (DTR), a greedy online algorithm for heuristically checkpointing arbitrary DL models. DTR operates like a tensor-level cache: it collects metadata on tensors and operators as a model is trained and uses it to guide heuristics that choose which activations to free and later recompute. As a runtime system, DTR can utilize dynamically gathered information (e.g., measured operator costs). Additionally, its simple, cache-like approach requires no advance knowledge of the model or application, letting it immediately support arbitrarily dynamic models and applications featuring higher-order differentiation. For example, given a model with data-dependent control flow like TreeLSTM (Tai et al., 2015) , DTR's runtime can simply evict tensors when memory runs out and rematerialize them as needed. By contrast, static planning techniques assume a static dataflow graph, which requires "unrolling" dynamic models and performing (potentially expensive) planning for every distinct input. Note: Wraps every tensor deallocation. Heuristic decides policy for t (e.g., free permanently or simply evict) Figure 1 : (Top) Pseudocode for DTR's basic logic (independent of heuristic), and (Bottom) DTR's sequence of events in an operator call. Note that PerformOp() may make further recursive calls in order to rematerialize arguments. This paper describes DTR's design (Sec. 2) and makes the following contributions: • We prove that DTR can train an N -layer linear feedforward network on an Ω( √ N ) memory budget with only O(N ) tensor operations (Sec. 3), which is within a constant factor of optimal and matches the offline bound of the Chen et al. ( 2016) static checkpointing technique. • We formalize DL model checkpointing as an online rematerialization problem and define a greedy algorithm parameterized by caching-inspired heuristics. In simulated trials our heuristic attains near-optimal performance on a variety of DL models (Sec. 4). • We implement a DTR prototype by making only modest modifications to the PyTorch framework, enabling training under restricted memory budgets for both static and dynamic models and demonstrating the ease with which our algorithm can be incorporated into an existing DL framework (Sec. 5). Note that techniques other than checkpointing, such as swapping tensors between devices, can also enable training under limited memory. In Sec. 6, we discuss these approaches and how they could operate with DTR.

2. DYNAMIC TENSOR REMATERIALIZATION

We introduce Dynamic Tensor Rematerialization (DTR), a thin runtime layer that intercepts tensor allocations, accesses, and deallocations and eliminates the need for ahead-of-time model analysis to support checkpointing. Figure 1 shows DTR's high-level approach. When a tensor allocation occurs (AllocateBuffer), DTR first checks if sufficient memory is available. If so, it generates a fresh tensor identifier, initializes its metadata for future recomputation, allocates the requested memory, and returns a new tensor. If not, DTR heuristically selects and evicts resident tensors until the requested allocation can be accommodated. Constant tensors (loaded from external data) cannot be evicted since no corresponding operation rematerializes them. Upon tensor access, DTR first checks if the tensor is resident in memory. If so, it updates tensor metadata before returning the requested tensor. If the tensor has been evicted, DTR rematerializes it by replaying the parent operation that originally produced the tensor. Crucially, rematerialization can be recursive: if the arguments to an evicted tensor's parent operation have also been evicted, then they must first be

