DECOUPLED TRAINING FOR LONG-TAILED CLASSIFI-CATION WITH STOCHASTIC REPRESENTATIONS

Abstract

Decoupling representation learning and classifier learning has been shown to be effective in classification with long-tailed data. There are two main ingredients in constructing a decoupled learning scheme; 1) how to train the feature extractor for representation learning so that it provides generalizable representations and 2) how to re-train the classifier that constructs proper decision boundaries by handling class imbalances in long-tailed data. In this work, we first apply Stochastic Weight Averaging (SWA), an optimization technique for improving the generalization of deep neural networks, to obtain better generalizing feature extractors for long-tailed classification. We then propose a novel classifier re-training algorithm based on stochastic representation obtained from the SWA-Gaussian, a Gaussian perturbed SWA, and a self-distillation strategy that can harness the diverse stochastic representations based on uncertainty estimates to build more robust classifiers. Extensive experiments on CIFAR10/100-LT, ImageNet-LT, and iNaturalist-2018 benchmarks show that our proposed method improves upon previous methods both in terms of prediction accuracy and uncertainty estimation.

1. INTRODUCTION

While deep neural networks have achieved remarkable performance on various computer vision benchmarks (e.g., image classification (Russakovsky et al., 2015) and object detection (Lin et al., 2014) ), there still are many challenges when it comes to applying them for real-world applications. One of such challenges is that the real-world classification data are long-tailed -the distribution of class frequencies exhibits a long tail, and many of the classes have only a few observations belonging to them. As a consequence, the class distribution of such data is extremely imbalanced, degrading the performance of a standard classification model trained with the balanced class assumption due to a paucity of samples from tail classes (Van Horn et al., 2018; Liu et al., 2019) . Thus, it is worth exploring a novel technique dealing with long-tailed data for real-world deployments. While several works have diagnosed the performance bottleneck of long-tailed recognition as distinct from balanced one (e.g., improper decision boundaries over the representation space (Kang et al., 2020) , low-quality representations from the feature extractor (Samuel and Chechik, 2021)), the shared design principle of them is giving tail classes a chance to compete with head classes. Decoupling (Kang et al., 2020) is one of the learning strategies proven to be effective for long-tailed data, where the representation learning via the feature extractor and classifier learning via the last classification layer are decoupled. Even for a classification network failing on long-tailed data, the representations obtained from the penultimate layer can be flexible and generalizable, provided that the feature extractor part is expressive enough (Donahue et al., 2014; Zeiler and Fergus, 2014; Girshick et al., 2014) . The main motivation behind the decoupling is that the performance bottleneck of the long-tailed classification is due to the improper decision boundaries set over the representation space. Based on this, Kang et al. ( 2020) has shown that a simple re-training of the last layer parameters could significantly improve the performance. )), to the best of our knowledge, it has never been explored for long-tailed classification problems. In Section 3, we empirically show that a naïve application of SWA for long-tailed classification would fail due to a similar bottleneck issue, but when combined with decoupling, SWA significantly improves the classification performance due to its property to obtain generalizable representations. Confirming that SWA can benefit long-tailed classification, we take a step further and propose a novel classifier re-training strategy. For this, we first obtain stochastic representations, the output of penultimate layers computed with multiple feature extractor parameters drawn from an approximate posterior, where we construct the approximate parameter with SWA-Gaussian (SWAG; Maddox et al., 2019) ; SWAG is a Bayesian extension of SWA, adding Gaussian noise to the parameter obtained from SWA to approximate posterior parameter uncertainty. In Section 3.2, we first show that the diverse stochastic representations obtained from SWAG samples well reflect the uncertainty of inputs. Hinging on this observation, we propose a novel self-distillation algorithm where the stochastic representations are used to construct an ensemble of virtual teachers, and the classifier re-training is formulated as a distillation (Hinton et al., 2015) from the virtual teacher. Fig. 1 depicts the overall composition of this paper as a diagram. Using CIFAR10/100-LT (Cao et al., 2019) , ImageNet-LT (Liu et al., 2019 ), and iNaturalist-2018 (Van Horn et al., 2018) benchmarks for long-tailed image classification, we empirically validate that our proposed method improves upon previous approaches both in terms of prediction accuracy and uncertainty estimation. 

2. PRELIMINARIES

where Θ = (θ, ϕ) denotes a set of trainable parameters. Given a training set D consisting of pairs of input x ∈ R D and corresponding label y ∈ {1, . . . , K}, Θ is trained to minimize the cross-entropy



Figure 1: A schematic diagram depicting the overall composition of the paper. Left: We first apply SWA to obtain more generalizing feature extractor (Section 3). Right: We then propose a novel self-distillation strategy distilling SWAG into SWA to obtain more robust classifier (Section 4).

DECOUPLED LEARNING FOR LONG-TAILED CLASSIFICATIONLet F θ : R D → R L be a neural network parameterized by θ that produces L-dimensional outputs for given D-dimensional inputs. For the K-way classification problem, an output from F θ is first transformed into K-dimensional logits via a linear classification layer parameterized by ϕ= (w k ∈ R L , b k ∈ R) Kk=1 , and then turned into a classification probability with the softmax function,p (k) (x; Θ) = exp w ⊤ k F θ (x) + b k K j=1 exp w ⊤ j F θ (x) + b j, for k = 1, ..., K,

