SAMPLED TRANSFORMER FOR POINT SETS

Abstract

The sparse transformer can reduce the computational complexity of the selfattention layers to O(n), whilst still being a universal approximator of continuous sequence-to-sequence functions. However, this permutation variant operation is not appropriate for direct application to sets. In this paper, we proposed an O(n) complexity sampled transformer that can process point set elements directly without any additional inductive bias. Our sampled transformer introduces random element sampling, which randomly splits point sets into subsets, followed by applying a shared Hamiltonian self-attention mechanism to each subset. The overall attention mechanism can be viewed as a Hamiltonian cycle in the complete attention graph, and the permutation of point set elements is equivalent to randomly sampling Hamiltonian cycles. This mechanism implements a Monte Carlo simulation of the O(n 2 ) dense attention connections. We show that it is a universal approximator for continuous set-to-set functions. Experimental results for classification and few-shot learning on point-clouds show comparable or better accuracy with significantly reduced computational complexity compared to the dense transformer or alternative sparse attention schemes.

1. INTRODUCTION

Encoding structured data has become a focal point of modern machine learning. In recent years, the defacto choice has been to use transformer architectures for sequence data, e.g., in language (Vaswani et al., 2017) and image (Dosovitskiy et al., 2020) processing pipelines. Indeed, transformers have not only shown strong empirical results, but also have been proven to be universal approximators for sequence-to-sequence functions (Yun et al., 2019) . Although the standard transformer is a natural choice for set data, with permutation invariant dense attention, its versatility is limited by the costly O(n 2 ) computational complexity. To decrease the cost, a common trick is to use sparse attention, reducing the complexity from O(n 2 ) to O(n) (Yun et al., 2020; Zaheer et al., 2020; Guo et al., 2019) . However, in general this results in an attention mechanism that is not permutation invariant -swapping two set elements change which elements they attend. As a result, sparse attention cannot be directly used for set data. Recent work has explored the representation power of transformers in point sets as a plug-in module (Lee et al., 2019) , a pretraining-finetuning pipeline (Yu et al., 2022; Pang et al., 2022) , and with a hierarchical structure (Zhao et al., 2021) . However, these set transformers introduced additional inductive biases to (theoretically) approach the same performance as the densely connected case in language and image processing applications. For example, to achieve permutation invariance with efficient computational complexity, previous work has required nearest neighbor search (Zhao et al., 2021) or inducing points sampling (Lee et al., 2019) . Following the above analysis, a research question naturally arises to avoid introducing unneeded inductive bias: Can O(n) complexity sparse attention mechanisms be applied directly to sets? We propose the sampled transformer to address this question, which is distinguished from the original sparse transformer by mapping the permutation of set elements to the permutation of attention matrix elements. Viewing this permutation sampling as attention matrix sampling, the proposed sampled attention approximates O(n 2 ) dense attention. This is achieved with the proposed random element sampling and Hamiltonian self-attention. To be specific, in random element sampling the input point set is first randomly split into several subsets of n s points (Fig. 1b ), each of which will be processed by shared self-attention layers. In addition, a sparse attention mechanism -namely

