THE BURES METRIC FOR TAMING MODE COLLAPSE IN GENERATIVE ADVERSARIAL NETWORKS

Abstract

Generative Adversarial Networks (GANs) are performant generative methods yielding high-quality samples. However, under certain circumstances, the training of GANs can lead to mode collapse or mode dropping, i.e. the generative models not being able to sample from the entire probability distribution. To address this problem, we use the last layer of the discriminator as a feature map to study the distribution of the real and the fake data. During training, we propose to match the real batch diversity to the fake batch diversity by using the Bures distance between covariance matrices in feature space. The computation of the Bures distance can be conveniently done in either feature space or kernel space in terms of the covariance and kernel matrix respectively. We observe that diversity matching reduces mode collapse substantially and has a positive effect on the sample quality. On the practical side, a very simple training procedure, that does not require additional hyperparameter tuning, is proposed and assessed on several datasets.

1. INTRODUCTION

In several machine learning applications, data is assumed to be sampled from an implicit probability distribution. The estimation of this empirical implicit distribution is often intractable, especially in high dimensions. To tackle this issue, generative models are trained to provide an algorithmic procedure for sampling from this unknown distribution. Popular approaches are Variational Auto-Encoders proposed by Kingma & Welling (2014) , Generating Flow models by Rezende & Mohamed (2015) and Generative Adversarial Networks (GANs) initially developed by Goodfellow et al. (2014) . The latter are particularly successful approaches to produce high quality samples, especially in the case of natural images, though their training is notoriously difficult. The vanilla GAN consists of two networks: a generator and a discriminator. The generator maps random noise, usually drawn from a multivariate normal, to fake data in input space. The discriminator estimates the likelihood ratio of the generator network to the data distribution. It often happens that a GAN generates samples only from a few of the many modes of the distribution. This phenomenon is called 'mode collapse'. Contribution. We propose BuresGAN: a generative adversarial network which has the objective function of a vanilla GAN complemented by an additional term, which is given by the squared Bures distance between the covariance matrix of real and fake batches in a latent space. This loss function promotes a matching of fake and real data in a feature space R f , so that mode collapse is reduced. Conveniently, the Bures distance also admits both a feature space and kernel based expression. Contrary to other related approaches such as in Che et al. (2017 ) or Srivastava et al. (2017) , the architecture of the GAN is unchanged, only the objective is modified. A variant called alt-BuresGAN, which is trained with alternating minimization, achieves competitive performance with a simple training procedure that does not require hyperparameter tuning or an additional regularization such as a gradient penalty. We empirically show that the proposed methods are robust when it comes to the choice of architecture and do not require an additional fine architecture search. Finally, an extra asset of BuresGAN is that it yields competitive or improved IS and FID scores compared with the state of the art on CIFAR-10 and STL-10 using a ResNet architecture. Related works. The Bures distance is closely related to the Fréchet distance (Dowson & Landau, 1982) which is a 2-Wasserstein distance between multivariate normal distributions. Namely,

