GFLOWNETS AND VARIATIONAL INFERENCE

Abstract

This paper builds bridges between two families of probabilistic algorithms: (hierarchical) variational inference (VI), which is typically used to model distributions over continuous spaces, and generative flow networks (GFlowNets), which have been used for distributions over discrete structures such as graphs. We demonstrate that, in certain cases, VI algorithms are equivalent to special cases of GFlowNets in the sense of equality of expected gradients of their learning objectives. We then point out the differences between the two families and show how these differences emerge experimentally. Notably, GFlowNets, which borrow ideas from reinforcement learning, are more amenable than VI to off-policy training without the cost of high gradient variance induced by importance sampling. We argue that this property of GFlowNets can provide advantages for capturing diversity in multimodal target distributions.

1. INTRODUCTION

Many probabilistic generative models produce a sample through a sequence of stochastic choices. Non-neural latent variable models (e.g., Blei et al., 2003) , autoregressive models, hierarchical variational autoencoders (Sønderby et al., 2016) , and diffusion models (Ho et al., 2020) can be said to rely upon a shared principle: richer distributions can be modeled by chaining together a sequence of simple actions, whose conditional distributions are easy to describe, than by performing generation in a single sampling step. When many intermediate sampled variables could generate the same object, making exact likelihood computation intractable, hierarchical models are trained with variational objectives that involve the posterior over the sampling sequence (Ranganath et al., 2016b) . This work connects variational inference (VI) methods for hierarchical models (i.e., sampling through a sequence of choices conditioned on the previous ones) with the emerging area of research on generative flow networks (GFlowNets; Bengio et al., 2021a) . GFlowNets have been formulated as a reinforcement learning (RL) algorithm -with states, actions, and rewards -that constructs an object by a sequence of actions so as to make the marginal likelihood of producing an object proportional to its reward. While hierarchical VI is typically used for distributions over real-valued objects, GFlowNets have been successful at approximating distributions over discrete structures for which exact sampling is intractable, such as for molecule discovery (Bengio et al., 2021a) , for Bayesian posteriors over causal graphs (Deleu et al., 2022) , or as an amortized learned sampler for approximate maximum-likelihood training of energy-based models (Zhang et al., 2022b) . Although GFlowNets appear to have different foundations (Bengio et al., 2021b) and applications than hierarchical VI algorithms, we show here that the two are closely connected. As our main theoretical contribution, we show that special cases of variational algorithms and GFlowNets coincide in their expected gradients. In particular, hierarchical VI (Ranganath et al., 2016b) and nested VI (Zimmermann et al., 2021) are related to the trajectory balance and detailed balance objectives for GFlowNets (Malkin et al., 2022; Bengio et al., 2021b) . We also point out the differences between VI and GFlowNets: notably, that GFlowNets automatically perform gradient variance reduction by estimating a marginal quantity (the partition function) that acts as a baseline and allow off-policy learning without the need for reweighted importance sampling. Our theoretical results are accompanied by experiments that examine what similarities and differences emerge when one applies hierarchical VI algorithms to discrete problems where GFlowNets have been used before. These experiments serve two purposes. First, they supply a missing hierarchical VI baseline for problems where GFlowNets have been used in past work. The relative performance of this baseline illustrates the aforementioned similarities and differences between VI and GFlowNets. Second, the experiments demonstrate the ability of GFlowNets, not shared by hierarchical VI, to learn from off-policy distributions without introducing high gradient variance. We show that this ability to learn with exploratory off-policy sampling is beneficial in discrete probabilistic modeling tasks, especially in cases where the target distribution has many modes.

2.1. GFLOWNETS: NOTATION AND BACKGROUND

We consider the setting of Bengio et al. (2021a) . We are given a pointedfoot_0 directed acyclic graph (DAG) G = (S, A), where S is a finite set of vertices (states), and A ⊂ S × S is a set of directed edges (actions). If 𝑠→𝑠 ′ is an action, we say 𝑠 is a parent of 𝑠 ′ and 𝑠 ′ is a child of 𝑠. There is exactly one state that has no incoming edge, called the initial state 𝑠 0 ∈ S. States that have no outgoing edges are called terminating. We denote by X the set of terminating states. A complete trajectory is a sequence 𝜏 = (𝑠 0 → . . . →𝑠 𝑛 ) such that each 𝑠 𝑖 →𝑠 𝑖+1 is an action and 𝑠 𝑛 ∈ X. We denote by T the set of complete trajectories and by 𝑥 𝜏 the last state of a complete trajectory 𝜏. GFlowNets are a class of models that amortize the cost of sampling from an intractable target distribution over X by learning a functional approximation of the target distribution using its unnormalized density or reward function, 𝑅 : X → R + . While there exist different parametrizations and loss functions for GFlowNets, they all define a forward transition probability function, or a forward policy, 𝑃 𝐹 (-| 𝑠), which is a distribution over the children of every state 𝑠 ∈ S. The forward policy is typically parametrized by a neural network that takes a representation of 𝑠 as input and produces the logits of a distribution over its children. Any forward policy 𝑃 𝐹 induces a distribution over complete trajectories 𝜏 ∈ T (denoted by 𝑃 𝐹 as well), which in turn defines a marginal distribution over terminating states 𝑥 ∈ X (denoted by 𝑃 ⊤ 𝐹 ): 𝑃 𝐹 (𝜏 = (𝑠 0 → . . . →𝑠 𝑛 )) = 𝑛-1 𝑖=0 𝑃 𝐹 (𝑠 𝑖+1 | 𝑠 𝑖 ) ∀𝜏 ∈ T , 𝑃 ⊤ 𝐹 (𝑥) = ∑︁ 𝜏 ∈ T:𝑥 𝜏 =𝑥 𝑃 𝐹 (𝜏) ∀𝑥 ∈ X. Given a forward policy 𝑃 𝐹 , terminating states 𝑥 ∈ X can be sampled from 𝑃 ⊤ 𝐹 by sampling trajectories 𝜏 from 𝑃 𝐹 (𝜏) and taking their final states 𝑥 𝜏 . GFlowNets aim to find a forward policy 𝑃 𝐹 for which 𝑃 ⊤ 𝐹 (𝑥) ∝ 𝑅(𝑥). Because the sum in (2) is typically intractable to compute exactly, training objectives for GFlowNets introduce auxiliary objects into the optimization. For example, the trajectory balance objective (TB; Malkin et al., 2022) introduces an auxiliary backward policy 𝑃 𝐵 , which is a learned distribution 𝑃 𝐵 (-| 𝑠) over the parents of every state 𝑠 ∈ S, and an estimated partition function 𝑍, typically parametrized as exp(log 𝑍) where log 𝑍 is the learned parameter. The TB objective for a complete trajectory 𝜏 is defined as L TB (𝜏; 𝑃 𝐹 , 𝑃 𝐵 , 𝑍) = log 𝑍 • 𝑃 𝐹 (𝜏) 𝑅(𝑥 𝜏 )𝑃 𝐵 (𝜏 | 𝑥 𝜏 ) 2 , where 𝑃 𝐵 (𝜏 | 𝑥 𝜏 ) = (𝑠→𝑠 ′ ) ∈ 𝜏 𝑃 𝐵 (𝑠 | 𝑠 ′ ). If L TB is made equal to 0 for every complete trajectory 𝜏, then 𝑃 ⊤ 𝐹 (𝑥) ∝ 𝑅(𝑥) for all 𝑥 ∈ X and 𝑍 is the inverse constant of proportionality: 𝑍 = 𝑥 ∈ X 𝑅(𝑥). The objective (3) is minimized by sampling trajectories 𝜏 from some distribution and making gradient steps on (3) with respect to the parameters of 𝑃 𝐹 , 𝑃 𝐵 , and log 𝑍. The distribution from which 𝜏 is sampled amounts to a choice of scalarization weights for the multi-objective problem of minimizing (3) over all 𝜏 ∈ T . If 𝜏 is sampled from 𝑃 𝐹 (𝜏) -note that this is a nonstationary scalarization -we say the algorithm runs on-policy. If 𝜏 is sampled from another distribution, the algorithm runs off-policy; typical choices are to sample 𝜏 from a tempered version of 𝑃 𝐹 to encourage exploration (Bengio et al., 2021a; Deleu et al., 2022) or to sample 𝜏 from the backward policy 𝑃 𝐵 (𝜏|𝑥) starting from given terminating states 𝑥 (Zhang et al., 2022b) . By analogy with the RL nomenclature, we call the behavior policy the one that samples 𝜏 for the purpose of obtaining a stochastic gradient, e.g, the gradient of the objective L TB in (3) for the sampled 𝜏.



A pointed DAG is one with a designated initial state.

availability

https://github.com/GFNOrg/GFN_vs_HVI.

