PAVI: PLATE-AMORTIZED VARIATIONAL INFERENCE

Abstract

Given observed data and a probabilistic generative model, Bayesian inference aims at obtaining the distribution of a model's latent parameters that could have yielded the data. This task is challenging for large population studies where thousands of measurements are performed over a cohort of hundreds of subjects, resulting in a massive latent parameter space. This large cardinality renders offthe-shelf Variational Inference (VI) computationally impractical. In this work, we design structured VI families that can efficiently tackle large population studies. Our main idea is to share the parameterization and learning across the different i.i.d. variables in a generative model -symbolized by the model's plates. We name this concept plate amortization, and illustrate the powerful synergies it entitles, resulting in expressive, parsimoniously parameterized and orders of magnitude faster to train large scale hierarchical variational distributions. We illustrate the practical utility of PAVI through a challenging Neuroimaging example featuring a million latent parameters, demonstrating a significant step towards scalable and expressive Variational Inference.

1. INTRODUCTION

Population studies analyse measurements over large cohorts of human subjects. These studies are ubiquitous in health care (Fayaz et al., 2016; Towsley et al., 2011) , and can typically involve hundreds of subjects and measurements per subject. For instance in Neuroimaging (Kong et al., 2019) , measurements X can correspond to signals in hundreds of locations in the brain for a thousand subjects. Given this observed data X, and a generative model that can produce data given model parameters Θ, we want to recover the parameters Θ that could have yielded the observed X. In our Neuroimaging example, Θ can be local labels for each brain location and subject, together with global parameters common to all subjects -such as the connectivity corresponding to each label. We want to recover the distribution of the Θ that could have produced X. Following the Bayesian formalism (Gelman et al., 2004) , we cast both Θ and X as Random Variables (RVs) and our goal is to recover the posterior distribution p(Θ|X). Due to the nested structure of our applications we focus on the case where p corresponds to a Hierarchical Bayesian Model (HBM) (Gelman et al., 2004) . In the context of population studies, the multitude of subjects and measurements per subject implies a large dimensionality for both Θ and X. This large dimensionality in turn creates computational hurdles that we tackle through our method. Several inference methods exist in the literature. Earliest works resorted to Markov Chain Monte Carlo (Koller & Friedman, 2009) , which tend to be slow in high dimensional settings (Blei et al., 2017) . Recent approaches, coined Variational Inference (VI), cast the inference as an optimization problem (Blei et al., 2017; Zhang et al., 2019; Ruiz & Titsias, 2019) . Inference reduces to finding the distribution q(Θ; ϕ) ∈ Q closest to the unknown posterior p(Θ|X) in a variational family Q chosen by the experimenter. Historically, VI required to manually derive and optimize over Q, which remains an effective method where applicable (Dao et al., 2021) . In contrast, we follow the idea of automatic VI: deriving an efficient family Q directly from the HBM p (Kucukelbir et al., 2016; Ambrogioni et al., 2021a; b) . Our method also relies on the pervasive idea of amortization in VI: deriving posteriors usable across multiple data points, which can be linked to meta-learning (Zhang et al., 2019; Ravi & Beatson, 2019; Iakovleva et al., 2020; Yao et al., 2019) . In particular, Neural Processes share with our method the conditioning a density estimator by the output of a permutationinvariant encoder (Garnelo et al., 2018; Dubois et al., 2020; Zaheer et al., 2018) . Meta-learning 1

