GROUP-DISENTANGLING CONDITIONAL SHIFT

Abstract

We propose a novel group disentanglement method called the Context-Aware Variational Autoencoder (CxVAE). Our model can learn disentangled representations on datasets with conditional shift. This phenomenon occurs when the distribution of the instance-level latent variable z conditional on the input observation x, p(z|x), changes from one group to another (i.e. p i (z|x) = p j (z|x), where i, j are two different groups). We show that existing methods fail to learn disentangled representations under this scenario because they infer the group u and instance z representations separately. CxVAE overcomes this limitation by conditioning the instance inference on the group variable q(z|x, u). Our model has the novel ability to disentangle ambiguous observations (those with incomplete information about the generative factors), which we evaluate on an image dataset. Additionally, we use a fair comparisons task to demonstrate empirically that conditional shift is the cause of our model's improved performance.

1. INTRODUCTION

Group disentanglement is the goal of learning representations that separate group-level variation from instance-level variation. Consider a dataset of observations organised into N groups of the form x n,1:Kn = {x n,1 , ..., x n,Kn }, n ∈ 1 : N . These could be pictures grouped by author, clinical outcomes grouped by the patient, or film ratings grouped by user. We train a representation network r(x n ) that encodes a group of observations {x n,1 , . . . x n,Kn } into one group code u n and a set of instance codes {z n,1 , . . . z n,Kn }, one for each observation. We want u to capture only the variation across groups and z only the variation within groups. The current state-of-the-art approaches for group disentanglement train the representation network r by using it as the variational latent posterior distribution in a Variational Autoencoder (Bouchacourt et al., 2018; Hosoya, 2019; Németh, 2020) . They assume a hierarchical generative model whereby the observation x n,k is generated by combining a group latent variable u n and an independent instance latent variable z n,k (Figure 2 -left). The standard setup involves training the variational latent posterior q(u n , z n,1:Kn |x n ) by maximising a lower bound to the data likelihood (Kingma & Welling, 2014; Rezende et al., 2014) . In our work, we show that the variational latent posterior, as defined in existing models, is unsuited to datasets with conditional shift. This is a property of the data-generating process whereby the true conditional distribution of the instance latent variable z changes from one group to another p i (z|x) = p j (z|x) where i, j are two groups (Zhang et al., 2013; Gong et al., 2016) . In our case, the conditional instance distribution for group i, which is p i (z|x), corresponds to p(z i,k |x i,k , u i ) where k is a given instance in the group. Conditional shift occurs in many real-world datasets that we would like to group-disentangle. For example, in the 3DIdent dataset (Zimmermann et al., 2021) , if we want to infer the colour of the teapot z n,k based on an image of that teapot x n,k , we should take into account the colour of the spotlight that illuminates the scene u n ; different coloured spotlights will make the same object appear different colours, as can be seen in Figure 1 . Existing group-disentanglement methods, which infer the instance variable (teapot colour) independently of the group (spotlight colour) fail to disentangle the two colours. Existing VAE-based methods fail to disentangle in the conditional shift setting because they make the assumption that the group and instance variables can be inferred independently of each other from the input observation (Figure 2 -middle): When defining the variational latent posterior, existing works

