A PROVABLY CONVERGENT AND PRACTICAL ALGO-RITHM FOR MIN-MAX OPTIMIZATION WITH APPLICA-TIONS TO

Abstract

We present a first-order algorithm for nonconvex-nonconcave min-max optimization problems such as those that arise in training GANs. Our algorithm provably converges in poly(d, L, b) steps for any loss function f : R d × R d → R which is b-bounded with L-Lipschitz gradient. To achieve convergence, we 1) give a novel approximation to the global strategy of the max-player based on first-order algorithms such as gradient ascent, and 2) empower the min-player to look ahead and simulate the max-player's response for arbitrarily many steps, but restrict the min-player to move according to updates sampled from a stochastic gradient oracle. Our algorithm, when used to train GANs on synthetic and real-world datasets, does not cycle, results in GANs that seem to avoid mode collapse, and achieves a training time per iteration and memory requirement similar to gradient descent-ascent.

1. INTRODUCTION

We consider the problem of min-max optimization min x∈R d max y∈R d f (x, y), where the loss function f may be nonconvex in x and nonconcave in y. Min-max optimization of such loss functions has many applications to machine learning, including to GANs (Goodfellow et al., 2014) and adversarial training (Madry et al., 2018) . In particular, following Goodfellow et al. (2014) , GAN training can be formulated as a min-max optimization problem where x encodes the parameters of a "generator" network, and y encodes the parameters of a "discriminator" network. Unlike standard minimization problems, the min-max nature of GANs makes them particularly difficult to train (Goodfellow, 2017) , and has received wide attention. A common algorithm to solve these min-max optimization problems, gradient descent ascent (GDA), alternates between stochastic gradient descent steps for x and ascent steps for y. 1 The advantage of GDA is that it just requires first-order access to f and each iteration is efficient in terms of memory and time, making it quite practical. However, as many works have observed, GDA can suffer from issues such as cycling (Arjovsky & Bottou, 2017) and "mode collapse" (Dumoulin et al., 2017; Che et al., 2017; Santurkar et al., 2018) . Several recent works have focused on finding convergent first-order algorithms for min-max optimization (Rafique et al., 2018; Daskalakis et al., 2018; Liang & Stokes, 2019; Gidel et al., 2019b; Mertikopoulos et al., 2019; Nouiehed et al., 2019; Lu et al., 2020; Lin et al., 2020; Mokhtari et al., 2019; Thekumparampil et al., 2019; Mokhtari et al., 2020) . However, these algorithms are also not guaranteed to converge for general nonconvex-nonconcave min-max problems. The challenge is that min-max optimization generalizes nonconvex minimization, which, in general, is intractable. Algorithms for nonconvex minimization resort to finding "local" optima or assume a starting point "close" to a global optimum. However, unlike minimization problems where local notions of optima exist (Nesterov & Polyak, 2006) , it has been challenging to define a notion of convergent points for min-max optimization, and most notions of local optima considered in previous works (Daskalakis & Panageas, 2018; Jin et al., 2020; Fiez et al., 2019) require significant restrictions for existence. Our contributions. Our main result is a new first-order algorithm for min-max optimization (Algorithm 1) that for any ε > 0, any nonconvex-nonconcave loss function, and any starting point, converges in poly(d, L, b, 1 /ε) steps, if f is b-bounded with L-Lipschitz gradient (Theorem 2.3). A key ingredient in our result is an approximation to the global max function max z∈R d f (x, z). Unlike GDA and related algorithms that alternate between updating the discriminator and generator in an incremental fashion, our algorithm lets the discriminator run a convergent algorithm (such as gradient ascent) until it reaches a first-order stationary point. We then empower the generator to simulate the discriminator's response for arbitrarily many gradient ascent updates. Roughly, at each iteration of our algorithm, the min-player proposes a stochastic (batch) gradient update for x and simulates the response of the max-player with gradient ascent steps for y until it reaches a first-order stationary point. If the resulting loss has decreased, the updates for x and y are accepted; otherwise they are only accepted with a small probability (a la simulated annealing). The point (x , y ) returned by our algorithm satisfies the following guarantee: if the min-player proposes a stochastic gradient descent update to x , and the max-player is allowed to respond by updating y using any "path" that increases the loss at a rate of at least ε -with high probability, the final loss cannot decrease by more than ε. See Section 2 for our convergence guarantees, Section 4 for the key ideas in our proof, and Appendix C for a comparison to previous notions of convergence. Empirically, we apply our algorithm for training GANs (with the cross-entropy loss) on both synthetic (mixture of Gaussians) and real-world (MNIST and CIFAR-10) datasets (Section 3). We compare our algorithm's performance against two related algorithms: gradient/ADAM descent ascent (with one or multiple discriminator steps), and Unrolled GANs (Metz et al., 2017) . Our simulations with MNIST (Figure 1 ) and mixture of Gaussians (Figure 2 ) indicate that training GANs using our algorithm can avoid mode collapse and cycling. For instance, on the Gaussian mixture dataset, we found that by around the 1500'th iteration GDA learned only one mode in 100% of the runs, and cycled between multiple modes. In contrast, our algorithm learned all four modes in 68% of the runs, and three modes in 26% of the runs. On 0-1 MNIST, we found that GDA tends to briefly generate shapes that look like a combination of 0's and 1's, then switches between generating only 1's and only 0's. In contrast, our algorithm seems to learn to generate both 0's and 1's early on and does not stop generating either digit. GANs trained using our algorithm generated both digits by the 1000'th iteration in 86% of the runs, while those trained using GDA only did so in 23% of the runs. Our CIFAR-10 simulations (Figure 3 ) indicate that our algorithm trains more stably, resulting in a lower mean and standard deviation for FID scores compared to GDA. Furthermore, the per-step computational and memory cost of our algorithm is similar to GDA indicating that our algorithm can scale to larger datasets.

Related work

Guaranteed convergence for min-max optimization. Several works have studied GDA dynamics in GANs (Nagarajan & Kolter, 2017; Mescheder et al., 2017; Li et al., 2018; Balduzzi et al., 2018; Daskalakis & Panageas, 2018; Jin et al., 2020) and established that GDA suffers from severe limitations: GDA can exhibit rotation around some points, or otherwise fail to converge. Thus, we cannot expect global convergence guarantees for GDA. To address these convergence issues for GDA, multiple works have proposed algorithms based on Optimistic Mirror Descent (OMD), Extra-gradient method, or similar approaches (Gidel et al., 2019b; Daskalakis et al., 2018; Liang & Stokes, 2019; Daskalakis & Panageas, 2019; Mokhtari et al., 2019; 2020) . These algorithms avoid some of the pathological behaviors of GDA and achieve guaranteed convergence in poly(κ, log( 1 /ε)) iterations where κ is the condition number of f. However, all these results either require convexity/concavity assumptions on f, which usually do not hold for GANs, or require that the starting point lies in a small region around an equilibrium point, and hence provide no guarantees for an arbitrary initialization. Some works also provide convergence guarantees for min-max optimization (Nemirovski & Yudin, 1978; Kinderlehrer & Stampacchia, 1980; Nemirovski, 2004; Rafique et al., 2018; Lu et al., 2020; Lin et al., 2020; Nouiehed et al., 2019; Thekumparampil et al., 2019) . However, they require f to be concave in y, again limiting their applicability. As for nonconvex-nonconcave min-max optimization, Heusel et al. (2017) prove convergence of finite-step GDA, under the assumption that the underlying continuous dynamics converge to a local min-max optimum (this assumption may not even hold for f that is bi-linear). Jin et al. (2020) present a version of GDA for min-max optimization (generalized by Fiez et al. (2019) ) such that if the algorithm converges, the convergence point is a local min-max optimum. Both these results require that the min-player use a vanishingly small step size relative to the max-player, resulting in slow convergence. Wang et al. (2020) present an algorithm that can converge for nonconvex-nonconcave functions, but requires the initial point to lie in a region close a local min-max optimum (such optima are not guaranteed to exist). In contrast to the above works, our algorithm is guaranteed to



In practice, gradients steps are often replaced by ADAM steps; we ignore this distinction for this discussion.

