LEARNING THE POSITIONS IN COUNTSKETCH

Abstract

We consider sketching algorithms which first compress data by multiplication with a random sketch matrix, and then apply the sketch to quickly solve an optimization problem, e.g., low-rank approximation and regression. In the learning-based sketching paradigm proposed by Indyk, Vakilian, and Yuan (2019), the sketch matrix is found by choosing a random sparse matrix, e.g., CountSketch, and then the values of its non-zero entries are updated by running gradient descent on a training data set. Despite the growing body of work on this paradigm, a noticeable omission is that the locations of the non-zero entries of previous algorithms were fixed, and only their values were learned. In this work, we propose the first learning-based algorithms that also optimize the locations of the non-zero entries. Our first proposed algorithm is based on a greedy algorithm. However, one drawback of the greedy algorithm is its slower training time. We fix this issue and propose approaches for learning a sketching matrix for both low-rank approximation and Hessian approximation for second order optimization. The latter is helpful for a range of constrained optimization problems, such as LASSO and matrix estimation with a nuclear norm constraint. Both approaches achieve good accuracy with a fast running time. Moreover, our experiments suggest that our algorithm can still reduce the error significantly even if we only have a very limited number of training matrices.

1. INTRODUCTION

The work of (Indyk et al., 2019) investigated learning-based sketching algorithms for low-rank approximation. A sketching algorithm is a method of constructing approximate solutions for optimization problems via summarizing the data. In particular, linear sketching algorithms compress data by multiplication with a sparse "sketch matrix" and then use just the compressed data to find an approximate solution. Generally, this technique results in much faster or more space-efficient algorithms for a fixed approximation error. The pioneering work of Indyk et al. (2019) shows it is possible to learn sketch matrices for low-rank approximation (LRA) with better average performance than classical sketches. In this model, we assume inputs come from an unknown distribution and learn a sketch matrix with strong expected performance over the distribution. This distributional assumption is often realisticthere are many situations where a sketching algorithm is applied to a large batch of related data. For example, genomics researchers might sketch DNA from different individuals, which is known to exhibit strong commonalities. The high-performance computing industry also uses sketching, e.g., researchers at NVIDIA have created standard implementations of sketching algorithms for CUDA, a widely used GPU library. They investigated the (classical) sketched singular value decomposition (SVD), but found that the solutions were not accurate enough across a spectrum of inputs (Chien & Bernabeu, 2019) . This is precisely the issue addressed by the learned sketch paradigm where we optimize for "good" average performance across a range of inputs. While promising results have been shown using previous learned sketching techniques, notable gaps remain. In particular, all previous methods work by initializing the sketching matrix with a random sparse matrix, e.g., each column of the sketching matrix has a single non-zero value chosen at a uniformly random position. Then, the values of the non-zero entries are updated by running gradient descent on a training data set, or via other methods. However, the locations of the non-zero entries are held fixed throughout the entire training process. Clearly this is sub-optimal. Indeed, suppose the input matrix A is an n × d matrix with first d rows equal to the d × d identity matrix, and remaining rows equal to 0. A random sketching matrix S with a single non-zero per column is known to require m = Ω(d 2 ) rows in order for S •A to preserve the rank of A (Nelson & Nguyên, 2014); this follows by a birthday paradox argument. On the other hand, it is clear that if S is a d × n matrix with first d rows equal to the identity matrix, then ∥S • Ax∥ 2 = ∥Ax∥ 2 for all vectors x, and so S preserves not only the rank of A but all important spectral properties. A random matrix would be very unlikely to choose the non-zero entries in the first d columns of S so perfectly, whereas an algorithm trained to optimize the locations of the non-zero entries would notice and correct for this. This is precisely the gap in our understanding that we seek to fill. Learned CountSketch Paradigm of Indyk et al. (2019) . Throughout the paper, we assume our data A ∈ R n×d is sampled from an unknown distribution D. Specifically, we have a training set Tr = {A 1 , . . . , A N } ∈ D. The generic form of our optimization problems is min X f (A, X), where A ∈ R n×d is the input matrix. For a given optimization problem and a set S of sketching matrices, define ALG(S, A) to be the output of the classical sketching algorithm resulting from using S; this uses the sketching matrices in S to map the given input A and construct an approximate solution X. We remark that the number of sketches used by an algorithm can vary and in its simplest case, S is a single sketch, but in more complicated sketching approaches we may need to apply sketching more than once-hence S may also denote a set of more than one sketching matrix. The learned sketch framework has two parts: (1) offline sketch learning and (2) "online" sketching (i.e., applying the learned sketch and some sketching algorithm to possibly unseen data). In offline sketch learning, the goal is to construct a CountSketch matrix (abbreviated as CS matrix) with the minimum expected error for the problem of interest. Formally, that is, arg min CS S E A∈Tr f (A, ALG(S, A)) -f (A, X * ) = arg min CS S E A∈Tr f (A, ALG(S, A)), where X * denotes the optimal solution. Moreover, the minimum is taken over all possible constructions of CS. We remark that when ALG needs more than one CS to be learned (e.g., in the sketching algorithm we consider for LRA), we optimize each CS independently using a surrogate loss function. In the second part of the learned sketch paradigm, we take the sketch from part one and use it within a sketching algorithm. This learned sketch and sketching algorithm can be applied, again and again, to different inputs. Finally, we augment the sketching algorithm to provide worst-case guarantees when used with learned sketches. The goal is to have good performance on A ∈ D while the worst-case performance on A ̸ ∈ D remains comparable to the guarantees of classical sketches. We remark that the learned matrix S is trained offline only once using the training data. Hence, no additional computational cost is incurred when solving the optimization problem on the test data. Our Results. In this work, in addition to learning the values of the non-zero entries, we learn the locations of the non-zero entries. Namely, we propose three algorithms that learn the locations of the non-zero entries in CountSketch. Our first algorithm (Section 4) is based on a greedy search. The empirical result shows that this approach can achieve a good performance. Further, we show that the greedy algorithm is provably beneficial for LRA when inputs follow a certain input distribution (Section F). However, one drawback of the greedy algorithm is its much slower training time. We then fix this issue and propose two specific approaches for optimizing the positions for the sketches for low-rank approximation and second-order optimization, which run much faster than all previous algorithms while achieving better performance. For low-rank approximation, our approach is based on first sampling a small set of rows based on their ridge leverage scores, assigning each of these sampled rows to a unique hash bucket, and then placing each non-sampled remaining row in the hash bucket containing the sampled row for which it is most similar to, i.e., for which it has the largest dot product with. We also show that the worst-case guarantee of this approach is strictly better than that of the classical Count-Sketch (see Section 5).

