OPTIMAL NEURAL NETWORK APPROXIMATION OF WASSERSTEIN GRADIENT DIRECTION VIA CONVEX OPTIMIZATION

Abstract

The computation of Wasserstein gradient direction is essential for posterior sampling problems and scientific computing. The approximation of the Wasserstein gradient with finite samples requires solving a variational problem. We study the variational problem in the family of two-layer networks with squared-ReLU activations, towards which we derive a semi-definite programming (SDP) relaxation. This SDP can be viewed as an approximation of the Wasserstein gradient in a broader function family including two-layer networks. By solving the convex SDP, we obtain the optimal approximation of the Wasserstein gradient direction in this class of functions. We also propose practical algorithms using subsampling and dimension reduction. Numerical experiments including PDEconstrained Bayesian inference and parameter estimation in COVID-19 modeling demonstrate the effectiveness and efficiency of the proposed method.

1. INTRODUCTION

Bayesian inference plays an essential role in learning model parameters from the observational data with applications in inverse problems, scientific computing, information science, and machine learning (Stuart, 2010) . The central problem in Bayesian inference is to draw samples from a posterior distribution, which characterizes the parameter distribution given data and a prior distribution. The Wasserstein gradient flow (Otto, 2001; Ambrosio et al., 2005; Junge et al., 2017) has shown to be effective in drawing samples from a posterior distribution, which attracts increasing attention in recent years. For instance, the Wasserstein gradient flow of Kullback-Leibler (KL) divergence connects to the overdampled Langevin dynamics. The time-discretization of the overdamped Langevin dynamics renders the classical Langevin Monte Carlo Markov Chain (MCMC) algorithm. In this sense, the computation of Wasserstein gradient flow yields a different viewpoint for sampling algorithms. In particular, the Wasserstein gradient direction also provides a deterministic update of the particle system (Carrillo et al., 2021b) . Based on the approximation or generalization of the Wasserstein gradient direction, many efficient sampling algorithms have been developed, including Wasserstein gradient descent (WGD) with kernel density estimation (KDE) (Liu et al., 2019) , Stein variational gradient descent (SVGD) (Liu & Wang, 2016) , and neural variational gradient descent (di Langosco et al., 2021) , etc. Meanwhile, neural networks exhibit tremendous optimization and generalization performance in learning complicated functions from data. They also have wide applications in Bayesian inverse problems (Rezende & Mohamed, 2015; Onken et al., 2020; Kruse et al., 2019; Lan et al., 2021) . According to the universal approximation theorem of neural networks (Hornik et al., 1989; Lu et al., 2017) , any arbitrarily complicated functions can be learned by a two-layer neural network with nonlinear activations and a sufficient number of neurons. Functions represented by neural networks naturally provide an approximation towards the Wasserstein gradient direction. However, due to the nonlinear and nonconvex structure of neural networks, optimization algorithms including stochastic gradient descent may not find the global optima of the training problem. Recently, based on a line of works (Pilanci & Ergen, 2020; Sahiner et al., 2020; Bartan & Pilanci, 2021a) , the regularized training problem of two-layer neural networks with ReLU/polynomial activation can be formulated as a convex program. Indeed, by solving the convex program, we can construct the entire set of global optima of the nonconvex training problem (Wang et al., 2020) . Theoretical analysis (Wang et al., 2022) shows that global optima of the training problem correspond to the simplest models with good generalization properties. Moreover, numerical results (Pilanci & Ergen, 2020) show that neural networks found by solving the convex program can achieve higher train accuracy and test accuracy compared to neural networks trained by SGD with the same number of parameters. In this paper, we study a variational problem, whose optimal solution corresponds to the Wasserstein gradient direction. Focusing on the family of two-layer neural networks with squared ReLU activation, we formulate the regularized variational problem in terms of samples. Directly training the neural network to minimize the loss may get the neural network stuck at local minima or saddle points and it often leads to biased sample distribution from the posterior. Instead, we analyze the convex dual problem of the training problem and study its semi-definite program (SDP) relaxation by analyzing the geometry of dual constraints. The resulting SDP can be efficiently solved by convex optimization solvers such as CVXPY (Diamond & Boyd, 2016) . We then derive the corresponding relaxed bidual problem (dual of the relaxed dual problem). Thus, the optimal solution to the dual problem yields an optimal approximation of the Wasserstein gradient direction in a broader function family. We also analyze the choice of the regularization parameter and present a practical implementation using subsampling and parameter dimension reduction to improve computational efficiency. Numerical results for experiments including PDE-constrained inference problems and Covid-19 parameter estimation problems illustrate the effectiveness and efficiency of our method.

1.1. RELATED WORKS

The time and spatial discretizations of Wasserstein gradient flows are extensively studied in literature (Jordan et al., 1998; Junge et al., 2017; Carrillo et al., 2021a; b; Bonet et al., 2021; Liutkus et al., 2019; Frogner & Poggio, 2020) . Recently, neural networks have been applied in solving or approximating Wasserstein gradient flows (Mokrov et al., 2021; Lin et al., 2021b; a; Alvarez-Melis et al., 2021; Bunne et al., 2021; Hwang et al., 2021; Fan et al., 2021) . For sampling algorithms, di Langosco et al. (2021) learns the transportation function by solving an unregularized variational problem in the family of vector-output deep neural networks. Compared to these studies, we focus on a convex SDP relaxation of the varitional problem induced by the Wasserstein gradient direction. Meanwhile, Feng et al. (2021) form the Wasserstein gradient direction as the mininimizer the Bregman score and they apply deep neural networks to solve the induced variational problem. In comparison to previous works on the convex optimization formulations of neural networks using SDP (Bartan & Pilanci, 2021a; b) , they focus on the polynomial activation and give the exact convex optimization formulation (instead of convex relaxation). In comparison, we focus on the neural networks with the squared ReLU activation, which has not been considered before. Our method can also apply to the analysis of supervised learning problem using squared ReLU activated neural networks.

2. BACKGROUND

In this section, we briefly review the Wasserstein gradient descent and present its variational formulation. In particular, we focus on the Wasserstein gradient descent direction of KL divergence functional. Later on, we design a neural network convex optimization problem to approximate the Wasserstein gradient in samples.

2.1. WASSERSTEIN GRADIENT DESCENT

Consider an optimization problem in the probability space: inf ρ∈P D KL (ρ π) = ρ(x)(log ρ(x) -log π(x))dx, Here the integral is taken over R d and the objective functional D KL (ρ π) is the KL divergence from ρ to π. The variable is the density function ρ in the space P = {ρ ∈ C ∞ (R d )| ρdx = 1, ρ > 0}. The function π ∈ C ∞ (R d ) is a known probability density function of the posterior distribution. By solving the optimization problem (1), we can generate samples from the posterior distribution. A known fact (Villani, 2003, Chapter 8.3.1) is that the Wasserstein gradient descent flow for the optimization problem (1) satisfies ∂ t ρ t =∇ • ρ t ∇ δ δρ t D KL (ρ t π) = ∇ • (ρ t (∇ log ρ t -∇ log π)) (a) = ∆ρ t -∇ • (ρ t ∇ log π), where ρ t (x) = ρ(x, t), δ δρt is the L 2 first variation operator w.r.t. ρ t , ∇ • F denotes the divergence of a vector valued function F : R d → R d and ∆ is the Laplace operator. In step (a) we uses the fact that ρ t ∇ log ρ t = ∇ρ t . This equation is also known as the gradient drift Fokker-Planck equation. It corresponds to the following updates in terms of samples: dx t = -(∇ log ρ t (x t ) -∇ log π(x t ))dt, where x t follows the distribution of ρ t . Clearly, when ρ t = π, the above dynamics reach the equilibrium, which implies that the samples x t are generated by the posterior distribution. To solve the Wasserstein gradient flow (2), we consider a forward Eulerian discretization in time. In the l-th iteration, suppose that {x n l } are samples drawn from ρ l . The update rule of Wasserstein gradient descent (WGD) on the particle system {x n l } follows x n l+1 = x n l -α l ∇Φ l (x n l ), where Φ l : R d → R is a function which approximates log ρ l -log π and α l > 0 is the step size.

2.2. VARIATIONAL FORMULATION OF WGD

Given the particles {x n } N n=1 , we design the following variational problem to choose a suitable function Φ approximating the function log ρ -log π. Consider inf Φ∈C 1 (R d ) 1 2 ∇Φ(x -(∇ log ρ(x) -∇ log π(x)) 2 2 ρ(x)dx. The objective functional evaluates the least-square discrepancy between ∇ log ρ -∇ log π and ∇Φ weighted by the density ρ. The optimal solution follows Φ = log ρ-log π, up to a constant shift. Let H ⊆ C 1 (R d ) be a finite dimensional function space. The following proposition gives a formulation of (4) in H. Proposition 1 Let H ⊆ C 1 (R d ) be a function space. The variational problem (4) in the domain H can be reformulated to inf Φ∈H 1 2 ∇Φ(x) 2 2 ρdx + ∆Φ(x)ρ(x)dx + ∇ log π(x), ∇Φ(x) ρ(x)dx. Remark 1 A similar variational problem has been studied in (di Langosco et al., 2021) . If we replace ∇Φ for Φ ∈ H by a vector field Ψ in certain function family, then, the quantity in (5) is the negative regularized Stein discrepancy defined in (di Langosco et al., 2021) between ρ and π based on Ψ. This problem is also similar to the varitional problem for the score matching estimator in (Hyvärinen & Dayan, 2005) by parameterizing Φ in a given probabilistic model. In comparison, our method can be viewed as a special case of score matching by using a two-layer neural network. Therefore, by replacing the density ρ by finite samples {x n } N n=1 ∼ ρ, the problem (5) in terms of finite samples forms inf Φ∈H 1 N N n=1 1 2 ∇Φ(x n ) 2 2 + ∆Φ(x n ) + 1 N N n=1 ∇ log π(x n ), ∇Φ(x n ) . 3 OPTIMAL NEURAL NETWORK APPROXIMATION OF WASSERSTEIN

GRADIENT

In this section, we focus on functional space H of functions represented by two-layer neural networks. We derive the primal and dual problem of the regularized Wasserstein variational problems. By analyzing the dual constraints, a convex SDP relaxation of the dual problem is obtained. We also present a practical implementation estimation of ∇ log ρ -∇ log π and discuss the choice of the regularization parameter. Let ψ be an activation function. Consider the case where H is a class of two-layer neural network with the activation function ψ(x): H = Φ θ ∈ C 1 (R d )|Φ θ (x) = α T ψ(W T x) , where θ = (W, α) is the parameter in the neural network with W ∈ R d×m and α ∈ R m . Remark 2 We can extend this model to handle the bias term by add an entry of 1 in x 1 , . . . , x n . For two-layer neural networks, we can compute the gradient and Laplacian of Φ ∈ H as follows: ∇Φ θ (x) = m i=1 α i w i ψ (w T i x) = W (ψ (W T x) • α), ∆Φ θ (x) = m i=1 α i w i 2 2 ψ (w T i x). Here • represents the element-wise multiplication. By adding a regularization term to the variational problem ( 6), we obtain min θ 1 2N N n=1 m i=1 α i w i ψ (w T i x n ) 2 2 + 1 N N n=1 m i=1 α i w i ψ (w T i x n ), ∇ log π(x n ) + 1 N N n=1 m i=1 α i w i 2 2 ψ (w T i x n ) + β 2 R(θ), where β > 0 is the regularization parameter. We focus on the squared ReLU activation ψ(z) = (z) 2 + = (max{z, 0}) 2 . Note that a non-vanishing second derivative is required for the Laplacian term in (9), which makes the ReLU activation inadequate. For this activation function, we consider the regularization function R(θ) = m i=1 ( w i 3 2 + |α i | 3 ). Remark 3 We note that ∇Φ θ (x) and ∆Φ θ (x) are all piece-wise degree-3 polynomials of the parameters θ. Hence, we consider a specific cubic regularization term above, analogous to (Bartan & Pilanci, 2021a) . By choosing this regularization term, we can derive a simplified dual problem. By utilizing the arithmetic and geometric mean (AM-GM) inequality, we can rescale the first and second-layer parameters and formulate the regularized variational problem (10) as follows. Proposition 2 (Primal problem) The regularized variational problem (10) can be reformulated to min W,α 1 2 N n=1 m i=1 α i w i ψ (w T i x n ) 2 + N n=1 m i=1 α i w i 2 2 ψ (w T i x n ) + N n=1 m i=1 α i w i ψ (w T i x n ), ∇ log π(x n ) + β α 1 , s.t. w i 2 ≤ 1, i ∈ [m], where β = 3 • 2 -5/3 N β and we denote [m] = {1, . . . , m}. In short, the optimal value of ( 10) and ( 11) are the same. We can obtain the optimal solution of ( 11) by rescaling the optimal solution of (10) and vice versa. For simplicity, we write Y ∈ R N ×d whose n-row is ∇ log π(x n ) for n ∈ [N ]. We introduce the slack variable z n = m i=1 α i w i ψ (x T n w i ) for n ∈ [N ] and denote Z = [z 1 . . . z N ] T ∈ R N ×d . Then, we can simplify the problem (11) to min W,α,Z 1 2 Z 2 F + N n=1 m i=1 α i w i 2 2 ψ (w T i x n ) + tr(Y T Z) + β α 1 , s.t. z n = m i=1 α i w i ψ (x T n w i ), n ∈ [N ], w i 2 ≤ 1, i ∈ [m]. To derive the convex relaxtion of neural network training problem, the dual problem plays an import role. By applying the Lagrangian duality, we can derive the dual problem of (12) as follows. Proposition 3 (Dual problem) The dual problem of the regularized variational problem (12) is max Λ - 1 2 Λ + Y 2 F , s.t. max w: w 2≤1 N n=1 w 2 2 ψ (x T n w) -λ T n wψ (x T n w) ≤ β, which provides a lower-bound on (12). We note that the dual problem can be infeasible if the regularization parameter β is below certain threshold. In other words, if the regularization term is missing or the regularization parameter is not large enough, the optimal value of the dual problem is -∞ and the primal problem is not lower bounded.

3.1. ANALYSIS OF DUAL CONSTRAINTS AND THE RELAXED DUAL PROBLEM

Now, we analyze the constraint in the dual problem. We note that it is closely related to the regularization parameter, which we will discuss later. For simplicity, we take ψ (0) = 0 as the subgradient of ψ (z) at z = 0, i.e., taking the left derivative of ψ (z ) at z = 0. Let X = [x 1 , . . . , x N ] T ∈ R N ×d . Denote the set of all possible hyper-plane arrangements corresponding to the rows of X as , where r = rank(X). Based on the analysis of the dual constraints, we can derive a convex SDP as a relaxed dual problem. S = {diag(I(Xw ≥ 0))|w ∈ R d , w = 0}. Proposition 4 (Relaxed dual problem) The relaxed dual problem is the following SDP: max Λ,{r (j,-) ,r (j,+) } p j=1 - 1 2 Λ + Y 2 F , s.t. Ãj (Λ) + Bj + N n=0 r (j,-) n H (j) n + βe d+1 e T d+1 0, r (j,-) ≥ 0, j ∈ [p], -Ãj (Λ) -Bj + N n=0 r (j,+) n H (j) n + βe d+1 e T d+1 0, r (j,+) ≥ 0, j ∈ [p], where we denote [p] = {1, . . . , p}. For j ∈ [p], we denote A j (Λ) = -Λ T D j X -X T D j Λ, B j = 2 tr(D j )I d , Ãj (Λ) = A j (Λ) 0 0 0 , Bj = B j 0 0 0 , H (j) 0 = I d 0 0 -1 and H (j) n = 0 (1 -2(D j ) nn )x n (1 -2(D j ) nn )x T n 0 , n ∈ [N ] The vector e d+1 ∈ R d+1 satisfies that (e d+1 ) i = 0 for i ∈ [d] and (e d+1 ) d+1 = 1. The optimal value of (15) gives a lower bound on the dual problem (13), and hence on the primal problem (12). The relaxed bi-dual problem provides insights on approximating the primal problem via convex optimization, which is derived as follows. As an equivalent formulation of the convex dual problem (15), it can be viewed as a convex relaxation of the primal problem (12). Proposition 5 (Relaxed bi-dual problem) The dual of the relaxed dual problem (15) is as follows min Z,{(S (j,+) ,S (j,-) )} p j=1 1 2 Z + Y 2 F - 1 2 Y 2 F + p j=1 tr( Bj (S (j,+) -S (j,-) )) + β p j=1 tr (S (j,+) + S (j,-) )e d+1 e T d+1 , s.t. Z = p j=1 Ã * j (S (j,-) -S (j,+) ), tr(S (j,-) H (j) n ) ≤ 0, tr(S (j,+) H (j) n ) ≤ 0, n = 0, . . . , N, j ∈ [p]. Here A * j is the adjoint operator of the linear operator A j . As ( 15) is a convex problem and the Slater's condition is satisfied, the optimal values of ( 15) and ( 16) are same. The bi-dual problem ( 16) is closely related to the primal problem (12). Indeed, any feasible solutions of the primal problem (11) can be mapped to feasible solutions of ( 16). We note that the mapping from the primal solution to the bi-dual solution cannot go both ways, unless these two problems are equivalent. Theorem 1 Suppose that (Z, W, α) is feasible to the primal problem (12). Then, there exist matrices {S (j,+) , S (j,-) } p j=1 constructed from (W, α) such that (Z, {S (j,+) , S (j,-) } p j=1 ) is feasible to the relaxed bi-dual problem (16). Moreover, the objective value of the relaxed bi-dual problem (16) at (Z, {S (j,+) , S (j,-) } p j=1 ) is the same as objective value of the primal problem (12) at (Z, W, α). Let J(Z, {S (j,+) , S (j,-) } p j=1 ) denote the objective value of the relaxed bi-dual problem ( 16) at a feasible solution (Z, {S (j,+) , S (j,-) } p j=1 ). Let (Z * , W * , α * ) denote a globally optimal solution of the primal problem (12). By Theorem 1, there exist matrices {S (j,+) , S (j,-) } p j=1 such that (Z * , {S (j,+) , S (j,-) } p j=1 ) is a feasible solution of the relaxed bi-dual problem (16) and J(Z * , {S (j,+) , S (j,-) } p j=1 ) is the same as the objective value of (12) at its global minimum (Z * , W * , α * ). On the other hand, let ( Z * , { S(j,+) , S(j,-) } p j=1 ) denote an optimal solution of the relaxed bi-dual problem (16). From the optimality of ( Z * , { S(j,+) , S(j,-) } p j=1 ), we have J( Z * , { S(j,+) , S(j,-) } p j=1 ) ≤ J(Z * , {S (j,+) , S (j,-) } p j=1 ). Note that at (Z * , W * , α * ) we obtain the optimal approximation of ∇ log ρ-∇ log π at x 1 , . . . , x N in the family of two-layer squared-ReLU networks (7). Smaller or equal objective value of the relaxed bi-dual problem ( 16) can be achieved at ( Z * , { S(j,+) , S(j,-) } p j=1 ) than at (Z * , {S (j,+) , S (j,-) } p j=1 ). Therefore, we can view Z * gives an optimal approximation of ∇ log ρ -∇ log π evaluated on x 1 , . . . , x N in a broader function family including the two-layer squared ReLU neural networks. From the derivation of the relaxed bi-dual problem, we have the relation Z * = -Λ * -Y , where (Λ * , {r (j,+) , r (j,-) ) is optimal to the relaxed dual problem ( 15) and ( Z * , { S(j,+) , S(j,-) } p j=1 ) is optimal to the relaxed bi-dual problem (16). Therefore, by solving Λ * from the relaxed dual problem (15), we can use -Λ * -Y as the approximation of ∇ log ρ -∇ log π evaluated on x 1 , . . . , x N . Remark 4 We note that solving the proposed convex optimization problem 15 renders the approximation of the Wasserstein gradient direction. Compared to the two-layer ReLU networks, it induces a broader class of functions represented by {S (j,+) , S (j,-) } p j=1 . This contains more variables than the neural network function.

3.2. PRACTICAL IMPLEMENTATION

Although the number p of all possible hyper-plane arrangements is upper bounded by 2r((N -1)e/r) r with r = rank(X), it is computationally costly to enumerate all possible p matrices D 1 , . . . , D p to represent the constraints in the relaxed dual problem (4). In practice, we first randomly sample M i.i.d. random vectors u 1 , . . . , u M ∼ N (0, I d ) and generate a subset Ŝ = {diag(I(Xu j ≥ 0)|j ∈ [M ]}. of S. Then, we optimize the randomly sub-sampled version of the relaxed dual problem based on the subset Ŝ and obtain the solution Λ. Here -Λ -Y is used as the direction to update the particle system X. If the regularization parameter is too large, then we will have -Λ -Y = 0, which makes the particle system unchanged. Therefore, to ensure that β is not too large, we decay β by a factor γ 1 ∈ (0, 1). This also appears in (Ergen et al., 2021) . On the other hand, if β is too small resulting the relaxed dual problem (4) infeasible, we increase β by multiplying γ -1 2 , where γ 2 ∈ (0, 1). Detailed explanation of the adjustment of the regularization parameter can be found in Appendix D. The overall algorithm is summarized in Algorithm 1.

Algorithm 1 Convex neural Wasserstein descent

Require: initial positions {x n 0 } N n=1 , step size α l , initial regularization parameter β0 , γ 1 , γ 2 ∈ (0, 1). 1: while not converge do 2: Form X l and Y l based on {x n l } N n=1 and {∇ log π(x n l )} N n=1 . 3: Solve Λ l from the relaxed dual problem (15) with β = βl .

4:

if the relaxed dual problem with β = βl is infeasible then 5: Set X l+1 = X l for n ∈ [N ] and set βl+1 = γ -1 2 βl . 6: else 7: Update X l+1 = X l + α l (Λ l + Y l ) for n ∈ [N ] and set βl+1 = γ 1 βl .

8:

end if 9: end while Applying the standard interior point method (Boyd et al., 2004) leads to the computational time O((max{N, d 2 }p) 6 ). For high-dimensional problems, i.e., d is large, the computational cost of solving (15) can be large. In this case, we apply the dimension-reduction techniques (Zahm et al., 2018; Chen & Ghattas, 2020; Wang et al., 2021a) to reduce the parameter dimension d to a data-informed intrinsic dimension d, which is often very low, i.e., d d, which can dramatically decrease the computational time (18).

4. NUMERICAL EXPERIMENTS

In this section, we present numerical results to compare WGD approximated by neural networks (WGD-NN) and WGD approximated using convex optimization formulation of neural networks (WGD-cvxNN). The performance of compared methods is assessed by the sample goodness-of-fit of the posterior. For WGD-NN, in each iteration, it updates the particle system using (3) with a function Φ represented by a two-layer squared ReLU neural network. The parameters of the neural network is obtained by directly solving the nonconvex optimization problem (10). For high-dimensional problems, we apply the dimension reduction technique and compare the projected versions (pWGD-NN and pWGD-cvxNN). We note that although the cost for solving the relaxed dual problem (15) using standard convex optimization solvers in WGD-cvxNN can be higher compared to that by a direct neural network training in WGD-NN, this cost difference is negligible in the entire optimization dominated by the likelihood evaluation when the model (e.g., PDE) is expensive to solve. In such cases WGD-cvxNN and WGD-NN have similar computational complexity but WGD-cvxNN achieves better performance. We use the standard convex optimization solver CVXPY (Diamond & Boyd, 2016) with MOSEK(ApS, 2019) inner solver. Applying randomized SDP solvers (Yurtsever et al., 2021) , randomized second-order methods (Pilanci & Wainwright, 2017; Lacotte et al., 2021) or advanced SDP solvers (Zhao et al., 2010; Yang et al., 2015; Wang et al., 2021b) for large-scale problem can improve the computation time. Moreover, the induced SDPs have specific structures of many similar constraints. Solving the SDP ( 15) can be accelerated by designing a specialized convex optimization solver, which is left for future work. Under review as a conference paper at ICLR 2023

4.1. A TOY EXAMPLE

We test the performance of WGD on a bimodal 2-dimensional double-banana posterior distribution introduced in (Detommaso et al., 2018) . We first generate 300 posterior samples by a Stein variational Newton (SVN) method (Detommaso et al., 2018) We observe that the samples by WGD-cvxNN achieves much smaller MMD than those of WGD-NN compared to the reference SVN samples, which is consistent with the results shown in Figure 1 .  , v + e x ∇u = 0 in D, ∇ • v = h in D, where u is pressure, v is velocity, h is force, e x is a random (permeability) field equipped with a Gaussian prior x ∼ N (x 0 , C) with covariance operator C = (-δ∆ + γI) -α where we set δ = 0.1, γ = 1, α = 2 and x 0 = 0. This problem is widely used in many areas, for instance, estimating permeability in groundwater flow, thermal conductivity in material science or electrical impedance in medical imaging, We impose Dirichlet boundary conditions u = 1 on the top boundary and u = 0 on the bottom boundary, and homogeneous Neumann boundary conditions on the left and right boundaries for u. We use a finite element method with piecewise linear elements for the discretization of the problem, resulting in 81 dimensions for the discrete parameter. The data is generated as pointwise observation of the pressure field at 49 points equidistantly distributed in (0, 1) 2 , corrupted with additive 5% Gaussian noise. We use a DILI-MCMC algorithm Cui et al. (2016) with 10000 effective samples to compute the sample mean and sample variance, which are used as the reference values to assess the goodness of the samples. We run pWGD-cvxNN and pWGD-NN with 64 samples for ten trials with step size α l = 10 -3 , where we set β = 10, γ 1 = 0.95, and γ 2 = 0.95 10 for both methods. The RMSE of the sample mean and sample variance are shown in Figure 3 for the two methods at each of the iterations. We can observe that pWGD-cvxNN achieves smaller errors for both the sample mean and the sample variance compared to pWGD-NN at each iteration. Moreover, pWGD-cvxNN provides much smaller variation of the sample mean and sample variance for the ten trials compared to pWGD-NN. Furthermore, by an effective reduction of the parameter dimension from 81 to data-informed 20 in our pWGD-cvxNN, as used and analyzed in (Zahm et al., 2018; Chen & Ghattas, 2020; Wang et al., 2021a) , the time for solving the SDP is significantly reduced from about 800 seconds in average to less than 1 second (about 0.7 in average), making our pWGD-cvxNN computationally efficient.

4.3. BAYESIAN INFERENCE FOR COVID-19

In this experiment, we use Bayesian inference to learn the dynamics of the transmission and severity of COVID-19 from the recorded data for New York state. We use the model, parameter, and data as in Chen & Ghattas (2020) . More specifically, we use a compartmental model for the modeling of the transmission and outcome of COVID-19. We take the number of hospitalized cases as the observation data to infer a social distancing parameter, a time-dependent stochastic process that is equipped with a Tanh-Gaussian prior to model the transmission reduction effect of social distancing, which becomes 96 dimensions after discretization. We use the projected Stein variational gradient descent (pSVGD) method Chen & Ghattas (2020) as the reference to evaluate the goodness of samples. We run pWGD-cvxNN and pWGD-NN using 64 samples for 100 iterations with step size α l = 10 -3 , where we set β = 10, γ 1 = 0.95, and γ 2 = 0.95 10 for both methods as in the last example. From Figure 4 we can observe that pWGD-cvxNN produces more consistent results than pWGD-NN compared to the reference pSVGD results, for both the sample mean and 90% credible interval, both in the inference of the social distancing parameter and in the prediction of the hospitalized cases.

5. CONCLUSION

In the context of Bayesian inference, we approximate Wasserstein gradient direction by the gradient of functions in the family of two-layer neural networks. We propose a convex SDP relaxation of the dual of the variational primal problem, which can be solved efficiently using convex optimization methods instead of directly training the neural network as a nonconvex optimization problem. In particular, we established that the gradient obtained by the new formulation and convex optimization is at least as good as the one approximated by functions in the family of two-layer neural networks, which is demonstrated by various numerical experiments. By stacking the two-layer neu- ral networks in each step together, our proposed method formulate a deep neural to learn the transportation map from prior to posterior. In future studies, specialized solvers for structured SDPs, including the relaxed dual problem, can lead to drastic accelerations of our proposed method and it is of central importance for the practical applications of our algorithms to real-world problems. We also expect to extend our convex optimization formulation of neural networks to the calculation/approximation of generalized Wasserstein flows. We also expect to apply deep neural networks for the approximation of Wasserstein gradient flows based on recent works on convex optimization formulations of deep neural networks (Wang et al., 2021c; Ergen & Pilanci, 2021a; b) .  Here the variables are β, Λ and {r (j,+) , r (j,-) } p j=1 . Let β1 be the optimal value of the above problem. Then, only for β ≥ β1 , there exists Λ ∈ R N ×d satisfying the constraints in (15). In other words, the relaxed dual problem (15) is feasible. We also note that β1 only depends on the samples X and it does not depend on the value of ∇ log π evaluated on x 1 , . . . , x N . On the other hand, consider the following SDP min β, s.t. Ãj (Y ) + Bj + N n=0 r (j,-) n H (j) n + βe d+1 e T d+1 0, -Ãj (Y ) -Bj + N n=0 r (j,+) n H (j) n + βe d+1 e T d+1 0, r (j,-) ≥ 0, r (j,+) ≥ 0, j ∈ [p], where the variables are β and {r (j,+) , r (j,-) } p j=1 . Let β2 be the optimal value of the above problem. For β ≥ β2 , as Y is feasible for the constraints in (15), the optimal value of the relaxed dual problem (15) is 0. In short, only when β ∈ [ β1 , β2 ], the variational problem ( 15) is non-trivial. To ensure that solving the relaxed dual problem (15) gives a good approximation of the Wasserstein gradient direction, we shall avoid choosing β either too small or too large.

E PROOFS

E.1 PROOF OF PROPOSITION 1 PROOF We first note that 1 2 ∇Φ -∇ log ρ + ∇ log π 2 2 ρdx = 1 2 ∇Φ 2 2 ρdx + ∇ log π -∇ log ρ, ∇Φ ρdx + 1 2 ∇ log ρ -∇ log π 2 2 ρdx. We notice that the term PROOF Suppose that ŵi = β -1 i w i and αi = β 2 i α i , where β i > 0 is a scale parameter for i ∈ [m]. Let θ = {( ŵi , αi )} m i=1 . We note that αi ŵi ψ ( ŵT i x n ) = β i α i w i ψ β -1 i w T i x n = α i w i ψ (w T i x n ), and αi ŵi 2 2 ψ ( ŵT i x n ) = α i w i 2 2 ψ ( ŵT i x n ) = α i w i 2 2 ψ (w T i x n ). This implies that Φ θ (x) = Φ θ (x) and ∇ • Φ θ (x) = ∇ • Φ θ (x). For the regularization term R(θ), we note that ŵi 3 2 + αi 3 2 =β 6 i |α i | 3 + β -3 i w i 3 2 =β 6 i |α i | 3 + 1 2 β -3 i w i 3 2 + 1 2 β -3 i w i 3 2 =3 • 2 -2/3 w i 2 2 |α i |. The optimal scaling parameter is given by α i = 2 -1/9 wi (28) For fixed W , the constraints on Z and α are linear and the strong duality holds. Thus, we can exchange the order of min Z,α and max Λ . Thus, we can compute that holds for all w ∈ R d satisfying w 2 ≤ 1, (2D j -I)Xw ≥ 0. This is equivalent to say that for all j ∈ [p] β ≥ min 2 tr(D j ) w 2 2 -2w T Λ T D j Xw, (32) s.t. w 2 ≤ 1, 2(D j -I)Xw ≥ 0, β ≤ max 2 tr(D j ) w 2 2 -2w T Λ T D j Xw, s.t. w 2 ≤ 1, 2(D j -I)Xw ≥ 0.



Here I(s) = 1 if the statement s is correct and I(s) = 0 otherwise. Let p = |S| be the cardinality of S, and write S = {D 1 , . . . , D p }. According to(Cover, 1965), we have the upper bound p ≤ 2r e(N -1) r r

as the reference, as shown in Figure 1. We evaluate the performance of WGD-NN and WGD-cvxNN by calculating the maximum mean discrepancy (MMD) between their samples in each iteration and the reference samples. In the comparison, we use N = 50 samples and run for 100 iterations with step sizes α l = 10 -3 . For WGD-cvxNN, we set β = 1, γ 1 = 0.95 and γ 2 = 0.95 10 . For WGD-NN, we use m = 200 neurons and optimize the regularized training problem (10) using all samples with the Adam optimizer (Kingma & Ba, 2014) with learning rate 10 -3 for 200 sub-iterations. We also set the regularization parameter β = 1 and decrease it by a factor of 0.95 in each iteration. We find that this setup of parameters is more suitable. The posterior density and the sample distributions by WGD-cvxNN and WGD-NN at the final step of 100 iterations are shown in Figure 1. It can be observed that WGD-cvxNN provides more representative samples than WGD-NN for the posterior density. In Figure 2, we plot the MMD of the samples by WGD-cvxNN and WGD-NN compared to the reference SVN samples at each iteration.

Figure 1: Posterior density and sample distributions by WGD-cvxNN and WGD-NN at the final step of 100 iterations, compared to the reference SVN samples (right).

Figure 2: MMD of WGD-cvxNN and WGD-NN samples compared to the reference SVN samples.

Figure 3: Ten trials and the RMSE of the sample mean (top) and sample variance (bottom) by pWGD-NN and pWGD-cvxNN at different iterations. Nonlinear inference with PDE constraint.

Figure 4: Comparison of pWGD-cvxNN and pWGD-NN to the reference by pSVGD for Bayesian inference of the social distancing parameter (left) from the data of the hospitalized cases (right) with sample mean and 90% credible interval.

Figure 5: Ten trials and the RMSE of the sample mean (top) and sample variance (bottom) by pWGD-NN and pWGD-cvxNN at different iterations. Linear inference problem.

As the scaling operation does not change w i 2 2 |α i |, we can simply let w i 2 = 1. Thus, the regularization term β 2 R(θ) becomes β N m i=1 u i 1 . This completes the proof.

w T i x n ) -y T n w i ψ (x T n w i ) ≤ β .(29) By exchanging the order of min and max, we can derive the dual problem:w T i x n ) -y T n w i ψ (x T n w i ) ≤ β, i ∈ [m] w T x n ) -y T n wψ (x T n w) ≤ β, i ∈ [m](30) This completes the proof.E.4 PROOF OF PROPOSITION 4PROOF Based on the hyper-plane arrangements D 1 , . . . , D p , the dual constraint is equivalent to that for all j ∈ [p],2 tr(D j ) w 2 2 -2w T Λ T D j Xw ≤ β (31)

YifeiWang, Yixuan Hua, Emmanuel Candés, and Mert Pilanci.  Overparameterized relu neural networks learn the simplest models: Neural isometry and exact recovery. arXiv preprint arXiv:2209.15265, 2022.

By restricting the domain C ∞ (R d ) to H, we complete the proof.

annex

A CODES FOR NUMERICAL EXPERIMENT All codes for the numerical experiment can be found in https://github.com/ai-submit/ OptimalWasserstein.

B COMPARISON WITH PREVIOUS WORKS ON CONVEX OPTIMIZATION FORMULATION OF NEURAL NETWORKS

Previous works on convex optimization formulation of neural networks mainly focus on the supervised learning problem of two-layer neural networks using convex loss functions (e.g., squared loss, logistic loss). Our work utilizes a similar convex analytic framework to solve the variational problem of approximating the Wasserstein gradient direction, which is different from supervised learning. The convex optimization approach is related to the idea of infinite width neural networks modeled as probability measures. The dual problem itself is equivalent to the convex dual problem when the neural network in the primal problem has infinitely many neurons. However, the convex optimization approach tackles networks of arbitrary width that are able to learn useful representations, while the infinite width limit is quite limited (limited to basically kernel methods).

C.1 PDE-CONSTRAINED LINEAR BAYESIAN INFERENCE

In this experiment, we consider a linear Bayesian inference problem constrained by a partial differential equation (PDE) model for contaminant diffusion in environmental engineering in domain D = (0, 1),where x is a contaminant source field parameter in domain D, u is the contaminant concentration which we can observe at some locations, κ and ν are diffusion and reaction coefficients. For simplicity, we set κ, ν = 1, u(0) = u(1) = 0, and consider 15 pointwise observations of u with 1% noise, equidistantly distributed in D. We consider a Gaussian prior distribution x ∼ N (0, C) with covariance given by a differential operator C = (-δ∆+γI) -α with δ, γ, α > 0 representing the correlation length and variance, which is commonly used in geoscience. We set δ = 0.1, γ = 1, α = 1. In this linear setting, the posterior is Gaussian with the mean and covariance given analytically, which are used as reference to assess the sample goodness. We solve this forward model by a finite element method with piece-wise linear elements on a uniform mesh of size 2 k , k ≥ 1. We project this high-dimensional parameter to the data-informed low dimensions as in Wang et al. (2021a) to alleviate the curse of dimensionality when applying WGD-cvxNN and WGD-NN, which we call pWGD-cvxNN and pWGD-NN, respectively. For k = 4 we have 17 dimensions for the discrete parameter and 4 dimensions after projection.We run pWGD-cvxNN and pWGD-NN using 16 samples for 200 iterations with α l = 10 -3 , β = 5, γ 1 = 0.95, and γ 2 = 0.95 10 for both methods. We use m = 200 neurons for pWGD-NN and train it by the Adam optimizer for 200 sub-iterations as in the first example. From Figure 5 , we observe that pWGD-cvxNN achieves better root mean squared error (RMSE) than pWGD-NN for both the sample mean and the sample variance compared to the reference.From a convex optimization perspective, the natural idea to interpret the constraint (32) is to transform the minimization problem into a maximization problem. We can rewrite the minimization problem in (32) as a trust region problem with inequality constraints:As the problem ( 33) is a convex problem, by taking the dual of (33) w.r.t. w, we can transform (33) into a maximization problem. However, as ( 33) is a trust region problem with inequality constraints, the dual problem of (33) can be very complicated. According to (Jeyakumar & Li, 2014) , the optimal value of the problem ( 33) is bounded by the optimal value of the following SDP The constraints on Λ in the dual problem (13) include that the optimal value of ( 34) is bounded from below by -β. According to Lemma 1, this constraint is equivalent to that there exist r ∈ R N +1 and γ such thatAs e d+1 e T d+1 is positive semi-definite, the above condition on Λ is also equivalent to that there exist r ∈ R N +1 such that wTherefore, replacing the dual constraint max w:we obtain the relaxed dual problem. As its feasible domain is a subset of the feasible domain of the dual problem, the optimal value of the relaxed dual problem gives a lower bound for the optimal value of the dual problem.E.5 PROOF OF PROPOSITION 5 PROOF Consider the Lagrangian functionwhere we write r = r (1,-) , . . . , r (p,-) , r (1,+) , . . . , r (p,+) ∈ R N +1 2p , S = S (1,-) , . . . , S (p,-) , S (1,+) , . . . , S (p,+) ∈ S d+1 + 2p .(47)Here we write S d+1 + = {S ∈ S d+1 |S 0}. By maximizing w.r.t. Λ and r, we derive the bi-dual problem ( 16).

E.6 PROOF OF THEOREM 1

Suppose that (Z, W, α) is a feasible solution to (11). Let D j1 , . . . , D j k be the enumeration ofandFor j / ∈ {j 1 , . . . , j k }, we simply set S (j,+) = 0, S (j,-) = 0. As w i 2 ≤ 1 and D ji = I(Xw i ≥ 0), we can verify that tr(S (j,-) H (j) n ) ≤ 0, tr(S (j,+) H (j) n ) ≤ 0 are satisfied for j = j 1 , . . . , j m and n = 0, 1, . . . , N . This is because for n = 0, as HFor n = 1, . . . , N , we haveBased on the above transformation, we can rewrite the bidual problem in the form of the primal problem (12). For S ∈ S d+1 , we note thatwhere 2α l (Xw l ) + w T l .Therefore, we have p j=1 Ã * j (S (j,-) -S (j,+) ) = 2 m i=1 α i (Xw i ) + w T i .As n-th row of Z satisfies that z n = 2 m i=1 α i w i (x T n w i ) + , this implies thatÃ * j (S (j,-) -S (j,+) ).Hence (Z, {(S (j,-) , (S (j,-) } p j=1 ) is feasible to the relaxed bi-dual problem ( 16). We can also compute that Thus, the primal problem ( 12) with (Z, W, α) and the relaxed bi-dual problem ( 16) with (Z, {(S (j,-) , (S (j,-) } p j=1 ) have the same objective value.

