REAL-TIME VARIATIONAL METHOD FOR LEARNING NEURAL TRAJECTORY AND ITS DYNAMICS

Abstract

Latent variable models have become instrumental in computational neuroscience for reasoning about neural computation. This has fostered the development of powerful offline algorithms for extracting latent neural trajectories from neural recordings. However, despite the potential of real time alternatives to give immediate feedback to experimentalists, and enhance experimental design, they have received markedly less attention. In this work, we introduce the exponential family variational Kalman filter (eVKF), an online recursive Bayesian method aimed at inferring latent trajectories while simultaneously learning the dynamical system generating them. eVKF works for arbitrary likelihoods and utilizes the constant base measure exponential family to model the latent state stochasticity. We derive a closed-form variational analogue to the predict step of the Kalman filter which leads to a provably tighter bound on the ELBO compared to another online variational method. We validate our method on synthetic and real-world data, and, notably, show that it achieves competitive performance.

1. INTRODUCTION

Population of neurons, especially in higher-order perceptual and motor cortices, show coordinated pattern of activity constrained to an approximately low dimensional 'neural manifold' (Sohn et al., 2019; Churchland et al., 2012; Saxena et al., 2022) . The dynamical structure of latent trajectories evolving along the neural manifold is thought to be a valid substrate of neural computation. This idea has fostered extensive experimental studies and the development of computational methods to extract these trajectories directly from electrophysiological recordings. Great strides have been made in developing computational tools for the purpose of extracting latent neural trajectories in post hoc neural data analysis. However, while recently developed tools have proven their efficacy in accurately inferring latent neural trajectories (Pandarinath et al., 2018; Pei et al., 2021; Yu et al., 2009; Zhao & Park, 2017) , learning their underlying dynamics has received markedly less attention. Furthermore, even less focus has been placed on real-time methods that allow for online learning of neural trajectories and their underlying dynamics. Real-time learning of neural dynamics would facilitate more efficient experimental design, and increase the capability of closed-loop systems where an accurate picture of the dynamical landscape leads to more precise predictions (Peixoto et al., 2021; Bolus et al., 2021) . In this work, we consider the problem of inferring latent trajectories while simultaneously learning the dynamical system generating them in an online fashion. We introduce the exponential family variational Kalman filter (eVKF), a novel variational inference scheme that draws inspiration from the 'predict' and 'update' steps used in the classic Kalman filter (Anderson & Moore, 1979) . We theoretically justify our variational inference scheme by proving it leads to a tighter 'filtering' evidence lower bound (ELBO) than a 'single step' approximation that utilizes the closed form solution of the proposed 'variational prediction' step. Finally, we show how parameterization of the dynamics via a universal function approximator in tandem with exponential family properties facilitates an alternative optimization procedure for learning the generative model. Our contributions are as follows: (i) We propose a novel variational inference scheme for online learning analogous to the predict and update steps of the Kalman filter. (ii) We show the variational prediction step offers a closed form solution when we restrict our variational approximations to constant base measure exponential families (Theorem 1). (iii) We justify our two step procedure by showing that we achieve a tighter bound on the ELBO, when compared to directly finding a variational approximation to the filtering distribution (Theorem 2). (iv) We show that when using universal function approximators for modeling the dynamics, we can optimize our model of the dynamics without propagating gradients through the ELBO as is typically done in variational expectation maximization (vEM) or variational autoencoders (VAEs) (Kingma & Welling, 2014).

2. BACKGROUND 2.1 STATE-SPACE MODELS

In this paper, we consider observations (e.g. neural recordings), y t , arriving in a sequential fashion. It is assumed these observations depend directly on a latent Markov process (e.g. structured neural dynamics), z t , allowing us to write the generative model in state-space form: z t | z t-1 ∼ p θ (z t | z t-1 ) (latent dynamics model) y t | z t ∼ p ψ (y t | z t ) (observation model) where z t ∈ R L , y t ∈ R N , ψ parameterize the observation model, and θ parameterize the dynamics model. After observing y t , any statistical quantities of interest related to z t can be computed from the filtering distribution, p(z t | y 1:t ). Since we are considering a periodically sampled data streaming setting, it is important that we are able to compute p(z t | y 1:t ) in a recursive fashion, with constant time and space complexity. In addition to inferring the filtering distribution over latent states, we will also be interested in learning the dynamics as the (prior) conditional probability distribution, p θ (z t | z t-1 ), which captures the underlying dynamical law that governs the latent state z and may implement neural computation. Learning the dynamics facilitates higher quality inference of the latent state, accurate forecasting, and generation of new data. In this paper we will be focused mainly on models where the dynamics are non-linear and parameterized by flexible function approximators. For example, we may model the dynamics as z t | z t-1 ∼ N (f θ (z t-1 ), Q), with f θ : R L → R L parameterized by a neural network.

2.2. KALMAN FILTER

Before diving into the general case, let's revisit the well-established Kalman filter (Särkkä, 2013) . Given linear Gaussian dynamics and observations, the state-space model description is given by p θ (z t | z t-1 ) = N (z t | Az t-1 , Q) p ψ (y t | z t ) = N (y t | Cz t + b, R) θ = {A, Q} ψ = {C, b, R} The Kalman filter recursively computes the Bayes optimal estimate of the latent state z t . Given the filtering posterior of previous time step, p(z t-1 | y 1:t-1 ) = N (m t-1 , P t-1 ), we first predict the latent state distribution (a.k.a. the filtering prior) at time t (3) In order to learn the underlying dynamics A, the linear readout C, state noise Q and observation noise R, the EM algorithm can be employed (Ghahramani & Hinton, 1996) . If a calibrated measure of uncertainty over the model parameters is important, then a prior can be placed over those quantities, and approximate Bayesian methods can be used to find the posterior (Barber & Chiappa, 2006) . When the dynamics are nonlinear, then approximate Bayesian inference can be used to compute the posterior over latent states (Kamthe et al., 2022; Hernandez et al., 2018; Pandarinath et al., 2018) . Note that these methods are for learning the parameters in the offline setting. p(z t | y 1:t-1 ) = E p(zt-1|y1:t-1) [p θ (z t | z t-1 )] (1) = N (z t | Am t-1 , AP t-1 A + Q)



2) Secondly, we update our belief of the current state with the observation y t by Bayes' rule p(z t | y 1:t ) ∝ p(y t | z t ) p(z t | y 1:t-1 ) = N (z t | m t , P t )

