LAMBDANETWORKS: MODELING LONG-RANGE INTERACTIONS WITHOUT ATTENTION

Abstract

We present lambda layers -an alternative framework to self-attention -for capturing long-range interactions between an input and structured contextual information (e.g. a pixel surrounded by other pixels). Lambda layers capture such interactions by transforming available contexts into linear functions, termed lambdas, and applying these linear functions to each input separately. Similar to linear attention, lambda layers bypass expensive attention maps, but in contrast, they model both content and position-based interactions which enables their application to large structured inputs such as images. The resulting neural network architectures, LambdaNetworks, significantly outperform their convolutional and attentional counterparts on ImageNet classification, COCO object detection and instance segmentation, while being more computationally efficient. Additionally, we design LambdaResNets, a family of hybrid architectures across different scales, that considerably improves the speed-accuracy tradeoff of image classification models. LambdaResNets reach excellent accuracies on ImageNet while being 3.2 -4.4x faster than the popular EfficientNets on modern machine learning accelerators. In large-scale semi-supervised training with an additinal 130M pseudo-labeled images, LambdaResNets achieve up to 86.7% ImageNet accuracy while being 9.5x faster than EfficientNet NoisyStudent and 9x faster than a Vision Transformer with comparable accuracies 1 .

1. INTRODUCTION

Modeling long-range dependencies in data is a central problem in machine learning. Selfattention (Bahdanau et al., 2015; Vaswani et al., 2017) has emerged as a popular approach to do so, but the costly memory requirement of self-attention hinders its application to long sequences and multidimensional data such as imagesfoot_1 . Linear (or efficient) attention mechanisms (Katharopoulos et al., 2020; Choromanski et al., 2020) offer a scalable remedy for high memory usage but fail to model internal data structure, such as relative distances between pixels or edge relations between nodes in a graph. This work addresses both issues. We propose lambda layers which model long-range interactions between a query and a structured set of context elements at a reduced memory cost. Lambda layers transform each available context into a linear function, termed a lambda, which is then directly applied to the corresponding query. Whereas self-attention defines a similarity kernel between the query and the context elements, a lambda layer instead summarizes contextual information into a fixed-size linear function (i.e. a matrix), thus bypassing the need for memory-intensive attention maps. This difference is illustrated in Figure 1 . Lambda layers are versatile and can be implemented to model both content-based and position-based interactions in global, local or masked contexts. The resulting neural networks, LambdaNetworks, are computationally efficient, model long-range dependencies at a small memory cost and can therefore be applied to large structured inputs such as high resolution images. We evaluate LambdaNetworks on computer vision tasks where works using self-attention are hindered by large memory costs (Wang et al., 2018; Bello et al., 2019) , suffer impractical implementations (Ramachandran et al., 2019) , or require vast amounts of data (Dosovitskiy et al., 2020) . In our experiments spanning ImageNet classification, COCO object detection and instance segmentation, LambdaNetworks significantly outperform their convolutional and attentional counterparts, while being more computationally efficient and faster than the latter. We summarize our contributions: • Lambda layers: a class of layers, that model content-based and position-based interactions without materializing attention maps. Lambda layers offer a unifying view of channel, spatial and linear attention (Appendix D.4). Some of our observations, such as the computational benefits of a multi-query formulation, extend to linear attention. Lambda layers are easily implemented with einsum operations and convolution kernels, operations with efficient implementations on modern machine learning accelerators. • Lambda layers significantly outperform their convolution and attention counterparts on the ImageNet classification task while being more computationally efficient. For example, simply replacing the 3x3 convolutions in the bottleneck blocks of the ResNet-50 architecture (He et al., 2016) with lambda layers yields a +1.5% top-1 ImageNet accuracy improvement while reducing parameters by 40% (Section 5.1). • Lambda layers achieve considerable computational benefits, both in latency and memory requirements, over multiple self-attention alternatives, including local and axial attention (Ramachandran et al., 2019; Wang et al., 2020a) . When used in a ResNet-50 architecture at image resolution 224, lambda layers reduce memory consumption by ∼200x compared to global attention (∼7x compared to axial attention) while being ∼3.7x faster than local attention (Section 5.2). • A study of hybrid convolution-lambda models as a means to maximize the speed-accuracy tradeoff (Section 5.3). Hybrid designs that first employ convolutions at the highest resolution and lambda layers in intermediate to low resolutions achieve the best speed-accuracy tradeoff. • LambdaResNets: a family of hybrids based on the training and scaling strategies recommended in Bello et al. (2021) . LambdaResNets achieve up to a 4.4x speed-up over EfficientNets on ImageNet, while being more memory-efficient. LambdaResNets can also be designed for parameter or flops efficiency. For example, a LambdaResNet with 42M parameters achieves 84.3% top-1 ImageNet accuracy at image resolution 320 (Section E.4). • In large-scale semi-supervised training with an additional 130M pseudo-labeled images, LambdaResNets achieve up to 86.7% top-1 ImageNet accuracy while being 9.5x faster than EfficientNet NoisyStudent (Xie et al., 2020) and 9x faster than a Vision Transformer (Dosovitskiy et al., 2020) with comparable accuracies (Section 5.3). • An evaluation of LambdaResNets on COCO object detection and instance segmentation using Mask-RCNN (He et al., 2017) . LambdaResNet backbones yield consistent gains across all metrics on both tasks (e.g. +1.8% mAP improvement for detecting small objects). A content-based interaction considers the content of the context but ignores the relation between the query position and the context (e.g. relative distance between two pixels). A position-based interaction considers the relation between the query position and the context position. Table 1 : Definition of content-based vs position-based interactions.

2. MODELING LONG-RANGE INTERACTIONS

In this section, we formally define queries, contexts and interactions. Starting from first principles, we motivate keys and relative position embeddings as a requirement for capturing structured interactions between queries and their contexts. We then show that lambda layers arise as an alternative to attention mechanisms for capturing long-range interactions. Notation. We denote scalars, vectors and tensors using lower-case, bold lower-case and bold upper-case letters, e.g., n, x and X. We denote |n| the cardinality of a set whose elements are indexed by n. We denote x n the n-th row of X. We denote x ij the |ij| elements of X. When possible, we adopt the terminology of self-attention to ease readability and highlight differences.

2.1. MOTIVATING QUERIES, KEYS, POSITION EMBEDDINGS AND VALUES

Defining queries and contexts. Let Q = {(q n , n)} and C = {(c m , m)} denote structured collections of vectors, respectively referred to as the queries and the context. Each query (q n , n) is characterized by its content q n ∈ R |k| and position n. Similarly, each context element (c m , m) is characterized by its content c m and its position m in the context. The (n, m) pair may refer to any pairwise relation between structured elements, e.g. relative distances between pixels or edges between nodes in a graph. Defining interactions. We consider the general problem of mapping a query (q n , n) to an output vector y n ∈ R |v| given the context C with a function F : ((q n , n), C) → y n . Such a function may act as a layer in a neural network when processing structured inputs. We refer to (q n , c m ) interactions as content-based and (q n , (n, m)) interactions as position-based. We note that while absolute positional information is sometimes directly added to the query (or context element) contentfoot_2 , we consider this type of interaction to be content-based as it ignores the relation (n, m) between the query and context element positions. Introducing keys and relative position embeddings to capture long-range interactions. In the context of deep learning, we prioritize fast batched linear operations and use dot-product operations as our interactions. This motivates introducing vectors that can interact with the queries via a dotproduct operation and therefore have the same dimension as the queries. In particular, content-based interactions (q n , c m ) require a |k|-dimensional vector that depends on c m , commonly referred to as the key k m . Conversely, position-based interactions (q n , (n, m)) require a relative position embedding e nm ∈ R |k| (Shaw et al., 2018) . As the query/key depth |k| and context spatial dimension |m| are not in the output y n ∈ R |v| , these dimensions need to be contracted as part of the layer computations. Therefore Every layer capturing long-range interactions can be characterized based on whether it contracts (1) the query depth or (2) the context positions first.

2.2. ATTENTION VS LAMBDA LAYERS.

(1) Attention layers. Contracting the query depth first creates a similarity kernel (the attention map) between the query and context elements and is known as the attention operation. As the number of context positions |m| grows larger and the input and output dimensions |k| and |v| remain fixed, one may hypothesize that computing attention maps become wasteful, given that the layer output is a vector of comparatively small dimension |v| |m|. (2) Lambda layers. Instead, it may be more efficient to simply map each query to its output as  y n = F ((q n , n), C) = λ(C, n)(q n ) for some linear function λ(C, n) : R |k| → R |v| . In this Context V K σ(K) λ c λ n p λ 1 p En E1 λ 1 λ n q 1 q n x 1 x n y 1 y n

Generating lambdas from the context

X ∈ R |n|×d inputs C ∈ R |m|×d context Q = XW Q ∈ R |n|×|k| queries K = CW K ∈ R |m|×|k| keys V = CW V ∈ R |m|×|v| values σ(K) = softmax(K, axis=m) normalized keys En ∈ R |m|×|k| relative position embeddings λ c = KT V ∈ R |k|×|v| content lambda λ p n = E T n V ∈ R |k|×|v| position lambdas λn = λ c + λ p n ∈ R |k|×|v| lambdas Figure 2 : Computational graph of the lambda layer. Contextual information for query position n is summarized into a lambda λ n ∈ R |k|×|v| . Applying the lambda dynamically distributes contextual features to produce the output as y n = λ T n q n . This process captures content-based and position-based interactions without producing attention maps. scenario, the context is aggregated into a fixed-size linear function λ n = λ(C, n). Each λ n acts as a small linear functionfoot_3 that exists independently of the context (once computed) and is discarded after being applied to its associated query q n .

3.1. LAMBDA LAYER: TRANSFORMING CONTEXTS INTO LINEAR FUNCTIONS.

A lambda layer takes the inputs X ∈ R |n|×din and the context C ∈ R |m|×dc as input and generates linear function lambdas that are then applied to the queries, yielding outputs Y ∈ R |n|×dout . Without loss of generality, we assume d in = d c = d out = d. As is the case with self -attention, we we may have C = X. In the rest of this paper, we focus on a specific instance of a lambda layer and show that it captures long-range content and position-based interactions without materializing attention maps. Figure 2 presents the computational graph of the lambda layer. We first describe the lambda layer when applied to a single query (q n , n). Generating the contextual lambda function. We wish to generate a linear function R |k| → R |v| , i.e. a matrix λ n ∈ R |k|×|v| . The lambda layer first computes keys K and values V by linearly projecting the context, and keys are normalized across context positions via a softmax operation yielding normalized keys K. The λ n matrix is obtained by using the normalized keys K and position embeddings E n to aggregate the values V as λ n = m ( km + e nm )v T m = KT V content lambda + E T n V position lambda ∈ R |k|×|v| (1) where we also define the content lambda λ c and position lambda λ p n . • The content lambda λ c is shared across all query positions n and is invariant to permutation of the context elements. It encodes how to transform the query q n solely based on the context content. • The position lambda λ p n depends on the query position n via the position embedding E n . It encodes how to transform the query q n based on the context elements c m and their relative positions to the query (n, m). Applying lambda to its query. The query q n ∈ R |k| is obtained from the input x n via a learned linear projection and the output of the lambda layer is obtained as y n = λ T n q n = (λ c + λ p n ) T q n ∈ R |v| . ( ) Interpretation of lambda layers. The columns of the λ n ∈ R |k|×|v| matrix can be viewed as a fixed-size set of |k| contextual features. These contextual features are aggregated based on the context's content (content-based interactions) and structure (position-based interactions). Applying the lambda then dynamically distributes these contextual features based on the query to produce the output as y n = k q nk λ nk . This process captures content and position-based interactions without producing attention maps and can be viewed as an efficient relative attention mechanism. Normalization. One may modify Equations 1 and 2 to include non-linearities or normalization operations. Our experiments indicate that applying batch normalization (Ioffe & Szegedy, 2015) after computing the queries and the values is helpful. 3.2 A MULTI-QUERY FORMULATION TO REDUCE COMPLEXITY. Complexity analysis. For a batch of |b| examples, each containing |n| inputs, the number of arithmetic operations and memory footprint required to apply our lambda layer are respectively Θ(bnmkv) and Θ(knm + bnkv). We still have a quadratic memory footprint with respect to the input length due to the e nm relative position embeddings. However this quadratic term does not scale with the batch size as is the case with the attention operation which produces per-example attention maps. In practice, the hyperparameter |k| is set to a small value (such as |k|=16) and we can process large batches of large inputs in cases where attention cannot (see Table 4 ). Additionally, position embeddings can be shared across lambda layers to keep their Θ(knm) memory footprint constantwhereas the memory footprint of attention maps scales with the number of layersfoot_4 . We propose to decouple the time and space complexities of our lambda layer from the output dimension d. Rather than imposing |v|=d, we create |h| queries {q h n }, apply the same lambda λ n to each query q h n , and concatenate the outputs as y n = concat(λ n q 1 n , • • • , λ n q |h| n ). We now have |v|=d/|h|, which reduces complexity by a factor of |h|. The number of heads |h| controls the size of the lambdas λ n ∈ R |k|×|d|/|h| relative to the total size of the queries q n ∈ R |hk| . We refer to this operation as a multi-query lambda layer and present an implementation using einsumfoot_5 in Figure 3 . The lambda layer is robust to |k| and |h| hyperparameter choices (see Appendix E.1), which enables flexibility in controlling its complexity. We use |h|=4 in most experi- ments. We note that while this resembles the multi-head or multi-query (Shazeer, 2019) foot_6 attention formulation, the motivation is different. Using multiple queries in the attention operation increases representational power and complexity. In contrast, using multiple queries in the lambda layer decreases complexity and representational power (ignoring the additional queries). Extending the multi-query formulation to linear attention. Finally, we point that our analysis extends to linear attention which can be viewed as a content-only lambda layer (see Appendix D.3 for a detailed discussion). We anticipate that the multi-query formulation can also bring computational benefits to linear attention mechanisms.

3.3. MAKING LAMBDA LAYERS TRANSLATION EQUIVARIANT.

Using relative position embeddings e nm enables making explicit assumptions about the structure of the context. In particular, translation equivariance (i.e. the property that shifting the inputs results in an equivalent shift of the outputs) is a strong inductive bias in many learning scenarios. We obtain translation equivariance in position interactions by ensuring that the position embeddings satisfy e nm = e t(n)t(m) for any translation t. In practice, we define a tensor of relative position embeddings R ∈ R |r|×|k| , where r indexes the possible relative positions for all (n, m) pairs, and reindexfoot_7 it into E ∈ R |n|×|m|×|k| such that e nm = r r(n,m) . 3.4 LAMBDA CONVOLUTION: LOCAL CONTEXTS ON THE GRID. Despite the benefits of long-range interactions, locality remains a strong inductive bias in many tasks. Using global contexts may prove noisy or computationally excessive. It may therefore be useful to restrict the scope of position interactions to a local neighborhood around the query position n as is the case for local self-attention and convolutions. This can be done by zeroing out the relative embeddings for context positions m outside of the desired scope. However, this strategy remains costly for large values of |m| since the computations still occur -they are only being zeroed out. Lambda convolution In the case where the context is arranged in a multidimensional grid, we can equivalently compute positional lambdas from local contexts by using a regular convolution. We term this operation the lambda convolution. A n-dimensional lambda convolution can be implemented using an n-d depthwise convolution with channel multiplier or (n+1)-d convolution that treats the v dimension in V as an extra spatial dimension. We present both implementations in Appendix C.1. As the computations are now restricted to a local scope, the lambda convolution obtains linear time and memory complexities with respect to the input length 9 . The lambda convolution is readily usable with additional functionalities such as dilation and striding and enjoys optimized implementations on specialized hardware accelerators (Nickolls & Dally, 2010; Jouppi et al., 2017) . This is in stark contrast to implementations of local self-attention that require materializing feature patches of overlapping query and context blocks (Parmar et al., 2018; Ramachandran et al., 2019) , increasing memory consumption and latency (see Table 4 ).

4. RELATED WORK

Table 2 reviews alternatives for capturing long-range interactions and contrasts them with the proposed multi-query lambda layer. We discuss related works in details in the Appendix D. Channel and linear attention The lambda abstraction, i.e. transforming available contexts into linear functions that are applied to queries, is quite general and therefore encompasses many previous works. Closest to our work are channel and linear attention mechanisms (Hu et al., 2018c; Katharopoulos et al., 2020; Choromanski et al., 2020) . Such mechanisms also capture long-range interactions without materializing attention maps and can be viewed as specific instances of a contentonly lambda layer. Lambda layers formalize and extend such approaches to consider both contentbased and position-based interactions, enabling their use as a stand-alone layer on highly structured data such as images. Rather than attempting to closely approximate an attention kernel as is the case with linear attention, we focus on the efficient design of contextual lambda functions and repurpose a multi-query formulation (Shazeer, 2019) to further reduce computational costs. Self-attention in the visual domain In contrast to natural language processing tasks where it is now the de-facto standard, self-attention has enjoyed steady but slower adoption in the visual domain (Wang et al., 2018; Bello et al., 2019; Ramachandran et al., 2019; Carion et al., 2020) . Concurrently to this work, Dosovitskiy et al. ( 2020) achieve a strong 88.6% accuracy on ImageNet by pre-training a Transformer on sequences of image patches on a large-scale dataset of 300M images.

5. EXPERIMENTS

In subsequent experiments, we evaluate lambda layers on standard computer vision benchmarks: ImageNet classification (Deng et al., 2009) , COCO object detection and instance segmentation (Lin et al., 2014) . The visual domain is well-suited to showcase the flexibility of lambda layers since (1) the memory footprint of self-attention becomes problematic for high-resolution imagery and (2) images are highly structured, making position-based interactions crucial. LambdaResNets We construct LambdaResNets by replacing the 3x3 convolutions in the bottleneck blocks of the ResNet architecture (He et al., 2016) . When replacing all such convolutions, we simply denote the name of the layer being tested (e.g. conv + channel attention or lambda layer). We denote LambdaResNets the family of hybrid architectures described in Table 19 (Appendix F.2). Unless specified otherwise, all lambda layers use |k|=16, |h|=4 with a scope size of |m|=23x23 and are implemented as in Figure 3 . Additional experiments and details can be found in the Appendix.

5.1. LAMBDA LAYERS OUTPERFORM CONVOLUTIONS AND ATTENTION LAYERS

We first consider the standard ResNet-50 architecture with input image size 224x224. In Table 3 , we compare the lambda layer against (a) the standard convolution (i.e. the baseline ResNet-50) (b) channel attention (squeeze-and-excitation) and (c) multiple self-attention variants. The lambda layer strongly outperforms all baselines at a fraction of the parameter cost and notably obtains a +0.8% improvement over channel attention.

5.2. COMPUTATIONAL BENEFITS OF LAMBDA LAYERS OVER SELF-ATTENTION

In Table 4 : The lambda layer reaches higher ImageNet accuracies while being faster and more memory-efficient than self-attention alternatives. Memory is reported assuming full precision for a batch of 128 inputs using default hyperparameters. The memory cost for storing the lambdas matches the memory cost of activations in the rest of the network and is therefore ignored. b: batch size, h: number of heads/queries, n: input length, m: context length, k: query/key depth, l: number of layers.

5.3. HYBRIDS IMPROVE THE SPEED-ACCURACY TRADEOFF OF IMAGE CLASSIFICATION

Studying hybrid architectures. In spite of the memory savings compared to self-attention, capturing global contexts with the lambda layer still incurs a quadratic time complexity (Table 2 ), which remains costly at high resolution. In Appendix 5.3, we study hybrid designs that use standard convolutions to capture local contexts and lambda layers to capture global contexts. We find that such convolution-lambda hybrids have increased representational power at a negligible decrease in throughput compared to their purely convolutional counterparts. LambdaResNets significantly improve the speed-accuracy tradeoff of ImageNet classification. We design a family of hybrids based on our study of hybrid architectures and the scaling/training strategies from Bello et al. (2021) (Section F.2). Figure 4 presents the speed-accuracy Pareto curve of LambdaResNets compared to EfficientNets (Tan & Le, 2019 ) on TPUv3 hardware. In order to isolate the benefits of lambda layers, we additionally compare against the same architectures when replacing lambda layers by (1) standard 3x3 convolutions (denoted ResNet-RS wo/ SE) and (2) 3x3 convolutions with squeeze-and-excitation (denoted ResNet-RS w/ SE). All architectures are trained for 350 epochs using the same regularization methods and evaluated at the same resolution they are trained at. LambdaResNets outperform the baselines across all scales on the speed-accuracy trade-off. Scaling to larger datasets with pseudo-labels We train LambdaResNets in a semi-supervised learning setting using 130M pseudo-labeled images from the JFT dataset, as done for training the EfficientNet-NoisyStudent checkpoints (Xie et al., 2020) . 

6. CONCLUSION

We propose a new class of layers, termed lambda layers, which provide a scalable framework for capturing structured interactions between inputs and their contexts. Lambda layers summarize available contexts into fixed-size linear functions, termed lambdas, that are directly applied to their associated queries. The resulting neural networks, LambdaNetworks, are computationally efficient and capture long-range dependencies at a small memory cost, enabling their application to large structured inputs such as high-resolution images. Extensive experiments on computer vision tasks showcase their versatility and superiority over convolutional and attentional networks. We introduce LambdaResNets, a family of hybrid LambdaNetworks which reach excellent ImageNet accuracies and achieve up to 9.5x speed-ups over the popular EfficientNets and Vision Transformers, significantly improving the speed-accuracy tradeoff of image classification models. How do lambda layers compare to the attention operation? Lambda layers scale favorably compared to self-attention. Vanilla Transformers using self-attention have Θ(blhn 2 ) memory footprint, whereas LambdaNetworks have Θ(lkn 2 ) memory footprint (or Θ(kn 2 ) when sharing positional embeddings across layers). This enables the use of lambda layers at higher-resolution and on larger batch sizes. Additionally, the lambda convolution enjoys a simpler and faster implementation than its local self-attention counterpart. Finally, our ImageNet experiments show that lambda layers outperforms self-attention, demonstrating that the benefits of lambda layers go beyond improved speed and scalability. How are lambda layers different than linear attention mechanisms? Lambda layers generalize and extend linear attention formulations to capture position-based interactions, which is crucial for modeling highly structured inputs such as images (see Table 9 in Appendix E.1). As the aim is not to approximate an attention kernel, lambda layers allow for more flexible non-linearities and normalizations which we also find beneficial (see Table 11 in Appendix E.1). Finally, we propose multi-query lambda layers as a means to reduce complexity compared to the multi-head (or single-head) formulation typically used in linear attention works. Appendix D.3 presents a detailed discussion of linear attention. How to best use lambda layers in the visual domain? The improved scalability, speed and ease of implementation of lambda layers compared to global or local attention makes them a strong candidate for use in the visual domain. Our ablations demonstrate that lambda layers are most beneficial in the intermediate and low-resolution stages of vision architectures when optimizing for the speed-accuracy tradeoff. It is also possible to design architectures that rely exclusively on lambda layers which can be more parameter and flops efficient. We discuss practical modeling recommendations in Appendix B.

A.2 EXTENDING LAMBDA LAYERS TO OTHER MODALITIES

While this work focuses on static image recognition, we note that lambda layers may be instantiated to model structured interactions on structures as diverse as graphs, time series, spatial lattices, etc. We anticipate that lambda layers will be helpful in more modalities, including multimodal tasks. We discuss masked contexts and auto-regressive tasks in the Appendix C.2. Lambda layers can be instantiated on other tasks simply by adapting the choice of structural/position embeddings to the task of interest and following the pseudo-code presented in Figure 3 . The choice of embeddings dictates the memory costs of the lambda layer. The assumption underlying the Θ(knm) space complexity of the lambda layer (Section 3.2) is that all examples in the batch share the same structure, i.e. relative position embeddings have shape k × n × m. This assumption does not hold when the data structure is different across examples (e.g. graphs with variable edge relations between nodes), in which case embeddings have shape b × k × n × m. In such cases, the lambda layer has Θ(bknm) space complexity, similar to self-attention.

B PRACTICAL MODELING RECOMMENDATIONS

I want to make it faster on TPUs/GPUs... Hybrid models reach a better speed-accuracy tradeoff. Global contexts can be computationally wasteful, especially in the early high resolution layers where features lack semantic information, and can be replaced by lambda convolutions with smaller scopes (e.g. |m|=5x5 or 7x7) or the standard 3x3 convolution. Additionally, using a hybrid can require less tuning when starting from a working model/training setup. I want to make to minimize FLOPS (e.g. embedded applications)... Consider a hybrid with inverted bottlenecks, as done in Section E.4.2. To further reduce FLOPS, prefer lambda convolutions with smaller scopes (e.g. |m|=5x5 or 7x7). I encounter memory issues... Memory footprint can be reduced by sharing position embeddings across layers (especially layers with the highest resolution). Using the lambda convolution is more memory efficient. Reducing the query depth |k| or increasing the number of heads |h| also decreases memory consumption. I'm experiencing instability... We found it important to initialize the γ parameter in the last batchnorm layer of the ResNet's bottleneck blocks to 0 (this is the default in most codebases). Normalizing the keys (i.e. with the softmax) along the context's length is important. Early experiments which employed 2 lambda layers sequentially in the same residual block were unstable, suggesting that using 2 lambda layers in sequence should be avoided. Which implementation of the lambda convolution should I use? In our experiments using Tensorflow 1.x on TPUv3 hardware, we found both the n-d depthwise and (n+1)-d convolution implementations to have similar speed. We point out that this can vary across software/hardware stacks. 

C.2 GENERATING LAMBDAS FROM MASKED CONTEXTS

In some applications, such as denoising tasks or auto-regressive training, it is necessary to restrict interactions to a sub-context C n ⊂ C when generating λ n for query position n. For example, parallel auto-regressive training requires masking the future to ensure that the output y n only depends on past context positions m < n. Self-attention achieves this by zeroing out the irrelevant attention weights a nm = 0 ∀m / ∈ C n , thus guaranteeing that y n = m a nm v m only depends on C n . Similarly, one can block interactions between queries and masked context positions when generating lambdas by applying a mask before summing the contributions of context positions. As long as the mask is shared across all elements in the batch, computing masked lambdas does not require materializing per-example attention maps and the complexities are the same as for global context case. See Figure 6 for an implementation.

C.3 MULTI-HEAD VS MULTI-QUERY LAMBDA LAYERS

In this section, we motivate using a multi-query formulation as opposed to the usual multi-head formulation used in self-attention. Figure 7 presents the implementation of a multi-head lambda layer. Table 6 compares complexities for multi-head and multi-query lambda layers. Using a multi-query formulation reduces computations by a factor of |h| (the number of queries per lambda) compared to the multi-head formulation. We also found in early experimentation that multi-query lambdas yield a better speed-accuracy trade-off. Additionally, the multi-head lambda layer does not enjoy a simple local implementation as the lambda convolution. b, m, v] . mask: a tensor of 0 and 1s with shape [n, m]. """ # We show the general case but a cumulative sum may be faster for masking the future. # Note that each query now also has its own content lambda since every query # interacts with a different context. # Keys should be normalized by only considering the elements in their contexts. content mu = einsum(normalized keys, values, 'bmk,bmv->bmkv') content lambdas = einsum(content mu, mask, 'bmkv,nm->bnkv') embeddings = einsum(embeddings, mask, 'knm,nm->knm') # apply mask to embeddings position lambdas = einsum (embeddings, values, 'knm,bmv->bnkv') content output = einsum(queries, content lambda, 'bhnk,bnkv->bnhv') position output = einsum(queries, position lambdas, 'bhnk,bnkv->bnhv') output = reshape(content output + position output, [b, n, d]) return output Table 6 : Complexity comparison between a multi-head and a multi-query lambda layer. Using a multi-query formulation reduces complexity by a factor |h| (the number of queries per lambda) compared to the standard multi-head formulation.

C.4 ADDING EXPRESSIVITY WITH AN EXTRA DIMENSION

We briefly experiment with a variant that enables increasing the cost of computing the lambdas while keeping the cost of applying them constant. This is achieved by introducing an additional dimension, termed the intra-depth with corresponding hyperparameter |u|, in keys, position embeddings and values. Each key (or positional embedding) is now a |k| × |u| matrix instead of a |k|-dimensional vector. Similarly, each value is now a |v| × |u| matrix instead of a |v|-dimensional vector. The lambdas are obtained via summing over context positions and the intra-depth position |u| and have |k| × |v| shape similar to the default case. See Figure 8 for an implementation and Table 7 for the complexities. Experiments (see Appendix E.1) demonstrate that this variant results in accuracy improvements but we find that using |u|=1 (i.e. the default case) is optimal when controlling for speed on modern machine learning accelerators. Table 7 : Complexity for a multi-query lambda layer with intra-depth |u|.

D ADDITIONAL RELATED WORK

In this section, we review the attention operation and related works on improving its scalability. We discuss connections between lambda layers and channel, spatial or linear attention mechanisms and show how they can be cast as less flexible specific instances of lambda layers. We conclude with a brief review of self-attention in the visual domain and discuss connections with expert models.

D.1 SOFTMAX ATTENTION

Softmax attention Softmax-attention produces a distribution over the context for each query q n as a n = softmax(Kq n ) ∈ R |m| where the keys K are obtained from the context C. The attention distribution a n is then used to form a linear combination of values V obtained from the context as y n = V T a n = m a nm v m ∈ R |v| . As we take a weighted sum of the valuesfoot_11 , we transform the query q n into the output y n and discard its attention distribution a n . This operation captures content-based interactions, but not position-based interactions.

Relative attention

In order to model position-based interactions, relative attention (Shaw et al., 2018) introduces a learned matrix of |m| positional embeddings E n ∈ R |m|×|k| and computes the attention distribution as a n = softmax((K + E n )q n ) ∈ R |m| . The attention distribution now also depends on the query position n relative to positions of context elements m. Relative attention therefore captures both content-based and position-based interactions. 2) Lambda layers do not necessarily attempt to approximate an attention kernel. While approximations of the attention kernel are theoretically motivated, we argue that they may be unnecessarily restrictive. For example, the kernel approximation in Equation 3 requires the same feature function φ on both Q and K and precludes the use of more flexible non-linearities and normalization schemes. In contrast, lambda layers do not attempt to approximate an attention kernel. This simplifies their design and allows for more flexible non-linearity and normalization schemes, which we find useful in our ablations (See Table 11 in Appendix E.1). Considering the position embeddings independently of the keys notably enables a simple and efficient local implementation with the lambda convolution. Approximating the relative attention kernel would require normalizing the position embeddings with the keys (i.e., φ(K + E n ) instead of φ(K) + E n ), which cannot be implemented in the local context case with a convolution. 3) The lambda abstraction reveals the computational benefits of the multi-query formulation. Finally, this work proposes to abstract the KT V and E T n V matrices as linear functions (the content and position lambdas) that are directly applied to the queries. The lambda abstraction reveals the benefits of multi-query formulation (as opposed to the traditional multi-head attention formulation) as a means to reduce computational costs.

D.4 CASTING CHANNEL AND SPATIAL ATTENTION AS LAMBDA LAYERS.

We show that the lambda abstraction generalizes channel and spatial attention mechanisms, both of which can be viewed as specific instances of lambda layers. This observation is consistent with our experiments which demonstrate that lambda layers outperform both channel and spatial attention while being more computationally efficient. Channel attention Channel attention mechanisms, such as Squeeze-and-Excitation (SE) (Hu et al., 2018c; b) and FiLM layers (Perez et al., 2017) , recalibrate features via cross-channel interactions by aggregating signals from the entire feature map. In particular, the SE operation can be written as y nk = w k q nk where w k is the excitation weight for channel k in the query q n . This can be viewed as using a diagonal lambda which is shared across query positions λ n = diag(w 1 • • • w |k| ). Channel attention mechanisms have proven useful to complement convolutions but cannot be used as a stand-alone layer as they discard spatial information.

Spatial attention

Conversely, spatial attention mechanisms, reweigh each position based on signals aggregated from all channels (Xu et al., 2015; Park et al., 2018; Woo et al., 2018) . These mechanisms can be written as y nk = w n q nk where w n is the attention weight for position n in the input query Q. This can be viewed as using (position-dependent) scalar lambdas λ n = w n I where I is the identity matrix. Spatial attention has also proven helpful to complement convolutions but cannot be used as a stand-alone layer as it discards channel information.

D.5 SELF-ATTENTION IN THE VISUAL DOMAIN

Self-attention has been used in a myriad of tasks in the visual domain. These include image classification (Bello et al., 2019; Ramachandran et al., 2019; Cordonnier et al., 2019; Zhao et al., 2020; Wu et al., 2020; Dosovitskiy et al., 2020) ; object detection and object-centric tasks (Wang et al., 2018; Hu et al., 2018a; Carion et al., 2020; Locatello et al., 2020) ; video tasks (Sun et al., 2019; Liao et al., 2019) ; autoregressive/adversarial generative modeling (Parmar et al., 2018; Zhang et al., 2019; Brock et al., 2019; Chen et al., 2020a) and multi-modal text-vision tasks (Chen et al., 2020b; Lu et al., 2019; Li et al., 2019; Radford et al., 2021) The first use of self-attention in vision dates back to the non-local block (Wang et al., 2018) , which added a single-head global self-attention residual in the low resolution stages of a ConvNet for longrange dependency modeling. The non-local block has proven useful to complement convolutions but cannot be used as a stand-alone layer as it does not model position-based interactions. Global relative attention replaces convolutions at low resolution. Bello et al. (2019) introduced a 2d relative attention mechanism that proved competitive as a replacement to convolutions but gives even stronger results when used to concatenate convolutional features with self-attention features. The spatial convolutions in the bottleneck block of the ResNet architecture were replaced with a global multi-head self-attention mechanism with 2d relative position embeddings. Due to the large memory constraints of global attention, this operation was restricted to low resolution feature maps and the proposed architecture was a conv-transformer hybrid. A similar hybrid design has recently been revisited by Srinivas et al. (2021) using modern training and scaling techniques. Srinivas et al. (2021) , rather than concatenating convolutional feature maps, propose to use a stride of 1 in the last stage of the ResNet architecture for improved performance. Local/axial relative attention replaces convolutions at high resolution. The large memory footprint of global attention was quickly solved by multiple works which proposed to limit the size of the attention contexts such as local attention (Ramachandran et al., 2019; Hu et al., 2019) and axial attention (Ho et al., 2019; Wang et al., 2020a; Shen et al., 2020) (See Section D.2). Such approaches enable using attention at higher resolution and facilitate fully-attentional models but can be slow due to the use of specialized attention patterns. Scaling trumps inductive bias Concurrently to this work, ViT (Dosovitskiy et al., 2020) propose to simply apply attention on pixel patches (as opposed to individual pixels) as a remedy to large memory requirements. While patch-based attention does not maintain accurate positional information or translation equivariance, the loss of inductive bias is recovered by pre-training on large-scale datasets (e.g. 300M images). Most remarkably, ViT achieves close to state-of-the-art accuracy when fine-tuned on the ImageNet dataset, while requiring less training compute that convolutional alternatives (Kolesnikov et al., 2020; Xie et al., 2020) . This result has reinvigorated interest in using self-attention in the visual domain with multiple follow-up works already building upon this approach (Touvron et al., 2021) 13 . In spite of the impressive image classification results, concerns remain as to whether the patch-based approach can scale to larger images and transfer to tasks that require precise localization such as detection. We stress that reducing memory by working with pixel patches is orthogonal to the specific operation used and that lambda layers (or linear attention) can also be used on pixel patches.

D.6 HYPERNETWORKS, EXPERT MODELS AND CONTEXT-DEPENDENT WEIGHTS

LambdaNetworks generate their own computations, i.e. lambdas such that y n = λ T n q n . As such, they can alternatively be viewed as an extension of HyperNetworks (Ha et al., 2016) that dynamically generate their computations based on structured contextual information. The concept of generating context-dependent weights is also related to fast weights (Ba et al., 2016) . Lastly, LambdaNetworks share some connections with sparsely-activated expert models (Shazeer et al., 2017; Fedus et al., 2021) . Whereas sparsely-activated expert models select the computation (i.e. the lambda) from a bank of weights based on the input query, LambdaNetworks generate their computations based on contextual information.

E ADDITIONAL EXPERIMENTS E.1 ABLATION STUDY

We perform several ablations and validate the importance of positional interactions, long-range interactions and flexible normalization schemes. Unless specified otherwise, all experimental results in this section report ImageNet accuracies obtained by training a LambdaNetwork architecture that replaces the spatial convolutions in the ResNet-50 with lambda layers. Varying query depth, number of heads and intra-depth. Table 10 : Impact of varying the scope size for positional lambdas on the ImageNet classification task. We replace the 3x3 spatial convolutions in the last 2 stages of a ResNet-50 with lambda layers (input image size is 224x224). Flops significantly increase with the scope size, however we stress that larger scopes do not translate to slower latencies when using the einsum implementation (see Figure 3 ). Normalization Table 11 ablates normalization operations in the design of the lambda layer. We find that normalizing the keys is crucial for performance and that other normalization functions besides the softmax can be considered. Applying batch normalization to the queries and values is also helpful. Table 11 : Impact of normalization schemes in the lambda layer. Normalization of the keys along the context spatial dimension m, normalization of the queries along the query depth k.

E.2 HYBRID MODELS STUDY

In this section, we study hybrid designs that use standard convolutions to capture local contexts and lambda layers to capture global contexts. 15Where are lambda layers most useful? Table 12 presents the throughputs and accuracies of hybrid LambdaNetwork architectures as a function of the location of convolutions and lambda layers in a ResNet-50 architecture. We observe that lambda layers are most helpful in the last two stages (commonly referred to as c4 and c5) when considering their speed-accuracy tradeoff. We refer to architectures that replaces 3x3 convolutions in the last 2 stages of the ResNet with lambda layers as LambdaResNet-C4. Further pushing the speed-accuracy Pareto frontier. In Table 13 , we further study how throughput and accuracy are impacted by the number of lambda layers in the c4 stage. Our results reveal that most benefits from lambda layers can be obtained by (a) replacing a few 3x3 convolutions with lambda layers in the c4 stage and (b) replacing all 3x3 convolutions in c5. The resulting hybrid LambdaResNets architectures have increased representational power at a virtually negligible decrease in throughput compared to their vanilla ResNet counterparts. Comparing hybrid lambda vs attention models. The memory savings of lambda layers compared to attention are less significant in the aforementioned hybrid design, since the operations occur at lower resolution. Therefore, it is natural to ask whether lambda layers still have benefits over self-attention when considering hybrid designs. We consider our largest hybrid as an example (see Table 19 ). LambdaResNet-420 is trained on 320x320 inputs, employs 8 lambda layers in c4 and can fit 32 examples per TPU-v3 core. This adds up to a cost of 38.4MB for lambda layers (4.8MB if sharing positional embeddings), whereas using attention layers instead would incur 0.625GB. The increase might not be significant in practice and it will be interesting to carefully benchmark the hybrid attention variants 16 . We point that experiments from Table 4 suggest that the benefits of lambda layers go beyond improved scalability and stress that the memory savings are more pronounced for tasks that require larger inputs such as object detection.

E.3 OBJECT DETECTION AND INSTANCE SEGMENTATION RESULTS

In Table 14 , we evaluate LambdaResNets as a backbone in Mask-RCNN (He et al., 2017) on the COCO object detection and instance segmentation tasks. Using lambda layers yields consistent gains across all object sizes, especially the small objects which are the hardest to locate. This indicates that lambda layers are also competitive for more complex visual tasks that require localization information. 49.4 (+0.5) 30.0 (+0.1) / 52.3 / 66.7 43.5 (+0.3) 24.6 (+0.4) / 46.8 / 61.8 LambdaResNet-152 50.0 (+1.1) 31.8 (+1.9) / 53.4 / 67.0 43.9 (+0.7) 25.5 (+1.3) / 47.3 / 62.0 Table 14: COCO object detection and instance segmentation with Mask-RCNN architecture on 1024x1024 inputs. We compare LambdaResNets against ResNets with or without squeeze-andexcitation (SE) and report Mean Average Precision (AP) for small, medium, large objects (AP s/m/l ). Using lambda layers yields consistent gains across all object sizes, especially small objects.

E.4.1 COMPUTATIONAL EFFICIENCY COMPARISONS TO LARGE EFFICIENTNETS

In Table 15 and Table 16 , we showcase the parameter and flops-efficiency of LambdaNetworks. We find that LambdaResNet-C4 which replaces the 3x3 convolutions in the last 2 stages of the ResNet architecture, where they incur the highest parameter costs, improves upon parameter and flops efficiency of large EfficientNets. These results are significant because EfficientNets were specifically designed by neural architecture search (Zoph & Le, 2017) to minimize computational costs using highly computationally efficient depthwise convolutions (Tan & Le, 2019) . Lastly, we briefly study lambda layers in a resource-constrained scenario using the MobileNetv2 architecture (Sandler et al., 2018) . MobileNets (Howard et al., 2017; Sandler et al., 2018; Howard et al., 2019) employ lightweight inverted bottleneck blocks which consist of the following sequence: 1) a pointwise convolution for expanding the number of channels, 2) a depthwise convolution for spatial mixing and 3) a final pointwise convolution for channel mixing. The use of a depthwise convolution (as opposed to a regular convolution) reduces parameters and flops, making inverted bottlenecks particularly well-suited for embedded applications. Lightweight lambda block. We construct a lightweight lambda block as follows. We replace the depthwise convolution in the inverted bottleneck with a lambda convolution with small scope size |m|=5x5, query depth |k|=32, number of heads |h|=4. We also change the first pointwise convolution to output the same number of channels (instead of increasing the number of channels) to further reduce computations. Adding lambda layers in MobileNetv2. We wish to assess whether lambda layers can improve the flops-accuracy (or parameter-accuracy) tradeoff of mobilenet architectures. We experiment with a simple strategy of replacing a few inverted bottlenecks with our proposed lightweight lambda block, so that the resulting architectures have similar computational demands as their baselines. A simple procedure of replacing the 10-th and 16-th inverted bottleneck blocks with lightweight lambda blocks in the MobileNet-v2 architecture reduces parameters and flops by ∼10% while improving ImageNet accuracy by 0.6%. This suggest that lambda layers may be well suited for use in resource constrained scenarios such as embedded vision applications (Howard et al., 2017; Sandler et al., 2018; Howard et al., 2019) . 



An updated version of this paper can be found on arXiv. For example, applying a single multi-head attention layer to a batch of 128 64x64 input images with 8 heads requires 64GB of memory, which is prohibitive in practice. This approach is often used in natural language processing tasks(Vaswani et al., 2017) but has had limited success in the visual domain where relative position information between pixels is crucial(Bello et al., 2019). This mechanism is reminiscent of functional programming and λ-calculus which motivates the lambda terminology. Attention maps typically need to be stored for back-propagation(Kitaev et al., 2020). The einsum operation denotes general contractions between tensors of arbitrary dimensions. It is numerically equivalent to broadcasting its inputs to share the union of their dimensions, multiplying element-wise and summing across all dimensions not specified in the output. (Shazeer, 2019) proposes a multi-query formulation to speed-up attention-based decoding. We refer the reader to the code for more details. FLOPs (time complexity) is not necessarily a good proxy for latency on TPUs/GPUs. Eventhough the lambda convolution has linear time/space complexities, it can be slower than than the global lambda layer in practice, especially for large convolution scope sizes. See Table4for an example. Latencies for local self-attention were provided privately byRamachandran et al. (2019) based on an implementation that relies on query blocks and overlapping memory blocks(Parmar et al., 2018). Specialized attention kernels may greatly speed up local self-attention, making it a promising avenue for future research. Note that the content-based lambda still captures global interactions. Sometimes the attention operation is instead used to point to specific context elements(Vinyals et al., 2015;Bello et al., 2016), which is not supported by lambda layers. Most follow-up works advertise improvements over ViT on smaller datasets which is not the intended purpose of ViT. This observation is challenged by concurrent work(Dosovitskiy et al., 2020) which demonstrates that content-based interactions can be sufficient for image classification when pre-training on large scale datasets (e.g. 300M images). We could alternatively use the lambda convolution to capture local contexts. We will benchmark such architectures in a future version of this draft.



Figure 1: Comparison between self-attention and lambda layers. (Left) An example of 3 queries and their local contexts within a global context. (Middle) Self-attention associates each query with an attention distribution over its context. (Right) The lambda layer transforms each context into a linear function lambda that is applied to the corresponding query.

Figure4: Speed-accuracy comparison between LambdaResNets and EfficientNets with matching training and regularization setups. LambdaResNets (annotated with (depth, image size)) are 3.2-4.4x faster than EfficientNets and 1.6-2.3x faster than ResNet-RS with squeeze-and-excitation, thus significantly improving the speed-accuracy Pareto curve of image classification. LambdaResNet-420 (image size 320), reaches a strong 84.9% top-1 accuracy, 0.9% over the corresponding architecture with standard 3x3 convolutions and 0.65% over the corresponding architecture with squeezeand-excitation.

queries, keys, position embeddings and values . . . . . . . . . . . . . . 3 2.2 Attention vs lambda layers. . . . . . . . . . . . . . . . . . . . . . . . . . . . layer: transforming contexts into linear functions. . . . . . . . . . . . . . 4 3.2 A multi-query formulation to reduce complexity. . . . . . . . . . . . . . . . . . . 5 3.3 Making lambda layers translation equivariant. . . . . . . . . . . . . . . . . . . . . 6 3.4 Lambda convolution: local contexts on the grid. . . . . . . . . . . . . . . . . layers outperform convolutions and attention layers . . . . . . . . . . . . 7 5.2 Computational benefits of lambda layers over self-attention . . . . . . . . . . . . 7 5.3 Hybrids improve the speed-accuracy tradeoff of image classification . . . . . . . discussion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 16 A.2 Extending lambda layers to other modalities . . . . . . . . . . . . . . . . . . . . . 16 code with lambda convolution . . . . . . . . . . . . . . . . . . . . . . . 18 C.2 Generating lambdas from masked contexts . . . . . . . . . . . . . . . . . . . . . . 18 C.3 Multi-head vs multi-query lambda layers . . . . . . . . . . . . . . . . . . . . . . . 18 C.4 Adding expressivity with an extra dimension . . . . . . . . . . . . . . . . . . . . . 19 D Additional Related Work 20 D.1 Softmax attention . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 20 D.2 Sparse attention . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 21 D.3 Linear attention: connections and differences . . . . . . . . . . . . . . . . . . . . 21 D.4 Casting channel and spatial attention as lambda layers. . . . . . . . . . . . . . . . 22 D.5 Self-Attention in the visual domain . . . . . . . . . . . . . . . . . . . . . . . . . 22 D.6 HyperNetworks, expert models and context-dependent weights . . . . . . . . . . . 23 E Additional Experiments 24 E.1 Ablation study . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 24 E.2 Hybrid models study . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 25 E.3 Object detection and instance segmentation results . . . . . . . . . . . . . . . . . 26 E.4 Parameter and FLOPs efficiency results . . . . . . . . . . . . . . . . . . . . . . . 27 F Experimental Details 29 F.1 Detailed LambdaResNets results . . . . . . . . . . . . . . . . . . . . . . . . . . . 29 F.2 Architectural details . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 29 F.3 Training details . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .

Figure 5: Pseudo-code for the multi-query lambda layer and the 1d lambda convolution. A n-d lambda convolution can equivalently be implemented via a regular (n+1)-d convolution or a nd depthwise convolution with channel multiplier. The embeddings can be made to satisfy various conditions (e.g. translation equivariance and masking) when computing positional lambdas with the einsum implementation.

Figure 6: Pseudo-code for masked multi-query lambda layer.

Figure7: Pseudo-code for the multi-head lambda layer. This is only shown as an example as we recommend the multi-query lambda laayer instead.

Figure8: Pseudo-code for the multi-query lambda layer with intra-depth |u|. Lambdas are obtained by reducing over the context positions and the intra-depth dimension. This variant allocates more computation for generating the lambdas while keeping the cost of applying them constant. The equivalent n-d lambda convolution can be implemented with a regular (n+1)-d convolution.

Recall that the lambda layer maps inputs x n ∈ R d to outputs y n ∈ R d . As presented in Equation2, this implies that |v|=d. Small values of |v| may therefore act as a bottleneck on the feature vector y n but larger output dimensions |v| can incur an excessively large computational cost given our Θ(bnmkv) and Θ(knm + bnkv) time and space complexities.

Alternatives for capturing long-range interactions.

Table 4, we compare lambda layers against self-attention and present throughputs, memory complexities and ImageNet accuracies. Our results highlight the weaknesses of self-attention: selfattention cannot model global interactions due to large memory costs, axial self-attention is still

Comparison of the lambda layer and attention mechanisms on ImageNet classification with a ResNet50 architecture. The lambda layer strongly outperforms attention alternatives at a fraction of the parameter cost. All models are trained in mostly similar setups (see Appendix F.3) and we include the reported improvements compared to the convolution baseline in parentheses. See Appendix C.4 for a description of the |u| hyperparameter.

Table5compares the throughputs and ImageNet accuracies of a representative set of models with similar accuracies when trained using the JFT dataset. LambdaResNet-152, trained and evaluated at image size 288, achieves a strong 86.7% top-1 ImageNet accuracy while being more parameter-efficient and 9.5x faster than the EfficientNet-NoisyStudent checkpoint with the same accuracy. Comparison of models trained on extra data. ViT-L/16 is pre-trained on JFT and finetuned on ImageNet at resolution 384x384, while EfficientNet and LambdaResNet are co-trained on ImageNet and JFT pseudo-labels. Training and inference throughput is shown for 8 TPUv3 cores.

Table8presents the impact of the query depth |k|, number of heads |h| and intra depth |u| on performance (See Appendix C.4 for a presentation of the intra-depth |u|). Our experiments indicate that the lambda layer outperforms convolutional and attentional baselines for a wide range of hyperparameters, demonstrating the robustness of the method. Ablations on the ImageNet classification task when using the lambda layer in a ResNet50 architecture. All configurations outpeform the convolutional baseline at a lower parameter cost. As expected, we get additional improvements by increasing the query depth |k| or intra-depth |u|. The number of heads is best set to intermediate values such as |h|=4. A large number of heads |h| excessively decreases the value depth |v| = d/|h|, while a small number of heads translates to too few queries, both of which hurt performance.Content vs position interactions Table9presents the relative importance of content-based and position-based interactions on the ImageNet classification task. We find that position-based interactions are crucial to reach high accuracies, while content-based interactions only bring marginal improvements over position-based interactions 14 .

Contributions of content and positional interactions. As expected, positional interactions are crucial to perform well on the image classification task.Importance of scope size The small memory footprint of LambdaNetworks enables considering global contexts, even at relatively high resolution. Table10presents flops counts and top-1 ImageNet accuracies when varying scope sizes in a LambdaNetwork architecture. We find benefits from using larger scopes, with a plateau around |m|=15x15, which validates the importance of longer range interactions compared to the usual 3x3 spatial convolutions used in the ResNet architecture. In our main experiments, we choose |m|=23x23 as the default to account for experiments that use larger image sizes.

Table19presents the detailed block configurations and placement of lambda layers for our family of LambdaResNets. Hybrid models achieve a better speed-accuracy trade-off. Inference throughput and top-1 accuracy as a function of lambda (L) vs convolution (C) layers' placement in a ResNet50 architecture on 224x224 inputs. Lambda layers in the c5 stage incur almost no speed decrease compared to standard 3x3 convolutions. Lambda layers in the c4 stage are relatively slower than standard 3x3 convolutions but yield significant accuracy gains.

Impact of number of lambda layers in the c4 stage of LambdaResNets. Most benefits from lambda layers can be obtained by having a few lambda layers in the c4 stage. Such hybrid designs maximize the speed-accuracy tradeoff. LambdaResNet-C4 architectures exclusively employ lambda layers in c4 and c5. LambdaResNet block configurations can be found in Table19. Models are trained for 350 epochs on the ImageNet classification task.

+0.3) 29.9 (+0.0) / 51.5 / 65.3 42.8 (+0.2) 24.0 (-0.2) / 46.0 / 60.2 LambdaResNet-101 49.4 (+1.2) 31.7 (+1.8) / 52.2 / 65.6 43.5 (+0.9) 25.9 (+1.7) / 46.5 / 60.8

Parameter-efficiency comparison between LambdaResNet-C4 and EfficientNet-B6. LambdaResNet-C4 is more parameter-efficient in spite of using a smaller image size. Increasing the image size would likely result in improved accuracy while keeping the number of parameters fixed. Models are trained for 350 epochs.

Flops-efficiency comparison between LambdaResNet-C4 and EfficientNet-B6. We use smaller local scopes (|m|=7x7) to reduce FLOPS in the lambda layers. Models are trained for 350 epochs.E.4.2 LAMBDA LAYERS IN A RESOURCE CONSTRAINED SCENARIO

Lambda layers improve ImageNet accuracy in a resource-constrained scenario. Replacing the 10-th and 16-th inverted bottleneck blocks with lightweight lambda blocks in the MobileNet-v2 architecture reduces parameters and flops by ∼10% while improving ImageNet accuracy by 0.6%.

ACKNOWLEDGMENTS

The author would like to thank Barret Zoph and William Fedus for endless discussions, fruitful suggestions and careful revisions; Jonathon Shlens, Mike Mozer, Prajit Ramachandran, Ashish Vaswani, Quoc Le, Neil Housby, Jakob Uszkoreit, Margaret Li, Krzysztof Choromanski for many insightful comments; Hedvig Rausing for the antarctic infographics; Zolan Brinnes for the OST; Andrew Brock, Sheng Li for assistance with profiling EfficientNets; Adam Kraft, Thang Luong and Hieu Pham for assistance with the semi-supervised experiments and the Google Brain team for useful discussions on the paper.

annex

Published as a conference paper at ICLR 2021 D.2 SPARSE ATTENTION A significant challenge in applying (relative) attention to large inputs comes from the quadratic Θ(|bnm|) memory footprint required to store attention maps. Many recent works therefore propose to impose specific patterns to the attention maps as a means to reduce the context size |m| and consequently the memory footprint of the attention operation. These approaches include:• local attention patterns (Dai et al., 2019; Parmar et al., 2018; Ramachandran et al., 2019) • axial attention patterns (Ho et al., 2019; Wang et al., 2020a; Shen et al., 2020) • static sparse attention patterns (Child et al.; Beltagy et al., 2020) • dynamic sparse attention patterns (Kitaev et al., 2020) See Tay et al. (2020) for a review. Their implementations can be rather complex, sometimes require low-level kernel implementations to get computational benefits or may rely on specific assumptions on the shape of the inputs (e.g., axial attention). In contrast, lambda layers are simple to implement for both global and local contexts using simple einsum and convolution primitives and capture dense content and position-based interactions with no assumptions on the input shape.

D.3 LINEAR ATTENTION: CONNECTIONS AND DIFFERENCES

Another approach to reduce computational requirements of attention mechanisms consists in approximating the attention operation in linear space and time complexity, which is referred to as linear (or efficient) attention. Linear attention mechanisms date back to de Brébisson & Vincent (2016); Britz et al. (2017) and were later introduced in the visual domain by Chen et al. (2018) ; Shen et al. (2018) . They are recently enjoying a resurgence of popularity with many works modifying the popular Transformer architecture for sequential processing applications (Katharopoulos et al., 2020; Wang et al., 2020b; Choromanski et al., 2020; Xiong et al., 2021) .Linear attention via kernel factorization Linear attention is typically obtained by reinterpreting attention as a similarity kernel and leveraging a low-rank kernel factorization asfor some feature function φ. Computing φ(K T )V ∈ R |k|×|v| first bypasses the need to materialize the attention maps φ(Q)φ(K T ) and the operation therefore has linear complexity with respect to the input length |n|.Multiple choices for the feature function φ have been proposed. For example, Katharopoulos et al. (2020) use φ(x) = elu(x) + 1, while Choromanski et al. (2020) use positive orthogonal random features to approximate the original softmax attention kernel. In the visual domain, both Chen et al. (2018) and Shen et al. (2018) use φ(x) = softmax(x). This choice is made to guarantee that the rows of the (non-materialized) attention maps φ(Q)φ(K) T sum to 1 as is the case in the regular attention operation.We discuss the main differences between lambda layers and linear attention mechanisms.1) Lambda layers extend linear attention to also consider position-based interactions. The kernel approximation from Equation 3 can be rewritten for a single query q n aswhich resembles the output of the content lambda y c n = (λ c ) T q n = ( KT V ) T q n from Equation 1. Lambda layers extend linear attention mechanisms to also consider position-based interactions asIn the above equation, computing the position (or content) lambda has Θ(bmkv) time complexity.As the position lambdas are not shared across query positions n, this cost is repeated for all |n| queries, leading to a total time complexity Θ(bnmkv 3 . Local positional lambdas can be implemented interchangeably with the lambda convolution or by using the global einsum implementation and masking the position embeddings outside of the local contexts (Figure 5 ). The latter can be faster but has higher FLOPS and memory footprint due to the Θ(knm) term (see Table 2 ). In our experiments, we use the convolution implementation only for input length |n| > 85 2 or intra-depth |u| > 1. When the intra-depth is increased to |u| >1, we switch to the convolution implementation and reduce the scope size to |m|=7x7 to reduce flops.Positional embeddings are initialized at random using the unit normal distribution N (0, 1). We use fan-in initialization for the linear projections in the lambda layer. The projections to compute K and V are initialized at random with the N (0, |d| -1/2 ) distribution. The projection to compute Q is initialized at random with the N (0, |kd| -1/2 ) distribution (this is similar to the scaled dotproduct attention mechanism, except that the scaling is absorbed in the projection). We apply batch normalization on Q and V and the keys K are normalized via a softmax operation.ResNets. We use the ResNet-v1 implementation and initialize the γ parameter in the last batch normalization (Ioffe & Szegedy, 2015) layer of the bottleneck blocks to 0. Squeeze-and-Excitation layers employ a squeeze ratio of 4. Similarly to ResNet-RS (Bello et al., 2021) , we use the ResNet-D (He et al., 2018) and additionally replace the max pooling layer in the stem by a strided 3x3 convolution. Our block allocation and scaling strategy (i.e. selected resolution as a function of model depth) also follow closely the scaling recommendations from ResNet-RS (Bello et al., 2021) .LambdaResNets. We construct our LambdaResNets by replacing the spatial 3x3 convolutions in the bottleneck blocks of the ResNet-RS architectures by our proposed lambda layer, with the exception of the stem which is left unchanged. We apply 3x3 average-pooling with stride 2 after the lambda layers to downsample in place of the strided convolution. Lambda layers are uniformly spaced in the c4 stage and all bottlenecks in c5 use lambda layers. Table 19 presents the exact block configuration and the location of the lambda layers for our hybrid LambdaResNets. We do not use squeeze-and-excitation in the bottleneck blocks that employ a lambda layer instead of the standard 3x3 convolution. Supervised ImageNet 90 epochs training setup with vanilla ResNet. In the 90 epoch setup, we use the vanilla ResNet for fair comparison with prior works. We used the default hyperparameters as found in official implementations without doing additional tuning. All networks are trained endto-end for 90 epochs via backpropagation using SGD with momentum 0.9. The batch size B is 4096 distributed across 32 TPUv3 cores (Jouppi et al., 2017) and the weight decay is set to 1e-4. The learning rate is scaled linearly from 0 to 0.1B/256 for 5 epochs and then decayed using the cosine schedule (Loshchilov & Hutter, 2017) . We use batch normalization with decay 0.9999 and exponential moving average with weight 0.9999 over trainable parameters and a label smoothing of 0.1. The input image size is set to 224x224. We use standard training data augmentation (random crops and horizontal flip with 50% probability).Most works compared against in Table 3 use a similar training setup and also replace the 3x3 spatial convolutions in the ResNet architecture by their proposed methods. We note that Ramachandran et al. (2019) train for longer (130 epochs instead of 90) but do not use label smoothing which could confound our comparisons.Supervised ImageNet 350 epochs training setup. Higher accuracies on ImageNet are commonly obtained by training longer with increased augmentation and regularization (Lee et al., 2020; Tan & Le, 2019) . Similarly to Bello et al. (2021) , the weiht decay is reduced to 4e-5 and we employ RandAugment (Cubuk et al., 2019) with 2 layers, dropout (Srivastava et al., 2014) and stochastic depth (Huang et al., 2016) . See Table 20 for exact hyperparameters. All architectures are trained for 350 epochs with a batch size B of 4096 or 2048 distributed across 32 or 64 TPUv3 cores, depending on memory constraints.We tuned our models using a held-out validation set comprising ∼2% of the ImageNet training set (20 shards out of 1024). We perform early stopping on the held-out validation set for the largest models, starting with LambdaResNet-350 at resolution 288x288, and simply report the final accuracies for the smaller models.Semi-supervised learning with pseudo-labels. (instead of 1024), we normalize the reported latency by multiplying the measured latency by 2. Table 4 , Table 12 and Table 13 report inference throughput on 8 TPUv3 cores using full precision (i.e. float32 activations). Latency for ViT (Dosovitskiy et al., 2020) was privately communicated by the authors.FLOPS count. We do not count zeroed out flops when computing positional lambdas with the einsum implementation from Figure 3 . Flops count is highly dependent on the scope size which is rather large by default (|m|=23x23). In Table 10 , we show that it is possible to significantly reduce the scope size and therefore FLOPS at a minimal degradation in performance.COCO object detection. We employ the architecture from the improved ImageNet training setup as the backbone in the Mask-RCNN architecture. All models are trained on 1024x1024 images from scratch for 130k steps with a batch size of 256 distributed across 128 TPUv3 cores with synchronized batch normalization. We apply multi-scale jitter of [0.1, 2.0] during training. The learning rate is warmed up for 1000 steps from 0 to 0.32 and divided by 10 at steps 90, 95 and 97.5% of training. The weight decay is set to 4e-5.Mobilenet training setup. All mobilenet architectures are trained for 350 epochs on Imagenet with standard preprocessing at 224x224 resolution. We use the same hyperparameters as Howard et al. (2019) . More specifically, we use RMSProp with 0.9 momentum and a batch size of 4096 split across 32 TPUv3 cores. The learning rate is warmed up linearly to 0.1 and then multiplied by 0.99 every 3 epochs. We use a weight decay 1e-5 and dropout with drop probability of 0.2

