INTEPRETING & IMPROVING PRETRAINED LANGUAGE MODELS: A PROBABILISTIC CONCEPTUAL APPROACH

Abstract

Pretrained Language Models (PLMs) such as BERT and its variants have achieved remarkable success in natural language processing. To date, the interpretability of PLMs has primarily relied on the attention weights in their self-attention layers. However, these attention weights only provide word-level interpretations, failing to capture higher-level structures, and are therefore lacking in readability and intuitiveness. In this paper, we propose a hierarchical Bayesian deep learning model, dubbed continuous latent Dirichlet allocation (CLDA), to go beyond wordlevel interpretations and provide concept-level interpretations. Our CLDA is compatible with any attention-based PLMs and can work as either (1) an interpreter which interprets model predictions at the concept level without any performance sacrifice or (2) a regulator which is jointly trained with PLMs during finetuning to further improve performance. Experimental results on various benchmark datasets show that our approach can successfully provide conceptual interpretation and performance improvement for state-of-the-art PLMs.

1. INTRODUCTION

Pretrained language models (PLMs) such as BERT Devlin et al. (2018) and its variants Lan et al. (2019) ; Liu et al. (2019) ; He et al. (2021) have achieved remarkable success in natural language processing. These PLMs are usually large attention-based neural networks that follow a pretrainfinetune paradigm, where models are first pretrained on large datasets and then finetuned for a specific task. As with any machine learning models, interpretability in PLMs has always been a desideratum, especially in decision-critical applications (e.g., healthcare). To date, the interpretability of PLMs has primarily relied on the attention weights in their selfattention layers. However, these attention weights only provide raw word-level importance scores as interpretations. Such low-level interpretations fail to capture higher-level semantic structures, and are therefore lacking in readability, intuitiveness and stability. For example, low-level interpretations often fail to capture influence of similar words to predictions, leading to unstable or even unreasonable explanations (see Sec. 4.2 for details). In this paper, we make an attempt to go beyond word-level attention and interpret PLM predictions at the concept (topic) level. Such higher-level semantic interpretations are complementary to word-level importance scores and tend to more readable and intuitive. The core of our idea is to treat a PLM's contextual word embeddings (and their corresponding attention weights) as observed variables and build a probabilistic generative model to automatically infer the higher-level semantic structures (e.g., concepts or topics) from these embeddings and attention weights, thereby interpreting the PLM's predictions at the concept level. Specifically, we propose a class of hierarchical Bayesian deep learning models, dubbed continuous latent Dirichlet allocation (CLDA), to (1) discover concepts (topics) from contextual word embeddings and attention weights in PLMs and (2) interpret individual model predictions using these concepts. It is worth noting that CLDA is 'continuous' because it treats attention weights as continuous-value word counts and models contextual word embeddings with continuous-value entries; this is in stark contrast to typical latent Dirichlet allocation Blei et al. (2003) that can only handle bag-of-words (both words and word counts are discrete values). Our CLDA is compatible with any attention-based PLMs and can work as either an interpreter, which interprets model predictions at the concept level without any performance sacrifice, or a regulator, which is jointly trained with PLMs during finetuning to further improve performance. Our contributions are as follows: • We propose a novel class of models, CLDA, to go beyond word-level interpretations and interpret PLM predictions at the concept level, thereby improving readability and intuitiveness. • Our CLDA is compatible with any attention-based PLMs and can work as either an interpreter, which interprets model predictions without performance sacrifice, or a regulator, which is jointly trained with PLMs during finetuning to further improve performance. • We provide empirical results across various benchmark datasets which show that CLDA can successfully interpret predictions from various PLM variants at the concept level and improve PLMs' performance when working as a regulator. (2019) . While these PLMs naturally provide attention weights for each word to intepret model predictions, such low-level interpretations fail to capture higher-level semantic structures, and are therefore lacking in readability and intuitiveness. In contrast, our CLDA goes beyond word-level attention and interpret PLM predictions at the concept (topic) level. These higher-level semantic interpretations are complementary to word-level importance scores and tend to more readable and intuitive.

2. RELATED WORK

Topic Models. Our work is also related to topic models Blei ( 2012 (2018) . Here we note several key differences between our CLDA and the methods above. (1) These methods focus on learning topic models from scratch given a collection of raw documents, while our CLDA learns topic models directly from the latent representations inside PLMs. (2) They assume word representations are static (i.e., the representation of a word remains the same across different documents), while PLMs' word representations are contextual (i.e., the representation of a word varies across different documents according to context). In contrast, our CLDA does not have such an assumption. (3) They assume word counts are discrete numbers, which is not applicable to PLMs where each word has a continuous-valued (or real-valued) word count (i.e., its attention weight). In contrast, our CLDA naturally handles continuous-valued word counts from PLMs in a differentiable manner to enable end-to-end training. Therefore these prior methods are not applicable to PLMs.

3. METHODS

In this section, we formalize the problem of conceptual interpretation of PLMs, and describe our methods for addressing this problem.



);Blei et al. (2003), which typically build upon latent Dirichlet allocation (LDA)Blei et al. (2003). Topic models takes the (discrete) bag-of-words representations of the documents (i.e., vocabulary-length vectors that count word occurrences) as input, discover hidden topics from them during training, and infer the topic proportion vector for each document during inference Blei et al. (2003); Blei & Lafferty (2006); Wang et al. (2012); Chang & Blei (2009). Besides these 'shallow' topic models, there has been recent work that employs 'deep' neural networks to learn topic models more efficiently Card et al. (2017); Xing et al. (2017); Peinelt et al. (2020), using techniques such as amortized variational inference. There is also work that improves upon traditional topic models by either leveraging word similarity as a regularizer for topic-word distributions Das et al. (2015); Batmanghelich et al. (2016) or including word embeddings into the generative process Hu et al. (2012); Dieng et al. (2020); Bunk & Krestel

Pretrained Language Models. Pretrained language models are large attention-based neural networks that follow a pretrain-finetune paradigm. Usually they are first pretrained on large datasets in a selfsupervised manner and then finetuned for a specific downstream task. BERT Devlin et al. (2018) is a pioneering PLM that has shown impressive performance across multifple downstream tasks. Following BERT, there have been variants, such as Albert Lan et al. (2019), DistilBERT Sanh et al. (2019), and Tinybert Jiao et al. (2019), that achieve performance comparable to BERT with fewer parameters. Other variants such as RoBERTa Liu et al. (2019) and BART Lewis et al. (2019) improve the performance using more sophisticated training schemes for the masked language modeling learning objective. More recently, there have also been BERT variants that design different selfsupervised learning objectives to achieve better performance; examples include DeBERTa He et al.

