LEARNING ASSOCIATIVE INFERENCE USING FAST WEIGHT MEMORY

Abstract

Humans can quickly associate stimuli to solve problems in novel contexts. Our novel neural network model learns state representations of facts that can be composed to perform such associative inference. To this end, we augment the LSTM model with an associative memory, dubbed Fast Weight Memory (FWM). Through differentiable operations at every step of a given input sequence, the LSTM updates and maintains compositional associations stored in the rapidly changing FWM weights. Our model is trained end-to-end by gradient descent and yields excellent performance on compositional language reasoning problems, meta-reinforcement-learning for POMDPs, and small-scale word-level language modelling. 1

1. INTRODUCTION

Humans continually adapt in order to understand new situations in changing environments. One important adaptive ability is associative inference for composing features extracted from distinct experiences and relating them to each other (Schlichting & Preston, 2015; Gershman et al., 2015) . Suppose Alice has shared with you pictures of her toddler. Later, at the office party, you see a man carrying the depicted toddler. Since the toddler yields a shared feature in two different contexts, it may be plausible to infer that the man is Alice's partner, without ever seeing him and Alice together. The ability to rapidly associate and bind together novel stimuli can help to derive knowledge systematically, in addition to the knowledge gained directly from observation. Virtually all modern cognitive architectures applied to challenging artificial intelligence problems are based on deep artificial neural networks (NNs). Despite their empirical successes and theoretical generality, NNs tend to struggle to generalise in situations similar to the given example (Lake et al., 2017; Phillips, 1995; Lake & Baroni, 2017) . This weakness becomes even more severe if the training and test data exhibit systematic differences (Atzmon et al., 2016; Agrawal et al., 2017) . For example, during training, the man's representation might never be associated with the toddler's, but during testing, this association might be necessary to make a useful prediction. In problems where humans excel, this sort of inference is likely ubiquitous since data is often combinatorially complex in a way that observations used during training will likely cover just a small fraction of all possible compositions. Such a lack of productivity and systematicity is a long-standing argument against the use of NNs as a substrate of an artificial cognitive architecture (Fodor & Pylyshyn, 1988; Hadley, 1994; McLaughlin, 2009) . The hidden state of a neural model is a learned representation of the task-relevant information extracted from the input. To generalise to never-seen-before compositions of stimuli, the function which produces the state representation must be able to systematically construct all possible states. This requires a general and preferrably differentiable method, such as the Tensor Product Representation (TPR; Smolensky (1990) ). TPRs provide a general and differentiable method for embed-ding symbolic structures in vector spaces. A TPR state representation is constructed via the tensor product (i.e. the generalised outer-product) of learned component representations. Under certain constraints, such a mechanism guarantees a unique representation for every possible combination of components (Smolensky, 1990; 2012) . In this work, we augment a recurrent NN (RNN) with an additional TPR-like memory representation. To facilitate the learning of multi-step associative inference, the TPR memory can be queried multiple times in a row, allowing the model to chain together various independent associations. In contrast to previous work on fast weights, we apply our memory-augmented RNN to much longer sequences. This requires the model to update its associative memory. Furthermore, we demonstrate the generality of our method by applying it to meta-reinforcement learning and small scale language modelling problems. In the next section, we cover related memory-augmented NNs. Section 3 describes the FWM in detail. Section 4 demonstrates the generality of our method through experiments in the supervised, self-supervised, and meta-reinforcement learning setting. The supervised-learning experiments in subsection 4.1 consist of a more challenging version of the bAbI dataset dubbed concatenated-bAbI or catbAbI. The meta-reinforcement learning experiment in section 4.2 demonstrates the FWM's ability to learn to explore a partially observable environment through its ability to perform associative inference. Finally, the self-supervised experiments in subsection 4.3 demonstrate that the FWM can compete with the state-of-the-art word-level language models on small benchmark datasets.

2. RELATED WORK

RNNs such as the Long Short-Term Memory (LSTM; Hochreiter & Schmidhuber (1997); Gers et al. (2000) ) are in theory capable of implementing any algorithm (Siegelmann & Sontag, 1991) . However, the linear growth of the hidden state of a fully connected RNN leads to quadratic growth in the number of trainable weights. Early work addressed this issue through the use of additional memory (Das et al., 1992; Mozer & Das, 1993) and differentiable fast weights (Schmidhuber, 1992; 1993) . Recently, memory-augmented NNs have solved algorithmic toy problems (Graves et al., 2014; 2016) as well as reasoning and inference problems in synthetic and natural language (Weston et al., 2015b; Xiong et al., 2016) . Inspired by the random-access memory of computer architectures, a common approach is to incorporate a soft and differentiable lookup table into the NN model. Such slot-based memory matrices have shown to be difficult to train (Munkhdalai & Yu, 2017b) and require sophisticated mechanisms for the allocation and deallocation of memory (Csordas & Schmidhuber, 2019) . The Transformer-XL (TXL; Dai et al. ( 2019)), an autoregressive language model variant of the Transformer (Vaswani et al., 2017) , can be understood as a slot-based memory-augmented RNN where every new state is pushed into an immutable queue of finite size. Although it is recurrent, the layers of a transformer architecture are strictly forced to use inputs from a lower layer which limits its generality. Nevertheless, a sufficiently deep and well regularised TXL model has achieved state-of-the-art performance in large scale language modelling tasks. A biologically more plausible alternative of increasing the memory capacity of NNs are fastchanging weights, i.e. stateful weights that can adapt as a function of its input. Non-differentiable fast weights or "dynamic links" have been published since 1981 (von der Malsburg, 1981; Feldman, 1982; Hinton & Plaut, 1987) . Subsequent work showed that a regular network can be trained by gradient descent to control the fast weights of a separate network (Schmidhuber, 1992) or of itself (Schmidhuber, 1993) in an end-to-end differentiable fashion. Recently, fast weights have made a comeback and achieved good results in small toy problems where regular NNs fall short (Ba et al., 2016a; Schlag & Schmidhuber, 2017; Munkhdalai & Yu, 2017a; Pritzel et al., 2017; Ha et al., 2017; Zhang & Zhou, 2017; Miconi et al., 2018; 2019; Schlag & Schmidhuber, 2018; Munkhdalai et al., 2019; Bartunov et al., 2020) . Most memory-augmented NNs are based on content-based or key-based lookup mechanisms. An alternative to the storage of patterns in a lookup table is the idea that patterns are reconstructed through the implicit iterative minimisation of an energy function, such as in the classical Hopfield network (Steinbuch, 1961; Willshaw et al., 1969; Hopfield, 1982; Kanerva, 1988) or the modern Hopfield network (Krotov & Hopfield, 2016; Demircigil et al., 2017; Ramsauer et al., 2020) . This is



Source code and data used in this paper is available at github.com/ischlag/Fast-Weight-Memory-public

