IMPROVING TRANSFORMATION INVARIANCE IN CONTRASTIVE REPRESENTATION LEARNING

Abstract

We propose methods to strengthen the invariance properties of representations obtained by contrastive learning. While existing approaches implicitly induce a degree of invariance as representations are learned, we look to more directly enforce invariance in the encoding process. To this end, we first introduce a training objective for contrastive learning that uses a novel regularizer to control how the representation changes under transformation. We show that representations trained with this objective perform better on downstream tasks and are more robust to the introduction of nuisance transformations at test time. Second, we propose a change to how test time representations are generated by introducing a feature averaging approach that combines encodings from multiple transformations of the original input, finding that this leads to across the board performance gains. Finally, we introduce the novel Spirograph dataset to explore our ideas in the context of a differentiable generative process with multiple downstream tasks, showing that our techniques for learning invariance are highly beneficial.

1. INTRODUCTION

Learning meaningful representations of data is a central endeavour in artificial intelligence. Such representations should retain important information about the original input whilst using fewer bits to store it (van der Maaten et al., 2009; Gregor et al., 2016) . Semantically meaningful representations may discard a great deal of information about the input, whilst capturing what is relevant. Knowing what to discard, as well as what to keep, is key to obtaining powerful representations. By defining transformations that are believed a priori to distort the original without altering semantic features of interest, we can learn representations that are (approximately) invariant to these transformations (Hadsell et al., 2006) . Such representations may be more efficient and more generalizable than lossless encodings. Whilst less effective for reconstruction, these representations are useful in many downstream tasks that relate only to the semantic features of the input. Representation invariance is also a critically important task in of itself: it can lead to improved robustness and remove noise (Du et al., 2020) , afford fairness in downstream predictions (Jaiswal et al., 2020) , and enhance interpretability (Xu et al., 2018) . Contrastive learning is a recent and highly successful self-supervized approach to representation learning that has achieved state-of-the-art performance in tasks that rely on semantic features, rather than exact reconstruction (van den Oord et al., 2018; Hjelm et al., 2018; Bachman et al., 2019; He et al., 2019) . These methods learn to match two different transformations of the same object in representation space, distinguishing them from contrasts that are representations of other objects. The objective functions used for contrastive learning encourage representations to remain similar under transformation, whilst simultaneously requiring different inputs to be well spread out in representation space (Wang & Isola, 2020) . As such, the choice of transformations is key to their success (Chen et al., 2020a) . Typical choices include random cropping and colour distortion. However, representations are compared using a similarity function that can be maximized even for representations that are far apart, meaning that the invariance learned is relatively weak. Unfor-tunately, directly changing the similarity measure hampers the algorithm (Wu et al., 2018; Chen et al., 2020a) . We therefore investigate methods to improve contrastive representations by explicitly encouraging stronger invariance to the set of transformations, without changing the core selfsupervized objective; we look to extract more information about how representations are changing with respect to transformation, and use this to direct the encoder towards greater invariance. To this end, we first develop a gradient regularization term that, when included in the training loss, forces the encoder to learn a representation function that varies slowly with continuous transformations. This can be seen as constraining the encoder to be approximately transformation invariant. We demonstrate empirically that while the parameters of the transformation can be recovered from standard contrastive learning representations using just linear regression, this is no longer the case when our regularization is used. Moreover, our representations perform better on downstream tasks and are robust to the introduction of nuisance transformations at test time. Test representations are conventionally produced using untransformed inputs (Hjelm et al., 2018; Kolesnikov et al., 2019) , but this fails to combine information from different transformations and views of the object, or to emulate settings in which transformation noise cannot simply be removed at test time. Our second key proposal is to instead create test time representations by feature averaging over multiple, differently transformed, inputs to address these concerns and to more directly impose invariance. We show theoretically that this leads to improved performance under linear evaluation protocols, further confirming this result empirically. We evaluate our approaches first on CIFAR-10 and CIFAR-100 (Krizhevsky et al., 2009) , using transformations appropriate to natural images and evaluating on a downstream classification task. To validate that our ideas transfer to other settings, and to use our gradient regularizer within a fully differentiable generative process, we further introduce a new synthetic dataset called Spirograph. This provides a greater variety of downstream regression tasks, and allows us to explore the interplay between nuisance transformations and generative factors of interest. We confirm that using our regularizer during training and our feature averaging at test time both improve performance in terms of transformation invariance, downstream tasks, and robustness to train-test distributional shift. In summary, the contributions of this paper are as follows: • We derive a novel contrastive learning objective that leads to more invariant representations. • We propose test time feature averaging to enforce further invariance. • We introduce the Spirograph dataset. • We show empirically that our approaches lead to more invariant representations and achieve state-of-the-art performance for existing downstream task benchmarks.

2. PROBABILISTIC FORMULATION OF CONTRASTIVE LEARNING

The goal of unsupervized representation learning is to encode high-dimensional data, such as images, retaining information that may be pertinent to downstream tasks and discarding information that is not. To formalize this, we consider a data distribution p(x) on X and an encoder f θ : X → Z which is a parametrized function mapping from data space to representation space. Contrastive learning is a self-supervized approach to representation learning that learns to make representations of differently transformed versions of the same input more similar than representations of other inputs. Of central importance is the set of transformations, also called augmentations (Chen et al., 2020a) or views (Tian et al., 2019) , used to distort the data input x. In the common application of computer vision, it is typical to include resized cropping; brightness, contrast, saturation and hue distortion; greyscale conversion; and horizontal flipping. We will later introduce the Spirograph dataset which uses quite different transformations. In general, transformations are assumed to change the input only cosmetically, so all semantic features such as the class label are preserved; the set of transformations indicates changes which can be safely ignored by the encoder. Formally, we consider a transformation set T ⊆ {t : X → X } and a probability distribution p(t) on this set. A representation z of x is obtained by applying a random transformation t to x and then encoding the result using f θ . Therefore, we do not have one representation of x, but an implicit distribution p(z|x). A sample of p(z|x) is obtained by sampling t ∼ p(t) and setting z = f θ (t(x)).

