PMI-MASKING: PRINCIPLED MASKING OF CORRELATED SPANS

Abstract

Masking tokens uniformly at random constitutes a common flaw in the pretraining of Masked Language Models (MLMs) such as BERT. We show that such uniform masking allows an MLM to minimize its training objective by latching onto shallow local signals, leading to pretraining inefficiency and suboptimal downstream performance. To address this flaw, we propose PMI-Masking, a principled masking strategy based on the concept of Pointwise Mutual Information (PMI), which jointly masks a token n-gram if it exhibits high collocation over the corpus. PMI-Masking motivates, unifies, and improves upon prior more heuristic approaches that attempt to address the drawback of random uniform token masking, such as whole-word masking, entity/phrase masking, and random-span masking. Specifically, we show experimentally that PMI-Masking reaches the performance of prior masking approaches in half the training time, and consistently improves performance at the end of training.

1. INTRODUCTION

In the couple of years since BERT was introduced in a seminal paper by Devlin et al. (2019a) , Masked Language Models (MLMs) have rapidly advanced the NLP frontier (Sun et al., 2019; Liu et al., 2019; Joshi et al., 2020; Raffel et al., 2019) . At the heart of the MLM approach is the task of predicting a masked subset of the text given the remaining, unmasked text. The text itself is broken up into tokens, each token consisting of a word or part of a word; thus "chair" constitutes a single token, but out-of-vocabulary words like "e-igen-val-ue" are broken up into several sub-word tokens. In BERT, 15% of tokens are chosen to be masked uniformly at random. It is the random choice of single tokens that we address in this paper: we show that this approach is suboptimal and offer a principled alternative. To see why Random-Token Masking is suboptimal, consider the special case of sub-word tokens. Given the masked sentence "To approximate the matrix, we use the eigenvector corresponding to its largest e-[mask]-val-ue", an MLM will quickly learn to predict "igen" based only on the context "e-[mask]-val-ue", rendering the rest of the sentence redundant. The question is whether the network will also learn to relate the broader context to the tokens comprising "eigenvalue". When they are masked together, the network is forced to do so, but such masking occurs with vanishingly small probability. One might hypothesize that the network would nonetheless be able to piece such meaning together from local cues; however, we show that it often struggles to do so. We establish this via a controlled experiment, in which we reduced the size of the vocabulary, thereby breaking more words into sub-word tokens. We compared the extent to which such vocabulary reduction degraded regular BERT relative to so-called Whole-Word Masking BERT (WW-BERT) (Devlin et al., 2019b) , a version of BERT that jointly masks all sub-word tokens comprising an out-of-vocabulary word during training. We show that vanilla BERT's performance degrades much more rapidly than that of WWBERT as the vocabulary size shrinks. The intuitive explanation is that Random-Token Masking is wasteful; it overtrains on easy sub-word tasks (such as predicting "igen") and undertrains on harder whole-word tasks (predicting "eigenvalue"). The advantage of Whole-Word Masking over Random-Token Masking is relatively modest for standard vocabularies, because out-of-vocabulary words are rare. However, the tokenization of words is a very special case of a much broader statistical linguistic phenomenon of collocation: the cooccurrence of series of tokens at levels much greater than would be predicted simply by their individual frequencies in the corpus. There are millions of collocated word n-grams -multi-word expressions, phrases, and other common word combinations -whereas there are only tens of thousands of words in frequent use. So it is reasonable to hypothesize that Random-Token Masking generates many wastefully easy problems and too few usefully harder problems because of multiword collocations, and that this affects performance even more than the rarer case of tokenized words; we show that this indeed is the case. Several prior works have considered the idea of masking across spans longer than a single word. Sun et al. ( 2019) and Guu et al. ( 2020) proposed Knowledge Masking and Salient Span Masking, respectively, in which tokens comprising entities or phrases, as identified by external parsers, are jointly masked. While extending the scope of Whole-Word Masking, the restriction to specific types of correlated n-grams, along with the reliance on imperfect tools for their identification, has limited the gains achievable by these approaches. With a similar motivation in mind, SpanBERT of Joshi et al. ( 2020) introduced Random-Span Masking, which masks word spans of lengths sampled from a geometric distribution at random positions in the text. Random-Span Masking was shown to consistently outperform Knowledge Masking, is simple to implement, and inspired prominent MLMs (Raffel et al., 2019) . However, while Random-Span Masking increases the chances of masking collocations, with high probability the selected spans break up correlated n-grams, such that the prediction task can often be performed by relying on local cues. In this paper we offer a principled approach to masking spans that consistently provide high signal, unifying the intuitions behind the above approaches while also outperforming them. Our approach, dubbed PMI-Masking, uses Pointwise Mutual Information (PMI) to identify collocations, which we then mask jointly. At a high level, PMI-Masking consists of two stages. First, given any pretraining corpus, we identify a set of contiguous n-grams that exhibit high cooccurrence probability relative to the individual occurrence probabilities of their components. We formalize this notion by proposing an extended definition of Pointwise Mutual Information from bigrams to longer n-grams. Second, we treat these collocated n-grams as single units; the masking strategy selects at random both from these units and from standard tokens that do not participate in such units. Figure 1 , detailed and reinforced by further experiments in section 5, shows that (1) PMI-Masking dramatically accelerates training, matching the end-of-pretraining performance of existing approaches in roughly half of the training time; and (2) PMI-Masking improves upon previous masking approaches at the end of pretraining.

2. MOTIVATION: MLMS ARE SENSITIVE TO TOKENIZATION

In this section we describe a simple experiment that motivates our PMI-Masking approach. We examined BERT's ability to learn effective representations for words consisting of multiple subword tokens, treating this setting as an easily controlled analogue for the multi-word collocation problem that truly interests us. Our experiment sought to assess the performance gain obtained from always masking whole words as opposed to masking each individual token uniformly at random. We



Figure 1: SQuAD2.0 development set F1 scores of BERT BASE models trained with different masking schemes, evaluated every 200K steps during pretraining.

