A FRAMEWORK FOR LEARNED COUNTSKETCH

Abstract

Sketching is a compression technique that can be applied to many problems to solve them quickly and approximately. The matrices used to project data to smaller dimensions are called "sketches". In this work, we consider the problem of optimizing sketches to obtain low approximation error over a data distribution. We introduce a general framework for "learning" and applying CountSketch, a type of sparse sketch. The sketch optimization procedure has two stages: one for optimizing the placements of the sketch's non-zero entries and another for optimizing their values. Next, we provide a way to apply learned sketches that has worst-case guarantees for approximation error. We instantiate this framework with three sketching applications: least-squares regression, low-rank approximation (LRA), and k-means clustering. Our experiments demonstrate that our approach substantially decreases approximation error compared to classical and naïvely learned sketches. Finally, we investigate the theoretical aspects of our approach. For regression and LRA, we show that our method obtains state-of-the art accuracy for fixed time complexity. For LRA, we prove that it is strictly better to include the first optimization stage for two standard input distributions. For k-means, we derive a more straightforward means of retaining approximation guarantees.

1. INTRODUCTION

In recent years, we have seen the influence of machine learning extend far beyond the field of artificial intelligence. The underlying paradigm, which assumes that a given algorithm has an input distribution for which algorithm parameters can be optimized, has even been applied to classical algorithms. Examples of classical problems that have benefitted from ML include cache eviction strategies, online algorithms for job scheduling, frequency estimation of data stream elements, and indexing strategies for data structures (Lykouris & Vassilvitskii, 2018; Purohit et al., 2018; Hsu et al., 2019; Kraska et al., 2018) . This input distribution assumption is often realistic. For example, many real-world applications use data streaming to track things like product purchasing statistics in real time. Consecutively streamed datapoints are usually tightly correlated and closely fit certain distributions. We are interested in how this distributional paradigm can be applied to sketching, a data compression technique. With the dramatic increase in the dimensions of data collected in the past decade, compression methods are more important than ever. Thus, it is of practical interest to improve the accuracy and efficiency of sketching algorithms. We study a sketching scheme in which the input matrix is compressed by multiplying it with a "sketch" matrix with a small dimension. This smaller, sketched input is then used to compute an approximate solution. Typically, the sketch matrix and the approximation algorithm are designed to satisfy worst-case bounds on approximation error for arbitrary inputs. With the ML perspective in mind, we examine if it is possible to construct sketches which also have low error in expectation over an input distribution. Essentially, we aim for the best of both worlds: good performance in practice with theoretical worst-case guarantees. Further, we are interested in methods that work for multiple sketching applications. Typically, sketching is very application-specific. The sketch construction and approximation algorithm are tailored to individual applications, like robust regression or clustering (Sarlos, 2006; Clarkson & Woodruff, 2009; 2014; 2017; Cohen et al., 2015; Makarychev et al., 2019) . Instead, we consider Our results. At a high level, our work's aim is to make sketch learning more effective, general, and ultimately, practical. We propose a framework for constructing and using learned CountSketch. We chose CountSketch because it is a sparse, input-independent sketch (Charikar et al., 2002) . Specifically, it has one non-zero entry (±1) per column and does not need to be constructed anew for each input matrix it is applied to. These qualities enable CountSketch to be applied quickly, since sparse matrix multiplication is fast and we can reuse the same CountSketch for different inputs. Our "learned" CountSketch will retain this characteristic sparsity pattern and input-independencefoot_0 , but its non-zero entries will range in R. We list our main contributions and follow this with a discussion. • Two-stage sketch optimization: to first place the non-zero entries and then learn their values. • Theoretical worst-case guarantees, two ways: we derived a time-optimal method which applies to MRR, LRA, k-means, and more. We also proved a simpler method works for k-means. • SOTA experimental results: we showed the versatility of our method on 5 data sets with 3 types. Our method dominated on the majority of experiments. • Theoretical analysis on the necessity of two stages: we proved that including the first stage is strictly better for LRA and two common input distributions. • Empirical demonstration of the necessity of two stages: showed that including the first stage gives a 12, 20% boost for MRR, LRA. Our sketch learning algorithm first places the sparse non-zero entries using a greedy strategy, and then learns their values using gradient descent. The resulting learned CountSketch is very different from the classical CountSketch: the non-zero entries no longer have random positions and ±1 values. As a result, the usual worst-case guarantees do not hold. We sought a way to obtain worst-case guarantees that was fast and reasonably general. Our solution is a fast comparison step which performs an approximate evaluation of learned and classical sketches and takes the better of the two. Importantly, we can run this step before the approximation algorithm without increasing its overall time complexity. As such, this solution is time-optimal and applies to MRR, LRA, k-means, and more. An alternate method was proposed by a previous work, but it was only proved for LRA (Indyk et al., 2019) . This "sketch concatenation" method just involves sketching with the concatenation of a learned and a classical sketch. Since it is somewhat simpler, we wanted to extend its applicability. In a novel theoretical result, we proved this works for k-means as well. We also ran a diverse set of experiments to demonstrate the versatility and practicality of our approach. We chose five data sets spanning three categories (image, text, and graph) to test our method on three applications (MRR, LRA, k-means). Importantly, these experiments have real-world counterparts. For example, LRA and k-means can be used to compress images, applying SVD (LRA) to text data is the basis of a natural language processing technique, and LRA can be used to compute approximate max cuts on graph adjacency matrices. Ultimately, our method dominated on the vast majority of tests, giving a 31, 70% improvement over classical CountSketch for MRR, LRA. Finally, we conducted ablation study of the components of our algorithm. In another novel theoretical result, we proved that including the time-consuming first optimization stage is strictly better than not to for LRA and two input distributions (spiked covariance and Zipfian). Empirically, this is case for all 3 applications. Related work. In the last few years, there has been much work on leveraging ML to improve classical algorithms; we only mention a few examples here. One related body of work is data-dependent



While learned CountSketch is data-dependent (it is optimized using sample input matrices), it is still considered input-independent because it is applied to unseen input matrices (test samples).

