DATA-EFFICIENT FINETUNING USING CROSS-TASK NEAREST NEIGHBORS

Abstract

Language models trained on massive prompted multitask datasets like T0 (Sanh et al., 2021) or FLAN (Wei et al., 2021a) can generalize to tasks unseen during training. We show that training on a carefully chosen subset of instances can outperform training on all available data on a variety of datasets. We assume access to a small number (250-1000) of unlabeled target task instances, select their nearest neighbors from a pool of multitask data, and use the retrieved data to train target task specific models. Our method is more data-efficient than training a single multitask model, while still outperforming it by large margins. We evaluate across a diverse set of tasks not in the multitask pool we retrieve from, including those used to evaluate T0 and additional complex tasks including legal and scientific document QA. We retrieve small subsets of P3 (the collection of prompted datasets from which T0's training data was sampled) and finetune T5 models that outperform the 3-billion parameter variant of T0 (T0-3B) by 3-30% on 12 out of 14 evaluation datasets while using at most 2% of the data used to train T0-3B. These models also provide a better initialization than T0-3B for few-shot finetuning on target-task data, as shown by a 2-23% relative improvement over few-shot finetuned T0-3B models on 8 datasets.

1. INTRODUCTION

Finetuning large models with data from a diverse set of tasks, augmented to include brief descriptions of the tasks (i.e., prompts) has been shown to help models generalize to unseen tasks (Wei et al., 2021a; Sanh et al., 2021) . This cross-task generalization capability is particularly helpful in cases where it is expensive to collect labeled target task training sets. Prior work trained single models with as much prompted data as possible -for example, Sanh et al. (2021) train a model on roughly 11 million instances (counting different prompt variations). The training datasets were selected without using any information about the target tasks with the goal of allowing models to generalize to new tasks from instructions alone, making the evaluation "zero-shot". However, it is unclear if all the training data is required for doing well on any given target task. Furthermore, given that neural network models have previously been shown to suffer from negative interference (where in training on more datasets results in worse performance on certain downstream tasks) in multitask setups (Aribandi et al., 2022) and benefit from pretraining on domain-relevant data (Gururangan et al., 2020; Phang et al., 2018) , it is possible that training only on relevant prompted data could further improve task generalization. Based on this hypothesis, we seek to find small subsets of relevant training data in the massive pool of multitask data that cause the models to generalize better to a given target task than the rest of the pool. Manually finding relevant training data in a massive pool of data is infeasible since it is not obvious which of the source tasks are relevant for a given target task, and which instances are most relevant for target task generalization within a source task dataset (see Section 5.1). Hence we rely on a simple method to automatically select these subsets. Additionally, as only some samples within a given dataset may be relevant to a target task, we select per-instance rather than per-dataset, unlike prior work, which tries to identify useful datasets for transfer learning (Aribandi et al., 2022; Phang et al., 2018) and train on all data within the chosen datasets. We use a setup similar to contemporary work examining retrieval-augmented cross-task generalization (Lin et al., 2022) : we assume access to a small number of unlabeled target task instances and use these to retrieve cross-task nearest neighbors -labeled instances from the massive pool of data most similar to our unlabeled target Given some unlabeled target-task instances, we find the most similar instances in a large pool of multitask data. We train a model on these instances. If we have access to labeled data, we optionally few-shot finetune the DEFT model. task instances. The similarity is computed as the distance between the representations produced by encoder of a pretrained seq2seq model. Unlike prior work, we then finetune target task specific models on these neighbors alone, without using any target task specific labeled data or any extra data from the pool of multitask data. We hope that the similarity between the cross-task neighbors and our target task data will enable greater generalization to our target task, with dissimilar examples that may cause interference removed from the training mixture. We also aim to produce models that perform at least as well as models trained on the entire multitask pool despite being trained on a fraction of data, greatly reducing the cost of training. We run experiments with T5 (Raffel et al., 2020) models, and use Public Pool of Prompts (P3) (Sanh et al., 2021) as the pool of prompted multitask data from which to retrieve cross-task nearest neighbors. We evaluate on the 11 datasets originally used to evaluate T0 (a collection of natural language understanding and commonsense tasks), as well as 3 additional datasets with varied domains (e.g., legal, NLP domains). Our findings are as follows: • For all the target tasks, we find that their cross-task nearest neighbors in P3 are much more relevant as training data than the rest of the pool-training T5 models, sometimes even variants smaller than T0-3B, on these subsets yields models with performance 3-30% better than T0-3B evaluated zero-shot across 12 out of 14 datasets. • For some target tasks on which T0-3B performs close to random chance, T5 models of the same size trained using cross-task nearest neighbors perform significantly above chance, confirming our hypothesis that massive multitask prompted training could lead to negative interference between tasks. • When target task labeled data is available for few-shot finetuning, we find that T5 models trained with cross-task nearest neighbors provide better initialization for parameter-efficient finetuning methods than T0-3B, performing 2-23% better than T0-3B with few-shot finetuning across 10 out of 11 datasets. • An analysis of what relevant data gets retrieved shows that most of the tasks in the massive pool of multitask data are not retrieved for any target tasks, confirming our hypothesis that only a small subset of data within the pool is relevant to any given target task. These findings suggest that training these models on all the available multitask prompted data results in negative interference, even in relatively large (3 billion parameter) models. Furthermore,



Figure1: Overview of the DEFT method. Given some unlabeled target-task instances, we find the most similar instances in a large pool of multitask data. We train a model on these instances. If we have access to labeled data, we optionally few-shot finetune the DEFT model.

