DCI-ES: AN EXTENDED DISENTANGLEMENT FRAME-WORK WITH CONNECTIONS TO IDENTIFIABILITY

Abstract

In representation learning, a common approach is to seek representations which disentangle the underlying factors of variation. Eastwood & Williams (2018) proposed three metrics for quantifying the quality of such disentangled representations: disentanglement (D), completeness (C) and informativeness (I). In this work, we first connect this DCI framework to two common notions of linear and nonlinear identifiability, thereby establishing a formal link between disentanglement and the closely-related field of independent component analysis. We then propose an extended DCI-ES framework with two new measures of representation quality-explicitness (E) and size (S)-and point out how D and C can be computed for black-box predictors. Our main idea is that the functional capacity required to use a representation is an important but thus-far neglected aspect of representation quality, which we quantify using explicitness or ease-of-use (E). We illustrate the relevance of our extensions on the MPI3D and Cars3D datasets.

1. INTRODUCTION

A primary goal of representation learning is to learn representations r(x) of complex data x that "make it easier to extract useful information when building classifiers or other predictors" (Bengio et al., 2013) . Disentangled representations, which aim to recover and separate (or, more formally, identify) the underlying factors of variation z that generate the data as x = g(z), are a promising step in this direction. In particular, it has been argued that such representations are not only interpretable (Kulkarni et al., 2015; Chen et al., 2016) but also make it easier to extract useful information for downstream tasks by recombining previously-learnt factors in novel ways (Lake et al., 2017) . While there is no single, widely-accepted definition, many evaluation protocols have been proposed to capture different notions of disentanglement based on the relationship between the learnt representation or code c = r(x) and the ground-truth data-generative factors z (Higgins et al., 2017; Eastwood & Williams, 2018; Ridgeway & Mozer, 2018; Kim & Mnih, 2018; Chen et al., 2018; Suter et al., 2019; Shu et al., 2020) . In particular, the metrics of Eastwood & Williams (2018)-disentanglement (D), completeness (C) and informativeness (I)-estimate this relationship by learning a probe f to predict z from c and can be used to relate many other notions of disentanglement (see Locatello et al. 2020, § 6) . In this work, we extend this DCI framework in several ways. Our main idea is that the functional capacity required to recover z from c is an important but thus-far neglected aspect of representation quality. For example, consider the case of recovering z from: (i) a noisy version thereof; (ii) raw, highdimensional data (e.g. images); and (iii) a linearly-mixed version thereof, with each c i containing the same amount of information about each z j (precise definition in § 6.1). The noisy version (i) will do quite well with just linear capacity, but is fundamentally limited by the noise corruption; the raw data (ii) will likely do quite poorly with linear capacity, but eventually outperform (i) given sufficient capacity; and the linearly-mixed version (iii) will perfectly recover z with just linear capacity, yet achieve the worst-possible disentanglement score of D = 0. Motivated by this observation, we introduce a measure of explicitness or ease-of-use based a representation's loss-capacity curve (see Fig. 1 ). Structure and contributions. First, we connect the DCI metrics to two common notions of linear and nonlinear identifiability ( § 3). Next, we propose an extended DCI-ES framework ( § 4) in which we: (i) introduce two new complementary measures of representation quality-explicitness (E), derived from a representation's loss-capacity curve, and size (S); and then (ii) elucidate a means to compute the D and C scores for arbitrary black-box probes (e.g., MLPs). Finally, in our experiments ( § 6), we use our extended framework to compare different representations on the MPI3D-Real (Gondal et al., 2019) and Cars3D (Reed et al., 2015) datasets, illustrating the practical usefulness of our E score through its strong correlation with downstream performance.

2. BACKGROUND

Given a synthetic dataset of observations x = g(z) along with the corresponding K-dimensional data-generating factors z ∈ R K , the DCI framework quantitatively evaluates an L-dimensional data representation or code c = r(x) ∈ R L using two steps: (i) train a probe f to predict z from c, i.e., ẑ = f (c) = f (r(x)) = f (r(g(z))); and then (ii) quantify f 's prediction error and its deviation from the ideal one-to-one mapping, namely a permutation matrix (with extra "dead" units in c whenever L > K).foot_0 For step (i), Eastwood & Williams (2018) use Lasso (Tibshirani, 1996) or Random Forests (RFs, Breiman 2001) as linear or nonlinear predictors, respectively, for which it is straightforward to read-off suitable "relative feature importances". Definition 2.1. R ∈ R L×K is a matrix of relative importances for predicting z from c via ẑ = f (c) if R ij captures some notion of the contribution of c i to predicting z j s.t. ∀i, j: R ij ≥ 0 and ∑ L i=1 R ij = 1. For step (ii), Eastwood & Williams use R and the prediction error to define and quantify three desiderata of disentangled representations: disentanglement (D), completeness (C), and informativeness (I). Disentanglement. Disentanglement (D) measures the average number of data-generating factors z j that are captured by any single code c i . The score D i is given by D i = 1 -H K (P i. ), where H K (P i. ) = -∑ K k=1 P ik log K P ik denotes the entropy of the distribution P i. over row i of R, with P ij = R ij / ∑ K k=1 R ik . If c i is only important for predicting a single z j , we get a perfect score of D i = 1. If c i is equally important for predicting all z j (for j = 1, . . . , K), we get the worst score of D i = 0. The overall score D is then given by the weighted average D = ∑ L i=1 ρ i D i , with ρ i = 1 K ∑ K k=1 R ik . Completeness. Completeness (C) measures the average number of code variables c i required to capture any single z j ; it has also been called compactness (Ridgeway & Mozer, 2018) . The score C j in capturing z j is given by C j = (1 -H L ( P.j )), where H L ( P.j ) = -∑ L ℓ=1 Pℓj log L Pℓj denotes the



W.l.o.g., it can be assumed that z i and c j are normalised to have mean zero and variance one for all i, j, for otherwise such normalisation can be "absorbed" into g(•) and r(•).



Figure1: Loss-capacity curves. Empirical loss-capacity curves (see § 4.1) for various representations (see legend), datasets (top: MPI3D-Real, bottom: Cars3D), and probe types (left: multi-layer perceptrons / MLPs, middle: Random Fourier Features / RFFs, right: Random Forests / RFs). The loss was first averaged over factors z j , and then means and 95% confidence intervals were computed over 3 random seeds. Details in § 6.

