INTEPRETING & IMPROVING PRETRAINED LANGUAGE MODELS: A PROBABILISTIC CONCEPTUAL APPROACH

Abstract

Pretrained Language Models (PLMs) such as BERT and its variants have achieved remarkable success in natural language processing. To date, the interpretability of PLMs has primarily relied on the attention weights in their self-attention layers. However, these attention weights only provide word-level interpretations, failing to capture higher-level structures, and are therefore lacking in readability and intuitiveness. In this paper, we propose a hierarchical Bayesian deep learning model, dubbed continuous latent Dirichlet allocation (CLDA), to go beyond wordlevel interpretations and provide concept-level interpretations. Our CLDA is compatible with any attention-based PLMs and can work as either (1) an interpreter which interprets model predictions at the concept level without any performance sacrifice or (2) a regulator which is jointly trained with PLMs during finetuning to further improve performance. Experimental results on various benchmark datasets show that our approach can successfully provide conceptual interpretation and performance improvement for state-of-the-art PLMs.

1. INTRODUCTION

Pretrained language models (PLMs) such as BERT Devlin et al. (2018) and its variants Lan et al. (2019) ; Liu et al. (2019) ; He et al. (2021) have achieved remarkable success in natural language processing. These PLMs are usually large attention-based neural networks that follow a pretrainfinetune paradigm, where models are first pretrained on large datasets and then finetuned for a specific task. As with any machine learning models, interpretability in PLMs has always been a desideratum, especially in decision-critical applications (e.g., healthcare). To date, the interpretability of PLMs has primarily relied on the attention weights in their selfattention layers. However, these attention weights only provide raw word-level importance scores as interpretations. Such low-level interpretations fail to capture higher-level semantic structures, and are therefore lacking in readability, intuitiveness and stability. For example, low-level interpretations often fail to capture influence of similar words to predictions, leading to unstable or even unreasonable explanations (see Sec. 4.2 for details). In this paper, we make an attempt to go beyond word-level attention and interpret PLM predictions at the concept (topic) level. Such higher-level semantic interpretations are complementary to word-level importance scores and tend to more readable and intuitive. The core of our idea is to treat a PLM's contextual word embeddings (and their corresponding attention weights) as observed variables and build a probabilistic generative model to automatically infer the higher-level semantic structures (e.g., concepts or topics) from these embeddings and attention weights, thereby interpreting the PLM's predictions at the concept level. Specifically, we propose a class of hierarchical Bayesian deep learning models, dubbed continuous latent Dirichlet allocation (CLDA), to (1) discover concepts (topics) from contextual word embeddings and attention weights in PLMs and (2) interpret individual model predictions using these concepts. It is worth noting that CLDA is 'continuous' because it treats attention weights as continuous-value word counts and models contextual word embeddings with continuous-value entries; this is in stark contrast to typical latent Dirichlet allocation Blei et al. (2003) that can only handle bag-of-words (both words and word counts are discrete values). Our CLDA is compatible with any attention-based PLMs and can work as either an interpreter, which interprets model predictions at the concept level without any performance sacrifice, or a regulator, which is jointly trained with PLMs during finetuning to further improve performance. Our contributions are as follows: • We propose a novel class of models, CLDA, to go beyond word-level interpretations and interpret PLM predictions at the concept level, thereby improving readability and intuitiveness. • Our CLDA is compatible with any attention-based PLMs and can work as either an interpreter, which interprets model predictions without performance sacrifice, or a regulator, which is jointly trained with PLMs during finetuning to further improve performance. • We provide empirical results across various benchmark datasets which show that CLDA can successfully interpret predictions from various PLM variants at the concept level and improve PLMs' performance when working as a regulator.

2. RELATED WORK

Pretrained Language Models. Pretrained language models are large attention-based neural networks that follow a pretrain-finetune paradigm. Usually they are first pretrained on large datasets in a selfsupervised manner and then finetuned for a specific downstream task. BERT Devlin et al. (2018) is a pioneering PLM that has shown impressive performance across multifple downstream tasks. Following BERT, there have been variants, such as Albert Lan et al. (2019) , DistilBERT Sanh et al. (2019) , and Tinybert Jiao et al. (2019) , that achieve performance comparable to BERT with fewer parameters. Other variants such as RoBERTa Liu et al. (2019) and BART Lewis et al. (2019) improve the performance using more sophisticated training schemes for the masked language modeling learning objective. More recently, there have also been BERT variants that design different selfsupervised learning objectives to achieve better performance; examples include DeBERTa He et al. (2021) , Electra Clark et al. (2020) , and XLNet Yang et al. (2019) . While these PLMs naturally provide attention weights for each word to intepret model predictions, such low-level interpretations fail to capture higher-level semantic structures, and are therefore lacking in readability and intuitiveness. In contrast, our CLDA goes beyond word-level attention and interpret PLM predictions at the concept (topic) level. These higher-level semantic interpretations are complementary to word-level importance scores and tend to more readable and intuitive. Topic Models. Our work is also related to topic models Blei (2012); Blei et al. (2003) , which typically build upon latent Dirichlet allocation (LDA) Blei et al. (2003) . Topic models takes the (discrete) bag-of-words representations of the documents (i.e., vocabulary-length vectors that count word occurrences) as input, discover hidden topics from them during training, and infer the topic proportion vector for each document during inference Blei et al. (2003) ; Blei & Lafferty (2006) ; Wang et al. (2012) ; Chang & Blei (2009) . Besides these 'shallow' topic models, there has been recent work that employs 'deep' neural networks to learn topic models more efficiently Card et al. (2017) ; Xing et al. (2017) ; Peinelt et al. (2020) , using techniques such as amortized variational inference. There is also work that improves upon traditional topic models by either leveraging word similarity as a regularizer for topic-word distributions Das et al. (2015) ; Batmanghelich et al. (2016) or including word embeddings into the generative process Hu et al. (2012) ; Dieng et al. ( 2020); Bunk & Krestel (2018) . Here we note several key differences between our CLDA and the methods above. (1) These methods focus on learning topic models from scratch given a collection of raw documents, while our CLDA learns topic models directly from the latent representations inside PLMs. (2) They assume word representations are static (i.e., the representation of a word remains the same across different documents), while PLMs' word representations are contextual (i.e., the representation of a word varies across different documents according to context). In contrast, our CLDA does not have such an assumption. (3) They assume word counts are discrete numbers, which is not applicable to PLMs where each word has a continuous-valued (or real-valued) word count (i.e., its attention weight). In contrast, our CLDA naturally handles continuous-valued word counts from PLMs in a differentiable manner to enable end-to-end training. Therefore these prior methods are not applicable to PLMs.

3. METHODS

In this section, we formalize the problem of conceptual interpretation of PLMs, and describe our methods for addressing this problem. between each word and the last-layer CLS token, where h denotes the h'th attention head. We denote the average attention weight over H heads as a mj = 1 H H h=1 a (h) mj and correspondingly a m ≜ [a mj ] Lm j=1 (see the PLM at the bottom of Fig. 1 ). In PLMs, these last-layer CLS embeddings are used as document-level representations for downstream tasks (e.g., document classification). Furthermore, our CLDA assumes K concepts (topics) for the corpus. For document m, our CLDA interpreter tries to infer a concept distribution vector θ m ∈ R K (also known as the topic proportion in topic models) for the whole document and a concept distribution vector ϕ mj = [ϕ mji ] K i=1 ∈ R K for word j in document m. In our continuous embedding space, the i'th concept is represented by a Gaussian distribution, N (µ i , Σ i ), of contextual word embeddings. The goal is to interpret PLMs' predictions at the concept level using the inferred document-level concept vector θ m , word-level concept vector ϕ mj , and the learned embedding distributions {N (µ i , Σ i )} K i=1 for each concept (see Sec. 4.2 for more detailed descriptions and visualizations). To effectively discover latent concept structures learned by PLMs at the dataset level and interpret PLM predictions at the data-instance level, our CLDA treats both the contextual word embeddings and their associated attention weights as observations to learn a probabilistic generative model of these observations, as shown in Fig. 1 . The key idea is to use the attention weights from PLMs to compute a virtual continuous count for each word, and model the contextual word embedding distributions with Gaussian mixtures. The generative process of CLDA is as follows (we mark key differences from LDA in blue and show the corresponding graphical model in Fig. 2 ):

3.1. CONTINUOUS LATENT DIRICHLET ALLOCATION

𝑪 𝒎 𝒆 𝐦𝟏 𝒆 mn 𝒆 m(L-1) 𝐏𝐫𝐞𝐝𝐢𝐜𝐭𝐨𝐫 BERT … … … … … … CLDA [𝑪𝑳𝑺] 𝑰 𝒋𝒖𝒊𝒄𝒆 [𝑺𝑬𝑷] 𝑯𝒆 𝒓𝒊𝒄𝒆 a m0 a m(L-1) 1. For each document m, 1 ≤ m ≤ M , (a) Draw the document-level concept distribution vector θ m ∼ Dirichlet(α) (b) For each word j, 1 ≤ j ≤ L m , i. Draw the word-level concept index z mj ∼ Categorical(θ m ) ii. With a continuous word count w mj ∈ R from the PLM's attention weights, A. Draw the contextual word embedding of the PLM from the corresponding Gaussian component e mj ∼ N (µ zmj , Σ zmj ) Given the generative process above, discovery of latent concept structures in PLMs at the dataset level boils down to learning the parameters {µ i , Σ i } K i=1 for the K concepts. Intuitively the global parameters {µ i , Σ i } K i=1 are shared across different documents, and they define a mixture of K Gaussian distributions. Each Gaussian distribution describes a 'cluster' of words and their contextual word embeddings. Similarly interpretations of PLM predictions at the data-instance level is equivalent to inferring the latent variables, i.e., document-level concept distribution vectors θ m and word-level concept indices z mj . Below we highlight several important aspects of our CLDA designs. Attention Weights as Continuous Word Counts. Different from typical topic models Blei et al. (2003) ; Blei (2012) and word embeddings Mikolov et al. (2013) that can only handle discrete word counts, our CLDA can handle continuous (virtual) word counts; this better aligns with continuous attention weights in PLMs. Specifically, we denote as w mj the continuous word count for the j'th word in document m. We explore three schemes of computing w mj : • Identical Weights: Use identical weights for different words, i.e., w mj = 1, ∀m, j. This is equivalent to typical discrete word counts. (2003) and typical word embeddings Mikolov et al. (2013); Dieng et al. (2020) where word representations are static, word representations in PLMs are contextual; specifically, the same word can have different embeddings in different documents (contexts). For example, the word 'soft' can appear as the j 1 'th word in document m 1 and as the j 2 'th word in document m 2 , and therefore have two different embeddings (i.e., e m1j1 ̸ = e m2j2 ). Correspondingly, in our CLDA, we do not constrain the same word to have a static embedding; instead we assume that a word embedding is drawn from a Gaussian distribution corresponding to its latent topic. It is also worth noting that word representations in CLDA is continuous, which is different from typical topic models Blei et al. (2003) based on (discrete) bag-of-words representations.

3.2. INFERENCE AND LEARNING

Below we discuss the inference and learning procedure for CLDA. We start by introducing the inference of document-level and word-level concepts (i.e., z mj and θ m ) given the global concept parameters (i.e., {(µ i , Σ i )} K i=1 ), and then introduce the learning of these global concept parameters.

3.2.1. INFERENCE

Inferring Document-Level and Word-Level Concepts. We formulate the problem of interpreting PLM predictions at the concept level as inferring document-level and word-level concepts. Specifically, given global concept parameters {(µ i , Σ i )} K i=1 , the contextual word embeddings e m ≜ [e mj ] Lm j=1 , and the associated attention weights a m ≜ [a mj ] Lm j=1 , a PLM produces for each document m, our CLDA infers the posterior distribution of the document-level concept vector θ m , i.e., p(θ m |e m , a m , {(µ i , Σ i )} K i=1 ), and the posterior distribution of the word-level concept index z mj , i.e., p(z mj |e m , a m , {(µ i , Σ i )} K i=1 ). Variational Distributions. These posterior distributions are intractable; we therefore resort to variational inference Jordan et al. (1998); Blei et al. (2003) and use variational distributions q(θ m |γ m ) and q(z mj |ϕ mj ) to approximate them. Here γ m ∈ R K and ϕ mj ≜ [ϕ mji ] K i=1 ∈ R K are variational parameters to be estimated during inference. This leads to the following joint variational distribution: q(θ m , {z mj } Lm j=1 |γ m , {ϕ mj } Lm j=1 ) = q(θ m |γ m ) • Lm j=1 q(z mj |ϕ mj ) (1) Evidence Lower Bound. For each document m, finding the optimal variational distributions is then equivalent to maximizing the following evidence lower bound (ELBO): L(γ m , {ϕ mj } Lm j=1 ; α, {(µ i , Σ i )} K i=1 ) = E q [log p(θ m |α)] + Lm j=1 E q [log p(z mj |θ m )] + Lm j=1 E q [log p(e mj |z mj , µ zmj , Σ zmj )] -E q [log q(θ m )] - Lm j=1 E q [log q(z mj )], where the expectation is taken over the joint variational distribution in Eqn. 1. Likelihood with Continuous Word Counts. One key difference between CLDA and typical topic models Blei et al. (2003) ; Blei ( 2012) is the virtual continuous word counts (discussed in Sec. 3.1). Specifically, we define the likelihood in the third term of Eqn. 2 as: p(e mj |z mj , µ zmj , Σz mj ) = [N (e mj ; µ mj , Σ mj )] wmj . (3) Note that Eqn. 3 is the likelihood of w mj (virtual) words, where w mj can be a continuous value derived from the PLM's attention weights (details in Sec. 3.1). Correspondingly, in the third item of Eqn. 2, we have: E q [log p(e mj |z mj , µ zmj , Σ zmj )] = m,j,i ϕ mji w mj log N (e mj |µ i , Σ i ) = m,j,i ϕ mji w mj {-1 2 (e mj -µ i ) T Σ -1 i (e mj -µ i ) -log[(2π) d/2 |Σ i | 1/2 ]} Update Rules. Taking the derivative of the ELBO in Eqn. 2 w.r.t. ϕ mji (see the Appendix for details) and setting it to 0 yields the update rule for ϕ mji : ϕ mji ∝ wmj |Σi| 1/2 exp[Ψ(γ mi ) -Ψ( i ′ γ mi ′ ) -1 2 (e mj -µ i ) T Σ -1 i (e mj -µ i )] with the normalization constraint K i=1 ϕ mji = 1. γ mi = α i + j ϕ mji w mj , where α ≜ [α i ] K i=1 is the hyperparameter for the Dirichlet prior distribution of θ m . In summary, the inference algorithm will alternate between updating ϕ mji for all (m, j, i) tuples and updating γ mi for all (m, i) tuples.

3.2.2. LEARNING

Learning Dataset-Level Concept Parameters. The inference algorithm in Sec. 3.2.1 assumes availability of the dataset-level (global) concept parameters {(µ i , Σ i )} K i=1 . To learn such these parameters, one needs to iterate between (1) inferring document-level variational parameters γ m as well as word-level variational parameters ϕ mj in Sec. 3.2.1 and (2) learning dataset-level concept parameters {(µ i , Σ i )} K i=1 . Update Rules. Similar to Sec. 3.2.1, we expand the ELBO in Eqn. 2 (see the Appendix for details), take its derivative w.r.t. µ i , set it to 0, yielding the update rule for learning µ i : µ i = m,j ϕmjiwmj emj m,j ϕmjiwmj , Similarly, setting the derivatives w.r.t. Σ to 0, we have Σ i = m,j ϕmjiwmj (emj -µ i )(emj -µ i ) T m,j ϕmjiwmj . ( ) Effect of Attention Weights. From Eqn. 7 and Eqn. 8, we can observe that the attention weight of the j'th word in document m, i.e., a mj , affects the virtual continuous word count w mj (see Sec. 3.1), thereby affecting the update of the dataset-level concept center and covariance µ i and Σ i . 𝜶 𝜽 𝒎 𝑧 mj 𝒆 𝒎𝒋 M L m 𝑤 #$ 𝝁 % 𝚺 % K 𝜦 0 𝝁 & 𝜐 & BERT 𝜅 & Figure 3: Probabilistic graphical model of our smoothed CLDA. Specifically, if we use attention-based weights with fixed length or variable length in Sec. 3.1, the continuous word count w mj will be proportional to the attention weight a mj . Therefore, when updating the concept center µ i as a weighted average of different word embeddings e mj , CLDA naturally places more focus on words with higher attention weights a mj from PLMs, consequently making the interpretations sharper and improving performance (see Sec. 4.2 for detailed results and the Appendix for theoretical analysis). Interestingly, we also observe that PLMs' attention weights on stop words such as 'the' and 'a' tend to be much lower; therefore CLDA can naturally ignore these concept-irrelevant stop words when learning and inferring concepts (topics). This is in contrast to typical topic models Blei et al. (2003) ; Blei (2012) that require preprocessing to remove stop words. Smoothing with Prior Distributions on {(µ i , Σ i )} K i=1 . To alleviate overfitting and prevent singularity in numerical computation, we impose priors distributions on µ i and Σ i to smooth the learning process (Fig. 3 ). Specifcally, we use a Normal-Inverse-Wishart prior on µ i and Σ i as follows: Σ i ∼ IW(Λ 0 , ν 0 ), µ i |Σ i ∼ N (µ 0 , Σ k /κ 0 ), where Λ 0 , ν 0 , µ 0 , and κ 0 are hyperparameters for the prior distributions. With the prior distribution above, the update rules for µ i and Σ i become: µ i = k 0 µ 0 +n µ i k 0 +n , Σi = Λ 0 +S i + κ 0 n k 0 +n ( µ i -µ 0 )( µ i -µ 0 ) T ν 0 +n-K-1 , Si = m,j ϕmjiwmj(emj -µ i )(emj -µ i ) T , where n = m,j ϕ mji w mj is the total virtual word counts used to estimate µ i and Σ i . Eqn. 9 is the smoothed version of Eqn. 7 and Eqn. 8, respectively. From the Bayesian perfective, they correspond to the expectations of µ i 's and Σ i 's posterior distributions (see the Appendix for detailed derivation). Online Learning of µ i and Σ i . PLMs are deep neural networks trained using minibatches of data, while Eqn. 7 and Eqn. 8 need to go through the whole dataset before each update. We therefore use exponential moving average (EMA) to work with minibatchs (see the Appendix for details).

3.3. CLDA AS PLM INTERPRETERS AND REGULATORS

CLDA can be used as either a PLM interpreter, which interprets model predictions at the concept level without any performance sacrifice, or a PLM regulator, which is jointly trained with PLMs during finetuning to further improve performance. This is possible because both the word counts and contextual word embeddings in CLDA is continuous and differentiable. Below we start with more details on the differentiability and then introduce our CLDA interpreter and CLDA regulator. Algorithm 1 Algorithm for CLDA Regulators (w/o EMA) Input: Pretrained f ae (•) and f c (•), initialized g(•), initialized {γ m } M m=1 , {Φ} M m=1 , and {Ω} M m=1 , documents {D} M m=1 , number of epochs T. for t = 1 : T do for m = 1 : M do Update Φ m and γ m using Eqn. 5 and Eqn. 6, respectively. Update f ae (•), f c (•), and g(•) using Eqn. 11. Update Ω using Eqn. 9. Differentiable Continuous Word Counts and Contextual Word Embeddings. One of CLDA's advantage is that it handle continuous word counts and word embeddings. Such continuity translates to better differentability, and therefore is particularly desirable when ones wants to jointly train CLDA and a PLM to further improve PLM performance. As shown in Fig. 2 , CLDA connects to a PLM through the attention weights a mj (related to the word counts w mj ) and contextual word embeddings e mj . Therefore, if the CLDA learning objective Green tea consumption is associated with decreased risk of breast , pancreatic , colon , oesophageal , and lung cancers in humans . [SEP] Tea protects from some diseases . (Eqn. 2) differentiable w.r.t. a mj and e mj , these gradients can then be propagated to the PLM parameters and help finetune the PLM. The derivative of the ELBO in Eqn. 2 w.r.t. e mj is: ∂L ∂emj = i ϕ mji w mj Σ -1 i (µ i -e mj ). Similarly we can get the derivative w.r.t. a mj using the chain rule, where ∂wmj ∂amj depends on the choice of schemes for computing w mj from a mj (described inSec. 3.1), and ∂L ∂wmj is: ∂L ∂wmj = i ϕ mji (µ i -e mj ) T Σ -1 i (µ i -e mj ) . CLDA as a PLM Interpreter. Using CLDA as a PLM interpreter is straightforward. One only needs to first learn the global concept parameters µ i and Σ i according to Sec. 3.2.2, and then infer document-level concept vectors θ m and word-level concept indices z mj . Together, they provide dataset-level, document-level, and word-level conceptual interpretations for PLM predictions. CLDA as a PLM Regulator. One could also use CLDA as a regulator (or regularizer) when finetuning a PLM. Assume a PLM that produces attention weights and contextual word embeddings for document m, i.e., (a m , e m ) = f ae (D m ), as well as the CLS embedding c m = f c ( c m , a m , e m ); here c m is the CLS embedding of the second-last layer. To better see the connection between CLDA and PLMs, we can rewrite the ELBO in Eqn. 2 as: L c γ m , Φ m ; Ω, a m , e m = L c γ m , Φ m ; Ω, f ae (D m ) , where Φ m ≜ {ϕ mj } Lm j=1 is the collection of word-level concept parameters for document m, and Ω = {(µ i , Σ i )} K i=1 is the global concept parameters. Assuming a document-level predictor g(c m ) and denoting the ground-truth label as y m , we have the PLM loss during finetuning: L p g(f c ( c m , f ae (D m ))), y m . Putting them together, we have the joint loss: Theoretical Analysis. In the Appendix, we provide theoretical guarantees that under mild assumptions our CLDA can learn concept-level interpretations for PLMs, especially in noisy data. L j = L p g(f c ( c m , f ae (D m ))), y m + λL c γ m , Φ m ; Ω, f ae (D m ) .

4.1. EXPERIMENT SETUP

Datasets. We use the GLUE benchmark Wang et al. (2018) to evaluate our methods. This benchmark includes multiple sub-tasks of predictions, with the paired sentences as inputs. In this paper, we use six datasets from GLUE (CoLA, MRPC, STS-B, QQP, RTE, and SST-2) to perform evaluation. Document-Level Interpretations. For document-level conceptual interpretations, we sample two example documents from MRPC (Fig. 4 (left)) and three from RTE (Fig. 4 (right)), where each document contains a pair of sentences. The MRPC task is to predict whether one sentence paraphrases the other. For example, in the first document of MRPC, we can see that our CLDA correctly interprets the model prediction 'True' with Concept 24 (politics). The RTE task is to predict whether one sentence entail the other. For example, in the second document of RTE, CLDA correctly interprets the model prediction 'True' with Concept 13 (countries).

Implementation

Word-Level Interpretations. For word-level conceptual interpretations, we can observe that CLDA interpret the PLM's prediction on MRPC's first document (Fig. 4 (left)) using words such as 'senate' and 'bitty' that are related to politics. Note that the word 'bitty' is commonly used (with 'little') by politicians to refer to the small size of tax relief/cut plans. Similarly, for RTE's first document (Fig. 4 (right)), CLDA correctly identifies Concept 67 (Islam) and interprets the model prediction 'False' by distinguishing between keywords such as 'Jihad' and 'Al Qaeda'.

4.3. QUANTITATIVE RESULTS

To evaluate CLDA as a PLM regulator, we use BERT, RoBERTa, BART, and DeBERTa as base models for our CLDA, leading four different CLDA models, BERT-CLDA, RoBERTa-CLDA, BART-CLDA, and DeBERTa-CLDA, respectively. Table 1 shows the performance of our CLDA variants and the correspondingly base PLMs on six benchmark datasets, CoLA, MRPC, STS-B, QQP, RTE, and SST-2. The last two rows show the average predictive performance across different PLM base models for different datasests. We can observe that on average, our CLDA significantly outperforms the baselines in all datasets. Notably, in the largest dataset, QQP, our CLDA improves upon the baselines by 5.6% in terms of F1 score and by 3.5% in terms of accuracy. Moreover, even for 'difficult' natural language inference tasks such as RTE, CLDA can still improve the average accuracy by a large margin of 2.9%. When using RoBERTa as the base model, CLDA achieves absolute improvements of 17.0% and 10.0% for F1 score and accuracy, respectively. Note that STS-B and SST-2 are relatively 'easy' datasets, and even the BERT-Base model could achieve correlation higher than 82% and accuracy higher than 88%, respectively. In this case the room for improvement is minimal. However, our CLDA could still lead to slight improvement in terms of Pearson correlation and Spearman correlation for STS-B, as well as reasonable accuracy improvement for SST-2. To evaluate different schemes of computing the virtual word counts w mj from attention weights (as introduced in Sec. 3.1), we perform ablation studies on the CoLA dataset using different PLM base models. Table 2 shows the results on the base model and CLDA with identical weights (CLDA-Identical), attention-based weights with variable length (CLDA-Variable), and attention-based weights with fixed length (CLDA-Fixed).

4.4. ABLATION STUDIES

One observation is that CLDA-Identical tends to underperform the PLM base models, while both CLDA-Variable and CLDA-Fixed can significantly outperform the base models. This verifies the importance of using attention weights to compute the virtual continuous word counts. Another interesting observation is that CLDA-Fixed slightly outperforms CLDA-Variable. Note that both CLDA-Fixed and CLDA-Variable use attention weights to compute virtual continuous word counts; the difference is that in CLDA-Variable assigns longer documents more (total) weights when learning the global concept parameters (Eqn. 7 and Eqn. 8), while CLDA-Fixed treats each document fairly (as if they had fixed length). Therefore Table 2 shows that it is beneficial to assume different documents have fixed total virtual word counts for CLDA.

5. CONCLUSION

We develop CLDA as a genearal framework to interpret pretrained word embeddings at the concept level. Our CLDA is compatible with any attention-based PLMs. It can not only interpret how PLMs make predictions, but also help improve contextual word embeddings in an end-to-end manner, thereby boosting predictive performance. A DETAILS ON LEARNING CLDA Update Rules. Similar to Sec. 3.2.1 of the main paper, we expand the ELBO in Eqn. 2 of the main paper, take its derivative w.r.t. µ i and set it to 0: ∂L ∂µ i = m,j ϕ mji w mj Σ -1 i (e mj -µ i ) = 0, yielding the update rule for learning µ i : µ i = m,j ϕmjiwmj emj m,j ϕmjiwmj , where Σ -1 i is canceled out. Similarly, setting the derivatives w.r.t. Σ to 0, i.e., ∂L ∂Σi = 1 2 m,j ϕ mji w mj (-Σ -1 i + Σ -1 i (e mj -µ i )(e mj -µ i ) T Σ -1 i ), we have Σ i = m,j ϕmjiwmj (emj -µ i )(emj -µ i ) T m,j ϕmjiwmj . ( ) Online Learning of µ i and Σ i . Note that PLMs are deep neural networks trained using minibatches of data, while Eqn. 7 and Eqn. 8 need to go through the whole dataset before each update. Inspired by Hoffman et al. (2010); Oord et al. (2017) , we using exponential moving average (EMA) to work with minibatchs. Specifically, we update them as: µ i ← ρ • N • µ i + (1 -ρ) • B • µ i , Σ i ← ρ • N • Σ i + (1 -ρ) • B • Σ i , N ← ρ • N + (1 -ρ) • B, µ i ← µ i N , Σ i ← Σi N , where B is the minibatch size, N is a running count, and ρ ∈ (0, 1) is the momentum hyperparameter. µ i and Σ i are the updated µ i and Σ i after applying Eqn. 7 and Eqn. 8 only on the current minibatch. Paired Sentences as a Document. Many modern natural language processing tasks involve predicting a label from a pair of sentences (for example, given two sentences, predict whether one sentence paraphrases the other). In this case, one document may contain a pair of sentences (with length L m1 and L m2 , and L m = L m1 + L m2 ) as PLM inputs, and γ of each sentence can be inferred as: γ m1i = α i + Lm 1 j=0 ϕ mji w mj , γ m2i = α i + Lm j=m1 ϕ mji w mj .

B EXPERIMENTAL SETTINGS AND IMPLEMENTATION DETAILS

We will release all code, models, and data. Below we provide more details on the experimental settings and practical implementation. Data Preprocessing. Our training/validation/test data split of GLUE datasets follows exactly Devlin et al. (2018) . We train our model on the training data, perform model selection (select hyperparameters) using the validation data, and evaluate methods on the test data. Our tokenization and bert-configurations follow previous PLMs that we compare with. For fair comparison, we use lowercase tokenization both in base models and our CLDA models. According to different versions of CLDA weighting, we can choose whether to calculate TF-IDF scores in documents. If we use an identical-weight CLDA, additional computing of TF-IDF scores is necessary to filter words with little information for our CLDA-based topic models. Implementation. All PLMs are base models from vanilla settings, with the hidden dimension of 768. We initialize the models with the seed 2021. The BERT Models are optimized by AdamW Kingma & Ba (2014) Optimizer, using a learning rate of 10 -4 with linear warmup and linear learning rate decay. We finetune the models until metrics on validation sets get the highest score. We treat the training batch-size, CLDA prior parameters, and λ (in Eqn. 11 of the main paper) as hyperparameters, and run grid-search in training and validation to search for models with the highest possible performance. To alleviate overfitting of CLDA during joint training, we periodically include the CLDA loss term L c in the joint loss L j along epochs, i.e., using the CLDA term every 1/3/5 epochs (as a training hyperparameter as well), along with original base PLM finetuning loss. We use the fixed scheme for CLDA training by default to produce the results in Table 1 of the main paper. We use the penultimate-layer word embeddings for CLDA in Eqn. 3 of the main paper, because our preliminary results show that using the penultimate layer instead of the output layer improves performance. Baselines. To ensure fair comparison, during fine-tuning, we select the epoch for both baselines and CLDA entirely based on validation accuracy and report the test accuracy; in contrast, [Devlin et al., 2018] directly chooses the 3rd epoch; we argue that this is not rigorous and potentially 'overfits' test sets. Also, as aforementioned, We follow the convention of topic models and preprocess the documents into lower-case words for both baselines and CLDA; in contrast, [Devlin et al., 2018] keeps the words unchanged. Nevertheless, note that our CLDA can interpret any PLMs without accuracy sacrifice; therefore the exact accuracy for BERT-base is less relevant in our case. Visualization Postprocessing. For better showcase the dataset-level concepts as in Fig. 4 of the main paper, we may employ simple linear transformations on the embedding of words after the aforementioned PCA step, in order to scatter all the informative words on the same figures. However, for some datasets such as STS-B, this is not necessary so we don't use it. Topic (Concept) Identification. Inspired by Blei et al. (2003) , we identify meaningful topics by listing the top-5 topics for each word, computing the inverse document frequency (IDF), and filtering out topics with the lowest IDF scores. Note that although GLUE benchmark are datasets that consists of documents with small size, making it particularly challenging for traditional topic models (such as LDA) to learn topics; interestingly our CLDA can still do well in learning the topics. We contribute this to the following observations: (1) Compared to traditional LDA using discrete word representations, CLDA uses continuous word embeddings. In such a continuous space, topics learned for one word can also help neighboring words; this alleviates the sparsity issue caused by short documents and therefore learns better topics. (2) CLDA's attention-based continuous word counts further improves sample efficiency. In CLDA, important words have larger attention weights and therefore larger continuous word counts. In this case, one important word in a sentence possesses statistical (sample) power equivalent to multiple words; this leads to better sample efficiency in CLDA.

C EXPANSION OF ELBO

We can expand the ELBO in Eqn. 2 of the main paper as: L(γ, ϕ; α, {µ} K , {Σ} K ) = log Γ( K i=1 α i ) - K i=1 log Γ(α i ) + K i=1 (α i -1)(Ψ(γ i ) -Ψ( K j=1 γ j )) + L j=1 K i=1 ϕ ji (Ψ(γ i ) -Ψ( K k=1 γ k )) + m,j,i ϕ mji w mj {-1 2 (e mj -µ i ) T Σ -1 i (e mj -µ i ) -log[(2π) d/2 |Σ i | 1/2 ]} -log Γ( K j=1 γ j ) + K i=1 log Γ(γ i ) - K i=1 (γ i -1)(Ψ(γ i ) -Ψ( K j=1 γ j )) - L j=1 K i=1 ϕ ji log ϕ ji .

D DERIVATION ON SMOOTHED CLDA

To alleviate overfitting and prevent singularity in numerical computation, we impose priors distributions on µ i and Σ i to smooth the learning process. Specifcally, we use a Normal-Inverse-Wishart prior on µ i and Σ i as follows: Σ i ∼ IW(Λ 0 , ν 0 ), µ i |Σ i ∼ N (µ 0 , Σ k /κ 0 ), where Λ 0 , ν 0 , µ 0 , and κ 0 are hyperparameters for the prior distributions. With the prior distribution above, the update rules for the parameters of the posterior distribution N IW(µ i , Σ i |µ (n) i , Λ (n) i , κ (n) i , ν (n) i ) become: µ (n) i = k0µ 0 +n µ i k0+n , Λ (n) i = Λ 0 + S i + κ0n k0+n ( µ i -µ 0 )( µ i -µ 0 ) T , κ (n) i = κ 0 + n, ν (n) i = ν 0 + n, S i = m,j ϕ mji w mj (e mj -µ i )(e mj -µ i ) T , where n = m,j ϕ mji w mj is the total virtual word counts used to estimate µ i and Σ i . Taking the expectations of µ i and Σ i over the posterior distibution N IW(µ i , Σ i |µ (n) i , Λ (n) i , κ (n) i , ν i ), we have the update rules as: µ i ← E N IW [µ i ] = k0µ 0 +n µ i k0+n , Σ i ← E N IW [Σ i ] = Λ0+Si+ κ0n k0+n ( µ i -µ 0 )( µ i -µ 0 ) T ν0+n-K-1 , S i = m,j ϕ mji w mj (e mj -µ i )(e mj -µ i ) T . Eqn. 20 and Eqn. 21 are the smoothed version of Eqn. 7 and Eqn. 8, respectively. From the Bayesian perspective, they correspond to the expectations of µ i 's and Σ i 's posterior distributions.

E MORE CONCEPTUAL INTERPRETATION RESULTS

Dataset-Level Interpretations. As in the main paper, we leverage CLDA as interpreter on STS-B and QQP, respectively, sample 4 concepts (topics) for each dataset, and plot the word embeddings of the top words (closest to the center µ i ) in these concepts using PCA. Fig. 5 shows the concepts from STS-B. We can observe Concept 63 is mostly about household and daily life, including words such as 'trash', 'flowers', 'airs', and 'garden'. Concept 60 is mostly about tools, including words such as 'stations', 'rope', 'parachute', and 'hose'. Concept 84 is mostly about national security, including words such as 'guerilla', 'NSA', 'espionage', and 'raided'. Concept 55 contains mostly countries and cities such as 'Kiev', 'Moscow', 'Algeria', and 'Ukrainian'. Similarly, Fig. 6 shows the concepts from QQP. We can observe Concept 12 is mostly about negative attitude, including words such as 'boring', 'criticism', and 'blame'. Concept 73 is mostly about Psychology, including words such as 'adrenaline', 'haunting', and 'paranoia'. Concept 34 is mostly about prevention and conservatives, including words such as 'destroys', 'unacceptable', and 'prohibits'. Concept 64 is mostly about strategies, including words such as 'rumours', 'boycott', and 'deportation'. Document-Level Interpretations. For document-level conceptual interpretations, we sample six example documents from STS-B (Fig. 5 ) and eight from QQP (Fig. 6 ), respectively, where each document contains a pair of sentences. The STS-B task is to predict the semantic similarity between two sentences with the score range of [0, 5]. For example, in Document (a) of Fig. The approximate marginal log-likelihood of word embeddings, i.e., the third term of the ELBO as mentioned in Eqn. 4 of the main paper, is: L (train) = Lm j=1 E q [log p(e mj |z mj , µ zmj , Σ zmj )] = m,j,i ϕ mji w mj {-1 2 (e mj -µ i ) T Σ -1 i (e mj -µ i ) -log[(2π) d/2 |Σ i | 1/2 ]}. The above equation is the training objective, yet for fair comparison of different training schemes, we calculate the approximated likelihood with word count 1 for all words. L (eval) = Lm j=1 E q [log p ′ (e mj |z mj , µ zmj , Σ zmj )] = m,j,i ϕ mji {-1 2 (e mj -µ i ) T Σ -1 i (e mj -µ i ) -log[(2π) d/2 |Σ i | 1/2 ]}. F.1 GAUSSIAN MIXTURE MODELS Suppose we have a ground truth GMM model with parameters π * ∈ R K and {µ * k , Σ * k } K k=1 , with K different Gaussian distributions. In the dataset, let N and N s denote the numbers of non-stop-words and stop-words, respectively. Then the marginal log likelihood of a learned GMM model on a given data sample e can be written as p(e|{µ, Σ}, π) = K k=1 π k N (e; µ k , Σ k ). Assuming a dataset of N + N s words {e i } N +Ns i=1 and taking the associated weights w i for each word into account, the log-likelihood of the dataset can be written as N +Ns i=1 p(e i |{µ k , Σ k } K k=1 , π) = N i=1 log K k=1 w i π k N (e i ; µ k , Σ k ) + N +Ns i=N +1 log K k=1 w i π k N (e i ; µ k , Σ k ). Leveraging Jensen's inequality, we obtain a lower bound of the above quantity (denoting as Θ the collection of parameters {µ k , Σ k } K k=1 and π): L GM M (Θ, {w i }) = N i=1 w i log K k=1 π k N (e i ; µ k , Σ k ) + N +Ns i=N +1 w i log K k=1 π k N (e i ; µ k , Σ k ) + C, where C is a constant. In the following theoretical analysis, we consider the following three different configurations of the weights w i . Definition F.1 (Weight Configurations). We define three different weight configurations as follows: • Identical Weights: w i = 1 N +Ns , i ∈ {1, 2, . . . , N + N s } • Ground-Truth Weights : w i = 1 N , i ∈ {1, 2, . . . , N } 0, i ∈ {N + 1, N + 2, . . . , N + N s } • Attention-Based Weights: w i = λ 1 ∈ [ 1 N +Ns , 1 N ], i ∈ {1, 2, . . . , N } λ 2 ∈ [0, 1 N +Ns ], i ∈ {N + 1, N + 2, . . . , N + N s } Definition F.2 (Advanced Weight Configurations). We define three different weight configurations as follows: • Identical Weights: Definition F.3 (Optimal Parameters) . With Definition F.1, the corresponding optimal parameters are then defined as follows: w i = 1 N +Ns , i ∈ {1, 2, . . . , N + N s } • Ground-Truth Weights : w i = 1 N , i ∈ {1, 2, . . . , N } 0, i ∈ {N + 1, N + 2, . . . , N + N s } • Attention-Based Weights: w i ∈ [ 1 N +Ns , 1 N ], i ∈ {1, 2, . . . , N } [0, 1 N +Ns ], i ∈ {N + 1, N + 2, . . . , N + N s } Θ I = arg max Θ L(Θ; w → Identical), Θ G = arg max Θ L(Θ; w → GT), Θ A = arg max Θ L(Θ; w → Attention), where w → Identical, w → GT, and w → Attention indicates that 'Identical Weights', 'Ground-Truth Weights', and 'Attention-Based Weights' are used, respectively. Lemma F.4. Suppose we have two series of functions {f 1,i (x)} and {f 2,i (x)}, with two non-negative weighting parameters λ 1 , λ 2 satisfying N λ 1 + N s λ 2 = 1. We define the final objective function f (•) as: f (x; λ 1 , λ 2 ) = λ 1 N i=1 f 1,i (x) + λ 2 Ns i=N +1 f 2,i (x). We assume two pairs of parameters (λ 1 , λ 2 ) and (λ ′ 1 , λ ′ 2 ), where λ 1 ≥ λ ′ 1 , (33) λ 2 ≤ λ ′ 2 . ( ) Defining the optimal values of the objective function for different weighting parameters as x = arg max x f (x; λ 1 , λ 2 ), x ′ = arg max x f (x; λ ′ 1 , λ ′ 2 ), we then have that f ( x; 1 N , 0) ≥ f ( x ′ ; 1 N , 0). (37) Proof. We prove this theorem by contradiction. Suppose that we have f ( x; 1 N , 0) < f ( x ′ ; 1 N , 0). According to Eqn. 49, i.e., λ 1 ≥ λ ′ 1 , and the equation N λ 1 + N s λ 2 = 1, we have λ 1 λ ′ 2 = λ 1 1-N λ ′ 1 Ns ≥ λ ′ 1 1-N λ1 Ns = λ ′ 1 λ 2 . According to Eqn. 36, we have the following equality: f ( x; λ ′ 1 , λ ′ 2 ) ≤ f ( x ′ ; λ ′ 1 , λ ′ 2 ). (40) Combined with the aforementioned assumption in Eqn. 38, we have that λ ′ 2 f ( x; λ 1 , λ 2 ) = λ 1 λ ′ 2 N i=1 f 1,i ( x) + λ 2 λ ′ 2 Ns i=N +1 f 2,i ( x) (41) =(λ ′ 1 λ 2 N i=1 f 1,i ( x) + λ ′ 2 λ 2 Ns i=N +1 f 2,i ( x)) + (N (λ 1 λ ′ 2 -λ ′ 1 λ 2 ) • 1 N N i=1 f 1,i ( x)) (42) =λ 2 f ( x; λ ′ 1 , λ ′ 2 ) + N (λ 1 λ ′ 2 -λ ′ 1 λ 2 )f ( x; 1 N , 0) (43) <λ 2 f ( x ′ ; λ ′ 1 , λ ′ 2 ) + N (λ 1 λ ′ 2 -λ ′ 1 λ 2 )f ( x ′ ; 1 N , 0) (44) =(λ ′ 1 λ 2 N i=1 f 1,i ( x ′ ) + λ ′ 2 λ 2 Ns i=N +1 f 2,i ( x ′ )) + (N (λ 1 λ ′ 2 -λ ′ 1 λ 2 ) • 1 N N i=1 f 1,i ( x ′ )) (45) =λ 1 λ ′ 2 N i=1 f 1,i ( x ′ ) + λ 2 λ ′ 2 Ns i=N +1 f 2,i ( x ′ ) (46) =λ ′ 2 f ( x ′ ; λ 1 , λ 2 ), Proof. First, by definition one can easily find that Θ G achieves the largest L(•; w → GT) among the three: max[L GM M (Θ I ; w → GT), L GM M (Θ A ; w → GT)] ≤ max Θ L GM M (Θ; w → GT) = L GM M (Θ G ; w → GT). ( 60) Next, we set {w i } N i=1 to λ 1 and {w i } N +Ns i=N +1 to λ 2 , respectively; we rewrite log K k=1 π k N (e i ; µ k , Σ k ) as f 1,i (x) for i ∈ {1, 2, . . . , N } and f 2,i (x) for i ∈ {N + 1, N + 1, . . . , N + N s }, where x corresponds to Θ ≜ (π, {µ k , Σ k } K k=1 ). By Lemma F.4, we have that L GM M (Θ A ; w → GT) ≤ L GM M (Θ G ; w → GT). (61) Combining Eqn. 60 and Eqn. 61 concludes the proof. Theorem F.7 shows that under mild assumptions, the attention-based weights can help produce better estimates Θ in the presence of noisy stop-words and therefore learns higher-quality topics from the corpus, improving both generalization performance and interpretability of PLMs. Theorem F.7 (Advantage of Θ A in the General Case). With Definition F.2 and Definition F.3, comparing Θ I , Θ G , and Θ A by evaluating them on the marginal log-likelihood of non-stop-words, i.e., L GM M (•, w → GT), we have that L GM M (Θ I ; w → GT) ≤ L GM M (Θ A ; w → GT) ≤ L GM M (Θ G ; w → GT). Proof. First, by definition one can easily find that Θ G achieves the largest L(•; w → GT) among the three: max[L GM M (Θ I ; w → GT), L GM M (Θ A ; w → GT)] ≤ max Θ L GM M (Θ; w → GT) = L GM M (Θ G ; w → GT). ( 63) Next, we invoke Lemma F.5 by (1) setting {w i } N i=1 to λ 1 and {w i } N +Ns i=N +1 to λ 2 , respectively, and (2) rewriting log K k=1 π k N (e i ; µ k , Σ k ) as f 1,i (x) for i ∈ {1, 2, . . . , N } and f 2,i (x) for i ∈ {N + 1, N + 1, . . . , N + N s }, where x corresponds to Θ ≜ (π, {µ k , Σ k } K k=1 ). By Lemma F.5, we then have that L GM M (Θ A ; w → GT) ≤ L GM M (Θ G ; w → GT). (64) Note that because f 1,i (•) and f 2,i (•) are Gaussian, therefore Assumption 1 and 2 in Lemma F.5 hold naturally under mild regularity conditions. Combining Eqn. 70 and Eqn. 71 concludes the proof.

F.2 CLDA AS INTERPRETERS

As mentioned in Eqn. 4 of the main paper, the ELBO of the marginal likelihood (denoting as Θ the collection of parameters ϕ, γ and {µ k , Σ k } K k=1 ) is as follows: 



Figure 4: Visualization of CLDA's learned topics of contextual word embeddings. Left: MRPC's dataset-level interpretation with two example documents. Concept 83 is relatively far from the other three concepts in the embedding space; therefore we omit it on the left panel for better readability. Right: RTE's dataset-level interpretation with three example documents.

When finetuning a PLM, one can iterate between updating (1) the PLM parameters in f ae (•), f c (•), and g(•), (2) the CLDA global concept parameters Ω, and (3) the CLDA document-level and wordlevel concept parameters γ m and Φ m . Note that f ae (•) contains most of the parameters in PLMs and appears in the CLDA loss term L c ; therefore CLDA can improve the finetuning process in PLMs. Alg. 1 shows an overview of CLDA training when used as a PLM regulator (to prevent clutter, we show the version without EMA).

Figure 5: Visualization of CLDA's learned topics of contextual word embeddings. We show STS-B's dataset-level interpretation with six example documents. The prediction of CLDA is between the range of [0, 5].

CLDA (Θ; {w i }) = L ′ j=1 Eq [log p(e mj |z mj , µ z mj Σz mj -µ i ) -log[(2π) d/2 |Σ i | 1/2 ]}. (65)Based on the definitions and lemmas above, we have the following theorems: Theorem F.8 (Advantage of Θ A in the Simplified Case). With Definition F.1 and Definition F.3, comparing Θ I , Θ G , and Θ A by evaluating them on the marginal log-likelihood of non-stop-words, i.e., L(•, w → GT), we have thatL CLDA (Θ I ; w → GT) ≤ L CLDA (Θ A ; w → GT) ≤ L CLDA (Θ G ; w → GT).(66)Proof. First, by definition one can easily find that Θ G achieves the largest L(•; w → GT) among the three:max[L CLDA (Θ I ; w → GT), L CLDA (Θ A ; w → GT)] ≤ max Θ L CLDA (Θ; w → GT) = L CLDA (Θ G ; w → GT). (67)

Problem Setting and Notation. We consider a corpus of M documents, where the m'th document contains L m words, and a PLM f (D m ), which takes as input the document m (denoted as D m ) with L m words and outputs (1) a CLS embedding c m ∈ R d , (2) L m contextual word embeddings e m ≜ [e mj ] Lm j=1 , and (3) the attention weights a

Attention-Based Weights with Fixed Length: Use w mj = L ′ a mj , where L ′ is a fixed sequence length shared across all documents.• Attention-Based Weights with Variable Length: Use w mj = Lmamj / Lm k=1 a mk , where L m is true sequence length without padding. Note that in practice,

Dimensionality Reduction of Embeddings with TopicsHe playfully chided the Senate's little bitty tax relief plan. [SEP] We don't need a little bitty tax relief plan.

. Our training/validation/test data split of GLUE datasets follows exactlyDevlin et al. (2018). All PLMs are base models from vanilla settings, with the hidden dimension of 768. The BERT Models are optimized by AdamW Kingma & Ba (2014) Optimizer, using a learning rate of 10 -4 with linear warmup and linear learning rate decay. We finetune the models until metrics on validation sets get the highest score (see the Appendix for more implementation details).

Results on GLUE benchmark datasets.

Ablation studies for the CoLA dataset in terms of Matthew's correlation.

Example concepts on RTE dataset learned by CLDA.

STS-B

Concept (Topic) 63: trash, flowers, airs, garden, wild, closet, sofa, vase, carrot, seeds, turf, playground, floors Concept (Topic) 60: stations, rope, parachute, hose, clarinet, sink, axe, rifle Concept (Topic) 84: guerrilla, NSA, espionage, raided, Canadian, Croatia, historic Concept (Topic) 55: Kiev, Moscow, resistance, Algeria, agrees, Ukrainian, emerge, Qaeda, final A man and a woman watch two dogs. [SEP] A man in a maroon bathing suit swings on a rope on a lake. Word-Level Interpretations. For word-level conceptual interpretations, we can observe that CLDA interprets PLM's prediction on Document (c) of Fig. 5 using words such as 'cat', 'floor', and 'garlic' that are related to household and daily life. Also, CLDA interprets PLM's prediction on Document (e) of Fig. 5 using words such as 'soldier' and 'border' that are related to national security. Similarly, for QQP's Document (d) (Fig. 6 ), CLDA correctly interprets the model prediction 'True' by identifying keywords such as 'sabotage' and 'oppose' with similar meanings in the topic of strategies. For QQP's Document (g), (Fig. 6 ), CLDA interprets the words in the both sentences with the same semantics, such as 'conservative' that is related to prevention and conservatives (note that in politics, 'conservative' refers to parties that tend to prevent/block new policies or legislation), and thereby predicting the correct label 'True'.Example Concepts. Following Blei et al. (2003) , we show the learned concepts on the RTE dataset in Table 3 , which is complementary to aforementioned explanations. We select several different topics from Fig. 4 of the main paper. As in Sec. 4.2 of the main paper, we obtain top words from each concept via first calculating the average of the each word's corresponding contextual embeddings over the dataset, and then getting the nearest words to each topic center (µ i ) in the embedding space. As we can see in Table 3 , CLDA can capture various concepts with profound and accurate semantics. Therefore, although PLM embeddings are contextual and continuous, our CLDA can still find conceptual patterns of words on the dataset-level.

F THEORETICAL ANALYSIS ON CONTINUOUS WORD COUNTS

Before going to the claims and proofs, first we specify some basic problem settings and assumptions. Suppose there are K + 1 topic groups, each of which is regarded to be sampled from a parameterized multivariate Gaussian distribution. In specific, the K + 1 'th distribution of topic has a much larger covariance, and in the same time, closed to the center of embedding space. The prementioned properties can be measured by a series of inequalities:which contradicts the definition of x in Eqn. 35 (i.e., x maximizes f (x; λ 1 , λ 2 )), completing the proof.Lemma F.5. Suppose we have two series of functions {f 1,i (x)} and {f 2,i (x)}, with two series of non-negative weighting parametersWe define the final objective function f (•) as:We assume two pairs of parameters (λ 1 , λ 2 ) andDefining the optimal values of the objective function for different weighting parameters asx ′ = arg maxUnder the following Assumptions (with 1 and 0 denoting vectors with all entries equal to 1 and 0, respectively):we have thatProof. We start with proving the following equality by contradiction:Specifically, ifleveraging the Assumption 1 and 2 above, we have thatwhich contradicts Eqn. 51. Therefore, Eqn. 55 holds.Combining Eqn. 55 and Assumption 2 above, we have thatconcluding the proof.Based on the definitions and lemmas above, we have the following theorems:Theorem F.6 (Advantage of Θ A in the Simplified Case). With Definition F.1 and Definition F.3, comparing Θ I , Θ G , and Θ A by evaluating them on the marginal log-likelihood of non-stop-words, i.e., L(•, w → GT), we have thatNext, we set ∪ m {w mj } Nm j=1 to λ 1 and ∪ m {w mj } Nm+Nm,s j=Nm+1to λ 2 , respectively; we rewriteCombining Eqn. 67 and Eqn. 68 concludes the proof.Theorem F.8 shows that under mild assumptions, the attention-based weights can help produce better estimates of Θ in the presence of noisy stop-words and therefore learns higher-quality topics from the corpus, improving both generalization performance and interpretability of PLMs. Theorem F.9 (Advantage of Θ A in the General Case). With Definition F.2 and Definition F.3, comparing Θ I , Θ G , and Θ A by evaluating them on the marginal log-likelihood of non-stop-words, i.e., L CLDA (•, w → GT), we have thatProof. First, by definition one can easily find that Θ G achieves the largest L(•; w → GT) among the three: 2 (e mj -µ i ) T Σ -1 i (e mj -µ i ) -log[(2π) d/2 |Σ i | 1/2 ]} as f 1,j (x) for j ∈ ∪ m {1, 2, . . . , N m } and f 2,j (x) for j ∈ ∪ m {N m + 1, N m + 1, . . . , N m + N m,s }, where x corresponds to Θ ≜ (ϕ, γ, {µ k , Σ k } K k=1 ). By Lemma F.5, we then have that L CLDA (Θ A ; w → GT) ≤ L CLDA (Θ G ; w → GT).(71)Note that because f 1,j (•) and f 2,j (•) are very close to Gaussian, therefore Assumption 1 and 2 in Lemma F.5 hold naturally under mild regularity conditions.Combining Eqn. 70 and Eqn. 71 concludes the proof.

