MASSIVELY SCALING HETEROSCEDASTIC CLASSIFIERS

Abstract

Heteroscedastic classifiers, which learn a multivariate Gaussian distribution over prediction logits, have been shown to perform well on image classification problems with hundreds to thousands of classes. However, compared to standard classifiers, they introduce extra parameters that scale linearly with the number of classes. This makes them infeasible to apply to larger-scale problems. In addition heteroscedastic classifiers introduce a critical temperature hyperparameter which must be tuned. We propose HET-XL, a heteroscedastic classifier whose parameter count when compared to a standard classifier scales independently of the number of classes. In our large-scale settings, we show that we can remove the need to tune the temperature hyperparameter, by directly learning it on the training data. On large image classification datasets with up to 4B images and 30k classes our method requires 14× fewer additional parameters, does not require tuning the temperature on a held-out set and performs consistently better than the baseline heteroscedastic classifier. HET-XL improves ImageNet 0-shot classification in a multimodal contrastive learning setup which can be viewed as a 3.5 billion class classification problem.

1. INTRODUCTION

Heteroscedastic models learn an input-dependent noise term to capture uncertainty in their predictions. In deep learning, they have been used successfully in large-scale image classification (Collier et al., 2021) , image segmentation (Kendall & Gal, 2017; Collier et al., 2020 ), regression (Lakshminarayanan et al., 2017) , uncertainty quantification (Tran et al., 2022; Nado et al., 2021) and in bandit problems (Osband et al., 2021) . It is known from the economics literature that heteroscedastic classifiers are particularly suited to modelling classification problems with many classes (Train, 2009) and this has been further observed in deep learning (Collier et al., 2021) . However, heteroscedastic classifiers add additional parameters to standard "deterministic" classifiers (DET) to define their K × K covariance matrix, with K the number of classes. Even with low-rank approximations, the number of additional parameters scales linearly in K, thus imposing a significant cost in large-scale settings. Also, these additional parameters must be stored in long-term storage and loaded in memory which can pose problems for both storage and memory bound applications. For example, on JFT-4B, a dataset with 29,593 classes, the state-of-the-art and most scalable, to the best of our knowledge, heteroscedastic classification method HET (Collier et al., 2021) , does not fit in memory on a large TPU slice (64 TPU v3 cells with 128 cores) when using a modest-sized ViT-L/32 base architecture. In this paper, we propose HET-XL whose extra parameter count over DET scales independently of the number of classes. In addition, HET requires tuning a temperature hyperparameter τ , which hinders the adoption of heteroscedastic classifiers in large-scale settings where hyperparameter sweeps are either very costly or not feasible at all. HET-XL, in contrast, learns τ directly on the training set. We argue and demonstrate empirically that this is feasible precisely in this very large-scale setting. Despite the improved



*Equal contribution.

annex

efficiency and ease of adoption, HET-XL performs consistently better across large-scale image classification tasks compared to DET and HET and improves upon a DET baseline in the contrastive learning setting.Contributions. In summary, our contributions are:(1) We develop the HET-XL heteroscedastic classifier, whose cost of deployment is significantly reduced as compared to HET, the prior state-of-the-art. In large-scale settings, HET-XL does not require tuning a temperature hyperparameter and has massively reduced parameter count compared to HET. Moreover, HET-XL allows for a plug-in compatibility with existing large-scale classifiers, which is not the case for HET.(2) On three image classification benchmarks-JFT 300M, ImageNet-21k and JFT-4B-and for two different popular model classes, ResNet152 and ViT-L/32, HET-XL consistently outperforms HET and HET-H, a new hashing-based baseline we introduce. For example, with a ResNet152 on JFT-300M, we increase precision@1 by 2.3% compared to HET, while adding about 9 times fewer parameters.(3) We extend HET-XL to contrastive learning where the method improves ImageNet 0-shot classification accuracy from 85.29% for a DET model to 85.56% on a LiT setup (Zhai et al., 2022) .

2. BACKGROUND ON HETEROSCEDASTIC CLASSIFIERS

We now review the core prior work that our method builds on, a wider review of related work is provided in Appendix A. We focus on classification tasks where we learn classifiers of the form softmax(W φ(x; θ)) and sigmoid(W φ(x; θ))(1) based on some training data D = {(x n , y n )} N n=1 . A pair (x n , y n ) corresponds to an input x n , e.g., an image, together with its label y n ∈ {0, 1} K belonging to one, or multiple, of the K classes, in the multi-class and multi-label settings, respectively. The model is parametrized by W ∈ R D×K and the D-dimensional representation φ(•; θ) output by a neural network with parameters θ. We have omitted the bias term to ease the presentation. Throughtout the paper, we refer to W φ(x; θ) ∈ R K as the logits, while we will use the term pre-logits for φ(x; θ) ∈ R D . We denote the elementwise product between tensors by •.Heteroscedastic classifiers learn an additional input-dependent noise distribution placed on the logits to capture uncertainty in the predictions of the model (Kendall & Gal, 2017; Collier et al., 2021; Train, 2009) . We consider the setting where this noise is modelled by a Gaussian, leading to the class predictions(2) where σ can be either the softmax or the sigmoid transformation (see Fig. 1 , right). Above, we have introduced the covariance matrix Σ(x; θ cov ) of that is parametrized by θ cov ; we will describe Σ in more details in Section 2.2. The resulting conditional probability p(y|x; {W , θ, θ cov }) in Eq. ( 2) is used to train the model on D by maximum likelihood and to make predictions at evaluation time.

2.1. ESTIMATING THE HETEROSCEDASTIC PREDICTIONS

We will focus on the HET method from Collier et al. (2021) as it obtains state-of-the-art performance on several benchmarks and operates at the largest scale. As shown in Eq. ( 2), heteroscedastic modelling requires

