SUFFICIENT AND DISENTANGLED REPRESENTATION LEARNING

Abstract

We propose a novel representation learning approach called sufficient and disentangled representation learning (SDRL). With SDRL, we seek a data representation that maps the input data to a lower-dimensional space with two properties: sufficiency and disentanglement. First, the representation is sufficient in the sense that the original input data is conditionally independent of the response or label given the representation. Second, the representation is maximally disentangled with mutually independent components and is rotation invariant in distribution. We show that such a representation always exists under mild conditions on the input data distribution based on optimal transport theory. We formulate an objective function characterizing conditional independence and disentanglement. This objective function is then used to train a sufficient and disentangled representation with deep neural networks. We provide strong statistical guarantees for the learned representation by establishing an upper bound on the excess error of the objective function and show that it reaches the nonparametric minimax rate under mild conditions. We also validate the proposed method via numerical experiments and real data analysis.

1. INTRODUCTION

Representation learning is a fundamental problem in machine learning and artificial intelligence (Bengio et al., 2013) . Certain deep neural networks are capable of learning effective data representation automatically and achieve impressive prediction results. For example, convolutional neural networks, which can encode the basic characteristics of visual observations directly into the network architecture, is able to learn effective representations of image data (LeCun et al., 1989) . Such representations in turn can be subsequently used for constructing classifiers with outstanding performance. Convolutional neural networks learn data representation with a simple structure that captures the essential information through the convolution operator. However, in other application domains, optimizing the standard cross-entropy and least squares loss functions do not guarantee that the learned representations enjoy any desired properties (Alain & Bengio, 2016) . Therefore, it is imperative to develop general principles and approaches for constructing effective representations for supervised learning. There is a growing literature on representation learning in the context deep neural network modeling. Several authors studied the internal mechanism of supervised deep learning from the perspective of information theory (Tishby & Zaslavsky, 2015; Shwartz-Ziv & Tishby, 2017; Saxe et al., 2019) , where they showed that training a deep neural network that optimizes the information bottleneck (Tishby et al., 2000) is a trade-off between the representation and prediction at each layer. To make the information bottleneck idea more practical, deep variational approximation of information bottleneck (VIB) is considered in Alemi et al. (2016) . Information theoretic objectives describing conditional independence such as mutual information are utilized as loss functions to train a representation-learning function, i.e., an encoder in the unsupervised setting (Hjelm et al., 2018; Oord et al., 2018; Tschannen et al., 2019; Locatello et al., 2019; Srinivas et al., 2020) . There are several interesting extensions of variational autoencoder (VAE) (Kingma & Welling, 2013) et al., 2018) . The idea of using a latent variable model has also been used in adversarial auto-encoders (AAE) (Makhzani et al., 2016) and Wasserstein auto-encoders (WAE) (Tolstikhin et al., 2018) . However, these existing works focus on the unsupervised representation learning. A challenge of supervised representation learning that distinguishes it from standard supervised learning is the difficulty in formulating a clear and simple objective function. In classification, the objective is clear, which is to minimize the number of misclassifications; in regression, a least squares criterion for model fitting error is usually used. In representation learning, the objective is different from the ultimate objective, which is typically learning a classifier or a regression function for prediction. How to establish a simple criterion for supervised presentation learning has remained an open question (Bengio et al., 2013) . We propose a sufficient and disentangled representation learning (SDRL) approach in the context of supervised learning. With SDRL, we seek a data representation with two characteristics: sufficiency and disentanglement. In the context of representation learning, sufficient means that a good representation should preserve all the information in the data about the supervised learning task. This is a basic requirement and a long-standing principle in statistics. This is closely related to the fundamental concept of sufficient statistics in parametric statistical models (Fisher, 1922) . A sufficient representation can be naturally characterized by the conditional independence principle, which stipulates that, given the representation, the original input data does not contain any additional information about the response variable. In addition to the basic sufficiency property, the representation should have a simple statistical structure. Disentangling is based on the general notion that some latent causes underlie data generation process: although the observed data are typically high-dimensional, complex and noisy, the underlying factors are low-dimensional, independent and have a relatively simple statistical structure. There is a range of definitions of disentangling (Higgins et al., 2018; Eastwood & Williams, 2018; Ridgeway & Mozer, 2018; Do & Tran, 2020) . Several metrics have been proposed for the evaluation of disentangling. However, none of these definitions and metrics have been turned into empirical criterions and algorithms for learning disentangled representations. We adopt a simple definition of disentangling which defines a representation to be disentangled if its components are independent (Achille & Soatto, 2018) . This definition requires the representation to be maximally disentangled in the sense that the total correlation is zero, where the total correlation is defined as the KL divergence between the joint distribution of g(x) and the product of the marginal distributions of its components (Watanabe, 1960) . In the rest of the paper, we first discuss the motivation and the theoretical framework for learning a sufficient and disentangled representation map (SDRM). This framework leads to the formulation of an objective function based on the conditional independence principle and a metric for disentanglement and invariance adopted in this work. We estimate the target SDRM based on the sample version of the objective function using deep neural networks and develop an efficient algorithm for training the SDRM. We establish an upper error bound on the measure of conditional independence and disentanglement and show that it reaches the nonparametric minimax rate under mild regularity conditions. This result provides strong statistical guarantees for the proposed method. We validate the proposed SDRL via numerical experiments and real data examples.

2. SUFFICIENT AND DISENTANGLED REPRESENTATION

Consider a pair of random vectors (x, y) ∈ R p ×R q , where x is a vector of input variables and y is a vector of response variables or labels. Our goal is to find a sufficient and disentangled representation of x. Sufficiency We say that a measurable map g : R p → R d with d ≤ p is a sufficient representation of x if y x|g(x), that is, y and x are conditionally independent given g(x). This condition holds if and only if the conditional distribution of y given x and that of y given g(x) are equal. Therefore, the information in x about y is completely encoded by g(x). Such a g always exists, since if we simply take g(x) = x, then (1) holds trivially. This formulation is a nonparametric generalization of the basic condition in sufficient dimension reduction (Li, 1991; Cook, 1998) , where it is assumed g(x) = B T x with B ∈ R p×d belonging to the Stiefel manifold, i.e., B T B = I d .



in the form of VAE plus a regularizer, including beta-VAE (Higgins et al., 2017), Annealed-VAE (Burgess et al., 2018), factor-VAE (Kim & Mnih, 2018), beta-TC-VAE (Chen et al., 2018), DIP-VAE (Kumar

