VARIATIONAL CAUSAL DYNAMICS: DISCOVERING MODULAR WORLD MODELS FROM INTERVENTIONS

Abstract

Latent world models allow agents to reason about complex environments with high-dimensional observations. However, adapting to new environments and effectively leveraging previous knowledge remain significant challenges. We present variational causal dynamics (VCD), a structured world model that exploits the invariance of causal mechanisms across environments to achieve fast and modular adaptation. By causally factorising a transition model, VCD is able to identify reusable components across different environments. This is achieved by combining causal discovery and variational inference to learn a latent representation and transition model jointly in an unsupervised manner. Specifically, we optimise the evidence lower bound jointly over a representation model and a transition model structured as a causal graphical model. In evaluations on simulated environments with state and image observations, we show that VCD is able to successfully identify causal variables, and to discover consistent causal structures across different environments. Moreover, given a small number of observations in a previously unseen, intervened environment, VCD is able to identify the sparse changes in the dynamics and to adapt efficiently. In doing so, VCD significantly extends the capabilities of the current state-of-the-art in latent world models while also comparing favourably in terms of prediction accuracy.

1. INTRODUCTION

The ability to adapt flexibly and efficiently to novel environments is one of the most distinctive and compelling features of the human mind. It has been suggested that humans do so by learning internal models which not only contain abstract representations of the world, but also encode generalisable, structural relationships within the environment (Behrens et al., 2018) . It is conjectured that this latter aspect is what allows humans to adapt efficiently and selectively. Recent efforts have been made to mimic this kind of representation in machine learning. World models (e.g. Ha and Schmidhuber, 2018) aim to capture the dynamics of an environment by distilling past experience into a parametric predictive model. Advances in latent variable models have enabled the learning of world models in a compact latent space (Ha and Schmidhuber, 2018; Watter et al., 2015; Hafner et al., 2019b; Buesing et al., 2018; Zhang et al., 2019) from high-dimensional observations such as images. Whilst these models have enabled agents to act in complex environments via planning (e.g. Hafner et al., 2019b; Sekar et al., 2020) or learning parametric policies (e.g. Hafner et al., 2019a; Ha and Schmidhuber, 2018) , structurally adapting to changes in the environment remains a significant challenge. The consequence of this limitation is particularly pronounced when deploying learning agents to environments, where distribution shifts occur. As such, we argue that it is beneficial to build structural world models that afford modular and efficient adaptation, and that causal modeling offers a tantalising prospect to discover such structure from observations. Causality plays a central role in understanding distribution changes, which can be modelled as causal interventions (Schölkopf et al., 2021) . The Sparse Mechanism Shift hypothesis (Schölkopf et al., 2021; Bengio et al., 2019) (SMS) states that naturally occurring shifts in the data distribution can be attributed to sparse and local changes in the causal generative process. This implies that many causal mechanisms remain invariant across domains (Schölkopf et al., 2012; Peters et al., 2016; Zhang et al., 2015) . In this light, learning a causal model of the environment enables agents to reason about distribution shifts and to exploit the invariance of learnt causal mechanisms across different environments. Hence, we posit that world models with a causal structure can facilitate modular

