GROUP-DISENTANGLING CONDITIONAL SHIFT

Abstract

We propose a novel group disentanglement method called the Context-Aware Variational Autoencoder (CxVAE). Our model can learn disentangled representations on datasets with conditional shift. This phenomenon occurs when the distribution of the instance-level latent variable z conditional on the input observation x, p(z|x), changes from one group to another (i.e. p i (z|x) = p j (z|x), where i, j are two different groups). We show that existing methods fail to learn disentangled representations under this scenario because they infer the group u and instance z representations separately. CxVAE overcomes this limitation by conditioning the instance inference on the group variable q(z|x, u). Our model has the novel ability to disentangle ambiguous observations (those with incomplete information about the generative factors), which we evaluate on an image dataset. Additionally, we use a fair comparisons task to demonstrate empirically that conditional shift is the cause of our model's improved performance.

1. INTRODUCTION

Group disentanglement is the goal of learning representations that separate group-level variation from instance-level variation. Consider a dataset of observations organised into N groups of the form x n,1:Kn = {x n,1 , ..., x n,Kn }, n ∈ 1 : N . These could be pictures grouped by author, clinical outcomes grouped by the patient, or film ratings grouped by user. We train a representation network r(x n ) that encodes a group of observations {x n,1 , . . . x n,Kn } into one group code u n and a set of instance codes {z n,1 , . . . z n,Kn }, one for each observation. We want u to capture only the variation across groups and z only the variation within groups. The current state-of-the-art approaches for group disentanglement train the representation network r by using it as the variational latent posterior distribution in a Variational Autoencoder (Bouchacourt et al., 2018; Hosoya, 2019; Németh, 2020) . They assume a hierarchical generative model whereby the observation x n,k is generated by combining a group latent variable u n and an independent instance latent variable z n,k (Figure 2 -left). The standard setup involves training the variational latent posterior q(u n , z n,1:Kn |x n ) by maximising a lower bound to the data likelihood (Kingma & Welling, 2014; Rezende et al., 2014) . In our work, we show that the variational latent posterior, as defined in existing models, is unsuited to datasets with conditional shift. This is a property of the data-generating process whereby the true conditional distribution of the instance latent variable z changes from one group to another p i (z|x) = p j (z|x) where i, j are two groups (Zhang et al., 2013; Gong et al., 2016) . In our case, the conditional instance distribution for group i, which is p i (z|x), corresponds to p(z i,k |x i,k , u i ) where k is a given instance in the group. Conditional shift occurs in many real-world datasets that we would like to group-disentangle. For example, in the 3DIdent dataset (Zimmermann et al., 2021) , if we want to infer the colour of the teapot z n,k based on an image of that teapot x n,k , we should take into account the colour of the spotlight that illuminates the scene u n ; different coloured spotlights will make the same object appear different colours, as can be seen in Figure 1 . Existing group-disentanglement methods, which infer the instance variable (teapot colour) independently of the group (spotlight colour) fail to disentangle the two colours. Existing VAE-based methods fail to disentangle in the conditional shift setting because they make the assumption that the group and instance variables can be inferred independently of each other from the input observation (Figure 2 -middle): When defining the variational latent posterior, existing works based on the Group VAE (GVAE) (Bouchacourt et al., 2018; Hosoya, 2019; Németh, 2020; Chen & Batmanghelich, 2020 ) make the assumption that the group and instance variables are conditionally independent given the observations (Figure 2 -middle): q(u n , z n,1:Kn |x n,1:Kn ) = q(u n |x n,1:Kn ) Kn k=1 q(z n,k |x n,k ). (1) The limitations of this assumption have not been identified so far in the literature because the datasets used to test disentanglement, such as Shapes3D (Kim & Mnih, 2018 ), SmallNORB (LeCun et al., 2004) , dSprites (Higgins et al., 2017 ), Cars3D (Reed et al., 2015) , MPI3D (Gondal et al., 2019) , have the property that one image is always sufficient to accurately infer its latent variables. For example, we only require a single image from the MPI3D dataset to uniquely identify the colour, position, and rotation of the depicted object. In this work, we show that conditioning the instance encoder on the group latent vector enables the model to learn disentangled representations on datasets with conditional shift. 1. In the first instance, we show that only our method is able to correctly disentangle between object-colour and spotlight-colour in the 3DIdent dataset (Zimmermann et al., 2021) , illustrated in Figure 1 . 2. Then, we use the task of fair comparisons between student test-scores (Figure 3 ) to show that the amount of conditional shift in the dataset determines the performance gap between our model and the other approaches.

2. RELATED WORK

Group Disentanglement. This class of problems comes under different names: style-content disentanglement (Tenenbaum & Freeman, 2000) , content-transformation disentanglement (Hosoya, 2019), and disentanglement with group supervision (Shu et al., 2020) , to name a few. Recent work (Shu et al., 2020; Locatello et al., 2020) has contextualised group disentanglement as a subproblem of weakly-supervised disentanglement, where disentangled representations are learned with the help of non-datapoint supervision (e.g. grouping, ranking, restricted labelling). Early work in this area focused on separating between visual concepts (Kulkarni et al., 2015; Reed et al., 2015) . This area has received renewed interest after the theoretical impossibility result of Locatello et al. ( 2019) and the identifiability proofs of Khemakhem et al. (2020) and Mita et al. (2021) . A key aspect of recent weakly-supervised models is the interpretation of the grouping as a signal of similarity between datapoints (Chen & Batmanghelich, 2020).



Figure 1: Conditional shift in the 3DIdent dataset (Zimmermann et al., 2021). The two images on the leftmost column appear to show objects of the same colour, even though each image was generated by a different combination of object-colour, spotlight-colour: One teapot is actually orange, while the other is bright green. We can see this by looking at other examples of the same objects under different lighting. The set of images to the right depict different views of the same object on each corresponding row.

