

Abstract

Autoencoders, or nonlinear factor models parameterized by neural networks, have become an indispensable tool for generative modeling and representation learning in high dimensions. Imposing structural constraints such as conditional independence on the latent variables (representation, or factors) in order to capture invariance or fairness with autoencoders has been attempted through adding ad hoc penalties to the loss function mostly in the variational autoencoder (VAE) context, often based on heuristic arguments. In this paper, we demonstrate that Wasserstein autoencoders (WAEs) are highly flexible in embracing structural constraints. Well-known extensions of VAEs for this purpose are gracefully handled within the framework of the seminal result by Tolstikhin et al. (2018) . In particular, given a conditional independence structure of the generative model (decoder), corresponding encoder structure and penalties are induced from the functional constraints that define the WAE. This property of WAEs opens up a principled way of penalizing autoencoders to impose structural constraints. Utilizing this generative model structure, we present results on fair representation and conditional generation tasks, and compare them with other preceding methods.

1. INTRODUCTION

The ability to learn informative representation of data with minimal supervision is a key challenge in machine learning (Tschannen et al., 2018) , toward obtaining which autoencoders have become an indispensable toolkit. An autoencoder consists of the encoder, which maps the input to a lowdimensional representation, and the decoder, that maps a representation back to a reconstruction of the input. Thus an autoencoder can be considered a nonlinear factor analysis model as the latent variable provided by the encoder carries the meaning of "representation" and the decoder can be used for generative modeling of the input data distribution. Most autoencoders can be formulated as minimizing some "distance" between the distribution P X of input random variable X and the distribution g ♯ P Z of the reconstruction G = g(Z), where Z is the latent variable or representation having distribution P Z and g is either deterministic or probabilistic decoder (in the latter case g is read as the conditional distribution of G given Z), which is variationally described in terms of an encoder Q Z|X . For instance, the variational autoencoder (VAE, Kingma & Welling, 2014) minimizes D VAE (P X , g ♯ P Z ) = inf Q Z|X ∈Q E P X [D KL (Q Z|X ∥P Z ) -E Q Z|X log g(Z)] (1) over the set of probabilistic decoders or conditional densities g of G given Z, where D KL is the Kullback-Leibler (KL) divergence, and the Wasserstein autoencoder (WAE, Tolstikhin et al., 2018) minimizes D WAE (P X , g ♯ P Z ) = inf Q Z|X ∈Q E P X E Q Z|X d p (X, g(Z)) over the set of deterministic decoders g, where d is the metric in the space of input X and p ≥ 1. Set Q restricts the search space for the encoder. In VAEs, a popular choice is a class of normal distributions Q = {Q Z|X regular conditional distribution : Z|{X = x} ∼ N (µ(x), Σ(x)), (µ, Σ) ∈ N N } where N N is a class of functions parametrized by neural networks. In WAEs, the choice Q = {Q Z|X regular conditional distribution : Q Z ≜ E P X Q Z|X = P Z } makes the left-hand side of Eq. ( 2) equal to the (p-th power of) the p-Wasserstein distance between P X and g ♯ P Z (Tolstikhin et al., 2018 , Theorem 1); Q Z is called an aggregate posterior of Z. If Q is a set of Dirac measures, i.e., Q = {Q Z|X : Q Z|X=x = δ f (x) , f ∈ N N } , then minimizing Eq. ( 2) reduces to the learning problem of a deterministic unregularized autoencoder. Of course, the notion of "informativeness" depends on the downstream task. The variation in the observations that are not relevant to the particular task is often called "nuisance" and is desirable to be suppressed from the representation. For example, in finding "fair representations," (Zemel, 2013) sensitive information such as gender or socioeconomic status should be removed from latent representations; in obtaining representations of facial images, those that are invariant to lighting conditions, poses, or wearing of eyeglasses are often sought. A popular approach to this goal is to explicitly separate informative and nuisance variables in the generative model by factorization. This approach imposes a structure on the decoder. Additionally the encoder is further factorized and a penalty promoting independence between the encoded representation and nuisance variable can be added. A well-known example is the variational fair autoencoder (VFAE, Louizos et al., 2016) , in which a variant of the "M1+M2" graphical model (Kingma et al., 2014) is used to factorize the decoder and a resembling factorization of the encoder (variational posterior) is assumed. Independence of the representation from nuisance variable is encouraged by adding a maximum mean discrepancy (MMD, Gretton et al., 2007) between conditional variational posteriors; in Lopez et al. (2018) , MMD is replaced by the Hilbert-Schmidt Independence Criterion (HSIC, Gretton et al., 2007) . Other authors employ penalties derived from the mutual information (MI) (Moyer et al., 2018; Song et al., 2019; Creager et al., 2019) . Another example is the Fader Networks (Lample et al., 2018) , in which the deterministic decoder takes an additional input of the attribute (such as whether or not eyeglasses are present in a portrait) and an adversarial penalty that hinders the accurate prediction of the attribute by the deterministic, unfactorized encoder. These examples illustrate that, while the generative model (decoder structure) can be chosen suitably for the downstream task, there is no principled way of imposing the corresponding encoder structure. In this paper, we show that the WAE framework allows us to automatically determine the encoder structure corresponding to imposed decoder structure. Specifically, when the deterministic decoder g in Eq. ( 2) is modified to handle the conditional independence structure of the imposed generative model, then the constraint set (amounting to the Q in Eq. ( 3)) that makes the LHS of Eq. (2) a proper (power of) Wasserstein distance determines the factorization of the (deterministic) encoder. In practice, the hard constraints in Q is relaxed and Eq. ( 2) is solved in a penalized form. Following the approach of Tolstikhin et al. (2018) , the cited constraint set can be systemically translated to penalties. Therefore, in addition to the theoretical advantage that the penalized form equals a genuine distributional distance for sufficiently large penalty parameter while that of Eq. ( 1) remains a lower bound of the negative log-likelihood of the model, the ad hoc manner of designing penalties prevalent in the VAE literature can be avoided in the WAE framework. Further, the allowance of deterministic encoder/decoder promotes better generation performance in many downstream tasks. We explain how the WAE framework leads to structured encoders given a generative model through examples reflecting downstream tasks in Sect. 3 after providing necessary background in Sect. 2. We would call these structured uses of WAEs the Wasserstein Fair Autoencoders (WFAEs). After reviewing relevant ideas in Sect. 4, WFAEs are experimented in Sect. 5 for datasets including VGGFace2 (Cao et al., 2018) . We conclude the paper in Sect. 6.

2. PRELIMINARIES

In fitting a given probability distribution P X of a random variable X on a measurable space (X , B(X )), where X ⊂ R D equipped with metric d, by a generative model P G of sample G on the same measurable space, one may consider minimizing the (pth power of) p-Wasserstein distance between the two distributions, i.e., min P G ∈M W p p (P X , P G ) := inf π∈P(P X ,P G ) E π d p (X, G) . Here, M is the model space of probability distributions, P(P X , P G ) is the coupling or the set of all joint distributions on (X × X , B(X × X )) having marginals P X and P G . Often the sample G is generated by transforming a variable in a latent space. When G = g(Z) a.s. for a latent variable Z in a probability space (Z, B(Z), P Z ), Z ⊂ R l , and measurable function g, then P G is denoted by g ♯ P Z , where ♯ is the push forward operator. In this setting, as discussed in Sect. 1, Tolstikhin et al. (2018) show that W p p (P X , g ♯ P Z ) = D WAE (P X , g ♯ P Z ) (Eq. ( 2)), with the constraint set Q on the probabilistic encoders Q Z|X given in Eq. (3). It is further claimed by Patrini et al. (2020) that the set of conditional distributions Q Z|X can be reduced to be deterministic, i.e., Z = f (X) a.s. for f measurable. However, this claim is not in general true unless g is injective: Theorem 1 Let d(x, y) = ∥x -y∥ 2 for x, y ∈ X . If P X has a density with respect to the Lebesgue measure, and the measurable function g : Z → X is injective, then W 2 2 (P X , g ♯ P Z ) = inf f ∈Q E P X d 2 (X, g(f (X))), ( ) where Q is the set of all measurable functions from X to Z such that f ♯ P X = P Z . The proof of this result is provided in Appendix A of the Supplement. Remark 1 In Patrini et al. (2020, Theorem A.2) , it is incorrectly claimed that for the right inverse g of g when the codomain of the latter is restricted to its range, (g • g) ♯ P Z (F ) is equal to P Z (g -1 (g -1 (F )), instead of the correct P Z (g -1 (g -1 (F ))). This confusion invalidates the rest of the argument of the cited theorem. In practice the set Q can be relaxed to F, a class of all measurable functions parameterized by deep neural networks, which contains a minimizer of the right-hand side (RHS) of Eq. (4); the constraint f ♯ P X = P Z can be met by adding a penalty λD(f ♯ P X ∥P Z ) for sufficiently large multiplier λ > 0 and a divergence D between two distribution. Thus if we define the distortion criterion δ(f, g) = E P X d 2 (X, g(f (X))) + λD(f ♯ P X ∥P Z ), then the generative modeling problem based on 2-Wasserstein distance can be formulated as inf g∈G inf f ∈F δ(f, g), for G a set of injective measurable functions from Z to X , typically parameterized by deep neural networks. The function f : X → Z has an interpretation of an encoder and g : Z → X has an interpretation of a decoder. Typically l ≪ D.

3. LEARNING INVARIANT REPRESENTATIONS WITH WFAES

Often generative modeling is more complicated than just involving a latent variable Z in Z and its reconstruction G in X . For example, data may come with labels, which can be employed in the generation process to learn invariant representations. Example 1 Let us begin with a simple generative model shown in Fig. 1a (Louizos et al., 2016, Fig. 1) ; see also (Kingma et al., 2014, M2) . Here, variable S ∈ S ⊂ R B represents the observed nuisance variation, and Z models the remaining information on G (with which we want to mimic the observable variable X) that is independent of S. Thus the Z encodes the representation invariant to the unwanted variation in S. Denoting the marginal distribution of the nuisance variable S by P S , the distribution of model G is g ♯ (P Z ⊗ P S ), where ⊗ is used to denote a product distribution. The goal is to make the joint distribution P GS of (G, S) close to P XS of (X, S), the observable. If we let g(z, s) = (g(z, s), s), then P GS = g♯ (P Z ⊗ P S ). Recall that X is equipped with metric d. Equip S with another metric d ′ and X × S with d = d 2 + (d ′ ) 2 . Then, by applying Theorem 1 to P XS , P Z ⊗ P S , and g, we obtain W 2 2 (P XS , g♯ (P Z ⊗ P S )) = inf f ∈ F E P XS d2 [X, S], g( f (X, S)) = inf f ∈F E P XS d2 ([X, S], [g(f (X, S), S), S]) = inf f ∈F E P XS d 2 (X, g(f (X, S), S)) , where F = { f : X × S → Z × S : f♯ P XS = P Z ⊗ P S }, F = {f : X × S → Z : (f, Π S ) ♯ P XS = P Z ⊗ P S }, and Π S : X × S → S : Π S (x, s) = s is the orthogonal projection from X × S onto S. The second equality holds by noting that f (x, s) = (f (x, s), h(x, s)) and taking h = Π S . The latter constraint set F means that f (X, S) d = Z, f (X, S) ⊥ ⊥ S. Following formulation equation 5 for the unstructured case, we can incorporate constraint equation 6 into the learning problem in a penalized form min g min f E P XS d 2 X, g(f (X, S), S) + λ 1 D(f ♯ P XS ∥P Z ) + λ 2 H((f, Π S ) ♯ P XS ), where D is an appropriate divergence between two probability distributions such as MMD or the generative adversarial network (GAN) loss as suggested by Tolstikhin et al. (2018) , and H promotes independence between two random variables f (X, S) and S, such as the HSIC (Lopez et al., 2018) . Example 2 Consider a more involved generative model shown in Fig. 1b , which is employed by the VFAE (Louizos et al., 2016, Fig. 2 ) as an extension of the "M1 + M2" semi-supervised model (Kingma et al., 2014) . This graphical model actually describes the conditional distribution P X|S of X given S, since S and Y are allowed to be correlated. Instead, it is required Z 1 ⊥ ⊥ S (7) in order to impose invariance to the nuisance variable S. Let g : Y × Z 2 × S → X × Y × S as g(y, z 2 , s) = g 1 g 2 (y, z 2 ), s , y, s . Denoting the marginal distribution of the nuisance variable S by P S and the joint distribution of Y and S by P Y S , the distribution of model G is g ♯ (P Y S ⊗ P Z2 ). The goal is to make the joint distribution P GS of (G, S) close to P XS of (X, S) when Y is not observed, and P GY S of (G, Y, S) to P XY S of (X, Y, S) when the data is fully observed. First consider the case that Y is missing. Let Π XS be the orthogonal projection operator from X × Y × S onto X × S. Then by applying Theorem 1 to P XS , P Y S ⊗ P Z2 , and Π XS g, we obtain W 2 2 (P XS , Π XS g ♯ (P Y S ⊗ P Z2 )) = inf f ∈F unobs E P XS d2 ([X, S], Π XS g(f (X, S), S)) = inf f ∈F unobs E P XS d2 ([X, S], [g 1 (g 2 (f (X, S)), S), S]) = inf f ∈F unobs E P XS d 2 (X, g 1 (g 2 (f (X, S)), S)), where F unobs = {(f unobs 1 , f unobs 2 )| f unobs 1 : X × S → Y, f unobs 2 : X × S → Z 2 , (f unobs 1 , Π S , f unobs 2 ) ♯ P XS = P Y S ⊗ P Z2 }. The latter constraint set means (f unobs 1 (X, S), S) d = (Y, S), f unobs 2 (X, S) d = Z 2 , (f unobs 1 (X, S), S) ⊥ ⊥ f unobs 2 (X, S). (8) Now consider the case Y is observed. Equip Y with a metric d ′′ and X × Y × S with d = d 2 + (d ′′ ) 2 + (d ′ ) 2 . Then by applying Theorem 1 to P XY S , P Y S ⊗ P Z2 , and g, we obtain W 2 2 (P XY S , g ♯ (P Y S ⊗ P Z2 )) = inf f obs 2 ∈F obs E P XY S d2 [X, Y, S], g(Y, f obs 2 (X, Y, S), S) = inf f obs 2 ∈F obs E P XY S d2 [X, Y, S], [g 1 (g 2 (Y, f obs 2 (X, Y, S)), S), Y, S] = inf f obs 2 ∈F obs E P XY S d 2 (X, g 1 (g 2 (Y, f obs 2 (X, Y, S)), S)), where F obs = {f obs 2 : X × Y × S → Z 2 : (Π Y , Π S , f obs 2 ) ♯ P XY S = P Y S ⊗ P Z2 } and Π Y : (x, y, s) → y, Π S : (x, y, s) → s are projections. The latter constraint set means f obs 2 (X, Y, S) d = Z 2 , (Y, S) ⊥ ⊥ f obs 2 (X, Y, S). In order to combine the two Wasserstein losses and constraints Eq. ( 7) to ( 9), let us extend Y to Ȳ = Y ∪ { * }, where ' * ' represents the missing value. For any (f unobs 1 , f unobs 2 ) ∈ F unobs and f obs 2 ∈ F obs , define f 1 : X × Ȳ × S → Y and f 2 : X × Ȳ × S → Z 2 as f 1 (x, y, s) = y, y ̸ = * , f unobs 1 (x, s), y = * , f 2 (x, y, s) = f obs 2 (x, y, s), y ̸ = * , f unobs 2 (x, s), y = * . Then we can formulate the learning problem for the WFAE in a penalized form min g1,g2 min f1,f2 E P XY S d 2 X, g 1 g 2 (Y, f 2 (X, Y, S)), S + E P XS d 2 X, g 1 g 2 (f 1 (X, * , S), f 2 (X, * , S)), S + λ 1 D 1 ((f 1 , Π S ) ♯ P XY S ∥P Y S ) + λ 2 D 2 (f 2♯ P XY S ∥P Z2 ) + λ 3 H 3 ((f 1 , Π S , f 2 ) ♯ P XY S ) + λ 4 H 4 ((g 2 ⋆ f 1 ) ♯ (P XY S ⊗ P Z2 )), where g 2 ⋆ f 1 (x, y, z 2 , s) = (g 2 (f 1 (x, y, s), z 2 ), s); D 1 and D 2 are appropriate divergences between two probability distributions, and H 3 , H 4 promotes independence between two random variables. Note, unlike Example 1 in which only the encoder f is constrained, Eq. ( 7) imposes a constraint on the decoder g 2 . Also note that, the divergence D 1 can be estimated in a two-sample fashion, namely from the a sample drawn from P Y S , i.e., (y i , s i ) with y i observed, and another sample drawn from (f 1 , Π S ) ♯ P XY S , either as (y j , s j ) if y j is observed or (f 1 (x j , * , s j ), s j ) otherwise. Hence all the data from the minibatch can be utilized. Likewise, divergence D 2 and the independence penalties H 3 and H 4 can utilize the full minibatch. Remark 2 VAE-based models, e.g., VFAE, assume a specific factorization of the variational posterior (encoder). Since the factor q ϕ (y|z) for imputing Y does not appear in the evidence lower bound (ELBO) of the observed likelihood, an additional penalty on this factor evaluated for the fully observed sample is coined (Louizos et al., 2016, Eq. 5) , making the bound not tight. In the WFAE, on the contrary, the D 1 term that arises naturally from constraint equation 8 for the Wasserstein distance penalizes the imputing encoder f 1 for both fully (by requiring f 1 (x j , y j , s j ) = y j ) and partially (by the divergence) observed samples. Example 3 The model shown in Fig. 1c extends Example 1 with two independent nuisance variables that can be missing. Here Y may represent a person's identity in her portrait, which may be missing, and S partially observed attributes (e.g., sunglasses on/off, mouth open/closed, and gender). In this setup we want two different portraits of a person to have similar values of Z, and those of two different people to have quite distinct values of Z, even if the encoder does not know whose portraits they are. We may also want Z to represent something immune even to gender switch. Proceeding as Example 2, we obtain for g : Y × Z × S → X × Y × S: g(y, z, s) = (g(y, z, s), y, s), W 2 2 (P X , Π X g♯ (P Y ⊗ P Z ⊗ P S )) = inf (f X 1 ,f X 2 ,f X 3 )∈F X E P X d 2 (X, g(f X 1 (X), f X 2 (X), f X 3 (X))), F X = {(f 1 , f 2 , f 3 ) : (f 1 , f 2 , f 3 ) ♯ P X = P Y ⊗ P Z ⊗ P S , f 1 : X → Y, f 2 : X → Z, f 3 : X → S} when both Y and S are unobserved, W 2 2 (P XS , Π XS g♯ (P Y ⊗ P Z ⊗ P S )) = inf (f XS 1 ,f XS 2 )∈F XS E P XS d 2 (X, g(f XS 1 (X, S), f XS 2 (X, S), S)), F XS = {(f 1 , f 2 ) : (f 1 , f 2 , Π S ) ♯ P XS = P Y ⊗ P Z ⊗ P S , f 1 : X × S → Y, f 2 : X × S → Z} when only Y is unobserved, W 2 2 (P XY , Π XY g♯ (P Y ⊗ P Z ⊗ P S )) = inf (f XY 2 ,f XY 3 )∈F XY E P XY d 2 (X, g(Y, f XY 2 (X, Y ), f XY 3 (X, Y ))), F XY = {(f 2 , f 3 ) : (Π Y , f 2 , f 3 ) ♯ P XY = P Y ⊗ P Z ⊗ P S , f 2 : X × Y → Z, f 3 : X × Y → S} when only S is unobserved, and  W 2 2 (P XY S , g♯ (P Y ⊗ P Z ⊗ P S )) = inf f XY S 3 ∈F XY S E P XY S d 2 (X, g(Y, f XY S 2 (X, Y, S), S)), F XY S = {f 2 : (Π Y , f 2 , Π S ) ♯ P XY S = P Y ⊗ P Z ⊗ P S , E P XY S d 2 X, g Y, f 2 (X, Y, S), S) + E P XY d 2 X, g Y, f 2 (X, Y, * ), f 3 (X, Y, * )) + E P XS d 2 X, g f 1 (X, * , S), f 2 (X, * , S), S) + E P X d 2 X, g f 1 (X, * , * ), f 2 (X, * , * ), f 3 (X, * , * )) + λ 1 D 1 (f 1♯ P XY S ∥P Y ) + λ 2 D 2 (f 2♯ P XY S ∥P Z ) + λ 3 D 3 (f 3♯ P XY S ∥P S ) + λ 4 H 4 ((f 1 , f 2 , f 3 ) ♯ P XY S ), where H 4 measures dependence of three random variables, e.g., the d-variate HSIC (Lopez et al., 2018) with d = 3, and F = {(f 1 , f 2 , f 3 ) : f 1 : X × Ȳ × S → Y, f 2 : X × Ȳ × S → Z, f 3 : X × Ȳ × S → S}, for f 1 (x, y, s) =    f X 1 (x), y = * , s = * , f XS 1 (x, s), y = * , s ̸ = * , y, y ̸ = * , f 3 (x, y, s) =    f X 3 (x), y = * , s = * , f XY 3 (x, y), y ̸ = * , s = * , s, s ̸ = * , and f 2 (x, y, s) is equal to f X 2 (x) if y = * , s = * , to f XY 2 (x, y) if y ̸ = * , s = * , to f XS 2 (x, s) if y = * , s ̸ = * , and to f XY S 2 (x, y, s) otherwise. Remark 3 If variable Y is removed and S is fully observed, Example 3 reduces to Example 1, where the f 2 (x, y, s) from the former corresponds to the f (x, s) from the latter. The Fader Networks (Lample et al., 2018) implicitly use this model to obtain attribute-invariant representations of facial images. The adversarial penalty for training the network (Lample et al., 2018, Eq. 2) can be understood as promoting independence between S and Ẑ = f (X, S). While in Lample et al. (2018) the encoder f does not depend on S, Example 1 shows that it is more natural to take S as an input to remove its effect on Ẑ. Example 3 can be considered a generalization of the Fader Networks for missing attributes and unknown identities.

4. RELATED WORK

The literature on VAEs is vast. β-VAE (Higgins et al., 2017) is one of well-known ways of adding penalties to the ELBO of a VAE, which adds one proportional to the expected KL divergence between the variational posterior (encoder) and prior P Z . It is observed that this penalty promotes factorization of the aggregate posterior Q Z (Kim & Mnih, 2018) . In fair representation, VFAE (Louizos et al., 2016) and HSIC-constrained VAE (HCV, Lopez et al., 2018) add penalties to the ELBO for semi-supervised disentanglement along this line. Adversarial penalties have been also considered (Edwards & Storkey, 2016; Madras et al., 2018) . Song et al. (2019) bring an information-theoretic interpretation to these approaches. In this regard, penalizing MI between nuisance variable S and encoded latent variable Z (Moyer et al., 2018; Song et al., 2019; Creager et al., 2019) or its tractable upper bounds, e.g., based on a variational approximation (Rodríguez-Gálvez et al., 2021) , has been advocated. Recently proposed FairDisCo (Liu et al., 2022) uses the L 2 distance between the joint density of S and Z and the product density of their marginals, showing its asymptotic equivalence to the MI. However, as stated in Sect. 1, these penalties promoting desired structures are chosen rather ad hoc and loosens the already-not-tight ELBO. Furthermore, there is no principle for choosing the encoder structure corresponding the imposed decoder structure. The WAE framework discussed in the previous section can overcome these pitfalls in VAEs. The WAE literature has focused on improving the divergence in the penalized form of Eq. ( 2) that matches the prior P Z and the aggregated posterior. The original proposal by Tolstikhin et al. (2018) is to employ either the MMD or GAN divergence. Kolouri et al. (2019) propose to use the sliced Wasserstein distance in order to simplify computation. Patrini et al. (2020) consider the Sinkhorn divergence (Genevay et al., 2018) , computing of which can be boosted by using the Sinkhorn algorithm (Cuturi, 2013) . Xu et al. (2020) and Nguyen et al. (2021) propose and improve the relational divergence called the fused Gromov-Wasserstein distance. The latter three works consider the setting in which the prior P Z is structured. In contrast, we focus on the setting in which the decoder is structured and nuisance information is (partially) available. According to the taxonomy of Tschannen et al. ( 2018), the former is close to the clustering meta-prior whereas the latter is close to the disentangling one. We emphasize that the cited divergences are compatible with our framework.

5. EXPERIMENTS

We experimented WFAEs with various real-world datasets. The generative models for these datasets mainly follow Examples 2 and 3, in most of which variable Y (and sometimes S) has the meaning of a label and thus categorical. In order to embed this variable to the Euclidean space R B where B does not necessarily depend on the number of categories, we employed the entity embedding network (Guo & Berkhahn, 2016) for observed labels. The trained embedding network naturally becomes a pretrained encoder f 1 or f 3 from Examples 2 and 3. A by-product of this embedding is that it is even possible to impute categories not present in training data.

5.1. FAIR REPRESENTATIONS

To demonstrate the performance of WFAEs on fair representation, we reproduced experiments in Liu et al. (2022) using two categorical datasets, namely the Adult Income and Health datasets. Refer to the appendix for data summary and network implementation. The generative model for the WFAE was the structure of Example 2. With the Z 1 = g 2 (Y, f 2 (X, Y, S)) encoded from the trained model, we quantified the trade-off between fairness and utility (Zhao et al., 2017) : we classified S and Y using random forest method, calculated the area under the ROC curve (AUC) on the test data (sAUC and yAUC) as a function of demographic parity ∆ DP , and compared the performance with the HSICconstrained VFAE (HCV) and the FairDisCo. The results are summarized in Fig. 2 . While WFAE shows a clear trade-off, other methods are relatively insensitive to demographic parity.

5.2. INVARIANT REPRESENTATIONS

The same structure as Sect. 5.1 is used to test the ability of WFAEs to learn invariant representations of controlled photographs. The cropped Extended Yale B dataset Georghiades et al. (2001) ; Lee et al. (2005) comprises of facial images of 38 human subjects in various lighting conditions. For each subject, the pictures of the person are split into training and test data with a fixed ratio, resulting in 1,664 and 750 images for the training and test respectively. We set the identity of the image as Y and the lighting condition (elevation and azimuth of the light direction normalized into [-1, 1] × [-1, 1]) as S. In the training stage, we first trained f 1 to estimate Y , then trained the rest of the network with f 1 held fixed. In consequence, we were able to encode and decode the test data without the information about Y by replacing it with f 1 (X, * , S). with continuous S, we present some of the results with S categorized in 5 directions, as in Lopez et al. (2018) . The results are presented Table 1 . The Z 1 encoded by WFAE shows better performance in predicting Y and worse in predicting S than others, suggesting better invariant representation. The t-SNE visualization of Z 1 in Fig. 3 accords with this result, showing noticeable separation with respect to Y , but not with respect to S. In panel A of Fig. 4 (top left), the green box depicts generated images by encoding the test image X and nuisance data S into Z 1 = g 2 (f 1 (X, * , S), f 2 (X, * , S)), and then computing g 1 (Z 1 , S). Those in the red box were generated by using the same Z 1 but setting S = (±0.3, ±0.3). WFAE produced reconstructions closer to the input than HCV and FairDisCo, and perturbing S only kept the identity of the input in the generated images. The sharpness and the Fréchet inception distance (FID) scores are shown Table 2 to assess the sample generation quality. WFAE produced sharper images than FairDisCo, confirming the visual inspection. The FID scores should be taken with caution, though. Since the sample generation is conducted by varying the "lighting direction" attribute (considered as the S variable) the generated samples should be different from the test data with scarce images. Rather, it may indicate samples generated from FairDisCo is less sensitive to S, which can also be verified visually.

5.3. CONDITIONAL GENERATION

We further investigated the conditional generation capability of WFAEs using the MNIST and VG-GFace2 datasets (Cao et al., 2018) . MNIST. We treated the digit attribute as S. The generative model for the data is similar to Example 3, but without Y . We first trained encoder f 3 that estimates S, then trained the rest of the network. The final network was tested with images with digit information removed. We also trained a network without the encoder f 2 for Z, hence it decodes an image using only estimated S. Fig. 4 summarize the results, all of which were generated from test data without information of S. Penal A (top right) shows decoded samples from g with estimated S and i) not using f 2 (blue box), ii) using encoded Z = f 2 (X, S) from the test data (green box), and iii) using Z sampled from prior P Z (red box). Decoded images with the same S all retained their digit information. Reconstruction without using f 3 , although recognizable, produced degraded images, implying loss of information. FairDisCo with a similar architecture produced quite degraded results; see also Table 2 . 1 In panel B (top), we estimated Z from the source and S from the target and generated new images by g(Z, S). VGGFace2. This dataset contains 3.14M training images of faces of total 8631 subjects and 169k test images of total 500 subjects, with partially observed binary attributes such as gender, wearing of sunglasses, and openness of mouth, available for a subset of 30,000 images. Here, we treat the identity of the image as class Y and the vector of attributes as S. The generative model for this dataset is the same as Example 3. The class-preserving generation and style transfer tasks were conducted in the same manner as MNIST. In addition, we also tried generating samples with manipulated attributes. Since the attribute encoder f 3 embeds S in the Euclidean space, we could extrapolate input S to decoder g beyond 0 and 1. For this attribute manipulation task, we compared results with Fader Networks trained with a similar architecture. Fig. 4 shows sample images for all tasks. Although images of persons who were not in the training data were used, the WFAE could successfully generate images retaining the identity while employing other identity-invariant 

6. CONCLUSION

We have shown that the WAE framework is rich enough to handle various conditional independence structures, leading to much more principled formulation of learning problems than the VAE counterparts. Importantly, a conditional independence structure imposes on the decoder automatically determines the encoder structure and the associated constraints. We hope this paper stimulates further research on extensions of WAEs in this direction, for instance, to complex hierarchical structures.

A PROOF OF THEOREM 1

Proof 1 Under the conditions of the theorem statement, the Monge-Kantorovich equivalence holds (see, e.g., Peyré & Cuturi, 2019, Theorem 2.1): W 2 2 (P X , P G ) = inf T :X →X :T ♯ P X =P G E P X d 2 (X, T (X)). Hence it suffices to show that inf f :X →Z:f ♯ P X =P Z X d 2 (x, g(f (x)))dP X = inf T :X →X :T ♯ P X =P G X d p (x, T (x))dP X or equivalently {g • f : f : X → Z, f ♯ P X = P Z } = {T : X → X : T ♯ P X = P G }. The forward inclusion ⊂ holds since for any measurable f : X → Z such that f ♯ P X = P X f -1 = P Z we have g • f : X → X measurable and for any Borel set E ⊂ X (g • f ) ♯ P X (E) = P X (g • f ) -1 (E) = P X (f -1 (g -1 (E))) = g ♯ [P X q -1 ](E) = g ♯ f ♯ P X (E) = g ♯ P Z (E) = P G (E). For the backward inclusion ⊃, suppose T X → X is measurable and satisfies T ♯ P X = P G . Since g is injective, it has a left inverse g † : X → Z. Let f = g † • T . Then for any Borel set F ⊂ Z, f ♯ P X (F ) = P X (g † • T ) -1 (F ) = P X (T -1 ((g † ) -1 (F ))) = T ♯ P X ((g † ) -1 (F )) = P G ((g † ) -1 (F )) = g ♯ P Z ((g † ) -1 (F )) = P Z (g -1 ((g † ) -1 (F ))) = P Z ((g † • g) -1 (F )) = P Z (F ), which completes the proof.

B ADDITIONAL DETAILS FOR THE FAIR REPRESENTATION EXPERIMENT

Following Liu et al. (2022) , fair representation experiment were held for the Adult Income and Health datasets, whose characteristics are described in Table 3 . Note that all variables were categorized: one-hot encoding was used for variables with multiple category to make all data either 0 or 1. The encoder-decoder architecture of the network was adopted from (Louizos et al., 2016) 

C FURTHER IMPLEMENTATION DETAILS

In the source code attached, settings for all experiments are gathered as configuration files in directory configs/train info. All the network architecture are listed in .py files in src/model directory, and model and architecture keyword in the configuration file specifies which architecture to use among them. The running script run.sh that states which configuration was used for each experiment, managed by Hydra 1.1.1. Yadan (2019) , can be found in the experiments directory. The prior P Z for the encoded Z was set to be a normal distribution N (0, 2I l ), where l is the dimension of the latent space Z. For the penalty divergences D i , we used the generative adversarial network (GAN) loss, which requires an additional discriminator (Tolstikhin et al., 2018) . All the networks were trained using ADAM (Kingma & Ba, 2014) without any learning rate scheduling. Extended Yale B The cropped version of the Extended Yale Face Database B dataset (Georghiades et al., 2001; Lee et al., 2005) were resized into a size of 128×128. The encoder-decoder architecture of the network had total of 18.5M parameters, and the discriminator architecture had 881 parameters (Table 5 ). After pre-training the Y -encoder with 2,100 iterations, we optimized the network for 5,200 iterations, which took about 40 minutes. The results were compared with the HSIC-contrained variational fair autoencoder (HCV, Table 6 ) and FairDisCo Table 7 . MNIST The encoder-decoder architecture of the network had 3.8M parameters, and the discriminator architecture had 7.4k parameters (Table 8 ). We pre-trained the S-encoder with 6,000 iterations, then optimized the rest of the network for 11,700 iterations, which took about half an hour. The results were compared with HCV and FairDisCo with the S information available for decoding (Table 9 and Table 10 ).

VGGFace2

The face region of the collected data were cropped and resized into a size of 128×128. The encoder-decoder architecture of the network had 88.4M parameters, and the discriminator architecture had 206k parameters (Table 11 ). We pre-trained the (Y, S)-encoder with 3,000 iterations, then optimized the rest of the network for 30,000 iterations, which took 16 hours. The results were compared with the Fader Network having an encoder-decoder architecture with 70.2M parameters and a discriminator architecture with 483k parameters (Table 12 ) trained for 20,000 iterations, which took 11 hours. Computing infrastructure We trained the networks with Intel® Xeon® CPU E5-2650 v4 @ 2.20GHz processors and Nvidia Titan X Pascal GPUs with 12GB memory. For the VGGFace2 experiments, we trained the network using four GPUs; those for the other experiments were all trained using a single GPU. All the implementations were based on Python 3.6, PyTorch 1.10.2, PyTorch Lightning 1.5.10, and CUDA 10.2.



In this experiment there is no Y and the digit class plays the role of S. So the generation is usual classconditional one, hence the FID scores are lower for the WFAE as expected from visual comparison.



Figure 1: Examples of generative models for WFAEs

f 2 : X × Z → S} when the data are fully observed. If we expand Y to Ȳ = Y ∪ { * } and S to S = S ∪ { * },

Figure 2: Fair representation trade-off.

Figure 3: T-SNE map of Z 1 in Extended Yale B

Figure 4: Conditional generation examples of WFAE

Although we trained the model Under review as a conference paper at ICLR 2023 Sample generation quality measures.

(Table 4)

Information on categorical datasets for fair representation task.

