GM-VAE: REPRESENTATION LEARNING WITH VAE ON GAUSSIAN MANIFOLD

Abstract

We propose a Gaussian manifold variational auto-encoder (GM-VAE) whose latent space consists of a set of diagonal Gaussian distributions. It is known that the set of the diagonal Gaussian distributions with the Fisher information metric forms a product hyperbolic space, which we call a Gaussian manifold. To learn the VAE endowed with the Gaussian manifold, we first propose a pseudo Gaussian manifold normal distribution based on the Kullback-Leibler divergence, a local approximation of the squared Fisher-Rao distance, to define a density over the latent space. With the newly proposed distribution, we introduce geometric transformations at the last and the first of the encoder and the decoder of VAE, respectively to help the transition between the Euclidean and Gaussian manifolds. Through the empirical experiments, we show competitive generalization performance of GM-VAE against other variants of hyperbolic-and Euclidean-VAEs. Our model achieves strong numerical stability, which is a common limitation reported with previous hyperbolic-VAEs.

1. INTRODUCTION

The geometry of latent space in generative models, such as the variational auto-encoders (VAE) (Kingma & Welling, 2013) and generative adversarial networks (GAN) (Goodfellow et al., 2020) , reflects the structure of the representation of the data. Mathieu et al. (2019) ; Nagano et al. (2019) ; Cho et al. (2022) show that employing a hyperbolic space as the latent space improves in preserving the hierarchical structure of the data in the latent space. The expanded geometry is not just limited to the hyperbolic space, as the space can be other types of Riemannian manifolds, such as spherical manifolds (Xu & Durrett, 2018; Davidson et al., 2018) and the product of Riemannian manifolds with mixed curvatures (Skopek et al., 2019) . Meanwhile, it is known that univariate Gaussian distributions equipped with Fisher information metric (FIM) form a Riemannian manifold, sharing the manifold with Poincaré half-plane which is one of the four isometric hyperbolic models. This statistical manifold is known to have a metric tensor akin to that of the Poincaré half-plane (Costa et al., 2015) , providing a possibility of viewing it as a hyperbolic space. Furthermore, the diagonal Gaussian distributions form a product of Riemannian manifolds showing the presence of an extended statistical manifold. Based on the connection between hyperbolic spaces and statistical manifolds, in this work, we add an alternative perspective on hyperbolic VAEs with a viewpoint from the information geometry. Previously proposed hyperbolic VAEs rely on the distributions defined over the hyperbolic space. Riemannian normal and wrapped normal are commonly used as prior and variational distributions over the hyperbolic space. Unlike the Gaussian distribution in Euclidean space, these distributions suffer from numerical instability (Mathieu et al., 2019; Skopek et al., 2019) . In addition, the Riemannian normal requires performing rejection sampling, which often generates too many unwanted samples. From the information geometric perspective of the hyperbolic space, we introduce a new distribution, named a pseudo Gaussian manifold normal distribution (PGM normal). The Gaussian manifold, here, refers to the statistical manifold with univariate Gaussian distributions. The newly proposed distribution uses the KL divergence as a statistical distance between two distributions in the Gaussian manifold. Since the KL divergence approximates the squared Riemannian distance of the statistical manifold, derived from FIM, the proposed distribution follows the geometric property of the Gaussian distributions. We show that the PGM normal is easy to sample, and the KL divergence between two PGM normals can be computed analytically. With the PGM normal as prior and variational distributions, we define a Gaussian manifold VAE (GM-VAE), whose latent space is defined over the Gaussian manifold. Nevertheless, the data points are still assumed to be defined over the Euclidean space. To correct the mismatch between the data space and the latent space, we introduce a transformation from Euclidean to hyperbolic space at the last and the first layers of the encoder and decoder, respectively. Empirical experiments with multiple datasets show that GM-VAE can achieve a competitive generalization performance against existing hyperbolic VAEs. During the experiments, we observe that the PGM normal is robust in terms of sampling and computation of the KL divergence, compared to the commonly-used hyperbolic distributions; we briefly explain the reason why others are numerically unstable. Analysis of the latent space exhibits that the geometrical structures and probabilistic semantics of the dataset can be captured in the representations learned with GM-VAE. We summarize our contributions as follows: • We propose a variant of VAE whose latent space is defined on a statistical manifold formed by diagonal Gaussian distributions. • We propose a new distribution called pseudo Gaussian manifold normal distribution, which is easy to sample and has closed form KL-divergence, to train the VAE on the manifold. • We propose new encoder and decoder structures to support the proper transition between Euclidean (data) space and the statistical manifold. • We empirically verify that the newly proposed model performs similarly to existing hyperbolic VAEs while achieving stable training without numerical issues.

2. PRELIMINARIES

In this section, we first review the fundamental concepts of the Riemannian manifold. We then explain the commonly-used distributions over the Riemannian manifolds and visit the concepts of Riemannian geometry between statistical objects.

2.1. REVIEW OF RIEMANNIAN MANIFOLD

A n-dimensional Riemannian manifold consists of a manifold M and a metric tensor g : M → R n×n , which is a smooth map from each point x ∈ M to a symmetric positive definite matrix. The metric tensor g(x) defines the inner product of two tangent vectors for each point of the manifold ⟨•, •⟩ x : T x M × T x M → R, where T x M is the tangent space of x. A Riemannian manifold can be characterized by the curvature of the curves defined on it. The curvature of a Riemannian manifold can be computed at each point of the curves, while some manifolds have curvature of a constant value. For example, the unit sphere S has constant positive curvature of +1, and the Poincaré half-plane U has constant negative curvature of -1. The hyperbolic models Among the hyperbolic space, the Klein model, the Poincaré disk model, the Lorentz (Hyperboloid) model, and Poincaré half-plane model are known to be isometric and have the same value of curvature -1 (Nickel & Kiela, 2018; Gulcehre et al., 2018; Tifrea et al., 2018) . The metric tensor induces basic operations of the Riemannian manifold such as a geodesic, exponential map, log map, and parallel transport. Given two points x, y ∈ M, geodesic γ x : [0, 1] → M is a unit speed curve on M being the shortest path between γ(0) = x and γ(1) = y. This can be interpreted as the generalized curve of a straight line in the Euclidean space. The exponential map exp x : T x M → M is defined as γ(1), where γ is a geodesic starting from x and γ ′ (0) = v, where a tangent vector v ∈ T x M. The log map log x : M → T x M is the inverse of the exponential map, i.e., log x (exp x (v)) = v. The parallel transport PT x→y : T x M → T y M moves the tangent vector v along the geodesic between x and y. The distance function d M (x, y) can be induced from the metric tensor as follows: d M (x, y) = 1 0 ⟨ γ(t), γ(t)⟩ γ(t) dt. (1)

2.2. DISTRIBUTIONS OVER RIEMANNIAN MANIFOLD

Given a squared distance function d 2 M : M × M → R >0 of a Riemannian manifold M, the probability density function of the Riemannian normal distribution can be computed by: p µ,σ (z) = 1 Z M exp - d 2 M (z, µ) 2σ 2 , ( ) where µ ∈ M is the Fréchet mean of the distribution, and σ ∈ R >0 is the dispersion parameter and Z M is the normalizing factor. This is known to be preserving the maximum entropy property of the Gaussian distribution (Pennec, 2006) . Note that the distribution requires computing the integral shown in Equation 1, which often does not have an analytic solution. In some special cases, one can compute the distance analytically but the computation is intractable in general. Mathieu et al. (2019) propose a rejection sampling method of the Riemannian normal defined on the Poincaré disk model, which we call a Poincaré normal distribution. An alternative to Riemannian normal is the wrapped normal distribution. The wrapped normal distribution is constructed by transforming a sample from Gaussian distribution via parallel transportation and an exponential map: z = exp µ (PT 0 M →µ (f (v))) , v ∼ N (0, Σ), where µ ∈ M is the mean vector of the distribution, 0 M is the origin of M, f (•) maps a Euclidean vector to a tangent vector of 0 M , and v is a sample obtained from Euclidean normal with the zero mean and covariance Σ. The probability density of the sample can be computed by using the change of variable technique. Note that f (•) is well-defined in hyperbolic spaces. For example, in the Lorentz model, we concatenate zero at the first dimension of the vector, and in the Poincaré disk model, it is an identity function. (Nagano et al., 2019) propose wrapped normal distribution on hyperbolic space, and we call it hyperbolic wrapped normal distribution.

2.3. STATISTICAL MANIFOLD

The parameter manifold M of the probability distributions p θ : X → R, where θ ∈ M, equipped with the Fisher information metric (FIM) forms a Riemannian manifold (Rao, 1992) . The FIM is defined as: g ij (θ) = X ∂ log p θ (x) ∂θ j ∂ log p θ (x) ∂θ j p θ (x) dx. In the parameter space of univariate Gaussian distributions {(µ, σ) | µ ∈ R, σ ∈ R >0 }, the FIM can be simplified as two-dimensional diagonal matrix σ -2 diag(1, 0.5) (Costa et al., 2015) . The diagonal form of the FIM implies that the Riemannian manifold with {(µ, σ)} has the same set of points as the manifold of the Poincaré half-plane, but with different curvature of value -0.5. The parameter space of the n-dimensional diagonal Gaussian distributions becomes the product of n manifolds of the parameter space of univariate Gaussian distributions. The operations on the product of the Riemannian manifolds n i=1 M i are defined manifold-wise. For example, an exponential map applied on a point (p i ) n i=1 ∈ n i=1 M i , with tangent vector v i ∈ T pi M i for each i ∈ {1, • • • , n}, can be represented as (exp pi (v i )) n i=1 .

2.4. STATISTICAL DISTANCE

The statistical distance is the distance, but may not be a metric, between two statistical objects such as random variables and probability density function. The statistical distance can provide similarities between two probability density functions. On a statistical manifold equipped with FIM, a statistical distance called the Fisher-Rao distance can be well-derived. The Fisher-Rao distance of the statistical manifold is the Riemannian distance induced from the Fisher information metric using Equation 1. For example, the Fisher-Rao distance in the statistical manifold with the univariate Gaussian distribution can be easily induced using the Riemannian distance of the Poincaré half-plane model, where the Riemannian metric is similar (Costa et al., 2015) . Kullback-Leibler (KL) divergence is another widely-used statistical distance, which is defined as D KL (p(x) ∥ q(x)) := x p(x) log p(x) q(x) dx for two distributions p(x), q(x) in the same statistical manifold. For example, the KL divergence for two univariate Gaussian distributions, N (µ 1 , σ 1 ) and N (µ 2 , σ 2 ), can be computed as: D KL (N (µ 1 , σ 1 ) ∥ N (µ 2 , σ 2 )) = log σ 2 σ 1 + σ 2 1 + (µ 1 -µ 2 ) 2 2σ 2 2 - 1 2 . For the n-dimensional diagonal Gaussians, the KL divergence is calculated by summing the KL divergence of the univariate Gaussians for each dimension. One notable property of KL divergence is that it can locally approximate the squared Fisher-Rao distance.

3. METHOD

In this section, we first derive a reparameterization of the Gaussian distribution to form a statistical manifold with an arbitrary curvature. We then propose a Pseudo Gaussian manifold (PGM) normal distribution. Finally, we suggest a new variant of the variational auto-encoder, whose latent space is defined over the statistical manifold.

3.1. MANIFOLD WITH ARBITRARY CURVATURE

As shown in Section 2.3, the univariate Gaussian distributions form a statistical manifold with a negative half curvature, whose manifold is the same as the manifold of Poincaré half-plane. Previous studies on hyperbolic spaces emphasize the importance of having an arbitrary curvature (Skopek et al., 2019; Mathieu et al., 2019) . These works empirically show that the generalization performances of hyperbolic VAEs can be improved with varying curvatures. We show that the statistical manifold of univariate Gaussian can have an arbitrary curvature by reparameterizing the Gaussian distribution properly. Let N ( √ 2cµ, σ) be the reparameterized Gaussian distribution with additional parameter c > 0. The reparameterization leads to the FIM of σ -2 diag(1, c) showing that the curvature of the statistical manifold is -c. With the arbitrary curvature, we also verify that the KL divergence still approximates the Riemannian distance as: D KL N ( √ 2c(µ + dµ), σ + dσ) ∥ N ( √ 2cµ, σ) 2c = 1 2 dµ dσ T 1 σ 2 0 0 1 cσ 2 dµ dσ + O (dσ) 3 , ) where the first term is the squared Riemannian norm of the vector (dµ, dσ) in the manifold, which approximates the squared Fisher-Rao distance between (µ, σ) and (µ + dµ, σ + dσ). The derivation of FIM and KL divergence with the reparameterized normal is described in Appendix A.1 and A.2. We call the statistical manifold with Gaussian distributions having a curvature of -c as the Gaussian manifold G c .

3.2. PSEUDO GAUSSIAN MANIFOLD NORMAL DISTRIBUTION

We propose a pseudo Gaussian manifold normal distribution (PGM normal) defining a distribution over the Gaussian manifold. Let (µ, σ) ∈ G be a point in the Gaussian manifold. Inspired by the Riemannian normal, we define the probability density function of PGM normal distribution with KL-divergence as: K c (µ, σ; α, β, γ 2 ) = (σ/β) 3 Z(c, γ 2 ) exp - D KL (N ( √ 2c • µ, σ) ∥ N ( √ 2c • α, β)) ( √ 2c • γ) 2 , ( ) where α, β, and γ 2 are the parameters of the distribution, and -c is the curvature. As shown in the previous section, the KL divergence approximate the Fisher-Rao distance from Gaussian distribution N ( √ 2c • α, β) to N ( √ 2c • µ, σ). Therefore, the PGM normal accounts for the geometric structure of the Guassian distributions.

Reconstructed image

Enc Dec < l a t e x i t s h a 1 _ b a s e 6 4 = " R 2 i b y x 8 w H X / q X Z I 9 V C f C d u r G + 2 Y = " > A A A B 7 3 i c b V D L S g N B E O y N r x h f U Y 9 e F o P g K e y G o B 4 D X v Q W w T w g W U P v Z D Y Z M j O 7 z s w K I e Q n v H h Q x K u / 4 8 2 / c Z L s Q R M L G o q q b r q 7 w o Q z b T z v 2 8 m t r W 9 s b u W 3 C z u 7 e / s H x c O j p o 5 T R W i D x D x W 7 R A 1 5 U z S h m G G 0 3 a i K I q Q 0 1 Y 4 u p 7 5 r S e q N I v l v R k n N B A 4 k C x i B I 2 V 2 t 0 B C o E P l V 6 x 5 J W 9 O d x V 4 m e k B B n q v e J X t x + T V F B p C E e t O 7 6 X m G C C y j D C 6 b T Q T T V N k I x w Q D u W S h R U B 5 P 5 v V P 3 z C p 9 N 4 q V L W n c u f p 7 Y o J C 6 7 E I b a d A M 9 T L 3 k z 8 z + u k J r o K J k w m q a G S L B Z F K X d N 7 M 6 e d / t M U W L 4 2 B I k i t l b X T J E h c T Y i A o 2 B H / 5 5 V X S r J T 9 i 3 L 1 r l q q 3 W Z x 5 O E E T u E c f L i E G t x A H R p A g M M z v M K b 8 + i 8 O O / O x 6 I 1 5 2 Q z x / A H z u c P t T 2 P y Q = = < / l a t e x i t > 2 < l a t e x i t s h a 1 _ b a s e 6 4 = " L o w 3 T L n Y S 2 o L m 1 g 6 B n o h z 7 / h R E A = " > A A A B 8 n i c b V B N S 8 N A E N 3 U r 1 q / q h 6 9 L B b B U 0 m k q M e C F 7 1 V s B + Q h r L Z b N q l m 9 2 w O y m U 0 J / h x Y M i X v 0 1 3 v w 3 b t s c t P X B w O O 9 G W b m h a n g B l z 3 2 y l t b G 5 t 7 5 R 3 K 3 v 7 B 4 d H 1 e O T j l G Z p q x N l V C 6 F x L D B J e s D R w E 6 6 W a k S Q U r B u O 7 + Z + d 8 K 0 4 U o + w T R l Q U K G k s e c E r C S 3 w c u I p Z P Z g N v U K 2 5 d X c B v E 6 8 g t R Q g d a g + t W P F M 0 S J o E K Y o z v u S k E O d H A q W C z S j 8 z L C V 0 T I b M t 1 S S h J k g X 5 w 8 w x d W i X C s t C 0 J e K H + n s h J Y s w 0 C W 1 n Q m B k V r 2 5 + J / n Z x D f B j m X a Q Z M 0 u W i O B M Y F J 7 / j y O u G Q U x t Y R Q z e 2 t m I 6 I J h R s S h U b g r f 6 8 j r p X N W 9 6 3 r j s V F r P h R x l N E Z O k e X y E M 3 q I n u U Q u 1 E U U K P a N X 9 O a A 8 + K 8 O x / L 1 p J T z J y i P 3 A + f w B x r J F k < / l a t e x i t > ṽ1 < l a t e x i t s h a 1 _ b a s e 6 4 = " v S G j L x L 1 H O 4 Z u R q n K K 7 2 q 3 2 6 B x 0 = " > A A A B 8 n i c b V B N S 8 N A E N 3 U r 1 q / q h 6 9 L B b B U 0 l K U Y 8 F L 3 q r Y G s h D W W z 2 b R L N 7 t h d 1 I o o T / D i w d F v P p r v P l v 3 L Y 5 a O u D g c d 7 M 8 z M C 1 P B D b j u t 1 P a 2 N z a 3 i n v V v b 2 D w 6 P q s c n X a M y T V m H K q F 0 L y S G C S 5 Z B z g I 1 k s 1 I 0 k o 2 F M 4 v p 3 7 T x O m D V f y E a Y p C x I y l D z m l I C V / D 5 w E b F 8 M h s 0 B t W a W 3 c X w O v E K 0 g N F W g P q l / 9 S N E s Y R K o I M b 4 n p t C k B M N n A o 2 q / Q z w 1 J C x 2 T I f E s l S Z g J 8 s X J M 3 x h l Q j H S t u S g B f q 7 4 m c J M Z M k 9 B 2 J g R G Z t W b i / 9 5 f g b x T Z B z m W b A J F 0 u i j O B Q e H 5 / z j i m l E Q U 0 s I 1 d z e i u m I a E L B p l S x I X i r L 6 + T b q P u X d W b D 8 1 a 6 7 6 I o 4 z O 0 D m 6 R B 6 6 R i 1 0 h 9 q o g y h S 6 B m 9 o j c H n B f n 3 f l Y t p a c Y u Y U / Y H z + Q N z M J F l < / l a t e x i t > ṽ2 < l a t e x i t s h a 1 _ b a s e 6 4 = " h o c 0 9 3 n x p 2 1 w 9 E F C i Z m e N 5 m J J 1 s = " > A A A B 6 n i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 L B b B U 0 l E 1 G P B i 9 4 q 2 g 9 o Q 9 l s J + 3 S z S b s b g o l 9 C d 4 8 a C I V 3 + R N / + N 2 z Y H b X 0 w 8 H h v h p l 5 Q S K 4 N q 7 7 7 R T W 1 j c 2 t 4 r b p Z 3 d v f 2 D 8 u F R U 8 e p Y t h g s Y h V O 6 A a B Z f Y M N w I b C c K a R Q I b A W j 2 5 n f G q P S P J Z P Z p K g H 9 G B 5 C F n 1 F j p c d z z e u W K W 3 X n I K v E y 0 k F c t R 7 5 a 9 u P 2 Z p h N I w Q b X u e G 5 i / I w q w 5 n A a a m b a k w o G 9 E B d i y V N E L t Z / N T p + T M K n 0 S x s q W N G S u / p 7 I a K T 1 J A p s Z 0 T N U C 9 7 M / E / r 5 O a 8 M b P u E x S g 5 I t F o W p I C Y m s 7 9 J n y t k R k w s o U x x e y t h Q 6 o o M z a d k g 3 B W 3 5 5 l T Q v q t 5 V 9 f L h s l K 7 z + M o w g m c w j l 4 c A 0 1 u I M 6 N I D B A J 7 h F d 4 c 4 b w 4 7 8 7 H o r X g 5 D P H 8 A f O 5 w 8 N r Y 2 u < / l a t e x i t > v 1 < l a t e x i t s h a 1 _ b a s e 6 4 = " B V 2 m D u L v B O U K G s o f L f 9 t 0 i h l P 3 g = " > A A A B 6 n i c b V D L S g N B E O y N r x h f U Y 9 e B o P g K e y G o B 4 D X v Q W 0 T w g W c L s p D c Z M j u 7 z M w G Q s g n e P G g i F e / y J t / 4 y T Z g y Y W N B R V 3 X R 3 B Y n g 2 r j u t 5 P b 2 N z a 3 s n v F v b 2 D w 6 P i s c n T R 2 n i m G D x S J W 7 Y B q F F x i w 3 A j s J 0 o p F E g s B W M b u d + a 4 x K 8 1 g + m U m C f k Q H k o e c U W O l x 3 G v 0 i u W 3 L K 7 A F k n X k Z K k K H e K 3 5 1 + z F L I 5 S G C a p 1 x 3 M T 4 0 + p M p w J n B W 6 q c a E s h E d Y M d S S S P U / n R x 6 o x c W K V P w l j Z k o Y s 1 N 8 T U x p p P Y k C 2 x l R M 9 S r 3 l z 8 z + u k J r z x p 1 w m q U H J l o v C V B A T k / n f p M 8 V M i M m l l C m u L 2 V s C F V l B m b T s G G 4 K 2 + v E 6 a l b J 3 V a 4 + V E u 1 + y y O P J z B O V y C B 9 d Q g z u o Q w M Y D O A Z X u H N E c 6 L 8 + 5 8 L F t z T j Z z C n / g f P 4 A D z G N r w = = < / l a t e x i t > v 2 < l a t e x i t s h a 1 _ b a s e 6 4 = " + 0 U 6 n Q N Z H 7 G k 8 A P q y + C v g L Z 1 o m E = " > A A A B 7 X i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 L B b B U 0 m k q M e C F 7 1 V s B / Q h j L Z b t q 1 m 0 3 Y 3 Q g l 9 D 9 4 8 a C I V / + P N / + N 2 z Y H b X 0 w 8 H h v h p l 5 Q S K 4 N q 7 7 7 R T W 1 j c 2 t 4 r b p Z 3 d v f 2 D 8 u F R S 8 e p o q x J Y x G r T o C a C S 5 Z 0 3 A j W C d R D K N A s H Y w v p n 5 7 S e m N I / l g 5 k k z I 9 w K H n I K R o r t X o o k h H 2 y x W 3 6 s 5 B V o m X k w r k a P T L X 7 1 B T N O I S U M F a t 3 1 3 M T 4 G S r D q W D T U i / V L E E 6 x i H r W i o x Y t r P 5 t d O y Z l V B i S M l S 1 p y F z 9 P Z F h p P U k C m x n h G a k l 7 2 Z + J / X T U 1 4 7 W d c J q l h k i 4 W h a k g J i a z 1 8 m A K 0 a N m F i C V H F 7 K 6 E j V E i N D a h k Q / C W X 1 4 l r Y u q d 1 m t 3 d c q 9 b s 8 j i K c w C m c g w d X U I d b a E A T K D z C M 7 z C m x M 7 L 8 6 7 8 7 F o L T j 5 z D H 8 g f P 5 A 5 C K j y g = < / l a t e x i t > ↵ < l a t e x i t s h a 1 _ b a s e 6 4 = " K N F b r K B x w 1 G o B n Y X n y 4 V E Y l x z u c = " > A A A B 7 H i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 L B b B U 0 l E 1 G P B i 9 4 q m L b Q h r L Z b t q l m 0 3 Y n Q g l 9 D d 4 8 a C I V 3 + Q N / + N 2 z Y H b X 0 w 8 H h v h p l 5 Y S q F Q d f 9 d k p r 6 x u b W + X t y s 7 u 3 v 5 B 9 f C o Z Z J M M + 6 z R C a 6 E 1 L D p V D c R 4 G S d 1 L N a R x K 3 g 7 H t z O / / c S 1 E Y l 6 x E n K g 5 g O l Y g E o 2 g l v x d y p P 1 q z a 2 7 c 5 B V 4 h W k B g W a / e p X b 5 C w L O Y K m a T G d D 0 3 x S C n G g W T f F r p Z Y a n l I 3 p k H c t V T T m J s j n x 0 7 J m V U G J E q 0 L Y V k r v 6 e y G l s z C Q O b W d M c W S W v Z n 4 n 9 f N M L o J c q H S D L l i i 0 V R J g k m Z P Y 5 G Q j N G c q J J Z R p Y W 8 l b E Q 1 Z W j z q d g Q v O W X V 0 n r o u 5 d 1 S 8 f L m u N + y K O M p z A K Z y D B 9 f Q g D t o g g 8 M B D z D K 7 w 5 y n l x 3 p 2 P R W v J K W a O 4 Q + c z x / I 3 4 6 0 < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " L J M j a k W 3 r z y D 2 R 2 7 A 7 T m q 2 / N 7 b c = " > A A A B 6 n i c b V B N S w M x E J 3 U r 1 q / q h 6 9 B I v g q e x K U Y 8 F L 3 q r a D + g X U o 2 z b a h S X Z J s k J Z + h O 8 e F D E q 7 / I m / / G t N 2 D t j 4 Y e L w 3 w 8 y  8 M B H c W M / 7 R o W 1 9 Y 3 N r e J 2 a W d 3 b / + g f H j U M n G q K W v S W M S 6 E x L D B F e s a b k V r J N o R m Q o W D s c 3 8 z 8 9 h P T h s f q 0 U 4 S F k g y V D z i l F g n P f R k 2 i 9 X v K o 3 B 1 4 l f k 4 q k K P R L 3 / 1 B j F N J V O W C m J M 1 / c S G 2 R E W 0 4 F m 5 Z 6 q W E J o W M y Z F 1 H F Z H M B N n 8 1 C k + c 8 o A R 7 F 2 p S y e q 7 8 n M i K N m c j Q d U p i R 2 b Z m 4 n / e d 3 U R t d B x l W S W q b o Y l G U C m x j P P s b D 7 h m 1 I q J I 4 R q 7 m 7 F d E Q 0 o d a l U 3 I h + M s v r 5 L W R d W / r N b u a 5 X 6 X R 5 H E U 7 g F M 7 B h y u o w y 0 0 o A k U h v A M r / C G B H p B 7 + h j 0 V p A + c w x / A H 6 / A F i Z 4 3 m < / l a t e x i t > µ < l a t e x i t s h a 1 _ b a s e 6 4 = " g j h m o J T m / m n k H Y 1 e N T T T P W F q C R I = " > A A A B 7 X i c b V D L S g N B E O y N r x h f U Y 9 e B o P g K e x K U I 8 B L 3 q L Y B 6 Q L G F 2 M p u M m c c y M y u E J f / g x Y M i X v 0 f b / 6 N k 2 Q P m l j Q U F R 1 0 9 0 V J Z w Z 6 / v f X m F t f W N z q 7 h d 2 t n d 2 z 8 o H x 6 1 j E o 1 o U 2 i u N K d C B v K m a R N y y y n n U R T L C J O 2 9 H 4 Z u a 3 n 6 g 2 T M k H O 0 l o K P B Q s p g R b J 3 U 6 h k 2 F L h f r v h V f w 6 0 S o K c V C B H o 1 / + 6 g 0 U S Q W V l n B s T D f w E x t m W F t G O J 2 W e q m h C S Z j P K R d R y U W 1 I T Z / N o p O n P K A M V K u 5 I W z d X f E x k W x k x E 5 D o F t i O z 7 M 3 E / 7 x u a u P r M G M y S S 2 V Z L E o T j m y C s 1 e R w O m K b F 8 4 g g m m r l b E R l h j Y l 1 A Z V c C M H y y 6 u k d V E N L q u 1 + 1 q l f p f H U Y Q T O I V z C O A K 6 n A L D W g C g U d 4 h l d 4 8 5 T 3 4 r 1 7 H 4 v W g p f P H M M f e J 8 / o W i P M w = = < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " L o w 3 T L n Y S 2 o L m 1 g 6 B n o h z 7 / h R E A = " > A A A B 8 n i c b V B N S 8 N A E N 3 U r 1 q / q h 6 9 L B b B U 0 m k q M e C F 7 1 V s B + Q h r L Z b N q l m 9 2 w O y m U 0 J / h x Y M i X v 0 1 3 v w 3 b t s c t P X B w O O 9 G W b m h a n g B l z 3 2 y l t b G 5 t 7 5 R 3 K 3 v 7 B 4 d H 1 e O T j l G Z p q x N l V C 6 F x L D B J e s D R w E 6 6 W a k S Q U r B u O 7 + Z + d 8 K 0 4 U o + w T R l Q U K G k s e c E r C S 3 w c u I p Z P Z g N v U K 2 5 d X c B v E 6 8 g t R Q g d a g + t W P F M 0 S J o E K Y o z v u S k E O d H A q W C z S j 8 z L C V 0 T I b M t 1 S S h J k g X 5 w 8 w x d W i X C s t C 0 J e K H + n s h J Y s w 0 C W 1 n Q m B k V r 2 5 + J / n Z x D f B j m X a Q Z M 0 u W i O B M Y F J 7 / j y O u G Q U x t Y R Q z e 2 t m I 6 I J h R s S h U b g r f 6 8 j r p X N W 9 6 3 r j s V F r P h R x l N E Z O k e X y E M 3 q I n u U Q u 1 E U U K P a N X 9 O a A 8 + K 8 O x / L 1 p J T z J y i P 3 A + f w B x r J F k < / l a t e x i t > ṽ1 < l a t e x i t s h a 1 _ b a s e 6 4 = " v S G j L x L 1 H O 4 Z u R q n K K 7 2 q 3 2 6 B x 0 = " > A A A B 8 n i c b V B N S 8 N A E N 3 U r 1 q / q h 6 9 L B b B U 0 l K U Y 8 F L 3 q r Y G s h D W W z 2 b R L N 7 t h d 1 I o o T / D i w d F v P p r v P l v 3 L Y 5 a O u D g c d 7 M 8 z M C 1 P B D b j u t 1 P a 2 N z a 3 i n v V v b 2 D w 6 P q s c n X a M y T V m H K q F 0 L y S G C S 5 Z B z g I 1 k s 1 I 0 k o 2 F M 4 v p 3 7 T x O m D V f y E a Y p C x I y l D z m l I C V / D 5 w E b F 8 M h s 0 B t W a W 3 c X w O v E K 0 g N F W g P q l / 9 S N E s Y R K o I M b 4 n p t C k B M N n A o 2 q / Q z w 1 J C x 2 T I f E s l S Z g J 8 s X J M 3 x h l Q j H S t u S g B f q 7 4 m c J M Z M k 9 B 2 J g R G Z t W b i / 9 5 f g b x T Z B z m W b A J F 0 u i j O B Q e H 5 / z j i m l E Q U 0 s I 1 d z e i u m I a E L B p l S x I X i r L 6 + T b q P u X d W b D 8 1 a 6 7 6 I o 4 z O 0 D m 6 R B 6 6 R i 1 0 h 9 q o g y h S 6 B m 9 o j c H n B f n 3 f l Y t p a c Y u Y U / Y H z + Q N z M J F l < / l a t e x i t > ṽ2 < l a t e x i t s h a 1 _ b a s e 6 4 = " M + w b d z 1 I w E g a U O l + d U y q Y 2 T C 9 9 U = " > A A A B 6 n i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 L B b B U 0 l E 1 G P B i 9 4 q 2 g 9 o Q 9 l s J + 3 S z S b s b g o l 9 C d 4 8 a C I V 3 + R N / + N 2 z Y H b X 0 w 8 H h v h p l 5 Q S K 4 N q 7 7 7 R T W 1 j c 2 t 4 r b p Z 3 d v f 2 D 8 u F R U 8 e p Y t h g s Y h V O 6 A a B Z f Y M N w I b C c K a R Q I b A W j 2 5 n f G q P S P J Z P Z p K g H 9 G B 5 C F n 1 F j p c d x z e + W K W 3 X n I K v E y 0 k F c t R q U H J l o v C V B A T k / n f p M 8 V M i M m l l C m u L 2 V s C F V l B m b T s G G 4 K 2 + v E 6 a l b J 3 V a 4 + V E u 1 + y y O P J z B O V y C B 9 d Q g z u o Q w M Y D O A Z X u V U l K U Z c F N w o u K t g H N L F M p t N 2 6 G Q S Z i Z C C f V X 3 L h Q x K 0 f 4 s 6 / c d J m o a 0 H B g 7 n 3 M s 9 c 4 K Y M 6 U d 5 9 s q r K 1 v b G 4 V t 0 s 7 u 3 v 7 B / b h U V t F i S S 0 R S I e y W 6 A F e V M 0 J Z m m t N u L C k O A 0 4 7 w e Q q 8 z u P V C o W i X s 9 j a k f 4 p F g Q 0 a w N l L f L n t M I C / E e k w w T 2 9 n f f J Q 6 9 s V p + r M g V a J m 5 M K 5 G j 2 7 S 9 v E J E k p E I T j p X q u U 6 s / R R L z Q i n s 5 K X K B p j M s E j 2 j N U 4 J A q P 5 2 H n 6 F T o w z Q M J L m C Y 3 m 6 u + N F I d K T c P A T G Y x 1 b K X i f 9 5 v U Q P L / 2 U i T j R V J D F o W H C k Y 5 Q 1 g Q a M E m J 5 l N D M J H M Z E V k j C U m 2 v R V M i W 4 y 1 9 e J e 1 a 1 T 2 v 1 u / q l c Z N X k c R j u E E z s C F C 2 j A N T S h B Q S m + Y S F J N B Z k f G q Q c 6 R j l P a A + k 5 R o P j E E E 8 l M V k R G W G K i T V s V U 4 K 7 + O V l 0 j 6 t u + f 1 s 7 u z W u O 2 q K M M h 3 A E J + D C B T T g B p r Q A g K P 8 A y v 8 G Y 9 W S / W u / U x H y 1 Z x c 4 B / I H 1 + Q M Z S Z P v < / l a t e x i t > 2 G c < l a t e x i t s h a 1 _ b a s e 6 4 = " R G m O 4 S 6 J 5 0 O i 6 g P v N 6 1 h N E 5 t q 9 c = " > A A A B / H i c b V D L S s N A F L 2 p r 1 p f 0 S 7 d D B b B V U l K U Z c F N w o u K t g H N L F M p t N 2 6 G Q S Z i Z C C f V X 3 L h Q x K 0 f 4 s 6 / c d J m o a 0 H B g 7 n 3 M s 9 c 4 K Y M 6 U d 5 9 s q r K 1 v b G 4 V t 0 s 7 u 3 v 7 B / b h U V t F i S S 0 R S I e y W 6 A F e V M 0 J Z m m t N u L C k O A 0 4 7 w e Q q 8 z u P V C o W i X s 9 j a k f 4 p F g Q 0 a w N l L f L n t M I C / E e k w w T 2 9 n f f J Q 6 9 s V p + r M g V a J m 5 M K 5 G j 2 7 S 9 v E J E k p E I T j p X q u U 6 s / R R L z Q i n s 5 K X K B p j M s E j 2 j N U 4 J A q P 5 2 H n 6 The factorization of the probability density function in Equation 5 multiplied with the square root of the determinant of the metric tensor shows the advantages of the PGM normal, which can be written as: F T o w z Q M J L m C Y 3 m 6 u + N F I d K T c P A T G Y x 1 b K X i f 9 5 v U Q P L / 2 U i T j R V J D F o W H C k Y 5 Q 1 g Q a M E m J 5 l N D M J H M Z E V k j C U m 2 v R V M i W 4 y 1 9 e J e 1 a 1 T 2 v 1 u / q l c Z N X k c R j u E E z s C F C 2 j A N T S h B Q S m K c (µ, σ; α, β, γ 2 ) • det(g) = N (µ; α, β 2 γ 2 ) • Gamma σ 2 ; 1 4cγ 2 + 1, 1 4cβ 2 γ 2 , ( ) where Gamma(z; a, b) = b a Γ(a) z a-1 exp (-bz) and g is the Fisher information metric of the Gaussian manifold. Note that the factorization has the same form as the well-known conjugate prior to the Gaussian distribution. In that sense, the PGM normal incorporates the geometric structure into the prior distribution explicitly. Thanks to the properties of Gaussian and Gamma distribution, the PGM normal is easy to sample and has a closed-form KL divergence. The detailed derivation is available in Appendix B.

3.3. GAUSSIAN MANIFOLD VAE

We propose a Gaussian manifold VAE (GM-VAE) whose latent space is defined over the Gaussian manifold. To be specific, we place a PGM normal prior over the latent space of the VAE and add a proper geometric transformation at the last layer of the encoder and the first layer of the decoder for the conversion between the Euclidean space and Gaussian manifold. The evidence lower bound (ELBO) of the GM-VAE can be formalized with the Gaussian-manifold {(µ, Σ) | µ ∈ R n , Σ ∈ R n >0 } as: E q ϕ (µ,Σ|x)• √ det(g) [log p θ (x | µ, Σ)] -D KL q ϕ (µ, Σ | x) • det(g) ∥ p(µ, Σ) • det(g) , where p θ (x | µ, Σ) is the decoder network, q ϕ (µ, Σ | x) is the encoder network and p(µ, Σ) is the prior. The variational distribution is set to q ϕ (µ, Σ | x) = K(α ϕ (x), β ϕ (x), γ 2 ϕ (x)), where α θ (x) ∈ R n and β ϕ (x), γ 2 ϕ (x) ∈ R n >0 , and the prior is set to p(µ, Σ) = K(0, I, I) in our experiments. The training for the parameters of GM-VAE (θ and ϕ) is to maximize the ELBO. Algorithm 1 Encoder Input Input data x, Encoding layers Enc(•) Output Parameter (α, β) ∈ G c , γ 2 ∈ R >0 1: ṽ, γ 2 = Enc(x) ▷ ṽ ∈ E 2: v = exp c 0 L (f (ṽ)) ▷ v ∈ L 2 c 3: (α, β) = T c (v) ▷ (α, β) ∈ G c 4: return (α, β), γ 2 Algorithm 2 Decoder Input Sample (µ, σ) ∼ K(•), Decoding layers Dec(•) Output Reconstruction x ′ 1: v = T -1 c (µ, σ) ▷ (µ, σ) ∈ G c , v ∈ L 2 c 2: ṽ = log c 0 L (v) ▷ ṽ ∈ E 3: x ′ = Dec(v) 4: return x ′ Geometric transformations on GM-VAE Mathieu et al. (2019) propose a transformation from a Euclidean space to the Poincaré disk to define a latent space over the Poincaré disk. We propose a novel transformation in VAE from a Euclidean space to the Gaussian manifold and vice versa. For a numerically stable transformation between two spaces, we adopt operations defined on the Lorentz model, which is isometric to the half-plane manifold (Nickel & Kiela, 2018) . The isometry T c : L 2 c → G c between the two-dimensional Lorentz model with curvature -c and the Gaussian manifold with curvature -c can be defined as: T c ((t, x, y)) = -y √ c(t -x) , 1 √ c(t -x) , and the inverse is as: T -1 c ((x, y)) = 1 + cx 2 + y 2 2 √ cy , -1 + cx 2 + y 2 2 √ cy , - x y . In the encoder, we convert the output of the last layer, which is in the Euclidean space, to the Lorentz model using the exponential map at the origin and then convert it to the Gaussian manifold using T c . In the decoder, we convert the input of the first layer, which is in the Gaussian manifold, to the Lorentz model using the inverse of the transformation T -1 c and then convert it to the Euclidean space using the log map at the origin of the Lorentz model. Figure 1 illustrates the architecture of GM-VAE and the pseudo code for the encoder and decoder are shown in Algorithm 1 and Algorithm 2. Remark Unlike a typical VAE, where the latent space consists of samples from a Gaussian distribution, the latent space of GM-VAE consists of a set of Gaussian distributions. With this aspect, GM-VAE can be considered as a hierarchical VAE with an additional prior over the Gaussian prior. However, instead of sampling another latent variable from the latent distribution, we directly transform the latent distribution itself to x in the decoder network via transformation from hyperbolic to Euclidean space. From this perspective, GM-VAE can also be considered as a variant of Poincaré VAE (Mathieu et al., 2019) , whose latent space can be interpreted via the Gaussian manifold.

4. RELATED WORK

Information geometry on VAE Focusing on the virtue of bridging probability theory and differential geometry, the adaptation of information geometry to the deep learning framework has been investigated in various aspects (Karakida et al., 2019; Bay & Sengupta, 2017; Gomes et al., 2022) . Having said that, Han et al. (2020) show that the training process of VAE can be seen as minimizing the distance between the two statistical manifolds: manifolds with the parameters of the decoder and the encoder. Not only can the parameters but the outputs from the VAE decoder be modeled as probability distributions. Arvanitidis et al. (2021) suggest a method of using the pull-back metric defined with arbitrary decoders on latent space. Our work focuses more on the statistical manifolds lying on the outputs of the encoder with the benefits from the information geometry.

VAE with Riemannian manifold latent space

The latent space of VAE reflects the geometrical property of the representations of the data. The efficacy of setting the latent space to be hyperbolic space (Mathieu et al., 2019; Nagano et al., 2019; Cho et al., 2022) or spherical space (Xu & Durrett, 2018; Davidson et al., 2018) has been verified for various datasets. Skopek et al. (2019) further extends the approach to enable the latent space to be the product of Riemannian manifolds with Table 1 : A comparison of the PGM normal (ours) with the commonly-used distributions on the hyperbolic space: Hyperbolic wrapped normal and Poincaré normal. Our method enables the easy sampling and computation of closed-form KL, with the utilization of the information geometry.

Easy sampling Information geometry

Closed-form KL Hyperbolic wrapped normal ⃝ × × Poincaré normal △ ⃝ × PGM normal (ours) ⃝ ⃝ ⃝ different learnable curvatures. On top of these arts, we explore the method of setting the latent space to be a diagonal Gaussian manifold, which is isometric to the product of the hyperbolic space, providing a novel viewpoint on prior work with information geometry. Distributions on the hyperbolic space Defining a distribution in the hyperbolic space with easy sampling is challenging. Nagano et al. (2019) suggests hyperbolic wrapped normal distribution from the observation that the tangent space is Euclidean space. Leveraging operations defined on the tangent spaces, e.g., parallel transport, enables an easy sampling algorithm. Mathieu et al. (2019) propose a sampling method for the Riemannian normal defined on the Poincaré disk model using rejection sampling. This method rejects the pathological samples and enables accurate sampling from the distribution, but this demands a high amount of time complexity. These distributions are applied in many cases (Cho et al., 2022; Skopek et al., 2019; Mathieu & Nickel, 2020) but suffer from stability issues because of the absence of closed-form KL divergence. Our proposed distribution, however, not only share the common merits but also has overcome the stability problem with closed-form KL divergence. Table 1 summarizes the properties of each distribution.

5. EXPERIMENTS

In this section, we compare the performance of GM-VAE with the three baselines: Euclidean VAE, hyperbolic wrapped normal VAE (HWN VAE), and Poincaré VAE. The Euclidean VAE is the standard VAE with Euclidean latent space. The HWN VAE uses the product of two-dimensional Lorentz models as a latent space and uses the hyperbolic wrapped normal to model the prior and variational distributions. The Poincaré VAE uses the product of two-dimensional Poincaré disk models as a latent space and uses the Poincaré normal to model the prior and variational distributions. The Euclidean VAE, HWN VAE, and Poincaré VAE are denoted as E-VAE, L-VAE, P-VAE, respectively in the following results.

5.1. DENSITY ESTIMATION

We first conduct a density estimation task to check the generalization ability of different models. We use three datasets: binarized-MNIST (Deng, 2012 ), binarized-Omniglot (Lake et al., 2015) , and the images from Atari 2600 Breakout with binarization (binarized-Breakout) (Nagano et al., 2019) . The binarized-Breakout are collected from plays with a pre-trained Deep Q-Network (Mnih et al., 2015) . The size of images are 28 × 28, 28 × 28, and 80 × 80 for binarized-MNIST, binarized-Omniglot, and binarized-Breakout, respectively. The value of the threshold for binarization is set to 0.5, 0.5, and 0.1 for binarized-MNIST, binarized-Omniglot, and binarized-Breakout, respectively; the threshold Table 2 : Density estimation on real-world datasets. d denotes the latent dimension. We report the negative test log-likelihoods of average 10 runs for binarized-MNIST and binarized-Omniglot, and an average 5 runs for binarized-Breakout with the 95% confidence interval. N/A in the log-likelihood indicates that the results are not available due to the failure of all runs, and N/A in the standard deviation indicates the results are not available due to failures of some runs. The best results are bolded.  d E-VAE L-VAE P-VAE GM-VAE (c = 1) GM-VAE (c = 1/2) GM-VAE (c = 3/

Omniglot

The results are reported at Table 2 . In binarized-MNIST and binarized-Omniglot, the models learned on the product hyperbolic space and the Gaussian manifold mostly outperform the Euclidean VAE. In binarized-Breakout, the GM-VAE with curvature values 1 and 3/2 outperform the baselines while the Poincaré VAE fails to run in all the settings and the HWN VAE fails to run in some of the settings due to numerical issues, which we further investigate in details.

Numerical stability

We conduct an analysis of the numerical stability of the PGM normal distribution compared to the HWN and Poincaré normal. During the density estimation experiment, the HWN VAE and Poincaré VAE are often shown to be numerically unstable and fail to run in binarized-Breakout. Similar observations have been reported in several previous works (Mathieu et al., 2019; Chen et al., 2021; Skopek et al., 2019) . The hyperbolic wrapped normal uses the exponential map when transforming the output of the encoder to the Lorentz model and during the sampling, as described in Equation 3. The overlapped Lorentz model exponential map often causes an overflow. In the training of Poincaré VAE, the KL divergence between the variational distribution and the prior distribution needs to be approximated by the log-probability of the samples due to the absence of closed-form KL divergence in Poincaré normal. To compute the log probability of a given sample, the distance between two Poincaré disk model points, which are the sample and the Fréchet mean of the distribution, needs to be calculated, where the denominator term is numerically unstable. The PGM normal, on the other hand, is free from instability with the help of stability when using log-covariance. Please check the detailed arguments with equations in Appendix E.

Geometric transformations

We conduct an ablation study on the geometric transformations of GM-VAE. We compare the setting of GM-VAEs incorporating the geometric transformations to the setting of GM-VAEs using only exponential function to send the output of the to the Gaussian manifold but no additional geometric transformation at the first layer of the decoder. The results are in Table 3 . We can see that the geometric transformations enhance the performance of the GM-VAE, except for two results but with similar performance.

5.2. LATENT SPACE ANALYSIS

To check whether the latent representation coincides with the known labels, we first plot the latent spaces of binarized-MNIST via t-SNE (van der Maaten & Hinton, 2008) visualization, with representations from all dimensions. The visualization shown in Figure 2a presents that the label semantics are well clustered in the learned latent space. We also analyze the changes in the reconstructed images along the geodesic of the latent space. Figure 2b shows the reconstructed images from a geodesic interpolation between two latent representations, with a fixed value of α. The interpolation of the latent space is performed within one dimension while fixing the value of representations in other dimensions. As β increases, the reconstructed images become ambiguous, matching our intuition on the role of variance. Reconstruction images with a fixed value of β is available at Appendix F. Figure 3 shows the analysis with binarized-Breakout. The images in the binarized-Breakout possess a hierarchy as the cumulative rewards and the amount of the breakout bricks, or the portion of blank space in the image, are highly correlated (Nagano et al., 2019) . We observe that there is a high correlation between β and the hierarchy. For example, as shown in Figure 3a , increasing β reconstructs a more general image in the hierarchy. The highest Pearson correlation between the β values and the negative cumulative reward is 0.655. Again, as β represents the variance, we conjecture the increasing variance induces a more general image in the dataset.

6. CONCLUSION

In this work, we propose a novel method of representation learning with GM-VAE, utilizing the Gaussian manifold for the latent space. With the newly-proposed PGM normal distribution defined over Gaussian manifold, which shows better stability and ease of sampling compared to the commonlyused ones, we verify the efficacy of our method on several real-world datasets. Our analysis of latent space and representations exhibits that GM-VAE is beneficial for capturing both the geometrical structures and probabilistic semantics. We believe that the connection between the statistical manifold and hyperbolic spaces provides a new insight to the research community and hope to see more interesting connections and analyses in the future. A GAUSSIAN MANIFOLD

A.1 CURVATURE OF THE GAUSSIAN MANIFOLD

We construct a Riemannian manifold {(µ, σ) | µ ∈ R, σ ∈ R >0 } with the metric tensor diag(1/σ 2 , 1/(cσ 2 )), which we will name Gaussian manifold. We need to show the value of the curvature. First, we need to compute the Christoeffel symbols of the Gaussian manifold defined as: Γ k ij = 1 2 g kl ∂g jl ∂g i + ∂g il ∂g j - ∂g ij ∂g l , where g ij is the (i, j) element of the metric tensor and g ij is the (i, j) element of the inverse of the metric tensor. The Christoeffel symbols of the Gaussian manifold are: Γ 1 ij = 0 -1 σ -1 σ 0 (9) Γ 2 ij = c σ 0 0 -1 σ . Then, the sectional curvature of the space κ g is computed as: κ g = Rm(µ, σ, σ, µ) det g , ( ) where Rm is the Riemannian curvature which is computed as: Rm(µ, t, t, µ) = g 1m ∂Γ m 22 ∂µ - ∂Γ m 12 ∂σ + Γ p 22 Γ m 1p -Γ p 12 Γ m 2p . By putting the metric tensor and the Christoeffel symbols together, the curvature of the Gaussian manifold is computed as: κ g = Rm(µ, σ, σ, µ) det g = -1 σ 4 1 cσ 4 = -c. A.2 GAUSSIAN MANIFOLD WITH KL-DIVERGENCE Between two univariate Gaussian distributions N (µ 1 , σ 2 1 ) and N (µ 2 , σ 2 2 ), we can compute the KL divergence as: D KL (N (µ 1 , σ 1 ) ∥ N (µ 2 , σ 2 )) = 1 2 log σ 2 2 σ 2 1 + σ 2 1 + (µ 1 -µ 2 ) 2 σ 2 2 -1 . We extend the KL divergence for an arbitrary curvature of the Gaussian manifold as: G c KL ((µ 1 , σ 1 ), (µ 2 , σ 2 )) := D KL (N ( √ 2cµ 1 , σ 1 ) ∥ N ( √ 2cµ 2 , σ 2 )) 2c . Now, we show that the extended KL divergence still approximates the Riemannian distance of the manifold as: G c KL ((µ + dµ, σ + dσ), (µ, σ)) = 1 2 * 2c log σ 2 (σ + dσ) 2 + (σ + dσ) 2 + 2c(dµ) 2 σ 2 -1 (16) = 1 2 * 2c -2 log 1 + dσ σ + 2σdσ + (dσ) 2 σ 2 + 2c(dµ) 2 σ 2 (17) = 1 2 * 2c -2 dσ σ - (dσ) 2 2σ 2 + 2σdσ + (dσ) 2 σ 2 + 2c(dµ) 2 σ 2 + O((dσ) 3 ) (18) = 1 2 dµ dσ T 1 σ 0 0 1 cσ 2 dµ dσ + O((dσ) 3 ).

B PSEUDO GAUSSIAN MANIFOLD NORMAL DISTRIBUTION

In this section, we propose a pseudo-Gaussian-manifold normal distribution for the Gaussian manifold as: K c (µ, σ; α, β, γ 2 ) = σ 3 Z(γ) • β 3 exp - G c KL ((µ, σ), (α, β)) γ 2 . ( ) The given probability density function needs to satisfies the following condition: G K c (µ, σ; α, β, γ 2 ) |g(µ, σ)|d(µ, σ), where |(µ, σ)|d(µ, σ) is the probability measure over the Gaussian manifold induced with the Lebesgue measure d(µ, σ) and the Lebesgue-Radon-Nikodym theorem. We can find the normalizing factor Z(γ) as: Z(γ) = ∞ 0 ∞ -∞ σ β 3 • exp - G c KL ((µ, σ), (α, β)) γ 2 1 √ cσ 2 dµ dσ = 1 √ cβ 3 ∞ 0 ∞ -∞ σ • exp - G c KL ((µ, σ), (α, β)) γ 2 dµ dσ = 1 √ cβ 3 β -1 -2cγ 2 exp 1 4cγ 2 ∞ 0 σ • (σ 2 ) 1 4cγ 2 +1 -1 exp - σ 2 4cβ 2 γ 2 dσ × ∞ -∞ exp - (µ -α) 2 2β 2 γ 2 dµ = 1 2 √ cβ 3 √ 2πβ 3 γ exp 1 4cγ 2 Γ 1 4cγ 2 1 4cγ 2 -1 4cγ 2 × ∞ 0 Gamma σ 2 ; 1 4cγ 2 + 1, 1 4cβ 2 γ 2 dσ 2 ∞ -∞ N (µ; α, β 2 γ 2 ) dµ = √ 2π 2 √ c γ exp 1 4cγ 2 Γ 1 4cγ 2 1 4cγ 2 -1 4cγ 2 . Finally, the logarithm of the normalizing factor is computed as:  log Z(γ) = 1 2 log(2π) - 1 2 log c -log 2 + 1 2 log γ 2 + log Γ 1 4cγ 2 + 1 4cγ 2 (1 + log(4cγ 2 )). T Lc→Pc ((t, x, y)) = x √ ct + 1 , y √ ct + 1 , and the inverse is: T -1 Lc→Pc ((x, y)) 1 + (x 2 + y 2 )c √ c(1 -(x 2 + y 2 )c) , 2x 1 -(x 2 + y 2 )c , 2y 1 -(x 2 + y 2 )c . Isometry between the Gaussian manifold and the Poincaré disk model T Pc→Uc : P c → U c is computed as: T Pc→Gc (x, y) = -2y ( √ cx -1) 2 + y -2c , 1 -(x 2 + y 2 )c ( √ cx -1) 2 + y -2c , and the inverse is: T -1 Pc→Gc (x, y) √ cx 2 + (y 2 -1)/ √ c cx 2 + (y + 1) 2 , -2x cx 2 + (y + 1) 2 . ( ) Finally, the isometry between the Gaussian manifold and the Lorentz model T Lc→Gc can be derived by composing T Lc→Pc and T Pc→Gc as: T Lc→Gc (t, x, y) = -y √ c(t -x) , 1 √ c(t -x) , and the inverse is: T -1 Lc→Gc (x, y) = 1 + cx 2 + y 2 2 √ cy , -1 + cx 2 + y 2 2 √ cy , - x y . We then empirically show that the isometries preserve the distance between the points when transformed to other models. We randomly sampled 1,000 pairs of Gaussian manifold points with range of µ ∈ [-100, 100] and σ ∈ [0, 100]. We report the average difference in the distance for each pair before the transformation and after the transformation. We vary the curvature value from 0.25 to 2. For the Gaussian manifold, we use the following distance function for arbitrary curvature: d Uc ((x 1 , y 1 ), (x 2 , y 2 )) = 1 √ c log c(x 1 -x 2 ) 2 + (y 1 + y 2 ) 2 + c(x 1 -x 2 ) 2 + (y 1 -y 2 ) 2 c(x 1 -x 2 ) 2 + (y 1 + y 2 ) 2 -c(x 1 -x 2 ) 2 + (y 1 -y 2 ) 2 . Table 4 shows that the proposed isometries well-preserve the distances. The n-dimensional Lorentz model with curvature -c is L n c where the manifold is {x ∈ R n+1 | ⟨x, x⟩ Lc = -1 c }, where ⟨x, y⟩ Lc is the Lorentzian product computed as ⟨x, y⟩ Lc = -x 0 y 0 + n i=1 x i y i . The exponential map of the Lorentz model is defined as: exp c x (v) = cosh(α)x + sinh(α) v α , and the log map of the Lorentz model is defined as: log c x (y) = cosh -1 (β) β 2 -1 (y -βx), where α = c⟨v, v⟩ Lc and β = -c⟨x, y⟩ Lc .

D IMPLEMENTATION DETAILS

In this section, we introduce the implementation details for the density estimation experiment. For the encoder and the decoder of binarized-MNIST and binarized-Omniglot, we use a two-layer fully connected neural network, where the dimension of the hidden units is 200 with the hyperbolic tangent activation, following the setting used in the importance weighted VAE (). For binarized-Breakout, the encoder is a four-layer convolutional neural network with leaky ReLU activation followed by a fully connected layer. The decoder consists of a fully connected layer followed by three transposed convolutional layers with ReLU activation. For both the encoder and decoder used in the binarized-Breakout, we place the batch normalization layer at the end of all layers. Since the datasets are binarized, we use the Bernoulli distribution as the output of decoders for all tasks. For training, we use Adam optimizer (Kingma & Ba, 2014) with a constant learning rate 1e-3 and set the batch size to 100. We train the VAEs for 300 epochs on binarized-MNIST and binarized-Omniglot and 200 epochs on binarized-Breakout. The trained VAEs are evaluated by the log-likelihood on the test set for each task with importance-weighted sampling (Burda et al., 2015) . The models learned from binarized-MNIST and binarized-Omniglot use 500 samples for the importance weighted sampling and the models learned from binarized-Breakout use 50 samples for the importance weighted sampling.

E NUMERICAL STABILITY

We conduct an analysis of the numerical stability of the PGM normal distribution compared to the HWN and Poincaré normal. During the density estimation experiment, the HWN VAE and Poincaré VAE are often shown to be numerically unstable and fail to run in binarized-Breakout. Similar observations have been reported in several previous works (Mathieu et al., 2019; Chen et al., 2021; Skopek et al., 2019) . The hyperbolic wrapped normal uses the exponential map when transforming the output of the encoder to the Lorentz model and during the sampling, as described in Equation 3. The overlapped Lorentz model exponential map often causes an overflow due to the hyperbolic functions in the exponential map such as cosh and sinh. Note that the hyperbolic functions exponentially grow with the positive input value. In the training of Poincaré VAE, the KL divergence between the variational distribution and the prior distribution needs to be approximated by the log-probability of the samples due to the absence of closed-form KL divergence in Poincaré normal. To compute the log probability of a given sample, the distance between two Poincaré disk model points, the sample and the Fréchet mean of the distribution needs to be calculated, where the distance function of the Poincaré disk model is defined as: d c P (x, y) = 1 √ c cosh -1 1 + 2c ∥x -y∥ 2 (1 -c∥x∥ 2 )(1 -c∥y∥ 2 ) , where ∥ • ∥ is the Euclidean norm. The denominator term is unstable when ∥x∥ or ∥y∥ is close to value 1 √ c , which is occured when x and y are near the border of the Poincaré disk. PGM-normal, on the other hand, the KL divergence between an arbitrary PGM normal and K c (0, I, I), which is the only operation used during the training of GM-VAE, can be stably computed using the log-covariance. For example, the KL divergence between an univariate Gaussian distribution N (µ, σ) and the prior distribution mentioned above written as Equation 14 can be computed with log σ 2 . The KL divergence between two Gamma distributions, Gamma(a 1 , b 1 ) and Gamma(a 2 , b 2 ), written as: )a 1 , (29) where ψ is the digamma function, can be stablly computed using log b 1 when b 1 is large due to small β and γ in the factorization Equation 6.

F LATENT SPACE ANALYSIS

In this section, we provide additional visualizations of the learned representations of GM-VAE. Figure 4 shows the t-SNE visualization of the Euclidean VAE learned on binarized-MNIST. Figure 5 shows the interpolation results of the binarized-MNIST representations along the µ axis, by changing the value of α on three dimensions and fixing all the other values including β. Figure 7 and Figure 6 show the latent representations from all the dimensions with binarized-MNIST and binarized-Breakout respectively. 



7 5 a 9 u P 2 Z p h N I w Q b X u e G 5 i / I w q w 5 n A a a m b a k w o G 9 E B d i y V N E L t Z / N T p + T M K n 0 S x s q W N G S u / p 7 I a K T 1 J A p s Z 0 T N U C 9 7 M / E / r 5 O a 8 M b P u E x S g 5 I t F o W p I C Y m s 7 9 J n y t k R k w s o U x x e y t h Q 6 o o M z a d k g 3 B W 3 5 5 l T Q v q t 5 V 9 f L h s l K 7 z + M o w g m c w j l 4 c A 0 1 u I M 6 N I D B A J 7 h F d 4 c 4 b w 4 7 8 7 H o r X g 5 D P H 8 A f O 5 w 8 M K Y 2 t < / l a t e x i t > v 0 < l a t e x i t s h a 1 _ b a s e 6 4 = " h o c 0 9 3 n x p 2 1 w 9 E F C i Z m e N 5 m J J 1 s= " > A A A B 6 n i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 L B b B U 0 l E 1 G P B i 9 4 q 2 g 9 o Q 9 l s J + 3 S z S b s b g o l 9 C d 4 8 a C I V 3 + R N / + N 2 z Y H b X 0 w 8 H h v h p l 5 Q S K 4 N q 7 7 7 R T W 1 j c 2 t 4 r b p Z 3 d v f 2 D 8 u F R U 8 e p Y t h g s Y h V O 6 A a B Z f Y M N w I b C c K a R Q I b A W j 2 5 n f G q P S P J Z P Z p K g H 9 G B 5 C F n 1 F j p c d z z e u W K W 3 X n I K v E y 0 k F c t R7 5 a 9 u P 2 Z p h N I w Q b X u e G 5 i / I w q w 5 n A a a m b a k w o G 9 E B d i y V N E L t Z / N T p + T M K n 0 S x s q W N G S u / p 7 I a K T 1 J A p s Z 0 T N U C 9 7 M / E / r 5 O a 8 M b P u E x S g 5 I t F o W p I C Y m s 7 9 J n y t k R k w s o U x x e y t h Q 6 o o M z a d k g 3 B W 3 5 5 l T Q v q t 5 V 9 f L h s l K 7 z + M o w g m c w j l 4 c A 0 1 u I M 6 N I D B A J 7 h F d 4 c 4 b w 4 7 8 7 H o r X g 5 D P H 8 A f O 5 w 8 N r Y 2 u < / l a t e x i t > v 1 < l a t e x i t s h a 1 _ b a s e 6 4 = " B V 2 m D u L v B O U K G s o f L f 9 t 0 i h l P 3 g = " > A A A B 6 n i c b V D L S g N B E O y N r x h f U Y 9 e B o P g K e y G o B 4 D X v Q W 0 T w g W c L s p D c Z M j u 7 z M w G Q s g n e P G g i F e / y J t / 4 y T Z g y Y W N B R V 3 X R 3 B Y n g 2 r j u t 5 P b 2 N z a 3 s n v F v b 2 D w 6 P i s c n T R 2 n i m G D x S J W 7 Y B q F F x i w 3 A j s J 0 o p F E g s B W M b u d + a 4 x K 8 1 g + m U m C f k Q H k o e c U W O l x 3 G v 0 i u W 3 L K 7 A F k n X k Z K k K H e K 3 5 1 + z F L I 5 S G C a p 1 x 3 M T 4 0 + p M p w J n B W 6 q c a E s h E d Y M d S S S P U / n R x 6 o x c W K V P w l j Z k o Y s 1 N 8 T U x p p P Y k C 2 x l R M 9 S r 3 l z 8 z + u k J r z x p 1 w m

H N E c 6 L 8 + 5 8 L F t z T j Z z C n / g f P 4 A D z G N r w = = < / l a t e x i t > v 2 < l a t e x i t s h a 1 _ b a s e 6 4 = " M + w b d z 1 I w E g a U O l + d U y q Y 2 T C 9 9 U = " > A A A B 6 n i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 L B b B U 0 l E 1 G P B i 9 4 q 2 g 9 o Q 9 l s J + 3 S z S b s b g o l 9 C d 4 8 aC I V 3 + R N / + N 2 z Y H b X 0 w 8 H h v h p l 5 Q S K 4 N q 7 7 7 R T W 1 j c 2 t 4 r b p Z 3 d v f 2 D 8 u F R U 8 e p Y t h g s Y h V O 6 A a B Z f Y M N w I b C c K a R Q I b A W j 2 5 n f G q P S P J Z P Z p K g H 9 G B 5 C F n 1 F j p c d x z e + W K W 3 X n I K v E y 0 k F c t R 7 5 a 9 u P 2 Z p h N I w Q b X u e G 5 i / I w q w 5 n A a a m b a k w o G 9 E B d i y V N E L t Z / N T p + T M K n 0 S x s q W N G S u / p 7 I a K T 1 J A p s Z 0 T N U C 9 7 M / E /r 5 O a 8 M b P u E x S g 5 I t F o W p I C Y m s 7 9 J n y t k R k w s o U x x e y t h Q 6 o o M z a d k g 3 B W 3 5 5 l T Q v q t 5 V 9 f L h s l K 7 z + M o w g m c w j l 4 c A 0 1 u I M 6 N I D B A J 7 h F d 4 c 4 b w 4 7 8 7 H o r X g 5 D P H 8 A f O 5 w 8 M K Y 2 t < / l a t e x i t > v 0 < l a t e x i t s h a 1 _ b a s e 6 4 = " V i d 9 0 N s P y Y i y p h Y x a g b i K 6 j T s Q E = " > A A A B + H i c b V D L S s N A F L 2 p r 1 o f j b p 0 M 1 g E V y W R o i 4 L I u i u g n 1 A E 8 p k O m m H T i Z h Z i L U 0 C 9 x 4 0 I R t 3 6 K O / / G S Z u F t h 4 Y O J x z L / f M C R L O l H a c b 6 u 0 t r 6 x u V X e r u z s 7 u 1 X 7 Y P D j o p T S W i b x D y W v Q A r y p m g b c 0 0 p 7 1 E U h w F n H a D y X X u d x + p V C w W D 3 q a U D / C I 8 F C R r A 2 0 s C u e k w g L 8 J 6 T D D P b mY D u + b U n T n Q K n E L U o M C r Y H 9 5 Q 1 j k k Z U a M K x U n 3 X S b S f Y a k Z 4 X R W 8 V J F E0 w m e E T 7 h g o c U e V n 8 + A z d G q U I Q p j a Z 7 Q a K 7 + 3 s h w p N Q 0 C s x k H l E t e 7 n 4 n 9 d P d X j l Z 0 w k q a a C L A 6 F K U c 6 R n k L a M g k J Z p P D c F E M p M V k T G W m G j T V c W U 4 C 5 / e Z V 0 z u v u R b 1 x 3 6 g 1 7 4 o 6 y n A M J 3 A G L l x C E 2 6 h B W 0 g k M I z v M K b 9 W S 9 W O / W x 2 K 0 Z B U 7 R / A H 1 u c P m e q T F w = = < / l a t e x i t > 2 E < l a t e x i t s h a 1 _ b a s e 6 4 = " R G m O 4 S 6 J 5 0 O i 6 g P v N 6 1 h N E 5 t q 9 c = " > A A A B / H i c b V D L S s N A F L 2 p r 1 p f 0 S 7 d D B b B

8 A y v 8 G Y 9 W S / W u / W x G C 1 Y + U 4 Z / s D 6 / A F S / 5 S Y < / l a t e x i t > 2 L 2 c < l a t e x i t s h a 1 _ b a s e 6 4 = " K 1 t e N L C C 7 O 5 T 1 bY 3 j Q Y W E 3 z L o T o = " > A A A B + n i c b V D L S s N A F L 2 p r 1 p f q S 7 d D B b B V U l E 1 G X B h b q r Y B / Q h D C Z T t u h k 0 m Y m S g l 9 l P c u F D E r V / i z r 9 x 0 m a h r Q c G D u f c y z 1 z w o Q z p R 3 n 2 y q t r K 6 t b 5 Q 3 K 1 v b O 7 t 7 d n W / r e J U E t o i M Y 9 l N 8 S K c i Z o S z P N a T e R F E c h p 5 1 w f J X 7 n Q c q F Y v F v Z 4 k 1 I / w U L A B I 1 g b K b C r H h P I i 7 A e E c y z 6 2 l A A r v m 1 J 0 Z 0 D J x C 1 K D A s 3 A / v L 6 M U k j K j T h W K m e6 y T a z 7 D U j H A 6 r X i p o g k m Y z y k P U M F j q j y s 1 n 0 K T o 2 S h 8 N Y m m e 0 G i m / t 7 I c K T U J A r N Z B 5 S L X q 5 + J / X S / X g 0 s

Figure 1: An examplar architecture of GM-VAE. The illustration shows the architecture of GM-VAE which sets the latent space to a three-dimensional diagonal Gaussian manifold. The encoder outputs parameters of the PGM normal, which are the points of the Gaussian manifold. The gray line refers to the sampling process. The decoder reconstructs data from the samples.

(a) t-SNE visualization.(b) Latent traversal of representations.

Figure 2: Analysis of the learned latent space of GM-VAE with binarized-MNIST. (a) t-SNE visualization of the representation with respect to the class labels. (b) Increasing the value of β, along the gray line, results in an increasing degree of uncertainty in the reconstructed images.

(a) Latent traversal of representations. (b) Hierarhcy in the data.

Figure 3: Analysis of the latent space learned from GM-VAE with the binarized-Breakout. (a) Reconstructing the representations, along the gray line, shows a similar hierarchical structure, where (b) the hierarchy between the images is expressed as the dotted line.

, we derive the isometries between the two-dimensional hyperbolic models, the Lorentz model, the Poincaré disk model and the Gaussian manifold with arbitrary curvatures. Isometry between the Poincaré disk model and the Lorentz model T Lc→Pc : L c → P c is computed as:

KL (Gamma(a 1 , b 1 ) ∥ Gamma(a 1 , b 1 )) = a 2 log b 1a 2 )ψ(a 1 ) -(1 -b 2 b 1

Figure 4: t-SNE visualization with representations learned from Euclidean-VAE.

Figure 5: Interpolation results of the representations learned from GM-VAE along the µ axis.

Figure 6: The binarized-Breakout representations learned from GM-VAE.

Figure 7: The binarized-MNIST representations learned from GM-VAE.

Omniglot 10 136.53 ±.30 136.25 ±.36 134.95 ±.47 134.01 ±.28 135.20 ±.21 133.79 ±.30 133.79 ±.30 133.79 ±.30 20 121.18 ±.33 119.95 ±.40 117.79 ±.13 117.79 ±.13 117.79 ±.13 118.79 ±.53 118.73 ±.39 119.03 ±.57 30 118.67 ±.67 117.16 ±.48 115.09 ±.56 115.09 ±.56 115.09 ±.56 117.97 ±.35 117.70 ±.47 117.95 ±.37

Ablation study on the geometric transformations of GM-VAE. Vanilla denotes the models without the geometric transformations, and Geo denotes the models with the geometric transformations. The geometric transformations enhance the generalization performance in most cases.

Validation of the proposed isometries between the hyperbolic models.

