PROTOVAE: USING PROTOTYPICAL NETWORKS FOR UNSUPERVISED DISENTANGLEMENT

Abstract

Generative modeling and self-supervised learning have in recent years made great strides towards learning from data in a completely unsupervised way. There is still however an open area of investigation into guiding a neural network to encode the data into representations that are interpretable or explainable. The problem of unsupervised disentanglement is of particular importance as it proposes to discover the different latent factors of variation or semantic concepts from the data alone, without labeled examples, and encode them into structurally disjoint latent representations. Without additional constraints or inductive biases placed in the network, a generative model may learn the data distribution and encode the factors, but not necessarily in a disentangled way. Here, we introduce a novel deep generative VAE-based model, ProtoVAE, that leverages a deep metric learning Prototypical network trained using self-supervision to impose these constraints. The prototypical network constrains the mapping of the representation space to data space to ensure that controlled changes in the representation space are mapped to changes in the factors of variations in the data space. Our model is completely unsupervised and requires no a priori knowledge of the dataset, including the number of factors. We evaluate our proposed model on the benchmark dSprites, 3DShapes, and MPI3D disentanglement datasets, showing state of the art results against previous methods via qualitative traversals in the latent space, as well as quantitative disentanglement metrics. We further qualitatively demonstrate the effectiveness of our model on the real-world CelebA dataset.

1. INTRODUCTION

One theory of the success of deep learning models for supervised learning revolves around their ability to learn mappings from the input space to a lower dimensional abstract representation space which are best predictive of the corresponding labels Tishby & Zaslavsky (2015) . However, for the models to be robust to noise and adversarial examples, be transferable to different domains and distributions and interpretable, we need to impose additional constraints on the learning paradigm. As a promising solution to this, the models can be encouraged to focus on all the latent "distinctive properties" of the data distribution and encode them into a representation for downstream supervised tasks. These latent distinctive properties or factors of variations are the interpretable abstract concepts that describe the data. The intuitive notion of disentanglement, first proposed in Bengio (2013), proposes to discover all the different factors of variations from the data, and encode each factor in a separate subspace or dimension of the learned latent representation. These disentangled representations are not only interpretable and give valuable insights into the data distribution but are also more robust for multiple downstream tasks Bengio (2013); Schoelkopf et al. (2012) which might depend only on a subset of factors Suter et al. (2019) . The problem of learning these disentangled representations in a completely unsupervised way is particularly challenging as we do not have access to the ground truth labels of factors nor are privy to the true number of factors or their nature. Recent works have proposed to solve this problem by training generative networks to effectively model the data distribution and in turn the factors of variations. From this generative perspective of disentanglement, higher dimensional data is assumed to be a non-linear mapping of these factors of variation, where each factor assumes different values to generate specific examples in the data distribution. Locatello et al. (2019) intuitively characterizes representations which encode the factors as disentangled if a change in a single underlying factor of variation in the data produces a change in a single factor of the learned representation (or a change in the subspace of the representation that encodes that factor). Conversely, from the generative perspective, for a representation to be disentangled, a change in a single subspace of the learned representation, when mapped to the data space, must produce a change in a single factor of variation. For this generative mapping between changes in the representation space to the changes in the factors of variations (in the data space) to be injective, we propose constrains on the changes in the factors of variations for pre-determined changes in the representation space. Each separate subspace of the representation, when changed, must map to a change in a unique factor of variation which in turn encourages information about the different factors to be encoded in separate subspaces of the representation. Moreover, each separate subspace must consistently map to a change in a single factor throughout the subspace range. This encourages the different subspaces of the representation to encode information only about a single factor of variation. The recent work of Horan et al. ( 2021) also demonstrated empirically that the concept of local isometry was a good inductive bias for unsupervised disentanglement, and it can aid generative models in discovering a "natural" decomposition of data into factors of variation. This local isometry constraint on the mapping enforces the changes in the data space to be proportional to any changes made in the representation space. In order to effectively impose the above constraints in an unsupervised manner, we turn towards deep metric learning. In recent years, metric learning has emerged as a powerful unsupervised learning paradigm for deep neural networks, in conjunction with self-supervised data augmentation. One of the more successful metric learning models, Prototypical Networks, projects the data into a new metric space where examples from the same class cluster around a prototype representation of the class and away from the prototypes of other classes. We use this ability of the network to cluster the different changes in the data space mapped by the corresponding changes in the representation space and thereby enforce the above described constraints. We develop a novel deep generative model, ProtoVAE, consisting of a Prototypical Network and Variational Autoencoder network (VAE). The VAE acts as the generative component, while the Prototypical Network guides the VAE in separating out the representation space by imposing the constraints for disentanglement. To learn these representations in an unsupervised way, as the prototypical network needs labeled data for clustering, we train the prototypical network using generated self-supervised datasets. To produce the self supervised dataset, we perform interventions in the representation space, which change individual elements of the latent space and map the intervened representations to the data space. Owing to the self-supervised training, our model is able to disentangle without any explicit prior knowledge of the data, including the number of desired factors. In this work, our core contributions are: • We design a self-supervised data generation mechanism using a VAE that creates new samples via a process of intervention to train a metric-learning prototypical network. 



• We design and implement a novel model, ProtoVAE, which combines a VAE and prototypical network to perform disentanglement without any prior knowledge of the underlying data.• We empirically evaluate ProtoVAE on standard benchmark DSprites, 3DShapes, MPI3D, and CelebA datasets, showing state of the art results. The VAE consists of an inference network which encodes the data into lower dimensional latent representations and a generator network that maps the representations back into the data space. To implicitly encourage the inference network to encode disentangled representations, we impose constraints on the generative mapping from changes in the representation space to changes in the factors of variations in the data space. This generative mapping is determined by both the generator and the inference networks. To generate self-supervised data for the prototypical network, we perform interventions (Sec 2.2) which changes individual dimensions of

