SAMPLED TRANSFORMER FOR POINT SETS

Abstract

The sparse transformer can reduce the computational complexity of the selfattention layers to O(n), whilst still being a universal approximator of continuous sequence-to-sequence functions. However, this permutation variant operation is not appropriate for direct application to sets. In this paper, we proposed an O(n) complexity sampled transformer that can process point set elements directly without any additional inductive bias. Our sampled transformer introduces random element sampling, which randomly splits point sets into subsets, followed by applying a shared Hamiltonian self-attention mechanism to each subset. The overall attention mechanism can be viewed as a Hamiltonian cycle in the complete attention graph, and the permutation of point set elements is equivalent to randomly sampling Hamiltonian cycles. This mechanism implements a Monte Carlo simulation of the O(n 2 ) dense attention connections. We show that it is a universal approximator for continuous set-to-set functions. Experimental results for classification and few-shot learning on point-clouds show comparable or better accuracy with significantly reduced computational complexity compared to the dense transformer or alternative sparse attention schemes.

1. INTRODUCTION

Encoding structured data has become a focal point of modern machine learning. In recent years, the defacto choice has been to use transformer architectures for sequence data, e.g., in language (Vaswani et al., 2017) and image (Dosovitskiy et al., 2020) processing pipelines. Indeed, transformers have not only shown strong empirical results, but also have been proven to be universal approximators for sequence-to-sequence functions (Yun et al., 2019) . Although the standard transformer is a natural choice for set data, with permutation invariant dense attention, its versatility is limited by the costly O(n 2 ) computational complexity. To decrease the cost, a common trick is to use sparse attention, reducing the complexity from O(n 2 ) to O(n) (Yun et al., 2020; Zaheer et al., 2020; Guo et al., 2019) . However, in general this results in an attention mechanism that is not permutation invariant -swapping two set elements change which elements they attend. As a result, sparse attention cannot be directly used for set data. Recent work has explored the representation power of transformers in point sets as a plug-in module (Lee et al., 2019) , a pretraining-finetuning pipeline (Yu et al., 2022; Pang et al., 2022) , and with a hierarchical structure (Zhao et al., 2021) . However, these set transformers introduced additional inductive biases to (theoretically) approach the same performance as the densely connected case in language and image processing applications. For example, to achieve permutation invariance with efficient computational complexity, previous work has required nearest neighbor search (Zhao et al., 2021) or inducing points sampling (Lee et al., 2019) . Following the above analysis, a research question naturally arises to avoid introducing unneeded inductive bias: Can O(n) complexity sparse attention mechanisms be applied directly to sets? We propose the sampled transformer to address this question, which is distinguished from the original sparse transformer by mapping the permutation of set elements to the permutation of attention matrix elements. Viewing this permutation sampling as attention matrix sampling, the proposed sampled attention approximates O(n 2 ) dense attention. This is achieved with the proposed random element sampling and Hamiltonian self-attention. To be specific, in random element sampling the input point set is first randomly split into several subsets of n s points (Fig. 1b ), each of which will be processed by shared self-attention layers. In addition, a sparse attention mechanism -namely Hamiltonian self-attention (Fig. 1c ) -is applied to reduce complexity of the subset inputs, so that n s point connections are sampled from O(n 2 s ) connections. The combination of all Hamiltonian self-attention mechanism for all subsets -namely cycle attention (Fig. 1d ) -can be viewed as a Hamiltonian cycle in the complete attention graph. As a result, the permutation of set elements is equivalent to the permutation of nodes in a Hamiltonian cycle (Fig. 1e ), which is in fact randomly sampling Hamiltonian cycles from the complete graph -thereby yielding the proposed sampled attention (Fig. 1f ). Finally, viewing this randomization as a Monte Carlo sample of attention pairs, repeated sampling can be used to approximate the complete O(n 2 ) dense connections. Furthermore, our proposed sampled transformer is proven to be a universal approximator for set data -any continuous set-to-set functions can be approximated to arbitrary precision. The contributions of this paper are summarized as follows. • We propose the sampled attention mechanism which maps the random permutation of set elements to the random sampling of Hamiltonian cycle attention matrices, permitting the direct processing of point sets. • We prove that the proposed sampled transformer is a universal approximator of continuous set-to-set functions, see Corollary 1. • Compared to previous transformer architectures, the empirical results show that our proposed sampled transformer achieves comparable (or better) performance with less inductive bias and complexity.

2. RELATED WORK

The transformer (Vaswani et al., 2017) is widely used in languages (Raffel et al., 2020; Dai et al., 2019; Yang et al., 2019b) and images (Dosovitskiy et al., 2020; Liu et al., 2021; Touvron et al., 2021; Ramachandran et al., 2019) . For example, Raffel et al. ( 2020) explored the transformer by unifying a



Figure 1: Attention mechanisms: (a) original dense attention; (b) the attention matrix after random element sampling; (c) a special case of attention -Hamiltonian (self-)attention -for each subset; (d) combining all subsets (which have overlapping element per (b)) connects the individual Hamiltonian attention sub-matrices, gives cycle attention which is a Hamiltonian cycle; (e) permutation of points permutes the elements in cycle attention matrix; (f) the resulting sampled attention, viewed as a sampled Hamiltonian cycle from the edges of the complete attention graph.

