MULTI-TIMESCALE REPRESENTATION LEARNING IN LSTM LANGUAGE MODELS

Abstract

Language models must capture statistical dependencies between words at timescales ranging from very short to very long. Earlier work has demonstrated that dependencies in natural language tend to decay with distance between words according to a power law. However, it is unclear how this knowledge can be used for analyzing or designing neural network language models. In this work, we derived a theory for how the memory gating mechanism in long short-term memory (LSTM) language models can capture power law decay. We found that unit timescales within an LSTM, which are determined by the forget gate bias, should follow an Inverse Gamma distribution. Experiments then showed that LSTM language models trained on natural English text learn to approximate this theoretical distribution. Further, we found that explicitly imposing the theoretical distribution upon the model during training yielded better language model perplexity overall, with particular improvements for predicting low-frequency (rare) words. Moreover, the explicit multi-timescale model selectively routes information about different types of words through units with different timescales, potentially improving model interpretability. These results demonstrate the importance of careful, theoretically-motivated analysis of memory and timescale in language models.

1. INTRODUCTION

Autoregressive language models are functions that estimate a probability distribution over the next word in a sequence from past words, p(w t |w t-1 , . . . , w 1 ). This requires capturing statistical dependencies between words over short timescales, where syntactic information likely dominates (Adi et al., 2017; Linzen et al., 2016) , as well as long timescales, where semantic and narrative information likely dominate (Zhu et al., 2018; Conneau et al., 2018; Gulordava et al., 2018) . Because this probability distribution grows exponentially with sequence length, some approaches simplify the problem by ignoring long-range dependencies. Classical n-gram models, for example, assume word w t is independent of all but the last n -1 words, with typical n = 5 (Heafield, 2011) . Hidden Markov models (HMMs) assume that the influence of previous words decays exponentially with distance from the current word (Lin & Tegmark, 2016) . In contrast, neural network language models such as recurrent (Hochreiter & Schmidhuber, 1997; Merity et al., 2018; Melis et al., 2018) and transformer networks (Melis et al., 2019; Krause et al., 2019; Dai et al., 2019) include longer-range interactions, but simplify the problem by working in lower-dimensional representational spaces. Attention-based networks combine position and content-based information in a small number of attention heads to flexibly capture different types of dependencies within a sequence (Vaswani et al., 2017; Cordonnier et al., 2019) . Gated recurrent neural networks (RNNs) compress information about past words into a fixed-length state vector (Hochreiter & Schmidhuber, 1997) . The influence each word has on this state vector tends to decay exponentially over time. However, each element of the state vector can have a different exponential time constant, or "timescale" (Tallec & Ollivier, 2018) , enabling gated RNNs like the long short-term memory (LSTM) network to flexibly learn many different types of temporal relationships (Hochreiter & Schmidhuber, 1997) . Stacked LSTM networks reduce to a single layer (Turek et al., 2020) , showing that network depth has an insignificant influence on how the LSTM captures temporal relationships. Yet in all these networks the shape of the temporal dependencies must be learned directly from the data. This seems particularly problematic for very long-range dependencies, which are only sparsely informative (Lin & Tegmark, 2016) . This raises two related questions: what should the temporal dependencies in a language model look like? And how can that information be incorporated into a neural network language model? To answer the first question, we look to empirical and theoretical work that has explored the dependency statistics of natural language. Lin & Tegmark (2016) quantified temporal dependencies in English and French language corpora by measuring the mutual information between tokens as a function of the distance between them. They observed that mutual information decays as a power law, i.e. M I(w k , w k+t ) ∝ t -d for constant d. This behavior is common to hierarchically structured natural languages (Lin & Tegmark, 2016; Sainburg et al., 2019) as well as sequences generated from probabilistic context-free grammars (PCFGs) (Lin & Tegmark, 2016) . Now to the second question: if temporal dependencies in natural language follow a power law, how can this information be incorporated into neural network language models? To our knowledge, little work has explored how to control the temporal dependencies learned in attention-based models. However, many approaches have been proposed for controlling gated RNNs, including updating different groups of units at different intervals (El Hihi & Bengio, 1996; Koutnik et al., 2014; Liu et al., 2015; Chung et al., 2017) , gating units across layers (Chung et al., 2015) , and explicitly controlling the input and forget gates that determine how information is stored and removed from memory (Xu et al., 2016; Shen et al., 2018; Tallec & Ollivier, 2018 ). Yet none of these proposals incorporate a specific shape of temporal dependencies based on the known statistics of natural language. In this work, we build on the framework of Tallec & Ollivier (2018) to develop a theory for how the memory mechanism in LSTM language models can capture temporal dependencies that follow a power law. This relies on defining the timescale of an individual LSTM unit based on how the unit retains and forgets information. We show that this theory predicts the distribution of unit timescales for LSTM models trained on both natural English (Merity et al., 2018) and formal languages (Suzgun et al., 2019) . Further, we show that forcing models to follow this theoretical distribution improves language modeling performance. These results highlight the importance of combining theoretical modeling with an understanding of how language models capture temporal dependencies over multiple scales.

2.1. TIMESCALE OF INFORMATION

We are interested in understanding how LSTM language models capture dependencies across time. Tallec & Ollivier (2018) elegantly argued that memory in individual LSTM units tends to decay exponentially with a time constant determined by weights within the network. We refer to the time constant of that exponential decay as the unit's representational timescale. Timescale is directly related to the LSTM memory mechanism (Hochreiter & Schmidhuber, 1997) , which involves the LSTM cell state c t , input gate i t and forget gate f t , i t = σ(W ix x t + W ih h t-1 + b i ) f t = σ(W f x x t + W f h h t-1 + b f ) ct = tanh(W cx x t + W ch h t-1 + b c ) c t = f t c t-1 + i t ct ,

