LEARNING ASSOCIATIVE INFERENCE USING FAST WEIGHT MEMORY

Abstract

Humans can quickly associate stimuli to solve problems in novel contexts. Our novel neural network model learns state representations of facts that can be composed to perform such associative inference. To this end, we augment the LSTM model with an associative memory, dubbed Fast Weight Memory (FWM). Through differentiable operations at every step of a given input sequence, the LSTM updates and maintains compositional associations stored in the rapidly changing FWM weights. Our model is trained end-to-end by gradient descent and yields excellent performance on compositional language reasoning problems, meta-reinforcement-learning for POMDPs, and small-scale word-level language modelling. 1

1. INTRODUCTION

Humans continually adapt in order to understand new situations in changing environments. One important adaptive ability is associative inference for composing features extracted from distinct experiences and relating them to each other (Schlichting & Preston, 2015; Gershman et al., 2015) . Suppose Alice has shared with you pictures of her toddler. Later, at the office party, you see a man carrying the depicted toddler. Since the toddler yields a shared feature in two different contexts, it may be plausible to infer that the man is Alice's partner, without ever seeing him and Alice together. The ability to rapidly associate and bind together novel stimuli can help to derive knowledge systematically, in addition to the knowledge gained directly from observation. Virtually all modern cognitive architectures applied to challenging artificial intelligence problems are based on deep artificial neural networks (NNs). Despite their empirical successes and theoretical generality, NNs tend to struggle to generalise in situations similar to the given example (Lake et al., 2017; Phillips, 1995; Lake & Baroni, 2017) . This weakness becomes even more severe if the training and test data exhibit systematic differences (Atzmon et al., 2016; Agrawal et al., 2017) . For example, during training, the man's representation might never be associated with the toddler's, but during testing, this association might be necessary to make a useful prediction. In problems where humans excel, this sort of inference is likely ubiquitous since data is often combinatorially complex in a way that observations used during training will likely cover just a small fraction of all possible compositions. Such a lack of productivity and systematicity is a long-standing argument against the use of NNs as a substrate of an artificial cognitive architecture (Fodor & Pylyshyn, 1988; Hadley, 1994; McLaughlin, 2009) . The hidden state of a neural model is a learned representation of the task-relevant information extracted from the input. To generalise to never-seen-before compositions of stimuli, the function which produces the state representation must be able to systematically construct all possible states. This requires a general and preferrably differentiable method, such as the Tensor Product Representation (TPR; Smolensky (1990) ). TPRs provide a general and differentiable method for embed-



Source code and data used in this paper is available at github.com/ischlag/Fast-Weight-Memory-public 1

