WEIGHTED ENSEMBLE SELF-SUPERVISED LEARNING

Abstract

Ensembling has proven to be a powerful technique for boosting model performance, uncertainty estimation, and robustness in supervised learning. Advances in self-supervised learning (SSL) enable leveraging large unlabeled corpora for state-of-the-art few-shot and supervised learning performance. In this paper, we explore how ensemble methods can improve recent SSL techniques by developing a framework that permits data-dependent weighted cross-entropy losses. We refrain from ensembling the representation backbone; this choice yields an efficient ensemble method that incurs a small training cost and requires no architectural changes or computational overhead to downstream evaluation. The effectiveness of our method is demonstrated with two state-of-the-art SSL methods, DINO (Caron et al., 2021) and MSN (Assran et al., 2022). Our method outperforms both in multiple evaluation metrics on ImageNet-1K, particularly in the few-shot setting. We explore several weighting schemes and find that those which increase the diversity of ensemble heads lead to better downstream evaluation results. Thorough experiments yield improved prior art baselines which our method still surpasses; e.g., our overall improvement with MSN ViT-B/16 is 3.9 p.p. for 1-shot learning.



The promise of self-supervised learning (SSL) is to extract information from unlabeled data and leverage this information in downstream tasks (He et al., 2020; Caron et al., 2021) ; e.g., semi-supervised learning (Chen et al., 2020a; b) , robust learning (Radford et al., 2021; Ruan et al., 2022; Lee et al., 2021) , few-shot learning (Assran et al., 2022) , and supervised learning (Tomasev et al., 2022) . These successes have encouraged increasingly advanced SSL techniques (e.g., Grill et al., 2020; Zbontar et al., 2021; He et al., 2022) . Perhaps surprisingly however, a simple and otherwise common idea has received limited consideration: ensembling. Ensembling combines predictions from multiple trained models and has proven effective at improving model accuracy (Hansen & Salamon, 1990; Perrone & Cooper, 1992) and capturing predictive uncertainty in supervised learning (Lakshminarayanan et al., 2017; Ovadia et al., 2019) . Ensembling in the SSL regime is nuanced, however; since the goal is to learn useful representations from unlabeled data, it is less obvious where and how to ensemble. We explore these questions in this work. We develop an efficient ensemble method tailored for SSL that replicates the non-representation parts (e.g., projection heads) of the SSL model. In contrast with traditional "post-training" ensembling, our ensembles are only used during training to facilitate the learning of a single representation encoder, which yields no extra cost in downstream evaluation. We further present a family of weighted crossentropy losses to effectively train the ensembles. The key component of our losses is the introduction of data-dependant importance weights for ensemble members. We empirically compare different choices from our framework and find that the choice of weighting schemes critically impacts ensemble diversity, and that greater ensemble diversity correlates with improved downstream performance. Our method is potentially applicable to many SSL methods; we focus on DINO (Caron et al., 2021) and MSN (Assran et al., 2022) to demonstrate its effectiveness. Fig. 1 shows DINO improvements from using our ensembling and weighted cross-entropy loss. • Develop a downstream-efficient ensemble method suitable for many SSL techniques (Sec. 3.1). • Characterize an ensemble loss family of weighted cross-entropy objectives (Sec. 3.2). • Conduct extensive ablation studies that improve the prior art baselines by up to 6.3 p.p. (Sec. 5.1). • Further improve those baselines with ensembling (e.g., up 5.5 p.p. gain for 1-shot) (Table 2 ).

2. BACKGROUND

In this section, we frame SSL methods from the perspective of maximum likelihood estimation (MLE) and use this as the notational basis to describe the state-of-the-art clustering-based SSL methods as well as derive their ensembled variants in Sec. 3. From Maximum Likelihood to SSL Denote unnormalized KL divergence (Dikmen et al., 2014) between non-negative integrable functions p, q by K[p(X), q( X)] = H × [p(X), q(X)] -H[p(X)], where H × [p(X), q(X)] = -X p(x) log q(x)dx + X q(x)dx -1 is the unnormalized cross-entropy (with 0 log 0 = 0) and H[p(X)] = H × [p(X), p(X)]. These quantities simplify to their usual definitions when p, q are normalized, but critically they enable flexible weighting of distributions for the derivation of our weighted ensemble losses in Sec. 3.2. Let ν(X, Y ) = ν(X)ν(Y |X) be nature's distribution of input/target pairs over the space X × Y and s(Y |θ, X) be a predictive model of target given the input parameterized by θ ∈ T . Supervised maximum likelihood seeks the minimum expected conditional population risk with respect to θ, E ν(X) K[ν(Y |X), s(Y |θ, X)] = E ν(X) H × [ν(Y |X), s(Y |θ, X)] -E ν(X) H[ν(Y |X)]. Henceforth In SSL, we interpret ν(Y |x) as being the oracle teacher under a presumption of how the representations will be evaluated on a downstream task. This assumption is similar to that made in Arora et al. (2019); Nozawa et al. (2020) . We also assume ν(Y |X) is inaccessible and/or unreliable. Under this view, some SSL techniques substitute ν(Y |x) for a weakly learned target or "teacher", t(Y |x). We don't generally expect t(Y |x) to recover ν(Y |x); we only assume that an optimal teacher exists and it is ν(Y |x). With the teacher providing the targets, the loss becomes omit E ν(X) H[ν(Y |X)] since it is constant in θ. Since ν(X, Y ) is unknown, -1 n x∈Dn H × [t(Y |x), s(Y |θ, x)]. Teacher and student in clustering SSL methods Clustering SSL methods such as SWaV (Caron et al., 2020) , DINO (Caron et al., 2021), and MSN (Assran et al., 2022 ) employ a student model characterized by proximity between learned codebook entries and a data-dependent code, s(Y |θ, x) = softmax 1 τ (h ψ • r ω )(x) • µ y (h ψ • r ω )(x) 2 µ y 2 : y ∈ [c] (2) θ = {ω, ψ, {µ y } y∈[c] } ∈ T , where the encoder r ω : X → Z produces the representations used for downstream tasks, and the projection head h ψ : Z → R d and codebook entries {µ y } y∈Y ∈ R d characterize the SSL loss. Eq. ( 2) can be viewed as "soft clustering", where the input is assigned to those centroids that are closer to the projection head's output. The projection head and codebook are used during training but thrown away for evaluation, which is empirically found vital for downstream tasks (Chen et al., 2020a; b) . Hyperparameters τ ∈ R >0 , c ∈ Z >0 represent temperature and codebook size. The teacher is defined as t(Y |x) = s(Y | stopgrad(g(θ)), x) where g : T → T . Commonly g(θ) is an exponential moving average of gradient descent iterates and the teacher uses a lower temperature than the student. To capture desirable invariances and prevent degeneracy, data augmentation and regularization (e.g., Sinkhorn-Knopp normalization (Caron et al., 2020) , mean entropy maximization (Assran et al., 2022)) are essential. As these are not directly relevant to our method, we omit them for brevity.



Figure Our improvements to DINO, including baseline improvements and ensembling.

a finite sample approximation is often employed. Denote a size-n i.i.d. training set by D n = {x i } i∈[n] ∼ ν ⊗n and empirical distribution by ν(X, Y ) = 1 n x∈Dn,y∼ν(Y |x) δ(X -x)δ(Y -y) where δ : R → {0, 1} is 1 when x = 0 and 0 otherwise. The sample risk is thus -1 n x∈Dn H × [ν(Y |x), s(Y |θ, x)].

