DISENTANGLING REPRESENTATIONS OF TEXT BY MASKING TRANSFORMERS

Abstract

Representations from large pretrained models such as BERT encode a range of features in a single vector that affords strong predictive accuracy on a multitude of downstream tasks. In this paper we explore whether it is possible to learn disentangled representations by identifying existing subnetworks within pretrained models that encode distinct, complementary aspect representations. Concretely, we learn binary masks over transformer weights or hidden units to uncover the subset of features that correlate with a specific factor of variation; this eliminates the need to train a disentangled model from scratch for a particular domain. We evaluate the ability of this method to disentangle representations of syntax and semantics, and sentiment from genre in the context of movie reviews. By combining masking with magnitude pruning we find that we can identify sparse subnetworks within BERT that strongly encode particular aspects (e.g., movie sentiment) while only weakly encoding others (movie genre). Moreover, despite only learning masks, we find that disentanglement-via-masking performs as well as -and often better than -previously proposed methods based on variational autoencoders and adversarial training.

1. INTRODUCTION AND MOTIVATION

Large-scale pretrained models such as ELMo (Peters et al., 2018) , BERT (Devlin et al., 2018), and XLNet (Yang et al., 2019) have come to dominate in modern natural language processing (NLP). Such models rely on self-supervision over large datasets to learn general-purpose representations of text that achieve competitive predictive performance across a spectrum of downstream tasks (Liu et al., 2019) . A downside of such learned representations is that it is not obvious what information they are encoding, which hinders model robustness and interpretability. The opacity of representations produced by models such as BERT has motivated a line of NLP research on designing probing tasks as a means of uncovering what properties of input texts are encoded into token-and sentencelevel representations (Rogers et al., 2020; Linzen et al., 2019; Tenney et al., 2019) . In this paper we investigate whether we can uncover disentangled representations from pretrained models. That is, rather than mapping inputs onto a single vector that captures arbitrary combinations of features, our aim is to extract a representation that factorizes into distinct, complementary properties of the input. An advantage of explicitly factorizing representations is that it aids interpretability, in the sense that it becomes more straightforward to determine what factors of variation inform predictions in downstream tasks. Further, disentangled representations may facilitate increased robustness under distributional shifts by capturing a notion of invariance: If syntactic changes in a sentence do not affect the representation of semantic features, and vice versa, then we can hope to learn models that are less sensitive to any incidental correlations between these factors. A general motivation for learning disentangled representations is to try and minimize -or at least exposemodel reliance on spurious correlations, i.e., relationships between (potentially sensitive) attributes and labels that exist in the training data but which are not casually linked (Kaushik et al., 2020) . This is particularly relevant in the context of large pretrained models like BERT, as we do not know what the representations produced by such models encode. To date, most research on disentangled representations has focused on applications in computer vision (Locatello et al., 2019b; Kulkarni et al., 2015; Chen et al., 2016; Higgins et al., 2016) , where there exist comparatively clear independent factors of variation such as size, position, color, and

