A THEORETICAL UNDERSTANDING OF SHALLOW VI-SION TRANSFORMERS: LEARNING, GENERALIZA-TION, AND SAMPLE COMPLEXITY

Abstract

Vision Transformers (ViTs) with self-attention modules have recently achieved great empirical success in many vision tasks. Due to non-convex interactions across layers, however, the theoretical learning and generalization analysis is mostly elusive. Based on a data model characterizing both label-relevant and label-irrelevant tokens, this paper provides the first theoretical analysis of training a shallow ViT, i.e., one self-attention layer followed by a two-layer perceptron, for a classification task. We characterize the sample complexity to achieve a zero generalization error. Our sample complexity bound is positively correlated with the inverse of the fraction of label-relevant tokens, the token noise level, and the initial model error. We also prove that a training process using stochastic gradient descent (SGD) leads to a sparse attention map, which is a formal verification of the general intuition about the success of attention. Moreover, this paper indicates that a proper token sparsification can improve the test performance by removing label-irrelevant and/or noisy tokens, including spurious correlations. Empirical experiments on synthetic data and CIFAR-10 dataset justify our theoretical results and generalize to deeper ViTs.

1. INTRODUCTION

As the backbone of Transformers (Vaswani et al., 2017) , the self-attention mechanism (Bahdanau et al., 2014) computes the feature representation by globally modeling long-range interactions within the input. Transformers have demonstrated tremendous empirical success in numerous areas, including nature language processing (Kenton & Toutanova, 2019; Radford et al., 2019; 2018; Brown et al., 2020 ), recommendation system (Zhou et al., 2018; Chen et al., 2019; Sun et al., 2019) , and reinforcement learning (Chen et al., 2021; Janner et al., 2021; Zheng et al., 2022) . Starting from the advent of Vision Transformer (ViT) (Dosovitskiy et al., 2020 ), Transformer-based models (Touvron et al., 2021; Jiang et al., 2021; Wang et al., 2021; Liu et al., 2021a) gradually replace convolutional neural network (CNN) architectures and become prevalent in vision tasks. Various techniques have been developed to train ViT efficiently. Among them, token sparsification (Pan et al., 2021; Rao et al., 2021; Liang et al., 2022; Tang et al., 2022; Yin et al., 2022) removes redundant tokens (image patches) of data to improve the computational complexity while maintaining a comparable learning performance. For example, Liang et al. ( 2022 (James Vuckovic, 2020; Kim et al., 2021) , properties of the neural tangent kernel (Hron et al., 2020; Yang, 2020) and expressive power and Turing-completeness (Dehghani et al., 2018; Yun et al., 2019; Bhattamishra et al., 2020a; b; Edelman et al., 2022; Dong et al., 2021;  



* Rensselaer Polytechnic Institute. Email: lih35@rpi.edu † Rensselaer Polytechnic Institute. Email: wangm7@rpi.edu ‡ Michigan State University & IBM Research. Email: liusiji5@msu.edu § IBM Research. Email: Pin-Yu.Chen@ibm.com



); Tang et al. (2022) prune tokens following criteria designed based on the magnitude of the attention map. Despite the remarkable empirical success, one fundamental question about training Transformers is still vastly open, which is Under what conditions does a Transformer achieve satisfactory generalization? Some recent works analyze Transformers theoretically from the perspective of proved Lipschitz constant of self-attention

