CONSTRAINING REPRESENTATIONS YIELDS MODELS THAT KNOW WHAT THEY DON'T KNOW

Abstract

A well-known failure mode of neural networks is that they may confidently return erroneous predictions. Such unsafe behaviour is particularly frequent when the use case slightly differs from the training context, and/or in the presence of an adversary. This work presents a novel direction to address these issues in a broad, general manner: imposing class-aware constraints on a model's internal activation patterns. Specifically, we assign to each class a unique, fixed, randomly-generated binary vector -hereafter called class code -and train the model so that its cross-depths activation patterns predict the appropriate class code according to the input sample's class. The resulting predictors are dubbed total activation classifiers (TAC), and TACs may either be trained from scratch, or used with negligible cost as a thin add-on on top of a frozen, pre-trained neural network. The distance between a TAC's activation pattern and the closest valid code acts as an additional confidence score, besides the default unTAC'ed prediction head's. In the add-on case, the original neural network's inference head is completely unaffected (so its accuracy remains the same) but we now have the option to use TAC's own confidence and prediction when determining which course of action to take in an hypothetical production workflow. In particular, we show that TAC strictly improves the value derived from models allowed to reject/defer. We provide further empirical evidence that TAC works well on multiple types of architectures and data modalities and that it is at least as good as state-of-the-art alternative confidence scores derived from existing models.

1. INTRODUCTION

Recent work has revealed interesting emerging properties for representations learned by neural networks (Papernot & McDaniel, 2018; Kalibhat et al., 2022; Bäuerle et al., 2022) . In particular, simple class-dependent patterns were observed after training: there are groups of representations that consistently activate more strongly depending on high-level features of inputs. This behaviour can be useful to define predictors able to reject/defer test data that do not follow common patterns, provided that one can efficiently verify similarities between new data and common patterns. Well known limitations of this model class can then be addressed such as its lack of robustness to natural distribution shifts (Ben-David et al., 2006) , or against small but carefully crafted perturbations to its inputs (Szegedy et al., 2013; Goodfellow et al., 2014) . These empirical evidences and potential use cases naturally lead to the question: can we enforce simple class-dependent structure rather than hope it emerges? In this work, we address this question and show that one can indeed constrain representations to follow simple, class-dependent, and efficiently verifiable patterns on learned representations. In particular, we turn the label set into a set of hard-coded class-specific binary codes and define models such that activations obtained from different layers match those patterns. In other words, class codes constrain representations and define a discrete set of valid internal configurations by indicating which groups of features should be strongly activated for a given class. At testing time, any measure of how well a model matches some of the valid patterns can be used to reject potentially erroneous predictions. Codes are chosen a priori and designed to approximate pair-wise orthogonality. Motivation. The motivation for constraining internal representations to satisfy a simple classdependent structure is two-fold: 1. Given data, we can measure how close to a valid pattern the activations of a model are, and finally use such a measure as a confidence score. That is, if the model is far from a valid activation pattern, then its prediction should be deemed unreliable. Moreover, we can make codes higher-dimensional than standard one-hot representations. Long enough codes enable us to represent classes with very distinct, hence discriminative, features. 2. Tying internal representations with the labels adds constraints to attackers. To illustrate the advantage of this scheme, consider that an adversary tries to fool a standard classifier: its only job is to make it so that any output unit fires up more strongly than the right one, and any internal configuration that satisfies that condition is valid. In our proposal, an attack is only valid if the entire set of activations matches the pattern of the wrong class, adding constraints to the attack problem and effectively making it harder for an attacker to succeed under a given compute/perturbation budget as compared to a standard classifier for which decisions are based solely on the output layer. Intuitively, we seek to define model classes and learning algorithms such that intermediate representations follow a class-dependent structure that can be efficiently verified. Concretely, we introduce total activation classifiers (TAC): a component that can be added to any class of multi-layer classifiers. Given data and a set of class codes, TAC decides on an output class depending on which class code best matches an observed activation pattern. To obtain activation patterns, TAC slices and reduces (e.g., sum or average) the activations of a stack of layers. Concatenating the results of the slice/reduce steps across the depth of the model yields a vector that we refer to as the activation profile. TAC learns by matching activation profiles to the underlying codes. At inference, TAC assigns the class with the closest code to the activation profile that a given test instance yields, and the corresponding distance behaves as a strong predictor of the prediction's quality so that, at testing time, one can decide to reject when activation profiles do not match valid codes to a threshold.

Contributions. Our contributions are summarized as follows:

1. We introduce a model class along with a learning procedure referred to as total activation classifiers, which satisfies the requirement of representations that follow class-dependent patterns. Resulting models require no access to out-of-distribution data during training, and offer inexpensive easy-toobtain confidence scores without affecting prediction performance. 2. We propose simple and efficient strategies leveraging statistics of TAC's activations to spot low confidence, likely erroneous predictions. In particular, we empirically observed TAC to be effective in the rejection setting, strictly improving the value of rejecting classifiers. 3. We provide extra results and show that TAC's scores can be used to detect data from unseen classes, and that it can be used as a robust surrogate of the base classifier if it's kept hidden from attackers, while preserving its clean accuracy to a greater extent than alternative robust predictors.

2.1. REPRESENTING LABELS AS CODES

We focus on the the setting of K-way classification. In this case, data instances correspond to pairs x, y ∼ X × Y, with X ⊂ R d and Y = {1, 2, 3, ..., K}, K ∈ N. Usually, model families targeting such a setting parameterize data-conditional categorical distributions over Y. That is, a given model f ∈ F : X → ∆ K-1 will project data onto the the probability simplex ∆ K-1 . Alternatively, we will consider classes of predictors of the form f ′ ∈ F ′ : X → [0, 1] L with L ≫ K, i.e., models that map data onto the unit cube in R L . We thus associate each element in Y with a vertex of the cube. In other words, we represent class labels with a set of binary codes C = {C 1 , C 2 , C 3 , ..., C K }, and training can be performed via searching arg min f ′ ∈F ′ E (x,y)∼(X ,Y) D(f ′ (x), C y ), for a distance D : R L × R L → R. In doing so, we seek models able to project data such that results are close to the vertex corresponding to the correct class label in terms of a distance D. At prediction time, one can predict via arg min i D(f ′ (x), C i ) or leverage D(f ′ (x), C ∈ C) to estimate confidence.

