THE GEOMETRY OF INTEGRATION IN TEXT CLASSIFICATION RNNS

Abstract

Despite the widespread application of recurrent neural networks (RNNs), a unified understanding of how RNNs solve particular tasks remains elusive. In particular, it is unclear what dynamical patterns arise in trained RNNs, and how those patterns depend on the training dataset or task. This work addresses these questions in the context of text classification, building on earlier work studying the dynamics of binary sentiment-classification networks (Maheswaranathan et al., 2019). We study text-classification tasks beyond the binary case, exploring the dynamics of RNNs trained on both natural and synthetic datasets. These dynamics, which we find to be both interpretable and low-dimensional, share a common mechanism across architectures and datasets: specifically, these text-classification networks use low-dimensional attractor manifolds to accumulate evidence for each class as they process the text. The dimensionality and geometry of the attractor manifold are determined by the structure of the training dataset, with the dimensionality reflecting the number of scalar quantities the network remembers in order to classify. In categorical classification, for example, we show that this dimensionality is one less than the number of classes. Correlations in the dataset, such as those induced by ordering, can further reduce the dimensionality of the attractor manifold; we show how to predict this reduction using simple word-count statistics computed on the training dataset. To the degree that integration of evidence towards a decision is a common computational primitive, this work continues to lay the foundation for using dynamical systems techniques to study the inner workings of RNNs.

1. INTRODUCTION

Modern recurrent neural networks (RNNs) can achieve strong performance in natural language processing (NLP) tasks such as sentiment analysis, document classification, language modeling, and machine translation. However, the inner workings of these networks remain largely mysterious. As RNNs are parameterized dynamical systems tuned to perform specific tasks, a natural way to understand them is to leverage tools from dynamical systems analysis. A challenge inherent to this approach is that the state space of modern RNN architectures-the number of units comprising the hidden state-is often high-dimensional, with layers routinely comprising hundreds of neurons. This dimensionality renders the application of standard representation techniques, such as phase portraits, difficult. Another difficulty arises from the fact that RNNs are monolithic systems trained end-toend. Instead of modular components with clearly delineated responsibilities that can be understood and tested independently, neural networks could learn an intertwined blend of different mechanisms needed to solve a task, making understanding them that much harder.

