ENERGY TRANSFORMER

Abstract

Transformers have become the de facto models of choice in machine learning, typically leading to impressive performance on many applications. At the same time, the architectural development in the transformer world is mostly driven by empirical findings, and the theoretical understanding of their architectural building blocks is rather limited. In contrast, Dense Associative Memory models or Modern Hopfield Networks have a well-established theoretical foundation, but have not yet demonstrated truly impressive practical results. We propose a transformer architecture that replaces the sequence of feedforward transformer blocks with a single large Associative Memory model. Our novel architecture, called Energy Transformer (or ET for short), has many of the familiar architectural primitives that are often used in the current generation of transformers. However, it is not identical to the existing architectures. The sequence of transformer layers in ET is purposely designed to minimize a specifically engineered energy function, which is responsible for representing the relationships between the tokens. As a consequence of this computational principle, the attention in ET is different from the conventional attention mechanism. In this work, we introduce the theoretical foundations of ET, explore it's empirical capabilities using the image completion task, and obtain strong quantitative results on the graph anomaly detection task.

1. INTRODUCTION

Transformers have become pervasive models in various domains of machine learning, including language, vision, and audio processing. Every transformer block uses four fundamental operations: attention, feed-forward multi-layer perceptron (MLP), residual connection, and layer normalization. Different variations of transformers result from combining these four operations in various ways. For instance, Press et al. (2019) propose to frontload additional attention operations and backload additional MLP layers in a sandwich-like instead of interleaved way, Lu et al. (2019) prepend an MLP layer before the attention in each transformer block, So et al. (2019) use neural architecture search methods to evolve even more sophisticated transformer blocks, and so on. Various methods exist to approximate the attention operation, multiple modifications of the norm operation, and connectivity of the block; see, for example, (Lin et al., 2021) for a taxonomy of different models. At present, however, the search for new transformer architectures is driven mostly by empirical evaluations, and the theoretical principles behind this growing list of architectural variations is missing. Additionally, the computational role of the four elements remains the subject of discussions. Originally, Vaswani et al. (2017) emphasized attention as the most important part of the transformer block, arguing that the learnable long-range dependencies are more powerful than the local inductive biases of convolutional networks. On the other hand more recent investigations (Yu et al., 2021) argue that the entire transformer block is important. The "correct" way to combine the four basic operations inside the block remains unclear, as does an understanding of the core computational function of the entire block and each of its four elements. In a seemingly unrelated line of work, Associative Memory models, also known as Hopfield Networks (Hopfield, 1982; 1984) , have been gaining popularity in the machine learning community thanks to theoretical advancements pertaining to their memory storage capacity and novel architectural modifications. Specifically, it has been shown that increasing the sharpness of the activation functions can lead to super-linear (Krotov & Hopfield, 2016) There are high-level conceptual similarities between transformers and Dense Associative Memories, since both architectures are designed for some form of denoising of the input. Transformers are typically pre-trained on a masked-token task, e.g., in the domain of Natural Language Processing (NLP) certain tokens in the sentence are masked and the model predicts the masked tokens. Dense Associative Memory models are designed for completing the incomplete patterns. For instance, a pattern can be the concatenation of an image and its label, and the model can be trained to predict part of the input (the label), which is masked, given the query (the image). They can also be trained in a self-supervised way by predicting the occluded parts of the image, or denoising the image. x t+1 = x t -α ∇ g E t There are also high-level differences between the two classes of models. Associative Memories are recurrent networks with a global energy function so that the network dynamics converges to a fixed point attractor state corresponding to a local minimum of the energy function. Transformers are typically not described as dynamical systems at all. Rather, they are thought of as feed-forward networks built of the four computational elements discussed above. Even if one thinks about them as dynamical systems with tied weights, e.g., (Bai et al., 2019) , there is no reason to expect that their dynamics converge to a fixed point attractor (see the discussion in (Lan et al., 2020) ). Additionally, a recent study (Yang et al., 2022) uses a form of Majorization-Minimization algorithms (Sun et al., 2016) to interpret the forward path in the transformer block as an optimization process. This interpretation requires imposing certain constraints on the operations inside the block, and attempting to find an energy function that describes the constrained block. We take a complementary



Figure 1: Overview of the Energy Transformer (ET). Instead of a sequence of conventional transformer blocks, a single recurrent ET block is used. The operation of this block is dictated by the global energy function. The token representations are updated according to a continuous time differential equation with the time-discretized update step α = dt/τ . On the image domain, images are split into non-overlapping patches that are linearly encoded into tokens with added learnable positional embeddings (POS). Some patches are randomly masked. These tokens are recurrently passed through ET, and each iteration reduces the energy of the set of tokens. The token representations at or near the fixed point are then decoded using the decoder network to obtain the reconstructed image. The network is trained by minimizing the mean squared error loss between the reconstructed image and the original image. On the graph domain, the same general pipeline is used. Each token represents a node, and each node has its own positional encoding. The token representations at or near the fixed point are used for the prediction of the anomaly status of each node.

and even exponential(Demircigil et al.,  2017)  memory storage capacity for these models, which is important for machine learning applications. This new class of Hopfield Networks is called Dense Associative Memories or Modern

