UNSUPERVISED LEARNING OF CAUSAL RELATION-SHIPS FROM UNSTRUCTURED DATA Anonymous authors Paper under double-blind review

Abstract

Endowing deep neural networks with the ability to reason about cause and effect would be an important step to make them more robust and interpretable. In this work we propose a variational framework that allows deep networks to learn latent variables and their causal relationships from unstructured data, with no supervision, or labeled interventions. Starting from an abstract Structural Equation Model (SEM), we show that maximizing its posterior probability yields a similar construction to a Variational Auto-Encoder (VAE), but with a structured prior coupled by non-linear equations. This prior represents an interpretable SEM with learnable parameters (such as a physical model or dependence structure), which can be fitted to data while simultaneously learning the latent variables. Unfortunately, computing KL-divergences with this non-linear prior is intractable. We show how linearizing arbitrary SEMs via back-propagation produces local non-isotropic Gaussian priors, for which the KL-divergences can be computed efficiently and differentiably. We propose two versions, one for IID data (such as images) which detects related causal variables within a sample, and one for non-IID data (such as video) which detects variables that are also related over time. Our proposal is complementary to causal discovery techniques, which assume given variables, and instead discovers both variables and their causal relationships. We experiment with recovering causal models from images, and learning temporal relations based on the Super Mario Bros videogame.

1. INTRODUCTION

Human reasoning and decision-making is often underpinned by cause and effect: we take actions to achieve a desired effect, or reason that events would have happened differently had we acted a certain way -or if conditions had been different. Similarly, scientific inquiry uses the same tools, albeit more formalized, to build knowledge about the world and how our society can affect it (Popper, 1962) . When building algorithms that automatically build statistical models of the world, as is common in machine learning practice, it would then be desirable to imbue them with similar inductive priors about cause and effect (Glymour et al., 2016) . In addition to being more robust than statistical models which only characterize the observational distribution (Peters et al., 2017) , they would allow reasoning about changing conditions outside the observed distribution (e.g. counterfactual reasoning). They would also allow communicating their inner workings more effectively -allowing us to ask "why" a given conclusion was reached, much in the same way that we do in scientific communication. Despite still being actively researched, there is now a mature body of work on understanding whether two or more variables are related as cause and effect (Peters et al., 2017) . Many techniques assume that the variables are given, and concern themselves with finding relationship between them (Spirtes & Glymour, 1991; Chickering, 2003; Lorch et al., 2021) . On the other hand, an advantage of modern deep neural networks is that they learn intermediate representations that do not have to be manually labeled (Yosinski et al., 2015) , and effective models can be trained without supervision (Kingma & Welling, 2014 ). An important question then arises: can a deep network simultaneously discover latent variables in the data and establish cause-effect relationships between them? We focus on learning Additive Noise Models (ANM) with Gaussian noise, which are identifiable (i.e. causal directions are distinguishable) as long as the functions relating the variables of interest are not linear (Hoyer et al., 2008) . This model fits well a variational learning framework, and so we are able to derive an analogue of a Variational Auto-Encoder (VAE) (Kingma & Welling, 2014) where the prior, rather than being an uninformative Gaussian, corresponds exactly to the ANM. When the ANM is linear with Gaussian noise, the joint probability of the variables also becomes Gaussian, and it is easy to perform variational inference. The dependencies between variables will then be expressed in the covariance matrix's sparsity structure. However, as mentioned earlier to make the causal directions identifiable the model cannot be linear (Hoyer et al., 2008) . We resolve this difficulty by learning models that are locally linear, but globally non-linear. This approach affords the full generality of a non-linear ANM, with the simplicity of variational inference on Gaussian models. In summary, our contributions are: • A rigorous derivation of the variational Evidence Lower Bound (ELBO) of an Additive Noise Model (ANM), allowing efficient inference of Structural Equation Models (SEM) with deep networks. • A linearization method leveraging automatic differentiation to construct a local Gaussian approximation of arbitrary non-linear ANMs. • A temporally-aware specialization of the causal ANM that encodes causal directions implicit in the arrow-of-time and is suitable for high-dimensional time series data such as video. • Experiments demonstrating that the proposed method is able to fit latent variables with a dependence structure in high-dimensional data, namely a synthetic image dataset and video game based data.

2. RELATED WORK

Our work lies on the intersection of causality, variational inference, representation learning, and high-dimensional unstructured input domains. Causal inference deals with determining the causes and effects from data. Causal discovery methods generally focus on recovering the causal graph responsible for generating the observed data, e.g. Spirtes & Glymour (1991) ; Chickering (2003) (for an overview of methods in see Peters et al. (2017) ). However, these methods are largely applied to structured datasets such as medical (Brooks-Gunn et al., 1992; Sachs et al., 2005; Louizos et al., 2017) or economics data LaLonde (1986) where the observed variables are provided by domain specialists. In contrast, we focus on unstructured data where the variables are not provided a priori. Variational inference is a way of performing inference by solving an optimisation problem. A popular instance is the Variational Auto-Encoder (VAE) (Kingma & Welling, 2014) which aims to extract a useful latent representation of the data by encoding and decoding it back. Traditionally the VAE prior is assumed to be an isotropic Gaussian distribution and the aim is to extract independent latent variables such as in the β-VAE (Higgins et al., 2016b) and FactorVAE (Kim & Mnih, 2018) . There are works which use hierarchical priors such as iteratively conditioning each variable on its preceding variable in the Ladder-VAE (Sønderby et al., 2016) and conditioning each variable on all its predecessors in NVAE (Vahdat & Kautz, 2020) and VDVAE (Child, 2021) . We also use a prior conditioning each variable on its predecessors but this comes as a natural consequence of basing our prior on a structural equation model (SEM). Recently there has been a growing interest in representation learning based on causal principles. For instance, the CausalVAE (Yang et al., 2021) 

3. BACKGROUND

In this section, we will give a self-contained overview of several results from variational and causal inference that we build upon. While they are not new, bringing them together under one formulation offers new insights and challenges, which we solve in sec. 4. Our goal is to fit a distribution p(x), defined over an input space x ∈ R m , to an empirical distribution p(x) composed of finite samples, by choosing the optimal p ∈ P out of a set of candidate distributions P (e.g. parameterized by a neural network). We do this by minimizing their KL-divergence D x computed over x, or equivalently maximizing the expected log-likelihood of p over the dataset p: p * = arg min p∈P D x (p(x)||p(x)) = arg max p∈P E x∼ p(x) [ln p(x)]. We now introduce a set of latent variables z ∈ R d , which we can marginalize to compute p(x) = p(z)p(x|z)dz, in terms of a conditional distribution p(x|z) and a "prior" distribution over latents p(z). In a standard VAE, this prior is an isotropic Gaussian distribution (Kingma & Welling, 2014), while other structured priors are possible (Sønderby et al., 2016; Vahdat & Kautz, 2020; Tomczak & Welling, 2018) . In this work, however, we will define it as a Structural Equation Model (SEM) (Pearl, 2009 ) (section 3.2).

3.1. VARIATIONAL INFERENCE

We can now apply standard tools of variational inference (Bishop, 2006, Ch. 10) to eq. 1 and replace the intractable marginalization (eq. 2) with an optimization of q ∈ Q over a variational family of distributions Q (in essence, training an additional neural network q). Eq. 1, when marginalized (eq. 2) is equivalent to (Kingma & Welling, 2014) : p * = arg max p∈P, q∈Q E x∼ p(x) E z∼q(z|x) [ln p(x|z)] -D z (q(z|x)||p(z)) . The first term in eq. 3 amounts to a reconstruction error, and the second term matches the latent variables to the prior. In practice, q(z|x) and p(x|z) in eq. 3 are often defined as Gaussian distributions parameterized by neural networks. These are the functions µ q (x) ∈ R d and Σ q (x) ∈ R d×d for the encoder q, and µ p (z) ∈ R d and σ p (z) ∈ R d for the decoder p, which parameterize the means and covariances of the distributions: q(z|x) = N (z|µ q (x), Σ q (x)), p(x|z) = N (x|µ p (z), diag(σ p (z))). (4) The KL-divergence between prior p(z) and variational posterior q(z|x) can be computed in closed form when both are Gaussians. Note that the encoder usually outputs a diagonal covariance matrix (motivated by the fact that an isotropic prior is also diagonal), but here we allow it to output full covariances Σ q (x), which will be important later (sec. 4).

3.2. STRUCTURAL EQUATION MODEL

We now consider a set of variables y, which have a dependency structure defined by a directed acyclic graph (DAG) G (i.e. there is an edge (i → j) ∈ G if y i is required to compute y j ). Additionally, the generation process for each variable y i can be described by a sequence of non-linear equations: y i = f i (y pa G (i) ) + n i , where f i : R |pa G (i)| → R 1 is an arbitrary deterministic function, pa G (i) denotes the indices of parent nodes of i in G, and n i ∼ N (n i |0, σ 2 i ) is an independent zero-mean noise variable, assumed to be Gaussian. This represents a Structural Equation Model (SEM), more specifically an Additive Noise Model (ANM) (Peters et al., 2017, Ch. 4.1.4) with Gaussian noise, which has the joint probability: p(y) = d i=1 p i (y i | y pa(i) ) = d i=1 N (y i -f i (y pa(i) ) | 0, σ 2 i ). Practical methods for causal learning with ANMs typically assume the variables y are observed, and are concerned with recovering the true causal graph G (and sometimes the functions f ) that generated the data (Hoyer et al., 2008; Peters et al., 2017) .

3.2.1. SCORE-BASED MODEL SELECTION

The formulation so far has assumed that the causal graph G, which encodes all the dependencies between variables, is given. We can, however, make use of a result by Nowzohour & Bühlmann (2016) that shows that a penalized likelihood score can be used to select between different models p * G (y) (each fit to an empirical distribution p(y)), with each assuming a different graph G ∈ G (e.g. out of all graphs with up to d nodes). We thus select the graph that has the maximum score: G * = arg max G∈G 1 s s y∼ p(y) ln p * G (y) - |G| ln s , with s the number of samples, and |G| the number of edges of G. This method finds the true causal graph when the causal dependencies in the ANM are all non-linear (Nowzohour & Bühlmann, 2016) .

4. METHOD 4.1 VARIATIONAL INFERENCE WITH A CAUSAL PRIOR

The previous exposition suggests a natural way to simultaneously learn latent variables and fit a causal model: use the ANM (eq. 6) as the prior in the variational optimization problem (eq. 3), by setting y ≡ z. This amounts to assuming that the latent variables z have a dependency structure defined by a DAG G, and are generated sequentially by application of the non-linear functions f i corrupted by Gaussian noise n i . Model selection (i.e. finding the optimal graph G) can then be done by the score-based search of sec. 3.2.1. Note that this model reduces to a VAE for a DAG with no edges and thus null functions f i : ∀i : pa(i) = ∅ ⇒ z i = n i (reduction to standard VAE) At a high level, the same tools used to train a VAE should be applicable in this new setting. However, the ANM is only identifiable (i.e. the true causal directions can be recovered) if the f i are nonlinear (Hoyer et al., 2008) , and this makes computing the KL-divergence in eq. 3 intractable, since it is no longer defined in closed form. We will resolve this difficulty by local linearization (sec. 4.4), although the model will still be globally non-linear to ensure identifiability. One major difference from the exposition by Nowzohour and Bühlmann is that they propose using a non-parametric fitting method over known variables x, while we want to simultaneously recover the causal structure and latent variables z (e.g. unknown parameters of objects depicted in images), which must be estimated from inputs x (e.g. raw pixels). Another difference is that their procedure is computationally expensive, as it entails enumerating all graphs G ∈ G explicitly. However, given the expressiveness of deep neural network models, in sec. 4.2 we will show how we can replace this search with a simpler model fitting procedure.

4.2. MAXIMAL DAG

For the purposes of fitting the model G and f to observed data x, we will consider a simplification where we take all ancestors of a variable z i to be its parents pa(i), i.e. we replace G with the full DAG with edges {i → j : i < j, i, j = 1, . . . , d}. The following proposition shows that this will be able to model any underlying SEM, although some of the independence relations will have to be modeled by the learnable edge probabilities described in sec. 4.5. Proposition 1. The set of all SEMs S = {(z i = f i (z pa G (i) )) d i=1 | ∀G, f } with arbitrary DAGs G and probabilistic functions f is contained within the set S Ω = {(z i = f i (z 1,...,i-1 )) d i=1 | ∀f }, up to reorderings of the variables z. Proof. See Appendix A. Prop. 1 says that, as long as the function class of f is expressive enough (as is usually the case with deep networks (Yosinski et al., 2015) ), we can find an equivalent SEM using a fixed maximal DAG G Ω with edges {i → j : i < j, i, j = 1, . . . , d}. This simplifies the exposition in the following sections and provides justification for the approximation in sec. 4.5.

4.3. PRIOR PROBABILITY FOR LINEAR SEMS

In the simple case when all functions f i of the ANM are linear, the resulting joint probability must be Gaussian, as it is a linear combination of independent Gaussian noise variables n i (eq. 5, recalling that y ≡ z). However, we need to compute its precise form for general linear ANMs. Theorem 2. Consider a linear ANM defined as z i = a T i z 1,...,i-1 + b i + n i , with a i ∈ R i-1 , b i ∈ R and n i ∼ N (n i |0, σ 2 i ) for i = 1, . . . , d. Missing edges in the causal graph can be represented as zeros in a i . Then the ANM's joint probability is given by p (z) = N (z | µ, Σ) with µ = 2 i=d A i b, Σ = 2 i=d A i diag i=1,...,d (σ 2 i ) 2 i=d A i T , A i =   I (i-1)×(i-1) O i-1 O (i-1)×(d-i) a T i 1 O T d-i O (d-i)×(i-1) O d-i I (d-i)×(d-i)   , b =    b 1 . . . b d    , where I k×k denotes an identity matrix, while O k and O k×l denote a zero column vector and matrix. Proof. See Appendix C. The process in Theorem 2 can be interpreted as a form of mean and covariance propagation: at each stage the ANM applies a linear transformation to the mean and covariance from the previous stage. The matrix 2 i=d A i is a lower-triangular matrix representing the edge strengths from parent to child nodes in G, obtained from the SEM. This linear transformation is then applied to the SEM biases b to obtain the total mean µ, as well as to the SEM noise variances σ 2 i to obtain the total covariance Σ. For identifiability, however, we must generalize this process to non-linear ANMs, which we do next.

4.4. PRIOR PROBABILITY FOR NON-LINEAR SEMS

Our approach to deal with non-linear ANMs is to linearize them around a pivot point z • . Due to Taylor's theorem this approximation will be accurate for a sufficiently small neighborhood (Nocedal & Wright, 1999, Ch. 1) . The ANM's joint probability p(z) will then be Gaussian by Theorem 2. Theorem 3. The best linear approximation (in the least-squares sense) of an ANM z i = f i (z 1,...,i-1 ) + n i (eq. 5, with missing edges in the causal graph corresponding to ignored inputs in f i ), around a pivot point z • , is given by eq. 9 (Theorem 2) with a i = ∂f i (z 1,...,i-1 ) ∂z 1,...,i-1 z1,...,i-1=z • 1,...,i-1 and b i = f i (z • 1,...,i-1 ) + n i -a T i z • 1,...,i-1 , where z o 1 = n 1 , z o 1,...,i = f (z 1,...i-1 ) + n i and n i ∼ N (0, σ 2 i ) with learnable σ i . Proof. See Appendix D. The advantage afforded by Theorem 3 is that the ANM's joint probability p(z) is locally Gaussian, and we can compute its parameters by mean and covariance propagation (eq. 9). Another advantage is that, in the context of training deep networks, one can use automatic differentiation (back-propagation) to compute eq. 10 for arbitrary functions f i , including very expressive ones such as multi-layer perceptrons (MLP). Since this locally-linear SEM is represented as an explicit Gaussian distribution, the objective's KL-divergence (eq. 3) can be computed in closed form (Kingma & Welling, 2014) . We can then obtain the full prior distribution by sampling many pivot points z o by ancestral sampling according to Theorem 3.

4.5. GRAPH SEARCH VIA SPARSITY

To efficiently search for the graphs in eq. 7, we use Proposition 1 to justify using a fixed maximal graph G Ω to learn a single SEM p f (z) (where we make explicit the dependence of p(z) on the learned causal functions f i , c.f. eq. 5). This is in contrast to the previously-proposed technique of enumerating all graphs G ∈ G and learning a different p f,G (z) for each G (Nowzohour & Bühlmann, 2016) . We define the graph of dependencies: Definition 4. We define the implicit dependency graph G(f ) as the graph with d nodes, and each edge i → j exists if f j (z 1,...,j-1 ) depends on its ith input. VAE (β = 1) VAE (β = 4) Causal-prior VAE We can then define the overall objective (from eq. 7 and eq. 3) as p * = arg max p∈P, q∈Q, f 1 s s x∼ p(z) E z∼q(z|x) [ln p(x|z)] -D z (q(z|x)||p f (z)) - |G(f )| ln s . ( ) To approximate the |G(f )| term that counts the edges of the graph G(f ), we introduce edge probabilities p i→j denoting the probability that an edge from node i to j exists, which allows us to write |G(f )| = j,i<j p i→j (i.e. for binary probabilities, this recovers the exact edge count). We define each p i→j as a binary Gumbel-Softmax distribution (Jang et al. (2017) , which is suitable for gradient-based optimization. The edge probabilities mask the inputs of the SEM's functions f j (z 1 , . . . , z j-1 ), as f j (p 1→j z 1 , . . . , p j-1→j z j-1 ). During model evaluation we can sample binary edge probabilities p i→j ∈ {0, 1}, producing a discrete graph G. We can thus use standard stochastic gradient methods to optimize eq. 11 and obtain an encoder q, decoder p, and SEM defined by f and p i→j , the latter of which models well-defined causal directions (as long as the underlying functions are non-linear) and reflects independence relations in the data.

4.6. IDENTIFIABILITY

In Theorem 7, we demonstrate the identifiability of the proposed model up to some unavoidable indeterminacies. The only transformations that can be implemented by the learned deep networks (encoder, decoder and SEM) and can result in an identical objective value to the optimal model are reorderings of the latent variables, and as well as shifts and orthogonal transformations of the input (implemented by both the encoder at the input, and decoder at the output). The conditions required to achieve this result depend on common features of standard deep networks, namely an encoder and decoder composed of ReLUs and linear operators, batch-normalization in the outputs of the SEM f , and a SEM with fixed scale, such as the quadratic used in experiments. The proof is given in Appendix B.

5. EXPERIMENTS

In this section we validate our method on two experiments: the first one learning atemporal variables and the second one learning time-varying variables. In the atemporal experiment (Section 5.1) we show that our method correctly recovers the positions of a shape placed on a parabola given a fixed quadratic prior. In the time-varying experiment (Section 5.2) we show that our method correctly recovers the positions of a moving character and their relationships over time, using a learnable linear prior.

5.1. LEARNING ATEMPORAL CAUSAL RELATIONS

In this section we demonstrate the method's ability to learn causal relationships that are not time dependent. We use a dataset of images of oval shapes at different rotations and scales, placed at VAE (β = 1) VAE (β = 4) Causal-prior VAE positions that follow a parabolic arc y = x 2 + n (with noise n). We thus have a causal graph with a single edge x → y. We train our causal-prior VAE with a 3-layer MLP and a parabolic prior (see Appendix E for details), and use the L 2 loss between each image and its prediction from the decoder as the reconstruction error (see Appendix ??). We also train a high-β and low-β VAEs for comparison (where β denotes the multiplicative factor of the KL divergence as in (Higgins et al., 2016a) ). Latent space visualisation. In Figure 1 we show the prior (blue) and the predicted distribution (red) for the low-β VAE, the high-β VAE, and our causal-prior VAE. For the high-β VAE (center) the two distributions match, but the latent variables are entangled (i.e. both follow an isotropic Gaussian distribution), and so do not reflect the underlying parabolic relationship between the variables that generate the data (y = x 2 + n). For the low-β VAE (left), although the observed variables recover the underlying parabolic shape (red), up to a vertical reflection of the coordinates, this distribution does not match the prior at all (blue), since it is constrained to be Gaussian. As such, this generative model p(x) (eq. 2) cannot generate samples that respect the causal relationships in the empirical data distribution. For our causal-prior VAE (right), the two distributions match and both follow the underlying noisy parabolic equation. Our model can thus be used as a generator p(x) that reproduces the true causal structure of the data, a capability that we will explore next in more detail. Latent response to position. Figure 2 shows for different VAEs the value of each latent variable as the position of the shape in the input image is varied, averaged over different rotations and scales. The variables are ordered by their standard deviation (σ). The goal is to assess how sensitive each variable is to the depicted shape's position. For the causal-prior VAE, the variable z 1 clearly corresponds to the shape's x position (with values linearly increasing towards the left), and variable z 2 to the y position (with values increasing towards the top), while other variables show minimal response. For the high-β VAE the position is entangled across several variables; for the low-β VAE the variables z 4 and z 1 correspond approximately to shape's x and y position, but there is still significant entanglement with the other variables, which are also highly sensitive to the shape's position. Latent noise traversal. In the first and third columns of Figure 3 , we show grids of images decoded by each method, obtained by taking one sample and varying the noise n i of the two latent variables z i (see eq. 5) that best correspond to horizontal and vertical position. In the first column, decoded images are color-coded by their horizontal position. Since the positions of shapes in this grid are difficult to compare, the second column shows the same images but superimposed, which reveals the underlying parabolic structure of the data. The third column shows the same grid of decoded images, color-coded by vertical position, while the fourth shows them superimposed. While all methods seem to generate images following roughly the parabolic data distribution, we can observe that in both β-VAE configurations the shapes' positions are entangled (i.e. vary jointly) with their rotation and scale, while the same is not true for the proposed causal-prior VAE. It is interesting to note that in the later case, the noise n 1 (corresponding to variable z 1 ) does not map directly to the shape's vertical position (as observed for the β-VAE), but rather expresses its offset from the parabola, according to z 1 = z 2 2 + n 1 . In the bottom-right panel, the red tint of images generated from positive noise offsets (n 1 > 0) is visible above the parabola, and the blue tint of negative offsets is visible below.

5.2. LEARNING TEMPORAL CAUSAL RELATIONS

We now evaluate our method on temporal data, namely short videos. We apply the encoder and decoder independently to each frame, and append an encoded background frame to the the decoder's VAE (β = 1) VAE (β = 4) Causal-prior VAE Figure 3 : Reconstructed images (as a 2D grid and as superimposed images) as latent noise values are traversed at regular intervals. For the β = 1 and β = 4 VAEs the depicted shape's position is entangled with its rotation and scale, while for the causal-prior VAE the horizontal position and vertical offset from the parabola are disentangled correctly. input (to allow the latent variables to ignore background/time-invariant appearance). The SEM only contains edges in the direction of the arrow of time (see Appendix E for details). We create a dataset based on the Super Mario Bros video game depicting an object moving with linear motion in different directions, at different speeds, and on varied backgrounds. Three frames are observed, with the object's coordinates (x i , y i ) for i ∈ {1, 2} sampled uniformly from [7, 12] , and the third obeying x 3 = 2x 2 -x 1 , y 3 = 2y 2 -y 1 . We train our method with a linear prior, and otherwise similarly to Sec. 5.1. (see Appendix F for details) and select the sparsest graph that still achieves the best reconstruction accuracy. We use the reconstruction error given by 3 t=1 L 2 (x t , xt ) + L 2 (x 3 , xf 3 ) where x t denotes a frame at time t, xt denotes a prediction made using the decoder p(x t |z t ) and xf 3 denotes a prediction made using the decoder p(x f 3 |f 3 (z 1 , z 2 )) (see Appendix ?? for details). For quantitative evaluation of learned graphs please refer to Appendix G. Learning a causal graph. Figure 4 shows the graph learned by the causal-prior VAE, with the bottom row (z 1 , z 2 , z 3 , z 4 ) corresponding to latent variables at time t = 1, the middle row (z 5 , z 6 , z 7 , z 8 ) corresponding to t = 2 and the top row (z 9 , z 10 , z 11 , z 12 ) corresponding to t = 3. Learned edges are shown in black and missing ones in beige; learned functions f i are represented as a number (multiplicative factor) next to the corresponding edge, since they are linear in this case. The graph shows that each component of the object's 2D position at time t = 3 (expressed as (z 10 , z 11 )) depends on its 2D positions at times t = 1 (z 2 , z 3 ) and t = 2 (z 6 , z 7 ), following the model z 10 = f 10 (z) = 2.3z 6 -1.1z 2 and z 11 = f 11 (z) = 2.2z 7 -1.1z 3 , which matches the data generation process up to a scale factor. Latent response to position. Figure 5 shows the response of each latent variable as the object's horizontal and vertical position is varied on the input. The rows correspond to variables from times t = 3, 2, 1 respectively. We can see that z 2 , z 6 , z 10 encode the object's position along the bottom-right to top-left diagonal at different moments in time, while z 3 , z 7 , z 11 encode the position along the orthogonal direction. The remaining variables do not exhibit any significant response to position. This shows the variables correctly match the data generation process up to a rotation.

Interventions visualisation.

Once a SEM relating the latent variables has been learned, we can perform interventions on some of its variables. In Figure 6 we encode a randomly sampled reference clip from the dataset and separately intervene by assigning different values (rows) to variables z 1 , ..., z 12 (columns), propagating the changes to their child variables using the SEM, and decoding the results. The results are decoded using a black background and the resulting 3-frame decoded sample is combined into a single frame by colouring the t = 1 decoded image by red, the t = 2 image by white, and t = 3 by green. We observe that intervening on the variables for t = 1 (z 2 , z 3 ) changes both the object's position in t = 1 (red, columns 1,2) and t = 3 (green, columns 1,2), following the SEM. Intervening on the variables for t = 2 (z 6 , z 7 ) changes the object's t = 2 position (white, columns 3,4) and t = 3 position (green, columns 3,4); while intervening on the variables for t = 3 (z 10 , z 11 ) changes only its t = 3 position (green, columns 5,6), which reflects the fact that other positions (in the past) do not depend on it. Intervening on other variables has no effect (last column). This confirms that the model has learned the correct temporal causal structure, as each intervention affects the right variables. Extrapolation visualisation. Having learned a SEM relating those at time t to the variables at times (t -1, t -2) now makes it possible to extrapolate the variable values into the future. In Figure 7 we start with 3 samples (rows) of 2 consecutive frames from the dataset x 1 , x 2 , encode them to obtain the t = 1, 2 latent variables (z 1 , ..., z 8 ), and iteratively pass them through the SEM to compute the t = 3 latent variables (z 9 , ..., z 12 ), decoding these into x3 . We now repeat this process by taking z 5 , ..., z 12 as the t = 1, 2 variables and use the SEM to predict the t = 3 variables which we decode into x4 and similarly for x5 , thus obtaining predictions for 3 time steps into the future. The future predictions confirm the knowledge learned in the SEM that the object moves linearly and can be used to predict future frames accurately.

6. CONCLUSIONS

In this work we proposed a general model that naturally extends the variational learning framework to learn a non-linear Structural Equation Model (SEM) as the prior, which enables causal learning to be performed on perceptual modalities such as images and video. To reconcile the non-linearity of SEMs while using Gaussian variables we proposed a fully differentiable method that locally linearises the SEM to obtain a locally-Gaussian distribution. Furthermore, we relaxed the search over causal graphs as a joint continuous optimization over non-linear causal functions f . The proposed method shows promise to scale to high-dimensional data and moderately complex SEMs, however future work should explore more large-scale data such as long videos and other input modalities. Proof. Because G is a DAG it has an acyclic ordering of its vertices (Bang-Jensen & Gutin, 2009) ; i.e. if G has vertices V = {v 1 , ..., v n } we can always construct a sequence (z 1 , ..., z n ) such that the set Z = {z 1 , ..., z n } is mapped one-to-one to the set V , and such that z i is an ancestor of z j (i.e. there exists a path from z i to z j in G) for all i < j. Therefore, the set of parents pa G (z j ) has to be a subset of the set {z i |i < j}, from which Proposition 5 follows.

B JOINT IDENTIFIABILITY OF REPRESENTATION AND CAUSAL MECHANISM

Before proving our identifiability result, we must first introduce a lemma about the commutativity of certain piece-wise linear functions. Lemma 6. Consider a Leaky ReLU activation function (He et al., 2015) , defined as: R C,D (x) = Cx if x ≥ 0 Dx if x < 0 , with constants C > 0, C ̸ = 1, D > 0, D ̸ = 1, C ̸ = D. Then the class of functions ϕ that commute with R, i.e. ϕ(R C,D (x)) = R C,D (ϕ(x)), ∀x is the set of monotonic piece-wise linear homogeneous functions with one piece in each half-plane, i.e. ϕ(x) = Ax if x ≥ 0 Bx if x < 0 , A, B > 0. Proof. To solve the commutativity identity in Eq. 13, we partition it into four domains, namely {x ≥ 0, x < 0} × {ϕ(x) ≥ 0, ϕ(x) < 0}: If x < 0, ϕ(x) < 0 ⇒ Cϕ(x) = ϕ(Cx) ⇒ ϕ(x) = Ax (15) If x < 0, ϕ(x) ≥ 0 ⇒ Cϕ(x) = ϕ(Dx) ⇒ ϕ(x) = 0 (16) If x ≥ 0, ϕ(x) < 0 ⇒ Dϕ(x) = ϕ(Cx) ⇒ ϕ(x) = 0 (17) If x ≥ 0, ϕ(x) ≥ 0 ⇒ Dϕ(x) = ϕ(Dx) ⇒ ϕ(x) = Bx Equations 15 and 18 amount to solving the equality Ef (x) = ϕ(Ex) which is satisfied for any homogeneous linear function ϕ assuming E ̸ = 1. Equations 16 and 17 amount to solving the equality Ef (x) = ϕ(F x) which is only satisfied for ϕ(x) = 0 assuming E ̸ = 1, E ̸ = F . Combining the results from Equations 15-18 the result follows. 1. The encoder q(z|x) and decoder p(x|z) are Gaussians (Eq. 4) parameterized by deep networks, containing a Leaky ReLU (Eq. 12) as its last layer and first layer, respectively. 2. The outputs of the SEM f i are de-meaned and normalized, i.e. they are composed with batch normalization (BN) operators: z i = BN(f (z pa(i) )) where BN(u) = (u -E[u])/ Var[u], or they have a fixed constant scale. Then the model is identifiable up to the following indeterminacies, i.e. denoting the optimal parameters by θ * , there is a different set of parameters θ such that p θ * (x) = p θ (x) only under the following learnable transformations: 1. Simultaneous shifts by b and orthogonal transformations R of the encoder's input and decoder's output, i.e.: µ q (x) ← µ q (Rx + b), µ p (z) ← R -1 µ p (z) -b. 2. Latent variable permutations, i.e. reorderings of z. 3. If the SEM mechanisms f i contain additional symmetries (e.g. f i = f i • S for some operator S), then there will be indeterminacy up to application of that operator. Proof. The training objective is to find the model parameters θ of the generative distribution p θ (x) such that it matches the observed empirical distribution p(x). The parameters are assumed to exist if the function class of p θ (x) is sufficiently expressive, as is the case in over-parameterized deep networks. We also assume that there exists a true model θ * which generates the data distribution p θ * (x). Then the claim of non-identifiability is that there exists at least one different parameterization θ for which p θ (x) = p θ * (x), i.e. it is possible to learn another model that also fits the true data distribution but has different parameters θ. We can express the lower bound on the data distribution using the evidence lower bound (ELBO, Eq. 3) ln p θ * (x) ≥ E x∼p(x) [E q θ (z|x) [ln p θ (x|z)] -KL[q θ (z|x)||p θ (z)]] with encoder q θ (z|x), decoder p θ (x|z), and learnable prior p θ (z) generated using the SEM (slightly overloading θ to include all parameters of the model). Under non-identifiability, we assume that we can find another set of parameters θ that fit the data perfectly: p θ (x) = p θ * (x). In that case Eq. 19 becomes a strict equality, implying the following two conditions: ln p θ * (x) = E x∼p(x) [E q θ (z|x) [ln p θ (x|z)]] (20) 0 = E x∼p(x) [KL[q θ (z|x)||p θ (z)]], In words, the expectation of the log likelihood (first term of right-hand side of Eq. 19) becomes the exact log evidence ln p θ * (x), while the posterior q θ (z|x) fits the prior p θ (z) exactly (their KLdivergence is zero). Furthermore, for the optimal model θ * we know that ln p θ * (x) = E x∼p(x) [E q θ * (z|x) [ln p θ * (x|z)]] Note the subtle difference from Eq. 20 is to use θ * instead of θ. We can thus combine the results for both models θ * (Equation 22) and θ (Equation 20) as E x∼p(x) [E q θ (z|x) [ln p θ (x|z)]] = E x∼p(x) [E q θ * (z|x) [ln p θ * (x|z)]] We will write the left-hand side (LHS) of Eq. 23 in a form that makes its equivalence classes more obvious, using function composition. In order to do that, we first recall Eq. 4, which we rewrite here (making θ explicit) for convenience: q θ (z|x) = N (z|µ θ (x), Σ θ (x)) The reparametrization trick, which is at the core of VAE implementations, allows us to sample from the probabilistic encoder q θ (z|x) with a deterministic function q θ,ϵ (x) (typically a deep network) that takes samples ϵ from a fixed normal distribution and applies an affine transformation to them: q θ,ϵ (x) = µ θ (x) + Σ θ (x) ϵ, ϵ ∼ N (0, 1). ( ) The inner expectation in the LHS of Eq. 23 can then be written as: E z∼q θ (z|x) [ln p θ (x|z)] = E ϵ∼N (0,1) [ln p θ (x|q θ,ϵ (x))]. By Eq. 4 the decoder p θ (x|z) is also a Gaussian with parameterized mean µ p θ (z),foot_0 so Eq. 23 is equivalent to: E x∼p(x) E z∼q θ (z|x) ln p θ (x|z) = E x∼p(x) E z∼q θ (z|x) -γ ∥µ p θ (z) -x∥ 2 (27) = E x∼p(x) E z∼q θ (z|x) [L x (µ p θ (z))] , L x (x ′ ) = -γ ∥x ′ -x∥ 2 , absorbing constants into γ. Finally, the LHS of Eq. 23 is then equivalent to the expectation of a composition of deterministic functions (applied from right to left): E x∼p(x),ϵ∼p(ϵ) [L x • µ p θ • q θ,ϵ • x]. This makes explicit the order of operations in computing the reconstruction loss, and separates out all non-deterministic elements into the expectation (namely by using the reparameterization trick with ϵ ∼ p(ϵ) = N (0, 1)). We can now enumerate the equivalence classes of Eq. 23, by inserting identity operators I = g -1 •g = h -1 • h (for arbitrary invertible functions g, h) between compositions of learnable functions in Eq. 29. Inserting these identities, Eq. 29 is equivalent to E x∼p(x),ϵ∼p(ϵ) [L x • (h • h -1 ) • µ p θ • (g -1 • g) • q θ,ϵ • (h • h -1 ) • x] (30) = E x∼p(x),ϵ∼p(ϵ) [(L x • h) • (h -1 • µ p θ • g -1 μp θ ) • (g • q θ,ϵ • h qθ,ϵ ) • (h -1 • x x )] where we grouped the arbitrary functions with the decoder as μp θ , with the encoder as qθ,ϵ , and with the input as x. We can then apply the substitution x = h -1 • x, using the fact that L x (x ′ ) = γ ∥h(x ′ ) -h(x)∥ 2 = L h•x (h • x ′ ) is only ever true if h is an orthogonal linear transformation plus a constant (since Euclidean distances are invariant only under generalized rotations and translations),foot_1 and thus obtaining the equivalent reparameterization E x∼p(x),ϵ∼p(ϵ) [L x • μp θ • qθ,ϵ • x]. (33) Thus the function class of h must necessarily be restricted to orthogonal transformations, induced by the Euclidean structure of L (Eq. 32). As for g, it is restricted by the fact that it must be absorbed into the encoder, i.e. in Eq. 31 we must be able to group it with the previous encoder q θ,ϵ and define a new encoder qθ,ϵ = g • q θ,ϵ • h (34) that can still be implemented as a deep network (of the same function class as q θ,ϵ ). Since we require that the encoder be followed by a Leaky ReLU R C,D (assumption 1 of Theorem 7), for this to be true, g must commute with R C,D , i.e.: qθ,ϵ = g • q θ,ϵ • h = g • R C,D • q ′ θ,ϵ q θ,ϵ • h = R C,D • g • q ′ θ,ϵ • h (35) where the first equality is taken from Eq. 31, in the second equality we decompose the network q θ,ϵ into its final Leaky ReLU R C,D and the rest of the network q ′ θ,ϵ , and finally in the last step we use the commutativity of g and R C,D (Lemma. 6). Since by Lemma. 6 only monotonic piece-wise linear homogeneous functions commute with R C,D , this restricts the class of admissible functions for g to that class. An identical conclusion follows for g -1 , as long as the first layer of the decoder is also a Leaky ReLU (by assumption 1). This means that the latent variables are non-identifiable up to g implementing individual rescaling of each variable and variable permutation, unless we impose further constraints. We can impose mild structural constraints on the SEM to fix the rescaling, namely assuming that the output of each f θ,i (z pa(i) ) has fixed scale. This can be achieved by either: 1) batch-normalization; 2) a SEM with no learnable scale. Batch-normalization, which is used by many common deep network models, fixes the distribution mean to zero and the variance to one for each dimension, removing scaling and shifting degrees-of-freedom. As a special case, the quadratic model used in Section 5.1) does not have learnable scale parameters. The only remaining degrees of freedom are variable permutations, which are unavoidable in latent variable models. Having characterized the equivalence classes of Eq. 20, we must now consider those of Eq. 21. Using the same reparameterization trick as in Eq. 25, the KL-divergence in Eq. 21 is equivalent to KL[q θ (z|x)||p θ (z)] = E z∼q θ (z|x) [ln q θ (z|x) -ln p θ (z)] (36) = E ϵ∼p(ϵ) [ln q θ,ϵ (x) -ln p θ (q θ,ϵ (x))] Therefore, we can also express it using a composition of operators: E x∼p(x),ϵ∼p(ϵ) [L • q θ,ϵ • x -L • µ f θ • q θ,ϵ • x], with µ f θ the mean of the Gaussian output by f (by Theorem 2), assuming variance one as before for simplicity. We can now insert the same identity operators and follow an identical derivation to Eq. 30-33, which recovers the exact same equivalence classes as before. The only identity that can be added differently from Eq. 30 will be an operator S for which µ f θ is invariant (µ f θ = µ f θ • S), so indeterminacy up to such symmetries of f is the only other possibility.

C PROOF OF MEAN AND VARIANCE PROPAGATION FOR LINEAR ANMS

Theorem 8. Consider a linear ANM defined as z i = a T i z 1,...,i-1 + b i + n i , with a i ∈ R i-1 , b i ∈ R and n i ∼ N (n i |0, σ 2 i ) for i = 1, . . . , d. Missing edges in the causal graph can be represented as zeros in a i . Then the ANM's joint probability is given by p (z) = N (z | µ, Σ) with µ = 2 i=d A i b, Σ = 2 i=d A i diag i=1,...,d (σ 2 i ) 2 i=d A i T , A i =   I (i-1)×(i-1) O i-1 O (i-1)×(d-i) a T i 1 O T d-i O (d-i)×(i-1) O d-i I (d-i)×(d-i)   , b =    b 1 . . . b d    , where I k×k denotes an identity matrix, while O k and O k×l denote a zero column vector and a zero matrix, respectively. Proof. We can write the ANM z i = a T i z 1,...,i-1 + b i + n i in a recursive multivariable form as z 1,...,i = I (i-1)×(i-1) a T i z 1,...,i-1 + O i-1 b i + O i-1 1 n i Applying the formula for linear combination of Gaussians (Bishop, 2006, Ch. 8.1.4) x ∼ N (x|µ x , Σ x ), y ∼ N (y|µ y , Σ y ) =⇒ Ax+By+c ∼ N (x|Aµ x +Bµ y +c, AΣ x A T +BΣ y B T ) (41) to eq. 40 we can express the mean and covariance of z 1,...,i as a function of the mean and covariance for z 1,...,i-1 as µ 1,...,i = I (i-1)×(i-1) a T i µ 1,...,i-1 + O i-1 b i (42) Σ 1,...,i = I (i-1)×(i-1) a T i Σ 1,...,i-1 I (i-1)×(i-1) a T i T + O i-1 1 σ 2 i O i-1 1 T where we have used µ 1,...,i = E [z 1,...,i ] , Σ 1,...,i = E (z 1,...,i -µ 1,...,i )(z 1,...,i -µ 1,...,i ) T , and for the noise variables n it holds that E [n] = 0 and E nn T = diag i (σ 2 i ). Now assume that eq. 39 holds for some d = k. Inserting expressions for µ 1,...,k and Σ 1,...,k from eq. 39 into eq. 42 and extending the matrices with zeros and ones we obtain  µ 1,...,k+1 = I k×k O k a T k+1 1 2 i=k A i,d=k O k O T k 1 b 1,...,k 0 + O k b k+1 (44) Σ 1,...,k+1 = I k×k O k a T k+1 1 2 i=k A i,d=k O k O T k 1 diag i=1,...,k (σ 2 i ) O k O T k 0 2 i=k A i,d=k O k O T k 1 T I k×k O k a T k+1 1 T + O k×k O k O T k 1 σ 2 k+1 Identifying A k+1,d=k+1 = I k×k O k a T k+1 1 , A i,d=k+1 = A i,d=k O k O T k 1 , Σ 1,2 = 1 a 2 Σ 1 1 a 2 T + 0 1 σ 2 2 0 1 T = A 2 diag(σ 2 1 , σ 2 2 )A T 2 where we have used µ 1 = b 1 and Σ 1 = σ 2 1 , which shows that eq. 39 holds for d = 2. Therefore, by the induction principle the relation 39 holds for all d ≥ 2.

D PROOF OF LINEARISATION OF NON-LINEAR SEMS

Theorem 9. The best linear approximation (in the least-squares sense) of an ANM z i = f i (z 1,...,i-1 ) + n i (eq. 5, with missing edges in the causal graph corresponding to ignored inputs in f i ), around a pivot point z • , is given by eq. 39 (Theorem 8) with a i = ∂f i (z 1,...,i-1 ) ∂z 1,...,i-1 z1,...,i-1=z • 1,...,i-1 and b i = f i (z • 1,...,i-1 ) + n i -a T i z • 1,...,i-1 , where z o 1,...,i-1 = f (z 1,...i-1 ) + n i and n i ∼ N (0, σ 2 i ) with learnable σ i . Proof. By Taylor's theorem, expanding f i (z 1,...,i-1 ) + n i around z • 1,...,i-1 up to first order gives f i (z 1,...,i-1 ) ≈ f i (z • 1,...,i-1 ) + n i + (z 1,...,i-1 -z • 1,...,i-1 ) T ∂f i (z 1,...,i-1 ) ∂z 1,...,i-1 z1,...,i-1=z • 1,...,i-1 (50) ≈ a T i z 1,...,i-1 + b i (51) where a i and b i are given by eq. 49. We can now use this linearisation of f i to define a linearised SEM as z i = a T i z 1,...,i-1 + b i and using this with Theorem 8 the result follows. and these are concatenated to form a 12-variate Gaussian, which is compared with its closest match in the SEM-sampled local Gaussian distribution. The distribution is then sampled and split into 4 samples per time frame and these are concatenated with the background latents and decoded back into 3 frames which are then compared using L 2 loss with the 3 frames on the input. Additionally, the latents from the first two frames are passed through the SEM to predict their value at the next time frame and this is decoded and compared using L 2 loss with the third frame on the input. The architecture is shown in Figure 9 . G QUANTITATIVE EVALUATION OF LEARNED TEMPORAL GRAPHS Reconstruction accuracy vs. graph complexity Figure 10 shows for the temporal experiments the reconstruction accuracy achieved with the model as a function of the parameter controlling graph sparsity (-1/ ln(s) in Equation 7) after training for 250 epochs. The plot shows that for weak edge penalisation the reconstruction accuracy is good (around 99.7%) while if the edge penalisation is too large the accuracy drops (to around 99.1%). This is because for weak edge penalisation the graph is relatively dense which allows the SEM to model the time-based causal relationships between the variables and when the penalisation is too big the graph becomes too sparse to be able to model these relationships. Somewhere around the value of the coefficient 10 -2 the graph becomes as sparse as possible while still keeping the reconstruction accuracy high, and this is the area from which we select the graph. Structural Hamming Distance vs. graph complexity Figure 11 shows for the temporal experiments the Structural Hamming Distance between the learned and the ground truth graph as a function of the parameter controlling graph sparsity (-1/ ln(s) in Equation 7) after training for 250 epochs. The SHD is computed by counting how many edges need to be inserted or removed to obtain the ground truth graph (up to a permutation of variables within each time step). In the range where the coefficient is below approx. 10 -3 the edge penalisation is too weak resulting in a graph with too many edges (thus a high SHD) and in the region above approx. 10 -2 the graph becomes too sparse with no edges (thus also resulting in high SHD). The region between 10 -3 and 10 -2 corresponds to the region where the learned graph has exactly the same structure as the ground truth graph (and thus SHD is zero).



Assuming identity covariance for simplicity, a common assumption in implementations, and which does not materially change the result. This is the same reason why the function h and its inverse are used twice in Eq. instead of inserting two different functions, e.g. h1 and h2 and their inverses; the same generalized rotation h must be applied to both inputs of the Euclidean distance in Eq. 32.



Figure 1: Samples from the observed and prior distributions for different VAEs, showing that the distributions match only for the β = 4 and the causal-prior VAEs. Only the causal-prior VAE recovers the underlying causal structure that relates the latent variables, a parabola.

Figure 2: Latent response to the object's position for different VAEs (higher values are brighter). The β = 4 and β = 1 VAEs have entangled responses (although it is less entangled for β = 1), while the causal-prior VAE correctly disentangles the object's horizontal and vertical positions.

Figure 4: Learned graph edges and mechanisms relating the latent variables, recovering the data generation process up to a scaling factor.

Figure 7: Extrapolation results obtained by iteratively passing the encoded variables through the learned SEM to predict their future values.

Figure 8: Overview of the causal-prior VAE architecture.

b 1,...,k+1 = b 1,...,k b k+1 inductive step d = k → k + 1. Finally, we apply eq. 41 to relation 40 for d =

learns independent latent variables which are then composed to form causal relationships, however they only consider linear relationships between variables. Other works use different approaches to VAEs for causal learning such as the CausalGAN(Kocaoglu et al., 2018) which use generative adversarial networks. Yet another line of work focuses on modelling object dynamics from video such asLi et al. (2020), however they use specialised modules for detecting keypoints and future prediction. Another line of work uses graph neural networks to infer an interaction graph such asKipf et al.

acknowledgement

We can now introduce our main result. Theorem 7. For a causal VAE model (Eq. 3 and Eq. 5 with y ≡ z), assume the following conditions:

E TRAINING DETAILS FOR THE CAUSAL-PRIOR VAE

We used a custom dataset of 76800 binary images of size 64 × 64 containing ovals generated at 6 different scales s, 40 rotations r, 32 horizontal positions x and 10 vertical positions y, where y = x 2 + n y and s and r factors are sampled independently. We trained our causal-prior VAE and two isotropic-prior VAEs for comparison, one with a lower β and one with higher β (where β denotes the factor multiplying the KL divergence as defined in (Higgins et al., 2016a) ). The encoder is a 3-layer MLP with 64 hidden ReLU units in each layer; the decoder is identical but with 4 layers, and there are 5 latent variables z. We use the Adagrad optimizer (Duchi et al., 2011) with learning rate 0.003 on batches of 100 samples, until convergence. The architecture is shown in Figure 8 .

F TRAINING DETAILS FOR THE TEMPORAL CAUSAL-PRIOR VAE

We created a custom dataset of 3-frame 20 × 20 px video sequences of the main Super Mario Bros character moving linearly in a random direction and with a random speed on different backgrounds, where the character's positions are given by x 1 , y 1 , x 2 , y 2 ∼ U(7, 12), x 3 = 2x 2 -x 1 , y 3 = 2y 2 -y 1 where x i and y i are the horizontal and vertical position at frame i and U is the uniform distribution. We train our causal-prior VAE with 12 variables, 4 per time frame, and allow the SEM to learn arbitrary linear relationships between them. The architecture is shown in Figure 9 . Each dataset sample consists of a tuple consisting of a background and 3 consecutive frames where the character moves linearly. Each of the 3 frames is then encoded separately into a 4-variate Gaussian distribution

