PAVI: PLATE-AMORTIZED VARIATIONAL INFERENCE

Abstract

Given observed data and a probabilistic generative model, Bayesian inference aims at obtaining the distribution of a model's latent parameters that could have yielded the data. This task is challenging for large population studies where thousands of measurements are performed over a cohort of hundreds of subjects, resulting in a massive latent parameter space. This large cardinality renders offthe-shelf Variational Inference (VI) computationally impractical. In this work, we design structured VI families that can efficiently tackle large population studies. Our main idea is to share the parameterization and learning across the different i.i.d. variables in a generative model -symbolized by the model's plates. We name this concept plate amortization, and illustrate the powerful synergies it entitles, resulting in expressive, parsimoniously parameterized and orders of magnitude faster to train large scale hierarchical variational distributions. We illustrate the practical utility of PAVI through a challenging Neuroimaging example featuring a million latent parameters, demonstrating a significant step towards scalable and expressive Variational Inference.

1. INTRODUCTION

Population studies analyse measurements over large cohorts of human subjects. These studies are ubiquitous in health care (Fayaz et al., 2016; Towsley et al., 2011) , and can typically involve hundreds of subjects and measurements per subject. For instance in Neuroimaging (Kong et al., 2019) , measurements X can correspond to signals in hundreds of locations in the brain for a thousand subjects. Given this observed data X, and a generative model that can produce data given model parameters Θ, we want to recover the parameters Θ that could have yielded the observed X. In our Neuroimaging example, Θ can be local labels for each brain location and subject, together with global parameters common to all subjects -such as the connectivity corresponding to each label. We want to recover the distribution of the Θ that could have produced X. Following the Bayesian formalism (Gelman et al., 2004) , we cast both Θ and X as Random Variables (RVs) and our goal is to recover the posterior distribution p(Θ|X). Due to the nested structure of our applications we focus on the case where p corresponds to a Hierarchical Bayesian Model (HBM) (Gelman et al., 2004) . In the context of population studies, the multitude of subjects and measurements per subject implies a large dimensionality for both Θ and X. This large dimensionality in turn creates computational hurdles that we tackle through our method. Several inference methods exist in the literature. Earliest works resorted to Markov Chain Monte Carlo (Koller & Friedman, 2009) , which tend to be slow in high dimensional settings (Blei et al., 2017) . Recent approaches, coined Variational Inference (VI), cast the inference as an optimization problem (Blei et al., 2017; Zhang et al., 2019; Ruiz & Titsias, 2019) . Inference reduces to finding the distribution q(Θ; ϕ) ∈ Q closest to the unknown posterior p(Θ|X) in a variational family Q chosen by the experimenter. Historically, VI required to manually derive and optimize over Q, which remains an effective method where applicable (Dao et al., 2021) . In contrast, we follow the idea of automatic VI: deriving an efficient family Q directly from the HBM p (Kucukelbir et al., 2016; Ambrogioni et al., 2021a; b) . Our method also relies on the pervasive idea of amortization in VI: deriving posteriors usable across multiple data points, which can be linked to meta-learning (Zhang et al., 2019; Ravi & Beatson, 2019; Iakovleva et al., 2020; Yao et al., 2019) . In particular, Neural Processes share with our method the conditioning a density estimator by the output of a permutationinvariant encoder (Garnelo et al., 2018; Dubois et al., 2020; Zaheer et al., 2018) . Meta-learning studies problems with few hierarchies, similar to many VI methods (Ravi & Beatson, 2019; Tran et al., 2017) . Agrawal & Domke (2021) notably studied the case of 2-level HBMs, providing with theoretical guarantees. In contrast, our focus is rather computational, and we study generic HBMs with an arbitrary number of hierarchies, aiming at tackling large population studies efficiently. Modern VI is effective in low-dimensional settings, but does not scale up to large population studies -that involve millions of random variables. In this work we identify and tackle 2 challenges to enable this scale-up. A first challenge with scalability is a detrimental trade-off between expressivity and high-dimensionality (Rouillard & Wassermann, 2022) . To reduce the inference gap, VI requires the variational family Q to contain distributions closely approximating p(Θ|X) (Blei et al., 2017 ). Yet the form of p(Θ|X) is usually unknown to the experimenter. Instead of a lengthy search for a valid family, one can resort to universal density approximators: normalizing flows (Papamakarios et al., 2019) . But the cost for this generality is a heavy parameterization, and normalizing flows scale poorly with the dimensionality of Θ. In large populations studies, as this dimensionality grows, the parameterization of normalizing flows becomes prohibitively large. To tackle this challenge, Rouillard & Wassermann (2022) recently proposed -in the ADAVI architecture-to partially share the parameterization of normalizing flows across the hierarchies of a generative model. While ADAVI tackled the over-parameterization of VI in population studies, it still could not perform inference in very large data regimes. This is due to a second challenge with scalability: as the size of Θ increases, the evaluation of a single gradient over the entirety of an architecture's weights quickly requires too much memory and compute. This second challenge can be overcome using stochastic VI (Hoffman et al., 2013) , which subsamples the parameters Θ inferred for at each optimization step. However, using SVI, the weights for the posterior for a local parameter θ ∈ Θ are only updated when θ is visited by the algorithm. In the presence of hundreds of thousands of such local parameters, stochastic VI can become prohibitively slow. In this work, we introduce the concept of plate amortization (PAVI) for fast inference in large scale HBMs. Instead of considering the inference over local parameters θ as separate problems, our main idea is to share both the parameterization and learning across those local parameters -or equivalently across a model's plates. We first propose an algorithm to automatically derive an expressive yet parsimoniously-parameterized variational family from a plate-enriched HBM. We then propose a hierarchical stochastic optimization scheme to train this architecture efficiently, obtaining orders of magnitude faster convergence. We detail two variants of our method, with different trade-offs between parameterization and inference quality. PAVI leverages the repeated structure of plate-enriched HBMs via an original combination of amortization and stochastic training. Through this combination, our main claim is to enable inference over arbitrarily large population studies -up to a million RVs-with reduced parameterization and training time as the cardinality of the problem augments. We illustrate this by applying PAVI to a challenging human brain cortex parcellation, featuring inference over a cohort of 1000 subjects with tens of thousands of measurements per subject, demonstrating a significant step towards scalable, expressive and fast VI.

2.1. HIERARCHICAL BAYESIAN MODELS (HBMS), TEMPLATES AND PLATES

Our objective is to perform inference in large population studies. As an example of how inference becomes impractical in this context, consider M in fig. 1 as a model for the height distribution in a population. θ 2,0 denotes the mean height across the population. θ 1,0 , θ 1,1 , θ 1,2 denote the mean heights for 3 groups of subjects, distributed around the population mean. X 0 , X 1 represent the observed heights of 2 subjects from group 0, distributed around the group mean. Given the observed subject heights X, the goal is to determine the posterior distributions of the group means p(θ 1,n |X) and population mean p(θ 2,0 |X). As the number of groups and subjects per group augments, the parameterization and time required to infer the posteriors of a growing number of RVs become prohibitively large. Our goal is to keep this inference computationally tractable. In the rest of this section, we define the inference problem and the associated notations. Population studies can be modelled using Hierarchical Bayesian Models (HBMs). Those HBMs feature samples from common conditional distributions at multiple levels. For instance subject heights are i.i.d. given group heights, which are i.i.d. given population height. Due to data being collected across hundreds of subjects (Kong et al., 2019) , HBMs representing population studies feature

