SUFFICIENT AND DISENTANGLED REPRESENTATION LEARNING

Abstract

We propose a novel representation learning approach called sufficient and disentangled representation learning (SDRL). With SDRL, we seek a data representation that maps the input data to a lower-dimensional space with two properties: sufficiency and disentanglement. First, the representation is sufficient in the sense that the original input data is conditionally independent of the response or label given the representation. Second, the representation is maximally disentangled with mutually independent components and is rotation invariant in distribution. We show that such a representation always exists under mild conditions on the input data distribution based on optimal transport theory. We formulate an objective function characterizing conditional independence and disentanglement. This objective function is then used to train a sufficient and disentangled representation with deep neural networks. We provide strong statistical guarantees for the learned representation by establishing an upper bound on the excess error of the objective function and show that it reaches the nonparametric minimax rate under mild conditions. We also validate the proposed method via numerical experiments and real data analysis.

1. INTRODUCTION

Representation learning is a fundamental problem in machine learning and artificial intelligence (Bengio et al., 2013) . Certain deep neural networks are capable of learning effective data representation automatically and achieve impressive prediction results. For example, convolutional neural networks, which can encode the basic characteristics of visual observations directly into the network architecture, is able to learn effective representations of image data (LeCun et al., 1989) . Such representations in turn can be subsequently used for constructing classifiers with outstanding performance. Convolutional neural networks learn data representation with a simple structure that captures the essential information through the convolution operator. However, in other application domains, optimizing the standard cross-entropy and least squares loss functions do not guarantee that the learned representations enjoy any desired properties (Alain & Bengio, 2016) . Therefore, it is imperative to develop general principles and approaches for constructing effective representations for supervised learning. There is a growing literature on representation learning in the context deep neural network modeling. Several authors studied the internal mechanism of supervised deep learning from the perspective of information theory (Tishby & Zaslavsky, 2015; Shwartz-Ziv & Tishby, 2017; Saxe et al., 2019) , where they showed that training a deep neural network that optimizes the information bottleneck (Tishby et al., 2000) is a trade-off between the representation and prediction at each layer. To make the information bottleneck idea more practical, deep variational approximation of information bottleneck (VIB) is considered in Alemi et al. (2016) . Information theoretic objectives describing conditional independence such as mutual information are utilized as loss functions to train a representation-learning function, i.e., an encoder in the unsupervised setting (Hjelm et al., 2018; Oord et al., 2018; Tschannen et al., 2019; Locatello et al., 2019; Srinivas et al., 2020) . There are several interesting extensions of variational autoencoder (VAE) (Kingma & Welling, 2013) in the form of VAE plus a regularizer, including beta-VAE (Higgins et al., 2017) , Annealed-VAE (Burgess et al., 2018) , factor-VAE (Kim & Mnih, 2018) , beta-TC-VAE (Chen et al., 2018) , DIP-VAE (Kumar et al., 2018) . The idea of using a latent variable model has also been used in adversarial auto-encoders (AAE) (Makhzani et al., 2016) and Wasserstein auto-encoders (WAE) (Tolstikhin et al., 2018) . However, these existing works focus on the unsupervised representation learning. A challenge of supervised representation learning that distinguishes it from standard supervised learning is the difficulty in formulating a clear and simple objective function. In classification, the objective is clear, which is to minimize the number of misclassifications; in regression, a least squares criterion for model fitting error is usually used. In representation learning, the objective is different from the ultimate objective, which is typically learning a classifier or a regression function for prediction. How to establish a simple criterion for supervised presentation learning has remained an open question (Bengio et al., 2013) . We propose a sufficient and disentangled representation learning (SDRL) approach in the context of supervised learning. With SDRL, we seek a data representation with two characteristics: sufficiency and disentanglement. In the context of representation learning, sufficient means that a good representation should preserve all the information in the data about the supervised learning task. This is a basic requirement and a long-standing principle in statistics. This is closely related to the fundamental concept of sufficient statistics in parametric statistical models (Fisher, 1922) . A sufficient representation can be naturally characterized by the conditional independence principle, which stipulates that, given the representation, the original input data does not contain any additional information about the response variable. In addition to the basic sufficiency property, the representation should have a simple statistical structure. Disentangling is based on the general notion that some latent causes underlie data generation process: although the observed data are typically high-dimensional, complex and noisy, the underlying factors are low-dimensional, independent and have a relatively simple statistical structure. There is a range of definitions of disentangling (Higgins et al., 2018; Eastwood & Williams, 2018; Ridgeway & Mozer, 2018; Do & Tran, 2020) . Several metrics have been proposed for the evaluation of disentangling. However, none of these definitions and metrics have been turned into empirical criterions and algorithms for learning disentangled representations. We adopt a simple definition of disentangling which defines a representation to be disentangled if its components are independent (Achille & Soatto, 2018) . This definition requires the representation to be maximally disentangled in the sense that the total correlation is zero, where the total correlation is defined as the KL divergence between the joint distribution of g(x) and the product of the marginal distributions of its components (Watanabe, 1960) . In the rest of the paper, we first discuss the motivation and the theoretical framework for learning a sufficient and disentangled representation map (SDRM). This framework leads to the formulation of an objective function based on the conditional independence principle and a metric for disentanglement and invariance adopted in this work. We estimate the target SDRM based on the sample version of the objective function using deep neural networks and develop an efficient algorithm for training the SDRM. We establish an upper error bound on the measure of conditional independence and disentanglement and show that it reaches the nonparametric minimax rate under mild regularity conditions. This result provides strong statistical guarantees for the proposed method. We validate the proposed SDRL via numerical experiments and real data examples.

2. SUFFICIENT AND DISENTANGLED REPRESENTATION

Consider a pair of random vectors (x, y) ∈ R p ×R q , where x is a vector of input variables and y is a vector of response variables or labels. Our goal is to find a sufficient and disentangled representation of x. Sufficiency We say that a measurable map g : R p → R d with d ≤ p is a sufficient representation of x if y x|g(x), that is, y and x are conditionally independent given g(x). This condition holds if and only if the conditional distribution of y given x and that of y given g(x) are equal. Therefore, the information in x about y is completely encoded by g(x). Such a g always exists, since if we simply take g(x) = x, then (1) holds trivially. This formulation is a nonparametric generalization of the basic condition in sufficient dimension reduction (Li, 1991; Cook, 1998) , where it is assumed g(x) = B T x with B ∈ R p×d belonging to the Stiefel manifold, i.e., B T B = I d . Denote the class of sufficient representations satisfying (1) by F = {g : R p → R d , g measurable and satisfies y x|g(x)}. We refer to F as a Fisher class because of its close connection with the concept of sufficient statistics (Fisher, 1922; Cook, 2007) . For an injective measurable transformation T : R d → R d and g ∈ F, T • g(x) is also sufficient by the basic property of conditional probability. Therefore, the Fisher class F is invariant in the sense that T • F = F, provided T is injective, where T • F = {T • g : g ∈ F}. An important class of transformations is the class of affine transformations, T • g = Ag + b, where A is a d × d nonsingular matrix and b ∈ R d . Disentanglement We focus on the disentangled representations among those that are sufficient. Therefore, we start from the functions of the input data that are sufficient representations in the Fisher class F. For any sufficient and disentangled representation g(x), let Σ g = Var(g(x)). Since the components of g(x) are disentangled in the sense that they are independent, Σ g is a diagonal matrix, thus Σ -1/2 g g(x) also has independent components. Therefore, we can always rescale g(x) such that it has identity covariance matrix. To further simplify the statistical structure of a representation g, we also require it to be rotation invariant in distribution, that is, Qg(x) = g(x) in distribution for any orthogonal matrix Q ∈ R d×d . The Fisher class F is rotation invariant in terms of conditional independence, but not all its members are rotation invariant in distribution. By the Maxwell characterization of the Gaussian distributions (Maxwell, 1860; Bartlett, 1934; Bryc, 1995; Gyenis, 2017) , a random vector of dimension two or more with independent components is rotation invariant in distribution if and only if it is Gaussian with zero mean and a spherical covariance matrix. Therefore, after absorbing the scaling factor, for a sufficient representation map to be disentangled and rotation invariant, it is necessarily distributed as N d (0, I d ). Let M be the Maxwell class of functions g : R d → R d , where g(x) is disentangled and rotation invariant in distribution. By the Maxwell characterization, we can write M = {g : R p → R d , g(x) ∼ N (0, I d )}. (2) Now our problem becomes that of finding a representation in F ∩ M, the intersection of the Fisher class and the Maxwell class. The first question to ask is whether such a representation exists. The following result from optimal transport theory provides an affirmative answer and guarantees that F ∩ M is nonempty under mild conditions (Brenier, 1991; McCann, 1995; Villani, 2008) . Lemma 2.1. Let µ be a probability measure on R d . Suppose it has finite second moment and is absolutely continuous with respect to the standard Gaussian measure, denoted by γ d . Then it admits a unique optimal transportation map T : R d → R d such that T # µ = γ d ≡ N (0, I d ), where T # µ denotes the pushforward distribution of µ under T . Moreover, T is injective µ-almost everywhere. Denote the law of a random vector z by µ z . Lemma 2.1 implies that, for any g ∈ F with E g(x) 2 < ∞ and µ g(x) absolutely continuous with respect to γ d , there exists a map T * transforming the distribution of g(x) to N (0, I d ). Therefore, R * := T * • g ∈ F ∩ M, that is, x y|R * (x) and R * (x) ∼ N (0, I d ), i.e., R * is a sufficient and disentangled representation map (SDRM).

3. OBJECTIVE FUNCTION FOR SDRL

The above discussions lay the theoretical foundation for formulating an objective function that can be used for constructing a SDRM R * satisfying (3), or equivalently, R * ∈ F ∩ M. Let V be a measure of dependence between random variables x and y with the following properties: (a) V[x, y] ≥ 0 with V[x, y] = 0 if and only if x y; (b) V[x, y] ≥ V[R(x), y] for all measurable function R; (c) V[x, y] = V[R * (x), y] if and only if R * ∈ F. The properties (a)-(c) imply that R * ∈ F ⇔ R * ∈ arg max R V[R(x), y] = arg min R {-V[R(x), y]}. We use a divergence measure D to quantify the difference between µ R(x) and γ d , as long as this measure satisfies the condition D(µ R(x) γ d ) ≥ 0 for all measurable function R and D(µ R(x) γ d ) = 0 if and only if R ∈ M. Then the problem of finding an R * ∈ F ∩ M can be expressed as a constrained minimization problem: arg min R -V[R(x), y] subject to D(µ R(x) γ d ) = 0. Its Lagrangian form is L(R) = -V[R(x), y] + λD(µ R(x) γ d ), ) where λ ≥ 0 is a tuning parameter. This parameter provides a balance between the sufficiency property and the disentanglement constraint. A small λ leads to a representation with more emphasis on sufficiency, while a large λ yields a representation with more emphasis on disentanglement. We show in Lemma 4.1 below that any R * satisfying (3) is a minimizer of L(R). Therefore, we can train a SDRM by minimizing the empirical version of L(R). There are several options for V with the properties (a)-(c) described above. For example, we can take V to be the mutual information V[R(x), y] = I(R(x); y). However, in addition to the estimation of the SDRM R, this choice requires the estimation of the density ratio between p(y, R(x)) and p(y)p(R(x)), which is not an easy task. We can also use the conditional covariance operators on reproducing kernel Hilbert spaces (Fukumizu et al., 2009) . To be specific, in this work we use the distance covariance (Székely et al., 2007) of y and R(x), which has an elegant U -statistic expression, does not involve additional unknown quantities and is easy to compute. For the divergnce measure of two distributions, we use the f -divergence (Ali & Silvey, 1966) , which includes the KLdivergence as a special case.

4. LEARNING SUFFICIENT AND DISENTANGLED REPRESENTATION

We first describe some essentials about distance covariance and f -divergence.

Distance covariance

We first recall the concept of distance covariance (Székely et al., 2007) , which characterizes the dependence of two random variables. Let i be the imaginary unit (-1) 1/2 . For any t ∈ R d and s ∈ R m , let ψ z (t) = E[exp it T z ], ψ y (s) = E[exp is T y ], and ψ z,y (t, s) = E[exp i(t T z+s T y) ] be the characteristic functions of random vectors z ∈ R d , y ∈ R q , and the pair (z, y), respectively. The squared distance covariance V[z, y] is defined as V[z, y] = R d+m |ψ z,y (t, s) -ψ z (t)ψ y (s)| 2 c d c m t d+1 s q+1 dtds, where c d = π (d+1)/2 Γ((d + 1)/2) . Given n i.i.d copies {z i , y i } n i=1 of (z, y), an unbiased estimator of V is the empirical distance covariance V n , which can be elegantly expressed as a U -statistic (Huo & Székely, 2016 ) V n [z, y] = 1 C 4 n 1≤i1<i2<i3<i4≤n h ((z i1 , y i1 ) , • • • , (z i4 , y i4 )) , ( ) where h is the kernel defined by h ((z 1 , y 1 ) , . . . , (z 4 , y 4 )) = 1 4 1≤i,j≤4 i =j z i -z j y i -y j -1 4 4 i=1 1≤j≤4 j =i z i -z j 1≤j≤4 i =j y i -y j + 1 24 1≤i,j≤4 i =j z i -z j 1≤i,j≤4 i =j y i -y j . f-divergence Let µ and γ be two probability measures on R d . The f -divergence (Ali & Silvey, 1966) between µ and γ with µ γ is defined as D f (µ γ) = R d f ( dµ dγ )dγ, where f : R + → R is a differentiable convex function satisfying f (1) = 0. Let f * be the Fenchel conjugate of f (Rockafellar, 1970), defined as f * (t) = sup x∈R {tx -f (x)}, t ∈ R. The f -divergence admits the following variational formulation (Keziou, 2003; Nguyen et al., 2010; Nowozin et al., 2016)  . Lemma 4.1. D f (µ γ) = max D:R d →dom(f * ) E z∼µ [D(z)] -E w∼γ [f * (D(w))], where the maximum is attained at D(z) = f ( dµ dγ (z)). Commonly used divergence measures include the Kullback-Leibler (KL) divergence, the Jensen-Shanon (JS) divergence and the χ 2 -divergence. Learning SDRM We are now ready to formulate an empirical objective function for learning SDR-M R * . Let R ∈ M, where M is the Maxwell class defined in (2). By the variational formulation (6), we can write the population version of the objective function (4) as L(R) = -V[R(x), y] + λ max D:R d →dom(f * ) {E x∼µx [D(R(x))] -E w∼γ d [f * (D(w))]}. ( ) This expression is convenient since we can simply replace the expectations by the corresponding empirical averages. Theorem 4.2. We have R * ∈ arg min R∈M L(R) provided (3) holds. According to Theorem 4.2, it is natural to estimate R * based on the empirical version of the objective function ( 7) when a random sample {(x i , y i )} n i=1 is available. We estimate R * using deep neural networks. We employ two networks as follows: • Representer network R θ : This network is used for training R * . Let R be the set of such neural networks R θ : R p → R d . • Discriminator network D φ : This network is used as the witness function for checking whether the distribution of the estimator of R * is approximately the same as N (0, I d ). Similarly, denote D as the set of such neural networks D φ : R d → R. Let {w i } n i=1 be n i.i.d random vectors drawn from γ d . The estimated SDRM is defined by R θ ∈ arg min R θ ∈R L(R θ ) = -V n [R θ (x), y] + λ D f (µ R θ (x) γ d ), where V n [R θ (x), y] is an unbiased and consistent estimator of V[R θ (x), y] as defined in (5) based on {(R θ (x i ), y i ), i = 1, . . . , n} D f (µ R θ (x) γ d ) = max D φ ∈D 1 n n i=1 [D φ (R θ (x i )) -f * (D φ (w i ))]. Statistical guarantee Since a SDRM R * is only identifiable up to orthogonal transforms under the constraint that R * (x) ∼ N (0, I d ), no consistency results for R θ itself can be obtained. But this is not a flaw of the proposed method. Indeed, the most important statistical guarantee of the learned R * is that the objective of conditional independence and disentanglement is achieved. Therefore, we establish an upper bound on the excess risk L( R θ ) -L(R * ) of the deep nonparametric estimator R θ in (8). We make the following assumptions. (A1) For any ε > 0, there is a constant B 1 > 0 such that µ x ([-B 1 , B 1 ] p ) > 1 -ε, and R * is Lipschitz continuous on [-B 1 , B 1 ] p with Lipschitz constant L 1 . (A2) For R ∈ M, we assume r(z) = dµ R(x) dγ d (z) is Lipschitz continuous on [-B 1 , B 1 ] p with Lipschitz constant L 2 , and 0 < c 1 ≤ r(z) ≤ c 2 . Denote B 2 = max{|f (c 1 )|, |f (c 2 )|}, B 3 = max |s|≤2L2 √ d log n+B2 |f * (s)|. The specifications of the network parameters, including depth, width, size and the supremum norm over the domains of the representer R θ and the discriminator D φ are given in Appendix B. Theorem 4.3. Suppose λ > 0 and λ = O(1). Suppose conditions (A1)-(A2) hold and set the network parameters according to (i)-(ii). Then E {xi,yi,wi} n i=1 [L( R θ ) -L(R * )] ≤ C((L 1 + L 2 ) dpn -2 2+p + L 2 √ d(log n)n -2 2+d ), where C is a constant that depends on B 1 , B 2 and B 3 but not on n, q, p and d. The proof of this theorem is given in Appendix B. The result established in Theorem 4.3 provides strong statistical guarantees for the proposed method. The rate n -2/(2+p) matches the minimax nonparametric estimation rate for Lipschitz class contained in R p of functions (Stone, 1982; Tsybakov, 2008) . Up to a log n factor, the rate (log n)n -2/(2+d) matches the minimax rate of nonparametric estimation of Lipschitz densities via GANs (Singh et al., 2018; Liang, 2018) .

5. COMPUTATION

We can update θ and φ alternately as in training GANs (Goodfellow et al., 2014) . However, this approach suffers from the instability issues. In our implementation, we utilize the more stable particle method based on gradient flow (Gao et al., 2019; 2020) . The key idea is to find a sequence of nonlinear but simpler residual maps, say T(z) = z + sv(z), pushing the samples from µ R θ (x) to the target distribution γ d along a velocity fields v(z) = -∇f (r(z)) that most decreases the f -divergence D f (•||γ d ) at µ R θ (x) . The residual maps can be estimated via deep density-ratio estimators, which take the form T(z) = z +s v(z), z ∈ R d , where s is a step size and v(z) = -f (r(z))∇r(z). Here r(z) is an estimated density ratio of the density of R θ (x) at the current value of θ over the density of the reference distribution. We use T to transform z i = R θ (x i ), i = 1, . . . , n into Gaussian samples. Once this is done, we update θ via minimizing the loss -V n [R θ (x), y]+λ n i=1 R θ (x i )-z i 2 /n. We describe the algorithm below. • Input {x i , y i } n i=1 . Tuning parameters: s, λ, d. Sample {w i } n i=1 ∼ γ d . • Outer loop for θ -Inner loop (particle method) * Let z i = R θ (x i ), i = 1, 2, ..., n. * Solve D φ ∈ arg min Q φ n i=1 1 n log(1 + exp D φ (zi) ) + log(1 + exp -D φ (wi) ) . * Define the residual map T(z) = z -sf (r(z))∇r(z) with r(z) = exp -D φ (z) . * Update the particles z i = T(z i ), i = 1, 2, ..., n. -End inner loop -Update θ via minimizing -V n [R θ (x), y] + λ n i=1 R θ (x i ) -z i 2 /n using SGD. • End outer loop

6. EXPERIMENTS

We evaluate the proposed SDRL with the KL-divergence using both simulated and real data. The goal of our experiments is to demonstrate that the representations trained based the proposed method perform well. Our proposed method is not trying to learn a classifier or a regression function directly, but rather to learn representation that preserve all the information. So our experiments are designed to evaluate the performance of simple classification and regression methods using the representations we learned as input. The results demonstrate that a simple classification or regression model using the representations we trained performs better than or comparably with the best classification or regression method using deep neural networks. Details on the network structures and hyperparameters are included in Appendix A. Our experiments were conducted on Nvidia DGX Station workstation using a single Tesla V100 GPU unit. The PyTorch code of SDRL is available at https://github.com/anonymous/SDRL.

6.1. SIMULATED DATA

In this subsection, we evaluate SDRL on simulated regression and classification problems.

Regression

We generate 5, 000 data points from two models. Model A: y = x 1 [0.5 + (x 2 + 1.5) 2 ] -1 + (1 + x 2 ) 2 + σε, where x ∼ N (0, I 4 ); Model B: y = sin 2 (πx 1 + 1) + σε, where x ∼ Uniform([0, 1] 4 ). In both models, ε ∼ N (0, I 4 ). We use a 3-layer network with ReLU activation for R θ and a single hidden layer ReLU network for D φ . We compare SDRL with two prominent sufficient dimension reduction methods: sliced inverse regression (SIR) (Li, 1991) and sliced average variance estimation (SAVE) (Cook & Weisberg, 1991) . We fit a linear model with the learned features and the response variable, and report the prediction errors in Table 1 . We see that SDRL outperforms SIR and SAVE in terms of prediction error. Classification We visualize the learned features of SDRL on two simulated datasets. We first generate (1) 2-dimensional concentric circles from two classes as in Figure 1 as in Figure 1 (i). In each dataset, we generate 5,000 data points for each class. We next map the data into 100-dimensional space using matrices with entries i.i.d Unifrom([0, 1]). Finally, we apply SDRL to these 100-dimensional datasets to learn 2-dimensional features. We use a 10-layer dense convolutional network (DenseNet) (Huang et al., 2017) as R θ and a 4-layer network as D φ . We display the evolutions of the learned 2-dimensional features by SDRL in Figure 1 . For ease of visualization, we push all the distributions onto the uniform distribution on the unit circle, which is done by normalizing the standard Gaussian random vectors to length one. Clearly, the learned features for different classes in the examples are well disentangled. This dataset has 515,345 observations with 90 predictors. The problem is to predict the year of song release. We randomly split the data into five parts for cross validated evaluation of the prediction performance. We employ a 3-layer network for both D φ and R θ . A linear regression model is fitted using the learned representations and the response. The mean prediction errors and their standard errors based on SDRL, principal component analysis (PCA), sparse principal component analysis (SPCA) and ordinary least squares (OLS) regression with original data are reported in Table 2 . SDRL outperforms PCA, SPCA and OLS in terms prediction accuracy. Classification We benchmark the classification performance of SDRL using MNIST (LeCun et al., 2010) , FashionMNIST (Xiao et al., 2017) , and CIFAR-10 ( Krizhevsky et al., 2009) against alterna- The classification accuracies are reported in Tables 3 and 4 . We can see that the classification accuracies of SDRL are comparable with those of CN and dCorAE. As shown in Table 4 , the classification accuracies of CN leveraging SDRL outperforms those of CN. We also calculate the estimated distance correlation (DC) between the learned features and their labels as ρ 2 z,y = V[z, y] 2 / (V[z] 2 × V[y] 2 ), where V[z] and V[y] are the distance variances such that V[z] = V[z, z], V[y] = V[y, y]. For more details, please see Székely et al. (2007) . Figure 2 shows the DC values MNIST, FashionMNIST and CIFAR-10 data. SDRL and SDRL+CN achieves higher DC values. 

7. CONCLUSION AND FUTURE WORK

In this work, we formulate a framework for sufficient and disentangled representation learning and construct an objective function characterizing conditional independence and disentanglement. This enables us to learn a representation with the desired properties empirically. We provide statistical guarantees for the learned representation by deriving an upper bound on the excess risk of the objective function. There are several questions that deserve further study. First, we can adopt different measures of conditional independence including mutual information and conditional covariance operators on reproducing kernel Hilbert spaces (Fukumizu et al., 2009) . We can also use other divergence measures such as the Wasserstein distance in the objective function. Finally, Lemma 2.1 suggests that the intersection of the Fisher class F and the Maxwell class M can still be large, and there can be many statistically equivalent representations in F ∩ M. We can make further reduction of F ∩ M by imposing additional constraints, for example, certain minimal properties, sparsity, and robustness against noise perturbation.

A APPENDIX: EXPERIMENTAL DETAILS A.1 SIMULATION STUDIES

The values of the hyper-parameters for the simulated experiments are given in Table A1 , where λ is the penalty parameter, d is the dimension of the SDRM, n is the mini-batch size in SGD, T 1 is the number of inner loops to push forward particles z i , T 2 is the number of outer loops for training R θ , and s is the step size to update particles. For the regression models, the neural network architectures are shown in Table A2 As shown in Table A3 , a multilayer perceptron (MLP) is utilized for the neural structure D φ in the classification problem. The detailed architecture of 10-layer dense convolutional network (DenseNet) (Huang et al., 2017; Amos & Kolter) deployed for R θ is shown in Table A4 . For all the settings, we adopted the Adam (Kingma & Ba, 2014) optimizer with an initial learning rate of 0.001 and weight decay of 0.0001. 

Regression:

In the regression problems, hyper-parameters are presented in Table A5 . The Adam optimizer with an initial learning rate of 0.001 and weight decay of 0.0001 is adopted. The MLP architectures of D φ and R θ for the YearPredictionMSD data are shown in Table A6 . Classification: For the classification problems, hyper-parameters are shown in Table A7 . We again use Adam as the SGD optimizers for both D φ and R θ . Specifically, learning rate of 0.001 and weight decay of 0.0001 are used for D φ in all datasets and for R θ on MNIST (LeCun et al., 2010) . We customized the SGD optimizers with momentum at 0.9, weight decay at 0.0001, and learning rate ρ in Table A8 for FashionMNIST (Xiao et al., 2017) and CIFAR-10 (Krizhevsky et al., 2012) . For the transfer learning of CIFAR-10, we use customized SGD optimizer with initial learning rate of 0.001 and momentum of 0.9 for R θ . MLP architectures of the discriminator network D φ for MNIST, FashionMNIST and CIFAR-10 are given in Table A3 . The 20-layer DenseNet networks shown in Table A9 were utlized for R θ on the MNIST dataset, while the 100-layer DenseNet networks shown in Table A10 and A11 are fitted for R θ on FashionMNIST and CIFAR-10. Proof. By assumption µ and γ d are both absolutely continuous with respect to the Lebesgue measure. The desired result holds since it is a spacial case of the well known results on the existence of optimal transport (Brenier, 1991; McCann, 1995) , see, Theorem 1.28 on page 24 of (Philippis, 2013) for details. Plugging the above display with t = dµ Z dγ (x) into the definition of f -divergence, we derive (6). B.3 PROOF OF THEOREM 4.2 Proof. Without loss of generality, we assume d = 1. For R * satisfying (3) and any R ∈ R, we have R = ρ (R,R * ) R * + ε R , where ρ (R,R * ) is the correlation coefficient between R and R * , ε R = R -ρ (R,R * ) R * . It is easy to see that ε R R * and thus Y ε R . As (ρ (R,R * ) R * , Y ) is independent of (ε R , 0), then by Theorem 3 of Székely & Rizzo (2009) V[R, y] =V[ρ (R,R * ) R * + ε R , y] ≤ V[ρ (R,R * ) R * , y] + V(ε R , 0) =V[ρ (R,R * ) R * , y] = |ρ (R,R * ) |V[R * , y] ≤V[R * , y]. As R(x) ∼ N (0, 1) and R * (x) ∼ N (0, 1), then D f (µ R(x) γ d ) = D f (µ R * (x) γ d ) = 0, and L(R) -L(R * ) = V[R * , y] -V[R, y] ≥ 0. The proof is completed.  1 p + 4p, 12n p 2+p / log n + 14}, size S = dn p-2 p+2 / log 4 (npd), B = (2B 3 L 1 √ p + log n) √ d, (ii) Discriminator network M D, W, S, B parameters: depth D = 9 log n + 12, width W = max{8d(n d 2+d / log n) 1 d + 4d, 12n d 2+d / log n + 14}, size S = n d-2 d+2 /(log 4 npd), B = 2L 2 √ d log n + B 2 . Before getting into the details of the proof of Theorem 4.3, we first give an outline of the basic structure of the proof. Without loss of generality, we assume that λ = 1 and m = 1, i.e. y ∈ R. First we consider the scenario that y is bounded almost surely, say |y| ≤ C 1 . We also assume B 1 < ∞. We can utilize the truncation technique to transfer the unbounded cases to the bounded ones under some common tail assumptions. Consequently, an additional log n multiplicative term will appear in the final results. For any R ∈ N D,W,S,B , we have, L( R θ ) -L(R * ) = L( R θ ) -L( R θ ) + L( R θ ) -L( R) + L( R) -L( R) + L( R) -L(R * ) ≤ 2 sup R∈N D,W,S,B |L(R) -L(R)| + inf R∈N D,W,S,B |L( R) -L(R * )|, where we use the definition of R θ in (8) and the feasibility of R. Next we bound the two error terms in (10), i.e., the approximation error inf  ( R) -L(R * )| ≤ 2600C 1 B 1 L 1 pdn -2 p+2 . Proof. By (3) and ( 6) and the definition of L, we have inf R∈N D,W,S,B |L( R) -L(R * )| ≤ |D f (µ Rθ(x) γ d )| + |V[R * (x), y] -V[ Rθ(x), y]|, where Rθ ∈ N D,W,S,B is specified in Lemma B.2 below. We finish the proof by ( 14) in Lemma B.3 and (15) in Lemma B.4, which will be proved below.  √ p + log n) √ d, such that Rθ -R * L 2 (µx) ≤ 160L 1 B 1 pdn -2 p+2 . ( p 2+p / log n) 1 p + 4d, 12n p 2+p / log n + 14}, Rθ i L ∞ = 2B 1 L 1 √ p + log n, such that Rθ i L ∞ ≤ 2B 1 L 1 √ p + log n, and R * i -Rθ i L ∞ ([-B1,B1] p \H) ≤ 80L 1 B 1 √ pn -2 p+2 , µ x (H) ≤ 80L 1 B 1 √ pn -2 p+2 2B 1 L 1 √ p + log n . Define Rθ = [ Rθ 1 , . . . , Rθ d ] ∈ N D,W,S,B . The above three display implies Rθ -R * L 2 (µx) ≤ 160L 1 B 1 pdn -2 p+2 . Lemma B.3. |V[R * (x), y] -V[ Rθ(x), y]| ≤ 2580C 1 B 1 L 1 pdn -2 p+2 . Proof. Recall that Székely et al. ( 2007) V[z, y] =E [ z 1 -z 2 |y 1 -y 2 |] -2E [ z 1 -z 2 |y 1 -y 3 |] + E [ z 1 -z 2 ] E [|y 1 -y 2 |] , where (z i , y i ), i = 1, 2, 3 are i.i.d. copies of (z, y). We have |V[R * (x), y] -V[ Rθ(x), y]| ≤ |E ( R * (x 1 ) -R * (x 2 ) -Rθ(x1) -Rθ(x2) )|y 1 -y 2 | | + 2|E ( R * (x 1 ) -R * (x 2 ) -Rθ(x1) -Rθ(x2) )|y 1 -y 3 | | + |E R * (x 1 ) -R * (x 2 ) -Rθ(x1) -Rθ(x2) E [ y 1 -y 2 ] | ≤ 8C 1 E | R * (x 1 ) -R * (x 2 ) -Rθ(x1) -Rθ(x2) | ≤ 16C 1 E | R * (x) -Rθ(x) ≤ 16C 1 (E R * (x) -Rθ(x) + E R * (x)1 R * (x)∈Ball c (0,log n) ), where in the first and third inequalities we use the triangle inequality, and second one follows from the boundedness of y. By (13), the first term in the last line is bounded by 2560C 1 B 1 L 1 √ pdn -1 p+2 . Some direct calculation shows that E R * (x)1 R * (x)∈Ball c (0,log n) ≤ C 2 (log n) d n . We finish the proof by comparing the order of the above two terms, i.e., C 2 dγ d (z). By definition we have (log n) d n ≤ 20C 1 B 1 L 1 √ pdn -2 p+2 . Lemma B.4. |D f (µ Rθ(x) γ d )| ≤ 20C 1 B 1 L 1 pdn -2 p+2 . ( D f (µ R * (x) γ d ) = E W ∼γ d [f (r * (W ))] = E W ∼γ d [f (r * (W ))1 W ∈Ball(0,log n) ] + E W ∼γ d [f (r * (W ))1 W ∈Ball c (0,log n) ]. (We can represent D f (µ Rθ γ d ) similarly. ) Then |D f (µ Rθ(x) γ d )| = |D f (µ Rθ(x) γ d ) -D f (µ R * (x) γ d )| ≤ E W ∼γ d [|f (r * (W )) -f (r(W ))|1 W ∈Ball(0,log n) ] + E W ∼γ d [|f (r * (W )) -f (r * (W ))|1 W ∈Ball c (0,log n) ] ≤ z ≤log n |f (r(z))||r * (z) -r(z)|dγ d (z) + z >log n |f (r(z))||r * (z) -r(z)|dγ d (z) ≤ C 3 z ≤log n |r * (z) -r(z)|dγ d (z) + C 4 z >log n |r * (z) -r(z)| The first term in the above display is small due to Rθ can approximate R * well. The second term is small due to the boundedness of r and the exponential decay of the Gaussian tails.

B.4.2 THE STATISTICAL ERROR

Lemma B.5. sup R∈N D,W,S,B |L(R) -L(R)| ≤ C 15 (B 1 (L 1 + L2) pd)n -2 2+p + (L 2 √ d + B 2 + B 3 ) log nn -2 2+d ) Proof. By the definition and the triangle inequality we have E[ sup R∈N D,W,S,B |L(R) -L(R)|] ≤ E[ sup R∈N D,W,S,B | V n [R(x), y] -V[(R(x), y)|] + E[ sup R∈N D,W,S,B | D f (µ R(x) ||γ d ) -D f (µ R(x) ||γ d )|]. We finish the proof based on (17) in Lemma B.6 and ( 22) in Lemma B.7, which will be proved below. Lemma B.6.

E[ sup

R∈N D,W,S,B | V n [R(x), y] -V[R(x), y]|] ≤ 4C 6 C 7 C 10 B 1 L 1 pdn -2 p+2 . Proof. We first fix some notation for simplicity. Denote O = (x, y) ∈ R p × R 1 and O i = (x i , y i ), i = 1, ...n are i.i.d copy of O, and denote µ x,y and P n as P and P n , respectively. ∀R ∈ N D,W,S,B , let Õ = (R(x), y) and Õi = (R(x i ), y i ), i = 1, ...n are i.i.d copy of Õ. Define centered kernel hR : (R p × R 1 ) 4 → R as hR ( Õ1 , Õ2 , Õ3 , Õ4 ) = 1 4 1≤i,j≤4, i =j R(x i ) -R(x j ) |y i -y j | -1 4 4 i=1 1≤j≤4, j =i R(x i ) -R(x j ) 1≤j≤4, i =j |y i -y j | + 1 24 1≤i,j≤4, i =j R(x i ) -R(x j ) 1≤i,j≤4, i =j |y i -y j | -V[R(x), y] . Then, the centered ) , y] can be represented as U -statistics V n [R(x), y] -V[R(x U n ( hR ) = 1 C 4 n 1≤i1<i2<i3<i4≤n hR ( Õi1 , Õi2 , Õi3 , Õi4 ). Our goal is to bound the supremum of the centered U -process U n ( hR ) with the nondegenerate kernel hR . By the symmetrization randomization Theorem 3.5.3 in De la Pena & Giné (2012), we have E[ sup R∈N D,W,S,B |U n ( hR )|] ≤ C 5 E[ sup R∈N D,W,S,B | 1 C 4 n 1≤i1<i2<i3<i4≤n i1 hR ( Õi1 , Õi2 , Õi3 , Õi4 )|], where, i1 , i 1 = 1, ...n are i.i.d Rademacher variables that are also independent with Õi , i = 1, . . . , n. We finish the proof by upper bounding the above Rademacher process with the matric entropy of N D,W,S,B . To this end we need the following lemma. Lemma B.7. If ξ i , i = 1, ...m are m finite linear combinations of Rademacher variables j , j = 1, ..J. Then E j ,j=1,...J max 1≤i≤m |ξ i | ≤ C 6 (log m) 1/2 max 1≤i≤m Eξ 2 i 1/2 . ( ) Proof. This result follows directly from Corollary 3.2.6 and inequality (4.3.1) in De la Pena & Giné (2012) with Φ(x) = exp(x 2 ). By the boundedness assumption on y and the boundedness of R ∈ N D,W,S,B , we have that the kernel hR is also bounded, say hR L ∞ ≤ C 7 (2B 1 L 1 √ p + log n) √ d. (21) ∀R, R ∈ N D,W,S,B define a random empirical measure (depends on O i , i = 1, . . . , n) e n,1 (R, R) = E i 1 ,i1=1,...,n | 1 C 4 n 1≤i1<i2<i3<i4≤n i1 ( hR -h R)( Õi1 , . . . , Õi4 )|. Condition on O i , i = 1, . . . , n, let C(N , e n,1 , δ)) be the covering number of N D,W,S,B with respect to the empirical distance e n,1 at scale of δ > 0. Denote N δ as the covering set of N D,W,S,B with cardinality of C(N , e n,1 , δ)). Then, E i 1 [ sup R∈N D,W,S,B | 1 C 4 n 1≤i1<i2<i3<i4≤n i1 hR ( Õi1 , Õi2 , Õi3 , Õi4 )|] ≤ δ + E i 1 [ sup R∈N δ | 1 C 4 n 1≤i1<i2<i3<i4≤n i1 hR ( Õi1 , Õi2 , Õi3 , Õi4 )|] ≤ δ + C 6 1 C 4 n (log C(N , e n,1 , δ)) 1/2 max R∈N δ [ n i1=1 i2<i3<i4 ( hR ( Õi1 , Õi2 , Õi3 , Õi4 )) 2 ] 1/2 ≤ δ + C 6 C 7 (2B 1 L 1 √ p + log n) √ d(log C(N , e n,1 , δ)) 1/2 1 C 4 n [ n(n!) 2 ((n -3)!) 2 ] 1/2 ≤ δ + 2C 6 C 7 (2B 1 L 1 √ p + log n) √ d(log C(N , e n,1 , δ)) 1/2 / √ n ≤ δ + 2C 6 C 7 (2B 1 L 1 √ p + log n) √ d(VC N log 2eBn δVC N ) 1/2 / √ n ≤ δ + C 6 C 7 C 10 (B 1 L 1 √ p + log n) √ d(DS log S log Bn δDS log S ) 1/2 / √ n. where the first inequality follows from the triangle inequality, the second inequality uses (20), the third and fourth inequalities follow after some algebra, and the fifth inequality where the first equality follows from the standard symmetrization technique, the second equality holds due to the iteration law of conditional expectation, the first inequality follows from the triangle inequality, and the second inequality uses equation 20, the third inequality uses the fact that b(D, R; S) is bounded, i.e., b(D, R; S) L ∞ ≤ 2L 2 √ d log n + B 2 + B 3 , and the fourth inequality follows from some algebra, and the fifth inequality follows from C(N , e n,1 , δ) ≤ C(N , e n,∞ , δ) (similar result for M) and log C(N , e n,∞ , δ)) ≤ VC N log 2eBn δVC N , and N D,W,S,B satisfying C 8 DS log S ≤ VC N ≤ C 9 DS log S, see Bartlett et al. (2019) . Then (22) follows from the above display with the selection of the network parameters of M D, W, S, B, N D,W,S,B and with δ = 1 n . Finally, Theorem 4.3 is a direct consequence of (11) in Lemma B.1 and ( 16) in Lemma B.5. This completes the proof of Theorem 4.3.



(a); (2) 2-dimensional moons data from two classes as in Figure 1 (e); (3) 3-dimensional Gaussian data from six classes

Figure 1: Evolving learned features. The first, second and third rows show concentric circles, moons and 3D Gaussian datasets, respectively.

(a) MNIST, d = 16 (b) Fashion MNIST, d = 16 (c) CIFAR-10 (from scratch), d = 16 (d) CIFAR-10 (transfer learning), d = 16

Figure 2: The distance correlations of labels with learned features based on SDRL, CN, SDRL+CN and dCorAE for FashionMNIST and CIFAR-10 data.

PROOF OF THEOREM 4.3 Denote B 2 = max{|f (c 1 )|, |f (c 2 )|}, B 3 = max |s|≤2L2 √ d log n+B2 |f * (s)|. We set the network parameters of the representer R θ and the discriminator D φ as follows. (i) Representer network R D,W,S,B parameters: depth D = 9 log n + 12, width W = d max{8d(n p 2+p / log n)

R∈N D,W,S,B |L( R) -L(R * )| and the statistical error sup R∈N D,W,S,B |L(R) -L(R)| separately. Then Theorem 4.3 follows after bounding these two error terms.

Define R * (x) = min{R * (x), log n}. There exist a Rθ ∈ N D,W,S,B with depth D = 9 log n + 12, width W = d max{8dlog 4 npd), B = (2B 1 L 1

Proof. By Lemma B.2 Rθ can approximate R * arbitrary well, the desired result follows from the fact that D f (µ R * (x) γ d ) = 0 and the continuity of D f (µ R(x) γ d ) on R. We present the sketch of the proof and omit the details here. Let r * (z) = dµ R * (x) dγ d (z) and r(z) = dµ R θ (x)

holds due to C(N , e n,1 , δ) ≤ C(N , e n,∞ , δ) and the relationship between the metric entropy and the VCdimension of the ReLU networks N D,W,S,B(Anthony & Bartlett, 2009), i.e., log C(N , e n,∞ , δ)) ≤ VC N log 2eBn δVC N , and the last inequality holds due to the upper bound of VC-dimension for the ReLU network N D,W,S,B satisfying C 8 DS log S ≤ VC N ≤ C 9 DS log S, see Bartlett et al. (2019). Then (17) holds by the selection of the network parameters and set δ= D f (µ R(x) ||γ d ) -D f (µ R(x) ||γ d )|] ≤ C14(L2 √ d + B2 + B3)(n -2 2+p + log nn -2 2+d ) (22) Proof. ∀R ∈ N D,W,S,B , let r(z) = dµ R(x) dγ d (z), g R (z) = f (r(z)). By assumption g R (z) : R d → R is Lipschitz continuous with the Lipschitz constant L 2 and g R L ∞ ≤ B 2 .Without loss of generality, we assume supp(g R ) ⊆ [-log n, log n] d . Then, similar to the proof of Lemma B.2 we can show that there exists a D φ ∈ M D, W, S, B with depth D = 9 log n + 12, n + 14}, and size S = nd-2 d+2 /(log 4 npd), B = 2L 2 √ d log n + B 2 such that for z ∼ γ d and z ∼ µ R(x) E z [| D φ(z) -g R (z)|] ≤ 160L 2 √ d log nn -2 d+2 . (23) ∀g : R d → R, define E(g) = E x∼µx [g(R(x))] -E W ∼γ d [f * (g(W ))], E(g) = E(g, R) = 1 n n i=1 [g(R(x i )) -f * (g(W i ))].By (6) we haveE(g R ) = D f (µ R(x) ||γ d ) = sup measureable D:R d →R E(D).(24)Then,|D f (µ R(x) ||γ d ) -D f (µ R(x) ||γ d )| = |E(g R ) -max D φ ∈M D, W, S, B E(D φ )| ≤ |E(g R ) -sup D φ ∈M D, W, S, B E(D φ )| + | sup D φ ∈M D, W, S, B E(D φ ) -max D φ ∈M D, W, S, B E(D φ )| ≤ |E(g R ) -E( D φ)| + sup D φ ∈M D, W, S, B |E(D φ ) -E(D φ )| ≤ E z∼µ R(x) [|g R -D φ|(z)] + E W ∼γ d [|f * (g R ) -f * ( D φ)|(W )] + sup D φ ∈M D, W, S, B |E(D φ ) -E(D φ )| ≤ 160(1 + B 3 )L 2 √ d log nn -2 d+2 + sup D φ ∈M D, W, S, B |E(D φ ) -E(D φ )|where we use the triangle inequality in the first inequality, and we use E(g R ) ≥ sup D φ ∈M D, W, S, B E(D φ ) followed from (24) and the triangle inequality in the second inequality, the third inequality follows from the triangle inequality, and the last inequality follows from (23) and the mean value theorem. We finish the proof via bounding the empirical processU(D, R) = E[ sup R∈N D,W,S,B ,D∈M D, W, S, B |E(D) -E(D)|]. Let S = (x, z) ∼ µ x γ d and S i , i = 1, . . . , n be n i.i.d copy of S. Denote b(D, R; S) = D(R(x)) -f * (D(z)). Then E(D, R) = E S [b(D, R; D, R; S i )|be the Rademacher complexity of M D, W, S, B × N D,W,S,B(Bartlett & Mendelson, 2002). Let C(M × N , e n,1 , δ)) be the covering number of M D, W, S, B × N D,W,S,B with respect to the empirical distance (depends on S i )d n,1 ((D, R), ( D, R)) = 1 n E i [ n i=1 | i (b(D, R; S i ) -b( D, R; S i ))|] at scale of δ > 0. Let M δ × N δ be such a converging set of M D, W, S, B × N D,W,S,B . Then, U(D, R) = 2G(M × N ) = 2E S1,...,Sn [E i,i=1,...,n [G(N × M)|(S 1 , ..., S n )]] R; S i )||(S 1 , . . . , S n )] ≤ 2δ + C 12 1 n E S1,...,Sn [(log C(M × N , e n,1 , δ)) D, R; S i )] 1/2 ] ≤ 2δ + C 12 1 n E S1,...,Sn [(log C(M × N , e n,1 , δ)) 1/2 √ n(2L 2 √ d log n + B 2 + B 3 )] ≤ 2δ + C 12 1 √ n (2L 2 √ d log n + B 2 + B 3 )(log C(M, e n,1 , δ) + log C(N , d n,1 , δ)) 1/2 ≤ 2δ + C 13 L 2 √ d log n + B 2 + B 3 √ n (DS log S log Bn δDS log S + D S log S log Bn δ D S log S ) 1/2

Averaged prediction errors and their standard errors (based on 5-fold validation).

Prediction error ± standard error: YearPredictionMSD dataset With CN, we use the feature extractor by dropping the cross entropy layer of the DenseNet trained for classification. The MNIST and FashionMNIST datasets consist of 60k and 10k grayscale images with 28 × 28 pixels for training and testing, respectively, while the CIFAR-10 dataset contains 50k and 10k colored images with 32 × 32 pixels for training and testing, respectively. For the learning from scratch strategy, the representer network R θ has 20 layers for MNIST data and 100 layers for CIFAR-10 data. We apply the transfer learning technique to the combination ofSDRL and CN on CIFAR-10 (Krizhevsky et al., 2009). The pretrained WideResnet-101 model(Zagoruyko & Komodakis, 2016) on the Imagenet dataset with Spinal FC(Kabir et al., 2020) is adopt for R θ . The discriminator network D φ is a 4-layer network. The the architecture of R θ and most hyperparameters are shared across all four methods -SDRL, CN, SDRL+CN and dCorAE. Finally, we use the k-nearest neighbor (k = 5) classifier on the learned features for all methods.

Classification accuracy for MNIST and FashionMNIST

Hyper-parameters for simulated examples, where s varies according to epoch

MLP architectures for D φ and R θ in regression

DenseNet architecture for R θ in the simulated classification examples

Hyper-parameters for YearPredictionMSD data

MLP architectures for D φ and R θ for YearPredictionMSD data

Hyper-parameters for the classification benchmark datasets

Learning rate ρ varies during training.

Architecture for MNIST, reduced feature size is d Average Pool,1 × 1 Conv 24 × 14 × 14

Architecture for FashionMNIST, reduced feature size is d

Architecture for CIFAR-10, reduced feature size is d Proof. Our proof followsKeziou (2003). Since f (t) is convex, then ∀t ∈ R, we have f (t) = f * * (t), where f * * (t) = sup

Proof. Let R * i (x) be the i-th entry of R * (x) : R d → R d . By the assumption on R * , it is easy to check that R * i (x) is Lipschitz continuous on [-B 1 , B 1 ] d with the Lipschitz constant L 1 and R * i L ∞ ≤ log n. By Theorem 4.3 in Shen et al. (2019), there exists a ReLU network Rθ i with with depth 9 log n + 12, width max{8d(n

