HOW TO EXPLOIT HYPERSPHERICAL EMBEDDINGS FOR OUT-OF-DISTRIBUTION DETECTION?

Abstract

Out-of-distribution (OOD) detection is a critical task for reliable machine learning. Recent advances in representation learning give rise to distance-based OOD detection, where testing samples are detected as OOD if they are relatively far away from the centroids or prototypes of in-distribution (ID) classes. However, prior methods directly take off-the-shelf contrastive losses that suffice for classifying ID samples, but are not optimally designed when test inputs contain OOD samples. In this work, we propose CIDER, a novel representation learning framework that exploits hyperspherical embeddings for OOD detection. CIDER jointly optimizes two losses to promote strong ID-OOD separability: a dispersion loss that promotes large angular distances among different class prototypes, and a compactness loss that encourages samples to be close to their class prototypes. We analyze and establish the unexplored relationship between OOD detection performance and the embedding properties in the hyperspherical space, and demonstrate the importance of dispersion and compactness. CIDER establishes superior performance, outperforming the latest rival by 13.33% in FPR95. Code is available at https://github.com/deeplearning-wisc/cider.

1. INTRODUCTION

When deploying machine learning models in the open world, it is important to ensure the reliability of the model in the presence of out-of-distribution (OOD) inputs-samples from an unknown distribution that the network has not been exposed to during training, and therefore should not be predicted with high confidence at test time. We desire models that are not only accurate when the input is drawn from the known distribution, but are also aware of the unknowns outside the training categories. This gives rise to the task of OOD detection, where the goal is to determine whether an input is in-distribution (ID) or not. A plethora of OOD detection algorithms have been developed recently, among which distance-based methods demonstrated promise (Lee et al., 2018; Xing et al., 2020) . These approaches circumvent the shortcoming of using the model's confidence score for OOD detection, which can be abnormally high on OOD samples (Nguyen et al., 2015) and hence not distinguishable from ID data. Distance-based methods leverage feature embeddings extracted from a model, and operate under the assumption that the test OOD samples are relatively far away from the clusters of ID data. Arguably, the efficacy of distance-based approaches can depend largely on the quality of feature embeddings. Recent works including SSD+ (Sehwag et al., 2021) and KNN+ (Sun et al., 2022) directly employ off-the-shelf contrastive losses for OOD detection. In particular, these works use the supervised contrastive loss (SupCon) (Khosla et al., 2020) for learning the embeddings, which are then used for OOD detection with either parametric Mahalanobis distance (Lee et al., 2018; Sehwag et al., 2021) or non-parametric KNN distance (Sun et al., 2022) . However, existing training objectives produce embeddings that suffice for classifying ID samples, but remain sub-optimal for OOD detection. For example, when trained on CIFAR-10 using SupCon loss, the average angular distance between ID and OOD data is only 29.86 degrees in the embedding space, which is too small for effective ID-OOD separation. This raises the important question: How to exploit representation learning methods that maximally benefit OOD detection? In this work, we propose CIDER, a Compactness and DispErsion Regularized learning framework designed for OOD detection. Our method is motivated by the desirable properties of hyperspherical embeddings, which can be naturally modeled by the von Mises-Fisher (vMF) distribution. vMF is a classical and important distribution in directional statistics (Mardia et al., 2000) , is analogous to spherical Gaussian distributions for features with unit norms. Our key idea is to design an endto-end trainable loss function that enables optimizing hyperspherical embeddings into a mixture of vMF distributions, which satisfy two properties simultaneously: (1) each sample has a higher probability assigned to the correct class in comparison to incorrect classes, and (2) different classes are far apart from each other. To formalize our idea, CIDER introduces two losses: a dispersion loss that promotes large angular distances among different class prototypes, along with a compactness loss that encourages samples to be close to their class prototypes. These two terms are complementary to shape hyperspherical embeddings for both OOD detection and ID classification purposes. Unlike previous contrastive loss, CIDER explicitly formalizes the latent representations as vMF distributions, thereby providing a direct theoretical interpretation of hyperspherical embeddings. In particular, we show that promoting large inter-class dispersion is key to strong OOD detection performance, which has not been explored in previous literature. Previous methods including SSD+ directly use off-the-shelf SupCon loss, which produces embeddings that lack sufficient interclass dispersion needed for OOD detection. CIDER mitigates the issue by explicitly optimizing for large inter-class margins and leads to more desirable hyperspherical embeddings. Noticeably, when trained on CIFAR-10, CIDER displays a relative 42.36% improvement of ID-OOD separability compared to SupCon. We further show that CIDER's strong representation can benefit different distance-based OOD scores, outperforming recent competitive methods SSD+ (Sehwag et al., 2021) and KNN+ (Sun et al., 2022) by a significant margin. Our key results and contributions are: 1. We propose CIDER, a novel representation learning framework designed for OOD detection. Compared to the latest rival (Sun et al., 2022) , CIDER produces superior embeddings that lead to 13.33% error reduction (in FPR95) on the challenging CIFAR-100 benchmark. 2. We are the first to establish the unexplored relationship between OOD detection performance and the embedding quality in the hyperspherical space, and provide measurements based on the notion of compactness and dispersion. This allows future research to quantify the embedding in the hyperspherical space for effective OOD detection. 3. We offer new insights on the design of representation learning for OOD detection. We also conduct extensive ablations to understand the efficacy and behavior of CIDER, which remains effective and competitive under various settings, including the ImageNet dataset.

2. PRELIMINARIES

We consider multi-class classification, where X denotes the input space and Y in = {1, 2, ..., C} denotes the ID labels. The training set D in tr = {(x i , y i )} N i=1 is drawn i.i.d. from P X Y in . Let P X denote the marginal distribution over X , which is called the in-distribution (ID). Out-of-distribution detection. OOD detection can be viewed as a binary classification problem. At test time, the goal of OOD detection is to decide whether a sample x ∈ X is from P X (ID) or not (OOD). In practice, OOD is often defined by a distribution that simulates unknowns encountered during deployment, such as samples from an irrelevant distribution whose label set has no intersection with Y in and therefore should not be predicted by the model. Mathematically, let D ood test denote an OOD test set where the label space Y ood ∩ Y in = ∅. The decision can be made via a level set estimation: G λ (x) = 1{S(x) ≥ λ}, where samples with higher scores S(x) are classified as ID and vice versa. The threshold λ is typically chosen so that a high fraction of ID data (e.g. 95%) is correctly classified. Hyperspherical embeddings. A hypersphere is a topological space that is homeomorphic to a standard n-sphere, which is the set of points in (n + 1)-dimensional Euclidean space that are located at a constant distance from the center. When the sphere has a unit radius, it is called the unit hypersphere. Formally, an n-dimensional unit-hypersphere S n := {z ∈ R n+1 |∥z∥ 2 = 1}. Geometrically, hyperspherical embeddings lie on the surface of a hypersphere.

