VARIATIONAL CAUSAL DYNAMICS: DISCOVERING MODULAR WORLD MODELS FROM INTERVENTIONS

Abstract

Latent world models allow agents to reason about complex environments with high-dimensional observations. However, adapting to new environments and effectively leveraging previous knowledge remain significant challenges. We present variational causal dynamics (VCD), a structured world model that exploits the invariance of causal mechanisms across environments to achieve fast and modular adaptation. By causally factorising a transition model, VCD is able to identify reusable components across different environments. This is achieved by combining causal discovery and variational inference to learn a latent representation and transition model jointly in an unsupervised manner. Specifically, we optimise the evidence lower bound jointly over a representation model and a transition model structured as a causal graphical model. In evaluations on simulated environments with state and image observations, we show that VCD is able to successfully identify causal variables, and to discover consistent causal structures across different environments. Moreover, given a small number of observations in a previously unseen, intervened environment, VCD is able to identify the sparse changes in the dynamics and to adapt efficiently. In doing so, VCD significantly extends the capabilities of the current state-of-the-art in latent world models while also comparing favourably in terms of prediction accuracy.

1. INTRODUCTION

The ability to adapt flexibly and efficiently to novel environments is one of the most distinctive and compelling features of the human mind. It has been suggested that humans do so by learning internal models which not only contain abstract representations of the world, but also encode generalisable, structural relationships within the environment (Behrens et al., 2018) . It is conjectured that this latter aspect is what allows humans to adapt efficiently and selectively. Recent efforts have been made to mimic this kind of representation in machine learning. World models (e.g. Ha and Schmidhuber, 2018) aim to capture the dynamics of an environment by distilling past experience into a parametric predictive model. Advances in latent variable models have enabled the learning of world models in a compact latent space (Ha and Schmidhuber, 2018; Watter et al., 2015; Hafner et al., 2019b; Buesing et al., 2018; Zhang et al., 2019) from high-dimensional observations such as images. Whilst these models have enabled agents to act in complex environments via planning (e.g. Hafner et al., 2019b; Sekar et al., 2020) or learning parametric policies (e.g. Hafner et al., 2019a; Ha and Schmidhuber, 2018) , structurally adapting to changes in the environment remains a significant challenge. The consequence of this limitation is particularly pronounced when deploying learning agents to environments, where distribution shifts occur. As such, we argue that it is beneficial to build structural world models that afford modular and efficient adaptation, and that causal modeling offers a tantalising prospect to discover such structure from observations. Causality plays a central role in understanding distribution changes, which can be modelled as causal interventions (Schölkopf et al., 2021) . The Sparse Mechanism Shift hypothesis (Schölkopf et al., 2021; Bengio et al., 2019) (SMS) states that naturally occurring shifts in the data distribution can be attributed to sparse and local changes in the causal generative process. This implies that many causal mechanisms remain invariant across domains (Schölkopf et al., 2012; Peters et al., 2016; Zhang et al., 2015) . In this light, learning a causal model of the environment enables agents to reason about distribution shifts and to exploit the invariance of learnt causal mechanisms across different environments. Hence, we posit that world models with a causal structure can facilitate modular transfer of knowledge. To date, however, methods for causal discovery (Spirtes et al., 2000; Pearl, 2009; Peters et al., 2017; Brouillard et al., 2020; Ke et al., 2020) require access to abstract causal variables to learn causal models from data. These are not typically available in the context of world model learning, where we wish to operate directly on high-dimensional observations. In order to benefit from the structure of causal models and the ability to represent high-dimensional observations, we propose Variational Causal Dynamics (VCD), which combines causal discovery with variational inference. Specifically, we train a latent state-space model with a structural transition model using variational inference and sparsity regularisation from causal discovery. By jointly training a representation and a transition model, VCD learns a causally factorised world model that can modularly adapt to different environments. The key intuition behind our approach is that, since sparse causal structures can only be discovered on abstract causal variables, training the representation and the causal discovery module in an end-to-end manner acts as an inductive bias that encourages causally meaningful representations. By leveraging the learnt causal structure, VCD is able to identify the sparse mechanism changes in the environment and re-learn only the intervened mechanisms. This enables fast and modular adaptation to changes in dynamics.

2. RELATED WORK

Predictive models of the environment can be used to derive exploration- (Sekar et al., 2020) or reward-driven (Ha and Schmidhuber, 2018; Hafner et al., 2019a; b) Goyal et al., 2021b; Becker-Ehmck et al., 2019) consider latent transition as discrete mechanisms. In a similar vein, the use of latent prediction models have also been explored in the context of video prediction (Villegas et al., 2019; Denton and Fergus, 2018; Assouel et al., 2022) . Our proposed approach shares the general principle that latent representations can be shaped by structured transition mechanisms (Ahuja et al., 2021) . However, to the best of our knowledge, VCD is the first approach that implements a causal transition model with high-dimensional inputs. Causal discovery methods enable the learning of causal structure from data. Approaches can be categorised as constraint-based (e.g. (Spirtes et al., 2000) ) and score-based (e.g. (Hauser and Bühlmann, 2012) ). The reader is referred to (Peters et al., 2017) for a detailed review of causal discovery methods. Motivated by the fact that these methods require access to abstract causal variables, recent efforts have been made to reconcile machine learning, which has the ability to operate on low-level data, and causality (Schölkopf et al., 2021) . Recent advances in this area include theoretical works exploring the conditions under which disentanglement of representation is possible (Yao et al., 2022;  



Figure1: Left: The general architecture of VCD. The dynamics model is trained using observation sequences from multiple environments by minimising the KL divergence between the predicted state distribution and the encoded posterior state distribution. Rollouts in the latent space can be performed by recursively applying the learnt transition model. Right: The structure of the causal transition model. Each dimension of the latent space is treated as a causal variable. Predictions are made using only the causal parents of each variable, according to a learnt causal graph.

behaviours. In this paper, we focus on the learning of latent dynamics models. World models (Ha and Schmidhuber, 2018) train a representation encoder and a RNN-based transition model in a two-stage process. Other approaches(Hafner et al., 2019b; Zhang et al., 2019; Watter et al., 2015)  learn a generative model by jointly training the representation and the transition via variational inference. PlaNet (Hafner et al., 2019b) parameterises the transition model with RNNs. E2C (Watter et al., 2015; Banijamali et al., 2018) and SOLAR (Zhang et al., 2019) use locally-linear transition models, arguing that including constraints in the dynamics model yields structured latent spaces that are suitable for control. Other approaches such as (

