PARTICLE-BASED VARIATIONAL INFERENCE WITH PRECONDITIONED FUNCTIONAL GRADIENT FLOW

Abstract

Particle-based variational inference (VI) minimizes the KL divergence between model samples and the target posterior with gradient flow estimates. With the popularity of Stein variational gradient descent (SVGD), the focus of particlebased VI algorithms has been on the properties of functions in Reproducing Kernel Hilbert Space (RKHS) to approximate the gradient flow. However, the requirement of RKHS restricts the function class and algorithmic flexibility. This paper offers a general solution to this problem by introducing a functional regularization term that encompasses the RKHS norm as a special case. This allows us to propose a new particle-based VI algorithm called preconditioned functional gradient flow (PFG). Compared to SVGD, PFG has several advantages. It has a larger function class, improved scalability in large particle-size scenarios, better adaptation to ill-conditioned distributions, and provable continuous-time convergence in KL divergence. Additionally, non-linear function classes such as neural networks can be incorporated to estimate the gradient flow. Our theory and experiments demonstrate the effectiveness of the proposed framework.

1. INTRODUCTION

Sampling from unnormalized density is a fundamental problem in machine learning and statistics, especially for posterior sampling. Markov Chain Monte Carlo (MCMC) (Welling & Teh, 2011; Hoffman et al., 2014; Chen et al., 2014) and Variational inference (VI) (Ranganath et al., 2014; Jordan et al., 1999; Blei et al., 2017) are two mainstream solutions: MCMC is asymptotically unbiased but sample-exhausted; VI is computationally efficient but usually biased. Recently, particle-based VI algorithms (Liu & Wang, 2016; Detommaso et al., 2018; Liu et al., 2019) tend to minimize the Kullback-Leibler (KL) divergence between particle samples and the posterior, and absorb the advantages of both MCMC and VI: (1) non-parametric flexibility and asymptotic unbiasedness; (2) sample efficiency with the interaction between particles; (3) deterministic updates. Thus, these algorithms are competitive in sampling tasks, such as Bayesian inference (Liu & Wang, 2016; Feng et al., 2017; Detommaso et al., 2018) , probabilistic models (Wang & Liu, 2016; Pu et al., 2017) . Given a target distribution p * (x), particle-based VI aims to find g(t, x), so that starting with X 0 ∼ p 0 , the distribution p(t, x) of the following method: dX t = g(t, X t )dt, converges to p * (x) as t → ∞. By the continuity equation (Jordan et al., 1998) , we can capture the evolution of p(t, x) by ∂p(t, x) ∂t = -∇ • (p(t, x)g(t, x)) . In order to measure the "closeness" between p(t, •) and p * , we typically adopt the KL divergence, D KL (t) = p(t, x) ln p(t, x) p * (x) dx. Using chain rule and integration by parts, we have dD KL (t) dt = -p(t, x)[∇ • g(t, x) + g(t, x) ⊤ ∇ x ln p * (x)]dx, which captures the evolution of KL divergence. To minimize the KL divergence, one needs to define a "gradient" to update the particles as our g(t, x). The most standard approach, Wasserstein gradient (Ambrosio et al., 2005) , defines a gradient for p(t, x) in the Wasserstein space, which contains probability measures with bounded second moments. In particular, for any functional L that maps probability density p(t, x) to a non-negative scalar, we say that the particle density p(t, x) follows the Wasserstein gradient flow of , 2009) . For KL divergence, the solution is ∇ ln p * (x) p(t,x) . However, the computation of deterministic and time-inhomogeneous Wasserstein gradient is non-trivial. It is necessary to restrict the function class of g(t, x) to obtain a tractable form. L if g(t, x) is the gradient field of L 2 (R d )-functional derivative of L (Villani Stein variational gradient descent (SVGD) is the most popular particle-based algorithm, which provides a tractable form to update particles with the kernelized gradient flow (Chewi et al., 2020; Liu, 2017) . It updates particles by minimizing the KL divergence with a functional gradient measured in RKHS. By restricting the functional gradient with bounded RKHS norm, it has an explicit formulation: g(t, x) can be obtained by minimizing Eq. (3). Nonetheless, there are still some limitations due to the restriction of RKHS: (1) the expressive power is limited because kernel method is known to suffer from the curse of dimensionality (Geenens, 2011); (2) with n particles, the O(n 2 ) computational overhead of kernel matrix is required. Further, we identify another crucial limitation of SVGD: the kernel design is highly non-trivial. Even in the simple Gaussian case, where particles start with N (0, I) and p * = N (µ * , Σ * ), commonly used kernels such as linear and RBF kernel, have fundamental drawbacks in SVGD algorithm (Example 1). Our motivation originates from functional gradient boosting (Friedman, 2001; Nitanda & Suzuki, 2018; Johnson & Zhang, 2019) . For each p(t, x), we find a proper function as g(t, x) in the function class F to minimize Eq. ( 3). In this context, we design a regularizer for the functional gradient to approximate variants of "gradient" explicitly. We propose a regularization family to penalize the particle distribution's functional gradient output. For well-conditioned -∇ 2 ln p *foot_0 , we can approximate the Wasserstein gradient directly; For ill-conditioned -∇ 2 ln p * , we can adapt our regularizer to approximate a preconditioned one. Thus, our functional gradient is an approximation to the preconditioned Wasserstein gradient. Regarding the function space, we do not restrict the function in RKHS. Instead, we can use non-linear function classes such as neural networks to obtain a better approximation capacity. The flexibility of the function space can lead to a better sampling algorithm, which is supported by our empirical results. Contributions. We present a novel particle-based VI framework that incorporates functional gradient flow with general regularizers. We leverage a special family of regularizers to approximate the preconditioned Wasserstein gradient flow, which proves to be more effective than SVGD. The functional gradient in our framework explicitly approximates the preconditioned Wasserstein gradient, making it well-suited to handle ill-conditioned cases and delivering provable convergence rates. Additionally, our proposed algorithm eliminates the need for the computationally expensive O(n 2 ) kernel matrix, resulting in increased computational efficiency for larger particle sizes. Both theoretical and empirical results demonstrate the superior performance of our framework and proposed algorithm.

2. ANALYSIS

Notations. In this paper, we use x to denote particle samples in R d . The distributions are assumed to be absolutely continuous w.r.t. the Lebesgue measure. The probability density function of the posterior is denoted by p * . p(t, x) (or p t ) refers to particle distribution at time t. For scalar function p(t, x), ∇ x p(t, x) denotes its gradient w.r.t. x. For vector function g(t, x), ∇ x g(t, x), ∇ x • g(t, x), ∇ 2 x g(t, x) denote its Jacobian matrix, divergence, Hessian w.r.t. x. We let g(t, x) belong to a vector-valued function class F, and find the best functional gradient direction. Inspired by the gradient boosting algorithm for regression and classification problems, we approximate the gradient flow by a function g(t, x) ∈ F with a regularization term, which solves the



For any positive-definite matrix, the condition number is the ratio of the maximal eigenvalue to the minimal eigenvalue. A low condition number is well-conditioned, while a high condition number is ill-conditioned.



.r.t. t. Without ambiguity, ∇ stands for ∇ x for conciseness. Notation ∥x∥ 2 H stands for x ⊤ Hx and ∥x∥ I is denoted by ∥x∥. Notation ∥ • ∥ H d denotes the RKHS norm on R d .

