ITERATIVE PATCH SELECTION FOR HIGH-RESOLUTION IMAGE RECOGNITION

Abstract

High-resolution images are prevalent in various applications, such as autonomous driving and computer-aided diagnosis. However, training neural networks on such images is computationally challenging and easily leads to out-of-memory errors even on modern GPUs. We propose a simple method, Iterative Patch Selection (IPS), which decouples the memory usage from the input size and thus enables the processing of arbitrarily large images under tight hardware constraints. IPS achieves this by selecting only the most salient patches, which are then aggregated into a global representation for image recognition. For both patch selection and aggregation, a cross-attention based transformer is introduced, which exhibits a close connection to Multiple Instance Learning. Our method demonstrates strong performance and has wide applicability across different domains, training regimes and image sizes while using minimal accelerator memory. For example, we are able to finetune our model on whole-slide images consisting of up to 250k patches (>16 gigapixels) with only 5 GB of GPU VRAM at a batch size of 16.

1. INTRODUCTION

Image recognition has made great strides in recent years, spawning landmark architectures such as AlexNet (Krizhevsky et al., 2012) or ResNet (He et al., 2016) . These networks are typically designed and optimized for datasets like ImageNet (Russakovsky et al., 2015) , which consist of natural images well below one megapixel. 1 In contrast, realworld applications often rely on high-resolution images that reveal detailed information about an object of interest. For example, in self-driving cars, megapixel images are beneficial to recognize distant traffic signs far in advance and react in time (Sahin, 2019) . In medical imaging, a pathology diagnosis system has to process gigapixel microscope slides to recognize cancer cells, as illustrated in Fig. 1 . Training neural networks on high-resolution images is challenging and can lead to out-of-memory errors even on dedicated high-performance hardware. Although downsizing the image can fix this problem, details critical for recognition may be lost in the process (Sabottke & Spieler, 2020; Katharopoulos & Fleuret, 2019) . Reducing the batch size is another common approach to decrease memory usage, but it does not scale to arbitrarily large inputs and may lead to instabilities in networks involving batch normalization (Lian & Liu, 2019) . On the other hand, distributed learning across multiple devices increases resources but is more costly and incurs higher energy consumption (Strubell et al., 2019) . We propose Iterative Patch Selection (IPS), a simple patch-based approach that decouples the consumed memory from the input size and thus enables the efficient processing of high-resolution images without running out of memory. IPS works in two steps: First, the most salient patches of an image are identified in no-gradient mode. Then, only selected patches are aggregated to train the network. We find that the attention scores of a cross-attention based transformer link both of these steps, and have a close connection to Multiple Instance Learning (MIL). In the experiments, we demonstrate strong performance across three very different domains and training regimes: traffic sign recognition on megapixel images, multi-task classification on synthetic megapixel MNIST digits, and using self-supervised pre-training together with our method for memory-efficient learning on the gigapixel CAMELYON16 benchmark. Furthermore, our method exhibits a significantly lower memory consumption compared to various baselines. For example, when scaling megapixel MNIST images from 1k to 10k pixels per side at a batch size of 16, we can keep peak memory usage at a constant 1.7 GB while maintaining high accuracy, in contrast to a comparable CNN, which already consumes 24.6 GB at a resolution of 2k×2k. In an ablation study, we further analyze and provide insights into the key factors driving computational efficiency in IPS. Finally, we visualize exemplary attention distributions and present an approach to obtain patch-level class probabilities in a weakly-supervised multi-label classification setting.

2. METHODS

We regard an image as a set of N patches. Each patch is embedded independently by a shared encoder network, resulting in D-dimensional representations, X ∈ R N ×D . Given the embeddings, we select the most salient patches and aggregate the information across these patches for the classification task. Thus our method, illustrated in Fig. 2 , consists of two consecutively executed modules: an iterative patch selection module that selects a fixed number of patches and a transformer-based patch aggregation module that combines patch embeddings to compute a global image embedding that is passed on to a classification head. Crucially, the patch aggregation module consists of a crossattention layer, that is used by the patch selection module in no-gradient mode to score patches. We discuss these in detail next and provide code at https://github.com/benbergner/ips.

2.1. ITERATIVE PATCH SELECTION

Given infinite memory, one could use an attention module to score each patch and select the top M patches for aggregation. However, due to limited GPU memory, one cannot compute and store all patch embeddings in memory at the same time. We instead propose to iterate over patches, I at a time, and autoregressively maintain a set of top M patch embeddings. In other words, say P t M is a buffer of M patch embeddings at time step t and P t+1 I are the next I patch embeddings in the autoregressive update step. We run the following for T iterations: P t+1 M = Top-M{P t M ∪ P t+1 I | a t+1 }, where T = (N -M )/I , P 0 M = {X 1 , . . . , X M } is the initial buffer of embeddings X {1,...,M } and a t+1 ∈ R M +I are attention scores of considered patches at iteration t + 1, based on which the selection in Top-M is made. These attention scores are obtained from the cross-attention transformer as described in Sect. 2.2. The output of IPS is a set of M patches corresponding to P T M . Note that both patch embedding and patch selection are executed in no-gradient and evaluation mode. The former entails that no gradients are computed and stored, which renders IPS runtime and memory-efficient. The latter ensures deterministic patch selection behavior when using BatchNorm and Dropout. Data loading We introduce three data loading strategies that trade off memory and runtime efficiency during IPS. In eager loading, a batch of images is loaded onto the GPU and IPS is applied to each image in parallel-this is the fastest variant but requires storing multiple images at once. In eager sequential loading, individual images are loaded onto the GPU and thus patches are selected for one image at a time until a batch of M patches per image is selected for training. This enables the processing of different sequence lengths without padding and reduces memory usage at the cost of a higher runtime. In contrast, lazy loading loads a batch of images onto CPU memory. Then, only patches and corresponding embeddings pertinent to the current iteration are stored on the GPU-this decouples GPU memory usage from the image size, again at the cost of a higher runtime.



For instance, a 256×256 image corresponds to only 0.06 megapixels.



selects salient patches for highresolution image recognition. In the top right insert, green contours represent ground truth cancerous cells and white overlays indicate high scoring patches from our model. Example patches are shown in the bottom right.

