SIMPLICIAL EMBEDDINGS IN SELF-SUPERVISED LEARNING AND DOWNSTREAM CLASSIFICATION

Abstract

Simplicial Embeddings (SEM) are representations learned through self-supervised learning (SSL), wherein a representation is projected into L simplices of V dimensions each using a softmax operation. This procedure conditions the representation onto a constrained space during pre-training and imparts an inductive bias for discrete representations. For downstream classification, we provide an upper bound and argue that using SEM leads to a better expected error than the unnormalized representation. Furthermore, we empirically demonstrate that SSL methods trained with SEMs have improved generalization on natural image datasets such as CIFAR-100 and ImageNet. Finally, when used in a downstream classification task, we show that SEM features exhibit emergent semantic coherence where small groups of learned features are distinctly predictive of semantically-relevant classes.

1. INTRODUCTION

Self-supervised learning (SSL) is an emerging family of methods that aim to learn representations of data without manual supervision, such as class labels. Recent works (Hjelm et al., 2019; Grill et al., 2020; Saeed et al., 2020; You et al., 2020) learn dense representations that can solve complex tasks by simply fitting a linear model on top of the learned representation. While SSL is already highly effective, we show that changing the type of representation learned can improve both the performance and interpretability of these methods. For this we draw inspiration from overcomplete representations: representations of an input that are non-unique combinations of a number of basis vectors greater than the input's dimensionality (Lewicki & Sejnowski, 2000) . Mostly studied in the context of the sparse coding literature (Gregor & LeCun, 2010; Goodfellow et al., 2012; Olshausen, 2013) , sparse overcomplete representations have been shown to increase stability in the presence of noise (Donoho et al., 2006) , have applications in neuroscience (Olshausen & Field, 1996; Lee et al., 2007) , and lead to more interpretable representations (Murphy et al., 2012; Fyshe et al., 2015; Faruqui et al., 2015) . But, the basis vector is learned using linear models (Lewicki & Sejnowski, 2000; Teh et al., 2003) . In this work, we show that SSL may be used to learn discrete, sparse and overcomplete representations. Prior work has considered sparse representation but not sparse and overcomplete representation learning with SSL; for example, Dessì et al. (2021) propose to discretize the output of the encoder in a SSL model using Gumbel-Softmax (Jang et al., 2017) . However, we show that discretization during pre-training is not necessary to achieve a sparse representation. Instead, we propose to project the encoder's output into L vectors of V dimensions onto which we apply a softmax function to impart an inductive bias toward sparse one-hot vectors (Correia et al., 2019; Goyal et al., 2022) , also alleviating the need to use high-variance gradient estimators to train the encoder. We refer to this embedding as Simplicial Embeddings (SEM), as the softmax functions map the unnormalized representations onto L simplices. The procedure to induce SEM is simple, efficient, and generally applicable. The SSL pre-training phase, used with SEM, learns a set of L approximately one-hot vectors. Key to controlling the inductive bias of SEM during pre-training is the softmax temperature parameter: the lower the temperature, the stronger the bias toward sparsity. Consistent with earlier attempts at sparse representation learning (Coates & Ng, 2011) , we find that the optimal sparsity for pre-training need not match the optimal level for downstream learning. For downstream classification, we may discretize the learned representation by, for example, taking the argmax for each simplex. But, we can also use SEM to control the representation's expressivity via the softmax's temperature. We provide a theoretical bound showing that the expected error follows a trade-off between the training error and the representations' expressivity that can be controlled by the softmax's temperature used to normalize the representation for downstream classification. Our bound also shows improved expected error as we increase L and V for SEM. SEM is generally applicable to recent SSL methods. Applying it to seven different SSL methods (Chen et al., 2020b; He et al., 2020; Grill et al., 2020; Caron et al., 2020; 2021; Zbontar et al., 2021; Bardes et al., 2022) , we find accuracy increases of 2% to 4% on CIFAR-100. We observe monotonic improvement as we increase the number of vectors L, showing the benefit of the overcomplete representations learned by SEM, while this improvement is absent when we do not use softmax normalization. When training a SSL method with SEM on ImageNet we also observe improvements on in-distribution compared to the baseline (Figure 1 ). We also observe improvement on out-of-distribution test sets, semi-supervised learning benchmark and transfer learning datasets, demonstrating the potential of SEM for large scale applications. Finally, we find that SEM learns features that are closely aligned to the semantic categories in the data. This demonstrates that SEM learns disentangled and interpretable representations, as previously observed in overcomplete representations (Faruqui et al., 2015) .

2. RELATED WORK

The softmax operation has been used in other contexts, notably as an architectural component for models to attend to context-dependent queries via, for example, an attention mechanism (Bahdanau et al., 2016; Vaswani et al., 2017; Correia et al., 2019; Goyal et al., 2022) , a mixture of experts (Jordan & Jacobs, 1993) or memory augmented networks (Graves et al., 2014) . This operation is also used for the computation of several SSL objectives such as InfoNCE (van den Oord et al., 2018; Hjelm et al., 2019) , and as a normalization of the output to compute the objective in DINO and SWaV (Caron et al., 2020; 2021) . Different from these, our method places the softmax at the output of an encoder to constrain the representation into a set of L sparse vectors. Similar to our approach, other architectural constraints such as Dropout (Srivastava et al., 2014 ), BatchNorm (Ioffe & Szegedy, 2015) and LayerNorm (Ba et al., 2016) also improve the training of large neural networks. However, contrary to SEMs, they are not used to induce sparsity on the representation or control its expressivity for downstream tasks. Closer to our work, Liu et al. (2021) propose to constrain the expressivity of the representation of a neural network with a set of discrete-valued symbols obtained using a set of Vector Quantized (Oord et al., 2018 ) bottlenecks. Similarly, Dessi et al. (2021) propose a communication game with a discrete bottleneck. The idea of discretizing the encoder's output is similar to using SEM vectors that are one-hot (e.g. temperature = 0) and only one symbol (e.g. L = 1, V = 2048). In our work, we find success in removing the hard-discretization and having L > 1, which can be interepreted as combining several symbols.

3. SIMPLICIAL EMBEDDINGS

Simplicial Embeddings (SEM) are representations that can be integrated easily into a contrastive learning model (Hjelm et al., 2019; Chen et al., 2020b ), the BYOL method (Grill et al., 2020 ), and other SSL methods (Caron et al., 2020; 2021; Zbontar et al., 2021) . For example, in BYOL, we insert the SEM after the encoder and before the projector and the rest is unchanged as shown in Figure 2c . In this figure, t and t are augmentations defined by the practitioner, ξ are parameters of the target network that are updated as moving average of the parameters θ of the online networks trained with SGD. So, ξ are updated as follow: ξ ← αξ + (1 -α)θ, with α ∈ [0, 1]. To produce SEM representation, the encoder's output e is embedded into L vectors z i ∈ R V . A temperature parameter τ scales z i , and then a softmax re-normalizes each vector z i to produce zi .



Figure 1: Linear probe accuracy of BYOL and BYOL + SEM on Ima-geNet trained for 200 epochs with a ResNet-50 architecture.

