SAAL: SHARPNESS-AWARE ACTIVE LEARNING

Abstract

While modern deep neural networks play significant roles in many research areas, they are also prone to overfitting problems under limited data instances. Particularly, this overfitting, or generalization issue, could be a problem in the framework of active learning because it selects a few data instances for learning over time. To consider the generalization, this paper introduces the first active learning method to incorporate the sharpness of loss space in the design of the acquisition function, inspired by sharpness-aware minimization (SAM). SAM intends to maximally perturb the training dataset, so the optimization can be led to a flat minima, which is known to have better generalization ability. Specifically, our active learning, Sharpness-Aware Active Learning (SAAL), constructs its acquisition function by selecting unlabeled instances whose perturbed loss becomes maximum. Over the adaptation of SAM into SAAL, we design a pseudo labeling mechanism to look forward to the perturbed loss w.r.t. the ground-truth label. Furthermore, we present a theoretic analysis between SAAL and recent active learning methods, so the recent works could be reduced to SAAL under a specific condition. We conduct experiments on various benchmark datasets for vision-based tasks in image classification and object detection. The experimental results confirm that SAAL outperforms the baselines by selecting instances that have the potentially maximal perturbation on the loss.

1. INTRODUCTION

Recently, deep learning is widely utilized in many research areas, such as computer vision, natural language processing, recommender systems, etc., but its success deeply depends on the large-scale labeled dataset for training the deep neural networks. The importance of the dataset is related to the generalization issue in deep learning, which refers that the model learned with the training dataset suffers from the degradation of performance when the unseen test dataset is encountered for deployment. This degradation results from the neural networks that are prone to overfitting under the lack of the training dataset (Keskar et al., 2016; Neyshabur et al., 2017; Kawaguchi et al., 2017) . The dependency on the dataset also invokes an adaptive data selection by acquisition functions, or active learning, which aims at the efficient use of the limited budget for annotations from oracle (Cohn et al., 1996; Tong, 2001; Settles, 2009) . Recently, various methods for active learning have been proposed; but the model trained with a small number of data from the adaptive selection is often difficult to be generalized (Dasgupta & Hsu, 2008) . Although there exist some prior works that deal with the generalization issue in active learning; those methods solve the problem by either proposing a new risk function (Farquhar et al., 2020) or adopting a new classifier network (Wan et al., 2021) , rather than by inventing a new acquisition function that considers the generalization. In this paper, we propose a new active learning algorithm, named Sharpness-Aware Active Learning (SAAL), that connects active learning and generalization ability to construct the acquisition function. Specifically, we are inspired by Sharpness-Aware Minimization, or SAM (Foret et al., 2020) , which minimizes the maximally perturbed loss of training dataset, leading to minimizing the loss sharpness as well as the task loss, itself. Such optimization leads to a flat minima of the loss landscape, which is shown to have a strong correlation with the generalization performance (Jiang et al., 2019) . Hence, SAAL adopts the maximally perturbed loss as the acquisition score. When calculating the acquisition score for SAAL, we cannot observe the labels for the unlabeled instances, so it is infeasible to compute the perturbed loss. To overcome this challenge, we utilize pseudo labels predicted by the current model, and we theoretically show that our proposed pseudo labeling conservatively estimates the maximally perturbed loss w.r.t. ground-truth label. Also, we theoretically derive the upper bound of the acquisition score of SAAL, which includes the loss, the norm of gradients, and the first eigenvalue of loss Hessian. Among the three terms of the upper bound, the loss and gradient terms are widely used metrics for active learning, which captures the model change by acquiring the instance (Yoo & Kweon, 2019; Ash et al., 2020; Settles et al., 2007) . Meanwhile, the first eigenvalue, which is newly considered by SAAL, is connected to the loss sharpness (Keskar et al., 2017) . Therefore, the selected instances by SAAL might contribute to the generalization of the model. We summarize our contributions in three points. First, we propose Sharpness-Aware Active Learning (SAAL), which considers the loss sharpness for constructing the acquisition function. The loss sharpness is related to the generalization of model, so selecting instances with a high value of loss sharpness might lead to a model with a better generalization performance. Second, we theoretically derive the upper bound of the acquisition score of SAAL and show the connection with the recent active learning methods. Specifically, we find that the upper bound also contains the first eigenvalue of loss Hessian, which is related to the generalization ability. Third, we empirically show that SAAL outperforms the baselines in various vision-based tasks on the benchmark dataset.

2. PRELIMINARIES

2.1 NOTATIONS Throughout this paper, we assume a classification problem and we represent our current deep learning model parameterized by θ as f θ : R d → R |Y | ; where d is the dimension of data instance, x, and Y is the set of candidate classes that x can have. There are two datasets: a dataset with labels, X L , and the other unlabeled dataset, X U . We denote the acquisition function of active learning as f acq : R d → R, where f acq receives a data instance as input, and calculates the informativeness, or the acquisition score, of the instance as output. The loss of a data instance, x, w.r.t. the given label y is represented as l(x, y; θ) := l CE (σ(f θ (x)), y), where σ(•) is a softmax function. The total loss of a dataset, S, is represented as L S (θ) = 1 N N i=1 l(x i , y i ; θ), where S = {(x i , y i )|i = 1, ..., N }. Lastly, we define the pseudo label, ŷ = argmax j∈Y σ(f θ (x)) j ; and we denote the ground-truth label as ȳ.

2.2. ACTIVE LEARNING

There are several active learning scenarios that differ by the setting of data accessibility; which include membership-query synthesis (Angluin, 1988; 2004) , stream-based active learning (Atlas et al., 1989; Cohn et al., 1994) , and pool-based active learning (Lewis & Gale, 1994) . In this paper, we focus on pool-based active learning, where the unlabeled data instances are provided as a large set of data pool, and the active learner sequentially selects the informative instances by a certain criterion. Pool-based active learning is categorized by the definition of informativeness, which includes the uncertainty, diversity, and hybrid-based methods. Uncertainty-based active learning adopts the acquisition function, f acq , to calculate the uncertainty of each unlabeled instance with regard to the current deep learning model, and an oracle provides the ground-truth label of the selected unlabeled instances with the highest uncertainty. Since the acquisition score is usually calculated for an unlabeled instance, x u ∈ X U , w.r.t. the current model, f θ , it is expanded as f acq (x u ; f θ ), resulting in the selection rule as the below. X S = argmax X ′ S ⊂X U xu∈X ′ S f acq (x u ; f θ ) (1) Entropy, which is denoted as f Ent acq (x u ; f θ ) = H[f θ (x u )] =j σ(f θ (x u )) j log 2 σ(f θ (x u )) j , or variation ratio, which is denoted as f V ar acq = 1-max j σ(f θ (x u )) j , are the most widely used methods for calculating uncertainty (Shannon, 1948; Freeman, 1965) . These days, additional networks are used to approximate the uncertainty of each instance. Learning Loss for Active Learning (LL4AL) (Yoo & Kweon, 2019) trains the loss prediction module, f LP M , which takes the hidden feature maps

