CONTRASTIVE UNSUPERVISED LEARNING OF WORLD MODEL WITH INVARIANT CAUSAL FEATURES Anonymous

Abstract

In this paper we present a world model, which learns causal features using the invariance principle. In particular, we use contrastive unsupervised learning to learn the invariant causal features, which enforces invariance across augmentations of irrelevant parts or styles of the observation. The world-model-based reinforcement learning methods independently optimize representation learning and the policy. Thus naïve contrastive loss implementation collapses due to a lack of supervisory signals to the representation learning module. We propose an intervention invariant auxiliary task to mitigate this issue. Specifically, we use data augmentation as style intervention on the RGB observation space and depth prediction as an auxiliary task to explicitly enforce the invariance. Our proposed method significantly outperforms current state-of-the-art model-based and model-free reinforcement learning methods on out-of-distribution point navigation tasks on the iGibson dataset. Moreover, our proposed model excels at the sim-to-real transfer of our perception learning module. Finally, we evaluate our approach on the DeepMind control suite and enforce invariance only implicitly since depth is not available. Nevertheless, our proposed model performs on par with the state-ofthe-art counterpart.

1. INTRODUCTION

An important branching point in reinforcement learning (RL) methods is whether the agent learns with or without a predictive environment model. In model-based methods, an explicit predictive model of the world is learned, enabling the agent to plan by thinking ahead (Deisenroth & Rasmussen, 2011; Silver et al., 2018; Ha & Schmidhuber, 2018; Hafner et al., 2021) . The alternative model-free methods do not learn the predictive model of the environment explicitly as the control policy is learned end-to-end from the pixels. As a consequence, model-free methods do not consider the future downstream tasks. Therefore, we hope that model-based methods are more suitable for out-of-distribution (OoD) generalization and sim-to-real transfer. A model-based approach has to learn the model of the environment purely from experience, which poses several challenges. The main problem is the training bias in the model, which can be exploited by an agent and lead to poor performance during testing (Ha & Schmidhuber, 2018) . Further, model-based RL methods learn the representation using observation reconstruction loss, for example variational autoencoders (VAE) (Kingma & Welling, 2014) . The downside of such a state abstraction method is that it is not suited to separate the task relevant states from irrelevant ones, resulting in current RL algorithms often overfit to environment-specific characteristics Zhang et al. (2020) . Hence, relevant state abstraction is essential for robust RL model, which is the aim of this paper. Causality is the study of learning cause and effect relationships. Learning causality in pixel-based control involves two tasks. The first is a causal variable abstraction from images, and the second is learning the causal structure. Causal inference uses graphical modelling (Lauritzen & Spiegelhalter, 1988) , structural equation modelling (Bollen, 1989) , or counterfactuals (Dawid, 2000) . Pearl ( 2009) provided an excellent overview of those methods. However, in complex visual control tasks the number of state variables involved is high, so inference of the underlying causal structure of the model becomes intractable (Peters et al., 2016) . Causal discovery using the invariance principle tries to overcome this issue and is therefore gaining attention in the literature (Peters et al., 2016 Here spurious or irrelevant features are learnt using environment specific encoders. However, these methods need multiple sources of environments with specific interventions or variations. In contrast, we propose using data augmentation as a source of intervention, where samples can come from as little as a single environment, and we use contrastive learning for invariant feature abstraction. Related to our work, Mitrovic et al. (2021) proposed a regularizer for self-supervised contrastive learning. On other hand, we propose an intervention invariant auxiliary task for robust feature learning. Model-based RL methods do not learn the feature and control policy together to prevent the greedy feature learning. The aim is that the features of model-based RL will be more useful for various downstream tasks. Hence, state abstraction uses reward prediction, reconstruction loss or both (Ha & Schmidhuber, 2018; Zhang et al., 2020; Hafner et al., 2021) . On the other hand contrastive learning does not use the reconstruction of the inputs and applies the loss at the embedding space. Therefore, we propose a causally invariant auxiliary task for invariant causal features learning. Specifically, we utilize depth predictions to extract the geometrical features needed for navigation, which are not dependent on the texture. Finally, we emphasize that depth is not required for deployment, enabling wider applicability of the proposed model. Importantly, our setup allows us to use popular contrastive learning on model-based RL methods and improves the sample efficiency and the OoD generalization. In summary, we propose a World Model with invariant Causal features (WMC), which can extract and predict the causal features (Figure 1 ). Our WMC is verified on the point goal navigation task from Gibson (Xia et al., 2018) and iGibson 1.0 (Shen et al., 2021) as well as the DeepMind control suite (DMControl) (Tunyasuvunakool et al., 2020) . Our main contributions are: 1. to propose a world model with invariant causal features, which outperforms state-of-the-art models on out-of-distribution generalization and sim-to-real transfer of learned features. 2. to propose intervention invariant auxiliary tasks to improve the performance. 3. to show that world model benefits from contrastive unsupervised representation learning.



Figure 1: Flow diagram of proposed World Model with invariant Causal features (WMC). It consists of three components: i) unsupervised causal representation learning, ii) memory, and iii) controller.

