REPRESENTATION LEARNING VIA INVARIANT CAUSAL MECHANISMS

Abstract

Self-supervised learning has emerged as a strategy to reduce the reliance on costly supervised signals by pretraining representations only using unlabeled data. These methods combine heuristic proxy classification tasks with data augmentations and have achieved significant success, but our theoretical understanding of this success remains limited. In this paper we analyze self-supervised representation learning using a causal framework. We show how data augmentations can be more effectively utilized through explicit invariance constraints on the proxy classifiers employed during pretraining. Based on this, we propose a novel selfsupervised objective, Representation Learning via Invariant Causal Mechanisms (RELIC), that enforces invariant prediction of proxy targets across augmentations through an invariance regularizer which yields improved generalization guarantees. Further, using causality we generalize contrastive learning, a particular kind of self-supervised method, and provide an alternative theoretical explanation for the success of these methods. Empirically, RELIC significantly outperforms competing methods in terms of robustness and out-of-distribution generalization on ImageNet, while also significantly outperforming these methods on Atari achieving above human-level performance on 51 out of 57 games.

1. INTRODUCTION

Training deep networks often relies heavily on large amounts of useful supervisory signal, such as labels for supervised learning or rewards for reinforcement learning. These training signals can be costly or otherwise impractical to acquire. On the other hand, unsupervised data is often abundantly available. Therefore, pretraining representations for unknown downstream tasks without the need for labels or extrinsic reward holds great promise for reducing the cost of applying machine learning models. To pretrain representations, self-supervised learning makes use of proxy tasks defined on unsupervised data. Recently, self-supervised methods using contrastive objectives have emerged as one of the most successful strategies for unsupervised representation learning (Oord et al., 2018; Hjelm et al., 2018; Chen et al., 2020a) . These methods learn a representation by classifying every datapoint against all others datapoints (negative examples). Under assumptions on how the negative examples are sampled, minimizing the resulting contrastive loss has been justified as maximizing a lower bound on the mutual information (MI) between representations (Poole et al., 2019) . However, (Tschannen et al., 2019) has shown that performance on downstream tasks may be more tightly correlated with the choice of encoder architecture than the achieved MI bound, highlighting issues with the MI theory of contrastive learning. Further, contrastive approaches compare different views of the data (usually under different data augmentations) to calculate similarity scores. This approach to computing scores has been empirically observed as a key success factor of contrastive methods, but has yet to be theoretically justified. This lack of a solid theoretical explanation for the effectiveness of contrastive methods hinders their further development. To remedy the theoretical shortcomings, we analyze the problem of self-supervised representation learning through a causal lens. We formalize intuitions about the data generating process using a causal graph and leverage causal tools to derive properties of the optimal representation. We show that a representation should be an invariant predictor of proxy targets under interventions on features that are only correlated, but not causally related to the downstream targets of interest. Since neither causally nor purely correlationally related features are observed and thus performing actual interventions on them is not feasible, for learning representation with this property we use data augmentations to simulate a subset of possible interventions. Based on our causal interpretation, we propose a regularizer which enforces that the prediction of the proxy targets is invariant across data augmentations. We propose a novel objective for self-supervised representation learning called REpresentation Learning with Invariant Causal mechanisms (RELIC). We show how this explicit invariance regularization leverages augmentations more effectively than previous self-supervised methods and that representations learned using RELIC are guaranteed to generalize well to downstream tasks under weaker assumptions than those required by previous work (Saunshi et al., 2019) . Next we generalize contrastive learning and provide an alternative theoretical explanation to MI for the success of these methods. We generalize the proxy task of instance discrimination commonly used in contrastive learning using the causal concept of refinements (Chalupka et al., 2014) . Intuitively, a refinement of a task can be understood as a more fine-grained variant of the original problem. For example, a refinement for classifying cats against dogs would be the task of classifying individual cat and dog breeds. The instance discrimination task results from the most fine-grained refinement, e.g. discriminating individual cats and dogs from one another. We show that using refinements as proxy tasks enables us to learn useful representations for downstream tasks. Specifically, using causal tools, we show that learning a representation on refinements such that it is an invariant predictor of proxy targets across augmentations is a sufficient condition for these representations to generalize to downstream tasks (cf. Theorem 1). In summary, we provide theoretical support both for the general form of the contrastive objective as well as for the use of data augmentations. Thus, we provide an alternative explanation to mutual information for the success of recent contrastive approaches namely that of causal refinements of downstream tasks. We test RELIC on a variety of prediction and reinforcement learning problems. First, we evaluate the quality of representations pretrained on ImageNet with a special focus on robustness and out-ofdistribution generalization. RELIC performs competitively with current state-of-the-art methods on ImageNet, while significantly outperforming competing methods on robustness and out-of-distribution generalization of the learned representations when tested on corrupted ImageNet (ImageNet-C (Hendrycks & Dietterich, 2019)) and a version of ImageNet that consist of different renditions of the same classes (ImageNet-R (Hendrycks et al., 2020)). In terms of robustness, RELIC also significantly outperforms the supervised baseline with an absolute reduction of 4.9% in error. Unlike much prior work that specifically focuses on computer vision tasks, we test RELIC for representation learning in the context of reinforcement learning on the Atari suite (Bellemare et al., 2013) . There we find that RELIC significantly outperforms competing methods and achieves above human-level performance on 51 out of 57 games.

Contributions.

• We formalize problem of self-supervised representation learning using causality and propose to more effectively leverage data augmentations through invariant prediction. • We propose a new self-supervised objective, REpresentation Learning with Invariance Causal mechanisms (RELIC), that enforces invariant prediction through an explicit regularizer and show improved generalization guarantees. • We generalize contrastive learning using refinements and show that learning on refinements is a sufficient condition for learning useful representations; this provides an alternative explanation to MI for the success of contrastive methods.

2. REPRESENTATION LEARNING VIA INVARIANT CAUSAL MECHANISMS

Problem setting. Let X denote the unlabelled observed data and Y = {Y t } T t=1 be a set of unknown tasks with Y t denoting the targets for task t. The tasks {Y t } T t=1 can represent both a multi-environment as well as a multi-task setup. Our goal is to pretrain with unsupervised data a representation f (X) that will be useful for solving the downstream tasks Y. Causal interpretation. To effectively leverage common assumptions and intuitions about data generation of the unknown downstream tasks for the learning algorithm, we propose to formalize

