BRIDGE THE INFERENCE GAPS OF NEURAL PRO-CESSES VIA EXPECTATION MAXIMIZATION

Abstract

The neural process (NP) is a family of computationally efficient models for learning distributions over functions. However, it suffers from under-fitting and shows suboptimal performance in practice. Researchers have primarily focused on incorporating diverse structural inductive biases, e.g. attention or convolution, in modeling. The topic of inference suboptimality and an analysis of the NP from the optimization objective perspective has hardly been studied in earlier work. To fix this issue, we propose a surrogate objective of the target log-likelihood of the meta dataset within the expectation maximization framework. The resulting model, referred to as the Self-normalized Importance weighted Neural Process (SI-NP), can learn a more accurate functional prior and has an improvement guarantee concerning the target log-likelihood. Experimental results show the competitive performance of SI-NP over other NPs objectives and illustrate that structural inductive biases, such as attention modules, can also augment our method to achieve SOTA performance.

1. INTRODUCTION

Figure 1 : Deep Latent Variable Models for Neural Processes. Here D C and D T respectively denote the context points for the functional prior inference and the target points for the function prediction. The global latent variable z is to summarize function properties. The model involves a functional prior distribution p(z|D C ; ϑ) and a functional generative distribution p(D T |z; ϑ). Please refer to Section (2) for detailed notation descriptions. The combination of deep neural networks and stochastic processes provides a promising framework for modeling data points with correlations (Ghahramani, 2015) . It exploits the high capacity of deep neural networks and enables uncertainty quantification for distributions over functions. As an example, we can look at the deep Gaussian process (Damianou & Lawrence, 2013) . However, the run-time complexity of predictive distributions in Gaussian processes is cubic w.r.t. the number of predicted data points. To circumvent this, Garnelo et al. (2018a; b) developed the family of neural processes (NPs) as the alternative, which can model more flexible function distributions and capture predictive uncertainty at a lower computational cost. In this paper, we study the vanilla NP as a deep latent variable model and show the generative process in Fig. (1) . In particular, let us recap the inference methods used in vanilla NPs: It learns to approximate the functional posterior q ϕ (z) ≈ p(z|D T ; ϑ) and a functional prior q ϕ (z|D C ) ≈ p(z|D C ; ϑ), which are permutation invariant to the order of data points. Then the predictive distribution for a data point [x * , y * ] can be formulated in the form E q ϕ (z|D C ) [p(y * |[x * , z]; ϑ)]. While the NP provides a computationally efficient framework for modeling exchangeable stochastic processes, it exhibits underfitting and fails to capture accurate uncertainty (Garnelo et al., 2018b; Kim et al., 2019) in practice. To improve its generalization capability, researchers have focused much attention on finding appropriate inductive biases, e.g. attention (Kim et al., 2019) and convolutional modules (Gordon et al., 2019; Kawano et al., 2020) , Bayesian mixture structures (Wang & van Hoof, 2022) or Bayesian hierarchical structures (Naderiparizi et al., 2020) , to incorporate in modeling. Research Motivations. Most previous work (Garnelo et al., 2018a; b; Kim et al., 2019; Gordon et al., 2019; Wang & van Hoof, 2022) ignores the reason why the vanilla NP suffers the performance bottleneck and what kind of functional priors the vanilla NPs can represent. In particular, we point out the remaining crucial issues that have not been sufficiently investigated in this domain, respectively: (i) understanding the inference suboptimality of vanilla NPs (ii) quantifying statistical traits of learned functional priors. To this end, we try to diagnose the vanilla NP from its optimization objective. Our primary interest is to find a tractable way to optimize NPs and examine the statistics of learned functional priors from diverse optimization objectives. Developed Methods. To understand the inference suboptimality of vanilla NPs, we establish connections among a collection of optimization objectives, e.g. approximate evidence lower bounds (ELBOs) and Monte Carlo estimates of log-likelihoods, in Section (3). Then we formulate a tractable optimization objective within the variational expectation maximization framework and obtain the Self-normalized Importance weighted neural process (SI-NP) in Section (4). Contributions. To summarize, our primary contributions are three-fold: (i) we analyze the inherent inference sub-optimality of NPs from an optimization objective optimization perspective; (ii) we demonstrate the equivalence of conditional NPs (Garnelo et al., 2018a) and SI-NPs with one Monte Carlo sample estimate, which closely relates to the prior collapse in Definition (3.1); (iii) our developed SI-NPs have an improvement guarantee to the likelihood of meta dataset in optimization and show a significant advantage over baselines with other objectives.

2. PRELIMINARIES

General Notations. We study NPs in a meta learning setup. T defines a set of tasks with τ a sampled task. Let D C τ = {(x i , y i )} n i=1 and D T τ = {(x i , y i )} n+m i=1 denote the context points for the functional prior inference and the target points for the function prediction. The latent variable z is a functional representation of a task τ with observed data points. We refer to ϑ ∈ Θ as the parameters of the deep latent variable model for NPs. In detail, ϑ consists of encoder parameters in a functional prior p(z|D C τ ; ϑ) and decoder parameters in a generative distribution p(D T τ |z; ϑ). ϕ refer to the parameters of a variational posterior distribution q ϕ (z) = q ϕ (z|D T τ ), while η refer to the parameters of a proposal distribution q η (z) in the following self-normalized importance sampling. Gaussian distributions with diagonal covariance matrices are the default choice for these distributions, e.g. p(z|D C τ ; ϑ) = N (z; µ ϑ (D C τ ), Σ ϑ (D C τ )), q ϕ (z) = N (z; µ ϕ (D T τ ), Σ ϕ (D T τ )) and q η (z|D T τ ) = N (z; µ η (D T τ ), Σ η (D T τ ) ). NPs as Exchangeable Stochastic Processes. In vanilla NPs, the element-wise generative process can be translated into Eq. ( 1). Here the mean and variance functions are respectively denoted by µ and Σ. ρ x1:n+m (y 1:n+m ) = p(z) n+m i=1 N (y i ; µ(x i , z), Σ(x i , z))dz Based on the Kolmogorov extension theorem (Klenke, 2013) and de Finneti's theorem (Kerns & Székely, 2006) , the above equation ρ x1:n+m (y 1:n+m ) is verified to be a well-defined exchangeable stochastic process. NPs in Meta Learning Tasks. Given a collection of tasks T , we can decompose the marginal distribution p(D T T |D C T ; ϑ) with a global latent variable z in Eq. ( 2). Here the conditional distribution p(z|D C τ ; ϑ) with τ ∈ T is permutation invariant w.r.t. the order of data points and encodes the functional prior in the generative process. 

