BRIDGE THE INFERENCE GAPS OF NEURAL PRO-CESSES VIA EXPECTATION MAXIMIZATION

Abstract

The neural process (NP) is a family of computationally efficient models for learning distributions over functions. However, it suffers from under-fitting and shows suboptimal performance in practice. Researchers have primarily focused on incorporating diverse structural inductive biases, e.g. attention or convolution, in modeling. The topic of inference suboptimality and an analysis of the NP from the optimization objective perspective has hardly been studied in earlier work. To fix this issue, we propose a surrogate objective of the target log-likelihood of the meta dataset within the expectation maximization framework. The resulting model, referred to as the Self-normalized Importance weighted Neural Process (SI-NP), can learn a more accurate functional prior and has an improvement guarantee concerning the target log-likelihood. Experimental results show the competitive performance of SI-NP over other NPs objectives and illustrate that structural inductive biases, such as attention modules, can also augment our method to achieve SOTA performance.

1. INTRODUCTION

The combination of deep neural networks and stochastic processes provides a promising framework for modeling data points with correlations (Ghahramani, 2015) . It exploits the high capacity of deep neural networks and enables uncertainty quantification for distributions over functions. As an example, we can look at the deep Gaussian process (Damianou & Lawrence, 2013). However, the run-time complexity of predictive distributions in Gaussian processes is cubic w.r.t. the number of predicted data points. To circumvent this, Garnelo et al. (2018a;b) developed the family of neural processes (NPs) as the alternative, which can model more flexible function distributions and capture predictive uncertainty at a lower computational cost. In this paper, we study the vanilla NP as a deep latent variable model and show the generative process in Fig. (1) . In particular, let us recap the inference methods used in vanilla NPs: It learns to approximate the functional posterior q ϕ (z) ≈ p(z|D T ; ϑ) and a functional prior q ϕ (z|D C ) ≈ p(z|D C ; ϑ), which are permutation invariant to the order of data points. Then the predictive distribution for a data point [x * , y * ] can be formulated in the form E q ϕ (z|D C ) [p(y * |[x * , z]; ϑ)]. While the NP provides a computationally efficient framework for modeling exchangeable stochastic processes, it exhibits underfitting and fails to capture accurate uncertainty (Garnelo et al., 2018b; Kim et al., 2019) in practice. To improve its generalization capability, researchers have focused much attention on finding appropriate inductive biases, e.g. attention (Kim * Correspondence Author. 1



Figure 1: Deep Latent Variable Models for Neural Processes. Here D C and D T respectively denote the context points for the functional prior inference and the target points for the function prediction. The global latent variable z is to summarize function properties. The model involves a functional prior distribution p(z|D C ; ϑ) and a functional generative distribution p(D T |z; ϑ). Please refer to Section (2) for detailed notation descriptions.

