ROBUST ACTIVE DISTILLATION

Abstract

Distilling knowledge from a large teacher model to a lightweight one is a widely successful approach for generating compact, powerful models in the semi-supervised learning setting where a limited amount of labeled data is available. In large-scale applications, however, the teacher tends to provide a large number of incorrect soft-labels that impairs student performance. The sheer size of the teacher additionally constrains the number of soft-labels that can be queried due to prohibitive computational and/or financial costs. The difficulty in achieving simultaneous efficiency (i.e., minimizing soft-label queries) and robustness (i.e., avoiding student inaccuracies due to incorrect labels) hurts the widespread application of knowledge distillation to many modern tasks. In this paper, we present a parameter-free approach with provable guarantees to query the soft-labels of points that are simultaneously informative and correctly labeled by the teacher. At the core of our work lies a game-theoretic formulation that explicitly considers the inherent trade-off between the informativeness and correctness of input instances. We establish bounds on the expected performance of our approach that hold even in worst-case distillation instances. We present empirical evaluations on popular benchmarks that demonstrate the improved distillation performance enabled by our work relative to that of state-of-the-art active learning and active distillation methods.

1. INTRODUCTION

Deep neural network models have been unprecedentedly successful in many high-impact application areas such as Natural Language Processing (Ramesh et al., 2021; Brown et al., 2020) and Computer Vision (Ramesh et al., 2021; Niemeyer & Geiger, 2021) . However, this has come at the cost of using increasingly large labeled data sets and high-capacity network models that tend to contain billions of parameters (Devlin et al., 2018) . These models are often prohibitively costly to use for inference and require millions of dollars in compute to train (Patterson et al., 2021) . Their sheer size also precludes their use in time-critical applications where fast decisions have to be made, e.g., autonomous driving, and deployment to resource-constrained platforms, e.g., mobile phones and small embedded systems (Baykal et al., 2022) . To alleviate these issues, a vast amount of recent work in machine learning has focused on methods to generate compact, powerful network models without the need for massive labeled data sets. Knowledge Distillation (KD) (Buciluǎ et al., 2006; Hinton et al., 2015; Gou et al., 2021; Beyer et al., 2021) is a general purpose approach that has shown promise in generating lightweight powerful models even when a limited amount of labeled data is available (Chen et al., 2020) . The key idea is to use a large teacher model trained on labeled examples to train a compact student model so that its predictions imitate those of the teacher. The premise is that even a small student is capable enough to represent complicated solutions, even though it may lack the inductive biases to appropriately learn representations from limited data on its own (Stanton et al., 2021; Menon et al., 2020) . In practice, KD often leads to significantly more predictive models than otherwise possible with training in isolation (Chen et al., 2020; Xie et al., 2020; Gou et al., 2021; Cho & Hariharan, 2019) . Knowledge Distillation has recently been used to obtain state-of-the-art results in the semi-supervised setting where a small number of labeled and a large number of unlabeled examples are available (Chen et al., 2020; Pham et al., 2021; Xie et al., 2020) . Semi-supervised KD entails training a teacher model on the labeled set and using its soft labels on the unlabeled data to train the student. The teacher is often a pre-trained model and may also be a generic large model such as GPT-3 (Brown et al., 2020) or PaLM (Chowdhery et al., 2022) . The premise is that a large teacher model can more aptly extract knowledge and learn from a labeled data set, which can subsequently be distilled into a small student. Despite its widespread success, KD generally suffers from various degrees of confirmation bias and inefficiency in modern applications to semi-supervised learning. Confirmation bias (Pham et al., 2021; Liu & Tan, 2021; Arazo et al., 2020; Beyer et al., 2021) is the phenomenon where the student exhibits poor performance due to training on noisy or inaccurate teacher soft-labels. Here, inaccuracy refers to the inconsistency between the teacher's predictions for the unlabeled inputs and their groundtruth labels. Feeding the student inaccurate soft-labels leads to increased confidence in incorrect predictions, which consequently produces a model that tends to resist new changes and perform poorly overall (Liu & Tan, 2021; Arazo et al., 2020) . At the same time, large-scale applications often require the teacher's predictions for billions of unlabeled points. For instance, consider distilling knowledge from GPT-3 to train a powerful student model. As of this writing, OpenAI charges 6c per 1k token predictions (OpenAI, 2022) . Assuming just 1M examples to label and an average of 100 tokens per example leads to a total cost of $6M. Hence, it is highly desirable to acquire the most helpful -i.e., informative and correct -soft-labels subject to a labeling budget (GPT-3 API calls) to obtain the most powerful student model for the target application. Thus, it has become increasingly important to develop KD methods that are both query-efficient and robust to labeling inaccuracies. Prior work in this realm is limited to tackling either distillation efficiency (Liang et al., 2022; Xu et al., 2020) , by combining mix-up (Zhang et al., 2017) and uncertainty-based sampling (Roth & Small, 2006) , or robustness (Pham et al., 2021; Liu & Tan, 2021; Arazo et al., 2020; Zheng et al., 2021; Zhang et al., 2020) , through clever training and weighting strategies, but not both of these objectives at the same time. In this paper, we present a simple-toimplement method that finds a sweet spot and improves over standard techniques. Relatedly, there has been prior work in learning under label noise (see Song et al. (2022) for a survey), however, these works generally assume that the noisy labels are available (i.e., no active learning component) or impose assumptions on the type of label noise (Younesian et al., 2021) . In contrast, we assume that the label noise can be fully adversarial and that we do not have full access to even the noisy labels. To the best of our knowledge, this work is the first to consider the problem of importance sampling for simultaneous efficiency and robustness in knowledge distillation. To bridge this research gap, we present an efficient algorithm with provable guarantees to identify unlabeled points with soft-labels that tend to be simultaneously informative and accurate. Our approach is parameter-free, imposes no assumptions on the problem setting, and can be widely applied to any network architecture and data set. At its core lies the formulation of an optimization problem that simultaneously captures the objectives of efficiency and robustness in an appropriate way. In particular, this paper contributes: 1. A mathematical problem formulation that captures the joint objective of training on informative soft-labels that are accurately labeled by the teacher in a query-efficient way 2. A near linear time, parameter-free algorithm to optimally solve it 3. Empirical results on benchmark data sets and architectures with varying configurations that demonstrate the improved effectiveness of our approach relative to the state-of-the-art 4. Extensive empirical evaluations that support the widespread applicability and robustness of our approach to varying scenarios and practitioner-imposed constraints.

2. PROBLEM STATEMENT

We consider the semi-supervised classification setting where we are given a small labeled set X L -typically tens or hundreds of thousands of examples -together with a large unlabeled set X U , typically on the order of millions or billions. The goal is to leverage both the labeled and unlabeled sets to efficiently and reliably train a compact, powerful model θ student . To do so, we use knowledge distillation (Xie et al., 2020; Liang et al., 2020) where the labeled points are used to train a larger, (often pre-trained) teacher model that can then be used to educate a small model (the student). We emphasize that the teacher may be a pre-trained model, however, it is not trained on the unlabeled set X U . The distillation process entails using the soft-labels of the teacher for the unlabeled points. The student is then trained on these soft-labeled points along with the original labeled data set. The key insight is that the large, pre-trained teacher model can more aptly learn representations from the limited data, which can then be imitated by the student.

