OVERTHINKING THE TRUTH: UNDERSTANDING HOW LANGUAGE MODELS PROCESS FALSE DEMONSTRATIONS

Abstract

Through few-shot learning or chain-of-thought prompting, modern language models can detect and imitate complex patterns in their prompt. This behavior allows language models to complete challenging tasks without fine-tuning, but can be at odds with completion quality: if the context is inaccurate or harmful, then the model may reproduce these defects in its completions. In this work, we show that this harmful context-following appears late in a model's computation-in particular, given an inaccurate context, models perform better after zeroing out later layers. More concretely, at early layers models have similar performance given either accurate and inaccurate few-shot prompts, but a gap appears at later layers (e.g. layers 13-14 for GPT-J). This gap appears at a consistent depth across datasets, and coincides with the appearance of "induction heads" that attend to previous answers in the prompt. We restore the performance for inaccurate contexts by ablating a small subset of these heads, reducing the gap by 23.2% on average across 14 datasets. Our results suggest that studying early stages of computation could be a promising strategy to prevent misleading outputs, and that understanding and editing internal mechanisms can help correct unwanted model behavior.

1. INTRODUCTION

A key behavior of modern language models is context-following: neural networks like GPT-3 are able to infer and imitate the patterns in their prompt. At its best, this allows language models to perform well on benchmarks without the need for fine-tuning (Brown et al., 2020; Rae et al., 2021; Hoffmann et al., 2022; Chowdhery et al., 2022; Srivastava et al., 2022) . This has led researchers to study how properties of the context affect few-shot performance (Min et al., 2022b; Kim et al., 2022; Xie et al., 2021; Zhao et al., 2021) , and what internal mechanisms underlie context-following (Olsson et al., 2022) . However, context-following can also lead to incorrect, toxic or unsafe model outputs (Rong, 2021) . For example, if an inexperienced programmer prompts Codex (Chen et al., 2021) with poorly written or vulnerable code, the model is likely to produce poorly written or vulnerable code completions. Similarly, in this work we study few-shot learning for classification tasks: prompting the model with inaccurate demonstrations reduces model accuracy (Figure 1 , left), because the model learns to reproduce the false demonstrations. We thus ask: Can we attribute this "false context-following" behavior to specific model components, and can we mitigate it by intervening on these components? We show that, perhaps surprisingly, false context-following in text classification is primarily a property of late stages of computation. In particular, stopping the model early-by zeroing out the later layers (Nostalgebraist, 2020)-actually improves performance (Figure 1 , center). Moreover, true and false contexts yield similar accuracy until some "critical layer" at which they sharply diverge. This demonstrates that even with false demonstrations, the model often "knows" the correct answer (it can be easily decoded from the latent states) but later replaces it with an incorrect answer that is more likely given the context. To identify the underlying mechanism for false context-following, we turn to Olsson et al. (2022) , who identify "induction heads" that attend to and reproduce previous patterns in the input. Motivated by this, we searched for heads that consistently attend to previous examples that have the same (true) answer as the current prompt. We found many such heads, primarily concentrated in later layers of the model (after the critical layer). By removing 10 of these heads, we are able to reduce the accuracy gap between accurate and inaccurate prompts by an average of 23.2% over 14 datasets, with negligible effects on the performance given true prefixes (Figure 1 , right). Figure 1 : Left: Given a prompt of inaccurate demonstrations, language models are more likely to output incorrect labels. Center: When demonstrations are incorrect, zeroing out the later layers increases the classification accuracy, here on SST-2. Right: We identify 10 attention heads and remove them from the model: this reduces the effect of incorrect demonstrations by 36.7% on SST-2, averaged over 15 prompt formats, without decreasing the accuracy given correct demonstrations. Our findings show how analyzing and editing model internals can help practictioners understand and mitigate model failures. Indeed, one intuition for why early-exiting succeeds is that the attention heads we identified cannot in general occur at the earliest layers. This is because these heads must recognize which inputs belong to the same class, which likely requires multiple layers of processing. Thus, early exiting might be a generally promising strategy to detect dishonest behavior in models.

2. PRELIMINARIES: FEW-SHOT LEARNING WITH FALSE DEMONSTRATIONS

We begin by introducing the setting we study: few-shot learning for classification, given demonstrations with correct or incorrect labels. Incorrect demonstrations consistently reduce classification performance, which is the phenomenon that we aim to study and mitigate in this work. Few-shot learning. We consider autoregressive transformer language models, which produce a conditional probability distribution p(t n+1 | t 1 , ..., t n ) over the next token t n+1 given previous tokens. We focus on the few-shot learning setting (Brown et al., 2020) for classification tasks: we sample k demonstrations (input-label pairs) from the task dataset, denoted (x 1 , y 1 ), ..., (x k , y k ). To query the model on a new input x, we use the predictive distribution p(y | x 1 , y 1 , ..., x k , y k , x). 2 , 3). For SST-2 we use the first of the 15 prompt formats in Zhao et al. (Table 5 ). We evaluated 3 autoregressive language models: GPT-J (Wang & Komatsuzaki, 2021), GPT2-XL (Radford et al., 2019), and GPT-NeoX-20B (Black et al., 2022) . Evaluation metrics. Given our focus on classification tasks, we are interested in how often the model assigns higher probability to the true label than to all other labels. However, model predictions can be very unstable with respect to small prompt perturbations (Gao et al., 2021) . To mitigate this variability, we measure the calibrated classification accuracy (Zhao et al., 2021) . Concretely, for a 2-class classification task, we measure how often the correct label has a higher probability than its median probability over the dataset. Assuming the dataset is balanced (which is true for us), this step has been shown to improve performance and reduce variability across prompts. Calibration for multi-class tasks follows a similar procedure, detailed in appendix A.1.



Datasets and models. We consider fourteen text classification datasets: SST-2 (Socher et al., 2013), Poem Sentiment (Sheng & Uthus, 2020), Financial Phrasebank (Malo et al., 2014), Ethos (Mollas et al., 2020), TweetEval-Hate (Barbieri et al., 2020), TweetEval-Atheism (Barbieri et al., 2020), TweetEval-Feminist (Barbieri et al., 2020), Medical Questions Pairs (McCreery et al., 2020), MRPC (Wang et al., 2019), SICK (Marelli et al., 2014), RTE (Wang et al., 2019), AGNews (Zhang et al., 2015), TREC (Voorhees & Tice, 2000), and DBpedia (Zhang et al., 2015). We used the same prompt formats as in Min et al. (2022b) and Zhao et al. (2021) (Table

