DISENTANGLING REPRESENTATIONS OF TEXT BY MASKING TRANSFORMERS

Abstract

Representations from large pretrained models such as BERT encode a range of features in a single vector that affords strong predictive accuracy on a multitude of downstream tasks. In this paper we explore whether it is possible to learn disentangled representations by identifying existing subnetworks within pretrained models that encode distinct, complementary aspect representations. Concretely, we learn binary masks over transformer weights or hidden units to uncover the subset of features that correlate with a specific factor of variation; this eliminates the need to train a disentangled model from scratch for a particular domain. We evaluate the ability of this method to disentangle representations of syntax and semantics, and sentiment from genre in the context of movie reviews. By combining masking with magnitude pruning we find that we can identify sparse subnetworks within BERT that strongly encode particular aspects (e.g., movie sentiment) while only weakly encoding others (movie genre). Moreover, despite only learning masks, we find that disentanglement-via-masking performs as well as -and often better than -previously proposed methods based on variational autoencoders and adversarial training.

1. INTRODUCTION AND MOTIVATION

Large-scale pretrained models such as ELMo (Peters et al., 2018) , BERT (Devlin et al., 2018), and XLNet (Yang et al., 2019) have come to dominate in modern natural language processing (NLP). Such models rely on self-supervision over large datasets to learn general-purpose representations of text that achieve competitive predictive performance across a spectrum of downstream tasks (Liu et al., 2019) . A downside of such learned representations is that it is not obvious what information they are encoding, which hinders model robustness and interpretability. The opacity of representations produced by models such as BERT has motivated a line of NLP research on designing probing tasks as a means of uncovering what properties of input texts are encoded into token-and sentencelevel representations (Rogers et al., 2020; Linzen et al., 2019; Tenney et al., 2019) . In this paper we investigate whether we can uncover disentangled representations from pretrained models. That is, rather than mapping inputs onto a single vector that captures arbitrary combinations of features, our aim is to extract a representation that factorizes into distinct, complementary properties of the input. An advantage of explicitly factorizing representations is that it aids interpretability, in the sense that it becomes more straightforward to determine what factors of variation inform predictions in downstream tasks. Further, disentangled representations may facilitate increased robustness under distributional shifts by capturing a notion of invariance: If syntactic changes in a sentence do not affect the representation of semantic features, and vice versa, then we can hope to learn models that are less sensitive to any incidental correlations between these factors. A general motivation for learning disentangled representations is to try and minimize -or at least exposemodel reliance on spurious correlations, i.e., relationships between (potentially sensitive) attributes and labels that exist in the training data but which are not casually linked (Kaushik et al., 2020) . This is particularly relevant in the context of large pretrained models like BERT, as we do not know what the representations produced by such models encode. To date, most research on disentangled representations has focused on applications in computer vision (Locatello et al., 2019b; Kulkarni et al., 2015; Chen et al., 2016; Higgins et al., 2016) , where there exist comparatively clear independent factors of variation such as size, position, color, and In this paper we ask whether complementary factors of variation might already be captured by pretrained models, and whether it is possible to uncover these by identifying appropriate "subnetworks". The intuition for this hypothesis is that generalization across a sufficiently large and diverse training set may implicitly necessitate representations that admit some notion of invariance, as the many factors of variation in the training data give rise to a combinatorial explosion of possible inputs. Intriguing prior work (Radford et al., 2017) examining the correlation between sentiment and individual nodes within pretrained networks offers some additional support for this intuition. To test this hypothesis, we propose to use masking as a mechanism to isolate representations of individual factors. Recent work on lottery tickets (Frankle & Carbin, 2018) suggest that overparameterized networks are redundant, in that a network reduced to a small subset of weights set to "winning" initial values can achieve predictive performance similar to the full network. Building on this intuition, we hypothesize that it might be possible to uncover a representation for a factor of interest by starting with a pretrained representation and simply masking out weights or hidden units that correlate with other factors of variation. We use BERT (Devlin et al., 2018) as an archetypal pretrained transformer to test two variants of this basic idea. In the first variant we learn binary masks for all weight matrices in the model; in the second we derive masks for all hidden units (intermediate representations). To learn these masks we minimize a triplet loss that encourages the resultant representations for instances that are similar with respect to an aspect of interest to be relatively near to one another, independent of other factors. Our approach of uncovering existing subnetworks within pretrained models that yield disentangled representations differs substantially from prior work on disentangling representations in NLP, which have either relied on adversarial debiasing approaches (Elazar & Goldberg, 2018; Barrett et al., 2019) or variational auto-encoders (Chen et al., 2019; Esmaeili et al., 2019) . We evaluate masking in the context of two tasks. The first is a setting in which we aim to disentangle a representation of features for a target task from that of information encoding a secondary, non-target attribute (e.g., this might be sensitive information, or simply an unrelated factor). In the second we follow prior work in attempting to induce representations of semantics and syntax, respectively. In both settings, our surprising finding is that masking alone often outperforms previously proposed approaches (which learn or finetune networks in their entirety). While a small amount of masking generally suffices to achieve disentanglement, we can further increase sparsity by combining masking with weight pruning. The main contributions of this paper are as follows. (1) We propose a novel method of disentangling representations in NLP: Masking weights or hidden units within pretrained transformers (here, BERT). ( 2) We empirically demonstrate that we are indeed able to identify sub-networks within pretrained transformers that yield disentangled representations that outperform existing approaches



Figure1: Masking weights and hidden activations in BERT. We show a linear layer with weights W , inputs h, and outputs h . We learn a mask for each disentangled factor, which is either applied to the weights W or to intermediate representations h.

