REAL-TIME VARIATIONAL METHOD FOR LEARNING NEURAL TRAJECTORY AND ITS DYNAMICS

Abstract

Latent variable models have become instrumental in computational neuroscience for reasoning about neural computation. This has fostered the development of powerful offline algorithms for extracting latent neural trajectories from neural recordings. However, despite the potential of real time alternatives to give immediate feedback to experimentalists, and enhance experimental design, they have received markedly less attention. In this work, we introduce the exponential family variational Kalman filter (eVKF), an online recursive Bayesian method aimed at inferring latent trajectories while simultaneously learning the dynamical system generating them. eVKF works for arbitrary likelihoods and utilizes the constant base measure exponential family to model the latent state stochasticity. We derive a closed-form variational analogue to the predict step of the Kalman filter which leads to a provably tighter bound on the ELBO compared to another online variational method. We validate our method on synthetic and real-world data, and, notably, show that it achieves competitive performance.

1. INTRODUCTION

Population of neurons, especially in higher-order perceptual and motor cortices, show coordinated pattern of activity constrained to an approximately low dimensional 'neural manifold' (Sohn et al., 2019; Churchland et al., 2012; Saxena et al., 2022) . The dynamical structure of latent trajectories evolving along the neural manifold is thought to be a valid substrate of neural computation. This idea has fostered extensive experimental studies and the development of computational methods to extract these trajectories directly from electrophysiological recordings. Great strides have been made in developing computational tools for the purpose of extracting latent neural trajectories in post hoc neural data analysis. However, while recently developed tools have proven their efficacy in accurately inferring latent neural trajectories (Pandarinath et al., 2018; Pei et al., 2021; Yu et al., 2009; Zhao & Park, 2017) , learning their underlying dynamics has received markedly less attention. Furthermore, even less focus has been placed on real-time methods that allow for online learning of neural trajectories and their underlying dynamics. Real-time learning of neural dynamics would facilitate more efficient experimental design, and increase the capability of closed-loop systems where an accurate picture of the dynamical landscape leads to more precise predictions (Peixoto et al., 2021; Bolus et al., 2021) . In this work, we consider the problem of inferring latent trajectories while simultaneously learning the dynamical system generating them in an online fashion. We introduce the exponential family variational Kalman filter (eVKF), a novel variational inference scheme that draws inspiration from the 'predict' and 'update' steps used in the classic Kalman filter (Anderson & Moore, 1979) . We theoretically justify our variational inference scheme by proving it leads to a tighter 'filtering' evidence lower bound (ELBO) than a 'single step' approximation that utilizes the closed form solution of the proposed 'variational prediction' step. Finally, we show how parameterization of the dynamics via a universal function approximator in tandem with exponential family properties facilitates an alternative optimization procedure for learning the generative model. Our contributions are as follows: (i) We propose a novel variational inference scheme for online learning analogous to the predict and update steps of the Kalman filter. (ii) We show the variational prediction step offers a closed form solution when we restrict our variational approximations to constant base measure exponential families (Theorem 1). (iii) We justify our two step procedure by showing that we achieve a tighter bound on the ELBO, when compared to directly finding a variational approximation to the filtering distribution (Theorem 2). (iv) We show that when using universal function approximators for modeling the dynamics, we can optimize our model of the dynamics without propagating gradients through the ELBO as is typically done in variational expectation maximization (vEM) or variational autoencoders (VAEs) (Kingma & Welling, 2014) .

2. BACKGROUND 2.1 STATE-SPACE MODELS

In this paper, we consider observations (e.g. neural recordings), y t , arriving in a sequential fashion. It is assumed these observations depend directly on a latent Markov process (e.g. structured neural dynamics), z t , allowing us to write the generative model in state-space form: z t | z t-1 ∼ p θ (z t | z t-1 ) (latent dynamics model) y t | z t ∼ p ψ (y t | z t ) (observation model) where z t ∈ R L , y t ∈ R N , ψ parameterize the observation model, and θ parameterize the dynamics model. After observing y t , any statistical quantities of interest related to z t can be computed from the filtering distribution, p(z t | y 1:t ). Since we are considering a periodically sampled data streaming setting, it is important that we are able to compute p(z t | y 1:t ) in a recursive fashion, with constant time and space complexity. In addition to inferring the filtering distribution over latent states, we will also be interested in learning the dynamics as the (prior) conditional probability distribution, p θ (z t | z t-1 ), which captures the underlying dynamical law that governs the latent state z and may implement neural computation. Learning the dynamics facilitates higher quality inference of the latent state, accurate forecasting, and generation of new data. In this paper we will be focused mainly on models where the dynamics are non-linear and parameterized by flexible function approximators. For example, we may model the dynamics as z t | z t-1 ∼ N (f θ (z t-1 ), Q), with f θ : R L → R L parameterized by a neural network.

2.2. KALMAN FILTER

Before diving into the general case, let's revisit the well-established Kalman filter (Särkkä, 2013) . Given linear Gaussian dynamics and observations, the state-space model description is given by p θ (z t | z t-1 ) = N (z t | Az t-1 , Q) p ψ (y t | z t ) = N (y t | Cz t + b, R) θ = {A, Q} ψ = {C, b, R} The Kalman filter recursively computes the Bayes optimal estimate of the latent state z t . Given the filtering posterior of previous time step, p(z t-1 | y 1:t-1 ) = N (m t-1 , P t-1 ), we first predict the latent state distribution (a.k.a. the filtering prior) at time t p(z t | y 1:t-1 ) = E p(zt-1|y1:t-1) [p θ (z t | z t-1 )] (1) = N (z t | Am t-1 , AP t-1 A + Q) (2) Secondly, we update our belief of the current state with the observation y t by Bayes' rule p(z t | y 1:t ) ∝ p(y t | z t ) p(z t | y 1:t-1 ) = N (z t | m t , P t ) (3) In order to learn the underlying dynamics A, the linear readout C, state noise Q and observation noise R, the EM algorithm can be employed (Ghahramani & Hinton, 1996) . If a calibrated measure of uncertainty over the model parameters is important, then a prior can be placed over those quantities, and approximate Bayesian methods can be used to find the posterior (Barber & Chiappa, 2006) . When the dynamics are nonlinear, then approximate Bayesian inference can be used to compute the posterior over latent states (Kamthe et al., 2022; Hernandez et al., 2018; Pandarinath et al., 2018) . Note that these methods are for learning the parameters in the offline setting.

3. EXPONENTIAL FAMILY VARIATIONAL KALMAN FILTER (EVKF)

When the models are not linear and Gaussian, the filtering prior Eq. ( 1) and filtering distribution Eq. ( 3) are often intractable. This is unfortunate since most models of practical interests deviate in one way or another from these linear Gaussian assumptions. Drawing inspiration from the predict and update procedure for recursive Bayesian estimation, we propose the exponential family variational Kalman filter (eVKF), a recursive variational inference procedure for exponential family models that jointly infers latent trajectories and learns their underlying dynamics.

3.1. EXPONENTIAL FAMILY DISTRIBUTIONS

We first take time to recall exponential family distributions, as their theoretical properties make them convenient to work with, especially when performing Bayesian inference. An exponential family distribution can be written as p(z) = h(z) exp λ t(z) -A(λ) where h(z) is the base measure, λ is the natural parameter, t(z) is the sufficient statistics, and A(λ) is the log-partition function (Wainwright & Jordan, 2008) . Many widely used distributions reside in the exponential family; a Gaussian distribution, p(z) = N (m, P), for example, has t(z) = z zz , λ = -foot_0 2 P -1 m -1 2 P -1 and h(z) = (2π) -L/2 . Note that the base measure h does not depend on z for a Gaussian distribution. We hereby call such an exponential family distribution a constant base measure if its base measure, h, is constant w.r.t. z. This class encapsulates many well known distributions such as the Gaussian, Bernoulli, Beta, and Gamma distributions. An additional and important fact we use is that, for a minimal 1 exponential family distribution, there exists a one-to-one mapping between the natural parameters, λ, and the mean parameters, µ := E p(z) [t(z) ]. This mapping is given by µ = ∇ λ A(λ), and its inverse by (Seeger, 2005) . λ = ∇ µ E p(z;λ(µ)) [log p(z; λ(µ))], though E p(z;λ(µ)) [log p(z; λ(µ))] is usually intractable If we have a conditional exponential family distribution, p θ (z t | z t-1 ), then the natural parameters of z t | z t-1 are a function of z t-1 . In this case, we can write the conditional density function as p θ (z t | z t-1 ) = h(z t ) exp(λ θ (z t-1 ) t(z t ) -A(λ θ (z t-1 ))) where λ θ (•) maps z t-1 to the space of valid natural parameters for z t . This allows us to use expressive natural parameter mappings, while keeping the conditional distribution in the constant base measure exponential family. Assume that at time t, we have an approximation to the filtering distribution, q(z t-1 ), and that this approximation is a constant base measure exponential family distribution so that p(z t-1 | y 1:t-1 ) ≈ q(z t-1 ) = h exp(λ t(z t-1 ) -A(λ)) The primary goal of filtering is to efficiently compute a good approximation q(z t ) of p(z t | y 1:t ), the filtering distribution at time t. As we will show, following the two-step variational prescription of, predict and then update, leads to a natural variational inference scheme and a provably tighter ELBO than a typical single-step variational approximation.

3.2. VARIATIONAL PREDICTION STEP

Now that we have relaxed the linear Gaussian assumption, the first problem we encounter is computing the predictive distribution (a.k.a. filtering prior) p(z t | y 1:t-1 ) = E p(zt-1|y1:t-1) [p θ (z t | z t-1 )] This is generally intractable, since the filtering distribution, p(z t-1 | y 1:t-1 ), can only be found analytically for simple SSMs. Similar to other online variational methods (Marino et al., 2018; Zhao & Park, 2020; Campbell et al., 2021) , we substitute an approximation for the filtering distribution, q(z t-1 ) ≈ p(z t-1 | y 1:t-1 ), and consider E q(zt-1) [p θ (z t | z t-1 )] Unfortunately, due to the nonlinearity in p θ (z t | z t-1 ), Eq. ( 8) is still intractable, making further approximation necessary. We begin by considering an approximation, q(z t ), restricted to a minimal exponential family distribution with natural parameter λ, i.e. E q(zt-1) [p θ (z t | z t-1 )] ≈ q(z t ) = h exp( λ t(z t ) -A( λ)) Taking a variational approach (Hoffman et al., 2013) , our goal is to find the natural parameter λ that minimizes D KL q(z t )||E q(zt-1) [p θ (z t | z t-1 )] . Since this quantity cannot be minimized directly, we can consider the following upper bound: F = -H(q(z t )) -E q(zt) E q(zt-1) [log p θ (z t | z t-1 )] ≥ D KL q(z t )||E q(zt-1) [p(z t | z t-1 )] (10) Rather than minimizing F with respect to λ through numerical optimization, if we take q(z t-1 ), q(z t ), and p θ (z t | z t-1 ) to be in the same constant base measure exponential family, then we can show the following theorem which tells us how to compute the λ * that minimizes F. Theorem 1 (Variational prediction distribution). If p θ (z t | z t-1 ), q(z t-1 ), and q(z t ) are chosen to be in the same minimal and constant base measure exponential family distribution, E c , then q * (z t ) = argmin q∈Ec F(q) has a closed form solution given by q * (z t ) with natural parameters, λθ λθ = E q(zt-1) [λ θ (z t-1 )] Eq. ( 11) demonstrates that the optimal natural parameters of q are the expected natural parameters of the prior dynamics under the variational filtering posterior. While λθ cannot be found analytically, computing a Monte-Carlo approximation is simple; we only have to draw samples from q(z t-1 ) and then pass those samples through λ θ (•). This also reveals a very nice symmetry that exists between closed form conjugate Bayesian updates and variationally inferring the prediction distribution. In the former case we calculate E q(zt-1) [p θ (z t | z t-1 )] while in the latter we calculate E q(zt-1) [λ θ (z t-1 )]. We summarize the eVKF two-step procedure in Algorithm 1, located in Appendix E.3.

3.3. VARIATIONAL UPDATE STEP

Analogous to the Kalman filter, we update our belief of the latent state after observing y t . When the likelihood is conjugate to the filtering prior, we can calculate a Bayesian update in closed form by using q(z t ) as our prior and computing p(z t | y 1:t ) ≈ q(z t ) ∝ p(y t | z t )q(z t ) where q(z t ), with natural parameter λ, belongs to the same family as q(z t-1 ). In the absence of conjugacy, we use variational inference to find q(z t ) by maximizing the evidence lower bound (ELBO) λ * = argmax λ L t (λ, θ) = argmax λ E q(zt) [log p(y t | z t )] -D KL (q(z t | λ)||q(z t )) If the likelihood happens to be an exponential family family distribution, then one way to maximize Eq. ( 12) is through conjugate computation variational inference (CVI) (Khan & Lin, 2017) . CVI is appealing in this case because it is equivalent to natural gradient descent, and thus converges faster, and conveniently it operates in the natural parameter space that we are already working in.

3.4. TIGHT LOWER BOUND BY THE PREDICT-UPDATE PROCEDURE

A natural alternative to the variational predict then update procedure prescribed is to directly find a variational approximation to the filtering distribution. One way is to substitute E q(zt-1) p θ (z t | z t-1 ) for q(z t ) into the ELBO earlier (Marino et al., 2018; Zhao & Park, 2020) . Further details are provided in Appendix B, but after making this substitution and invoking Jensen's inequality we get the following lower bound on the log-marginal likelihood at time t M t = E q(zt) [log p(y t | z t )] -E q(zt) log q(z t ) -E q(zt-1) [log p(z t | z t-1 )] ) However, as we prove in Appendix B, this leads to a provably looser bound on the evidence compared to eVKF, as we state in the following theorem. Theorem 2 (Tightness of L t ). If we set ∆(q) = L t (q) -M t (q) (14) then, we have that ∆(q) = E q(zt-1) [A(λ θ (z t-1 ))] -A( λθ ) ≥ 0. (15) so that log p(y t ) ≥ L t (q) ≥ M t (q) In other words, the bound on the evidence when using the variational predict then update procedure is always tighter than the one step procedure. Thus, not only do the variational predict then update steps simplify computations, and make leveraging conjugacy possible, they also facilitate a better approximation to the posterior filtering distribution.

3.5. LEARNING THE DYNAMICS

Our remaining desiderata is the ability to learn the parameters of the dynamics model p θ (z t | z t-1 ). One way of learning θ, is to use variational expectation maximization; with λ * fixed, we find the θ * that maximizes the ELBO θ * = argmax θ L(λ * , θ) (17) = argmin θ D KL q(z t ; λ * )||q θ (z t ; λθ ) This objective may require expensive computation in practice, e.g. the log-determinant and Cholesky decomposition for Gaussian q and qθ . However, since we chose qθ and q to be in the same exponential family, then as described in the following Proposition, we can consider the more computationally tractable square loss function as an optimization objective. Proposition 1 (Optimal θ). If the mapping from z t-1 to the natural parameters of z t , given by λ θ (z t-1 ), is a universal function approximator with trainable parameters, θ, then setting θ * = argmin θ 1 2 ||λ * -λθ || 2 (19) is equivalent to finding θ * = argmax θ L t (λ * , θ). The proposition indicates that we find the optimal θ * that matches the natural parameters of predictive distribution to that of the filtering distribution. The proof can be found in Appendix C. Empirically, we have found that even for small neural networks, following Eq. ( 19), works better in practice than directly minimizing the KL term.

3.6. CORRECTING FOR THE UNDERESTIMATION OF VARIANCE

It might be illuminating to take a linear and Gaussian dynamical system, and compare the variational approximation of eVKF to the closed form solutions given by Kalman filtering. Given p θ (z t | z t-1 ) = N (z t | Az t-1 , Q), the mapping from a realization of z t-1 to the natural parameters of z t is given by λ θ (z t-1 ) = -1 2 Q -1 Az t-1 -1 2 vec(Q -1 ) . With this mapping, we can determine, in closed form, the prediction distribution given by eVKF. Assuming that q(z t-1 ) = N (z t-1 | m t-1 , P t-1 ), we can find the optimal variational prediction distribution by plugging λ θ (z t-1 ), into Eq. ( 11) to find q(z t ) = N (z t | Am t-1 , Q) (20) However, we know that the prediction step of the Kalman filter returns p(z t ) = N (z t | Am t-1 , Q + AP t-1 A ) Though this issue has been examined when applying VI to time series models, as in Turner & Sahani (2011) , it demonstrates that eVKF underestimates the true variance by an amount AP t-1 A . For this example, we see that because the second natural parameter does not depend on at least second order moments of z t-1 , the uncertainty provided by P t-1 will not be propagated forward. At least for the linear and Gaussian case, we can correct this with a post-hoc fix by adding AP t-1 A to the variance of the variational prediction. If we consider nonlinear Gaussian dynamics with p θ (z t | z t-1 ) = N (z t | m θ (z t-1 ), Q), then there does not exist an exact correction since the true prediction distribution will not be Gaussian. Empirically, we have found that adding an extended Kalman filter (Särkkä, 2013) like correction of M t-1 P t-1 M t-1 to the prediction distribution variance, where M t-1 = ∇m θ (m t-1 ), helps to avoid overconfidence. In the Appendix E.4 we show a case where not including an additional variance term gives unsatisfactory results when dynamical transitions are Gamma distributed.

4. RELATED WORKS

Classic recursive Bayesian methods such as the particle filter (PF), extended Kalman filter (EKF), and unscented Kalman filter (UKF) are widely used for online state-estimation (Särkkä, 2013) . Typically, these methods assume a known generative model, but unknown parameters can also be learned by including them through expectation maximization (EM), or dual filtering (Haykin, 2002; Wan & Van Der Merwe, 2000; Wan & Nelson, 1997) . While the PF can be used to learn the parameters of the dynamics in an online fashion, as in Kantas et al. (2015) , it suffers from the well known issue of "weight degeneracy" limiting its applicability to low dimensional systems. While methods from the subspace identification literature are frequently employed to estimate the underlying dynamics in an offline setting, they are often limited to linear systems (Buesing et al., 2012) . Note that it updates the second most recent state with the most recent observation so that it is technically smoothing rather than filtering, furthermore, the computational complexity of this method can be prohibitive in the online setting as is evident from Table 2 .

5.1. SYNTHETIC DATA AND PERFORMANCE MEASURES

We first evaluate and compare eVKF to other online variational methods as well as classic filtering methods using synthetic data. Since the ground truth is available for synthetic examples, we can measure the goodness of inferred latent states and learned dynamical system in reference to the true ones. To measure the filtering performance, we use the temporal average log density of inferred filtering distribution evaluated at the true state trajectory: T -1 T t=1 log q(Z t ; λ * t ), where λ * t are the optimal variational parameters of the approximation of the filtering distribution at time t. To assess the learning of the dynamics model, we sample points around the attractor manifold, evolve them one step forward, and calculate the KL divergence to the true dynamics: S -1 S i=1 D KL p θ * (z t+1 | Z i t )||p θ (z t+1 | Z i t ) , where Z i t are the perturbed samples around the attractor manifold (e.g. stable limit cycle) of the true dynamics, p θ * is the learned distribution over the dynamics, and p θ is the true distribution over the dynamics. This helps us evaluate the learned dynamics in the vicinity of the attractor where most samples originate from. The above divergence measures only the local structure of the learned dynamical system. To evaluate the global structure, we employ the Chamfer distance (Wu et al., 2021 ) D CD (S 1 ||S 2 ) = |S 1 | -1 x∈S1 min y∈S2 ||x -y|| 2 + |S 2 | -1 y∈S2 min x∈S1 ||y -x|| 2 (22) where S 1 and S 2 are two distinct sets of points. Usually, this metric is used to evaluate the similarity of point clouds. Intuitively, a low Chamfer distance would mean that trajectories from the learned dynamics would generate a manifold (point cloud) close to the true dynamics-a signature that the attractor structure can be generated. Since the Chamfer distance is not symmetric, we symmetrize it as D CD (S 1 , S 2 ) = 1 2 (D CD (S 1 ||S 2 ) + D CD (S 2 ||S 1 ) ) and take the logarithm. Chaotic recurrent neural network dynamics. We first evaluate the filtering performance of eVKF. We consider the chaotic recurrent neural network (CRNN) system used in Campbell et al. (2021) ; Zhao et al. (2022)  p θ (z t+1 | z t ) = N (z t+1 | z t + ∆τ -1 (γW tanh(z t ) -z t ), Q) and vary the latent dimensionality. Since we restrict ourselves to filtering, we fix the model parameters at their true values. In addition to the online variational methods, we also include classical filtering algorithms: ensemble Kalman filter (enKF) and bootstrap particle filter (BPF) (Douc et al., 2014) . METHOD L = 2 L = 16 L = 32 L = 64 EVKF (OURS) 0.047 ± 6.4e-4 0.150 ± 5.8e-4 0.250 ± 1.5e-3 0.450 ± 5.8e-3 OVS 0.103 ± 6.4e-4 0.178 ± 5.8e-4 0.302 ± 1.5e-3 0.323 ± 1.5e-3 VJF 0.105 ± 2.8e-2 0.288 ± 4.0e-2 0.400 ± 1.1e-2 0.711 ± 4.4e-2 ENKF (1,000) 0.115 ± 3.3e-3 0.437 ± 6.0e-2 0.619 ± 8.2e-2 0.620 ± 2.8e-2 BPF (10,000) 0.047 ± 6.7e-4 0.422 ± 9.3e-3 0.877 ± 2.5e-2 1.660 ± 4.2e-2 Table 1 : RMSEs of state estimation for Chaotic RNN dynamics. We show the mean ± one standard deviation (over 10 trials) of latent state RMSEs. The latent dimensionality L varies from 2 up to 64. Those in the parentheses are the size of ensemble and the number of particles. Table 1 shows the RMSEs (mean ± standard deviation over 10 trials of length 250) under increasing latent dimensionality. Surprisingly, eVKF offers competitive performance to the BPF for the 2D case, a regime where the BPF is known to excel. The results show eVKF offers satisfactory results compared to the classic filtering algorithms as well as similar online variational algorithms. We see that OVS performs better in the case L = 64, however, this is at the cost of significantly higher computational complexity, as shown in Table 2 . Learning nonlinear dynamics. In this experiment we evaluate how well eVKF can learn the dynamics of a nonlinear system that we only have knowledge of through a sequential stream of observations y 1 , y 2 , • • • and so on. These observations follow a Poisson likelihood with intensity given by a linear readout of the latent state. For the model of the dynamics we consider a noise corrupted Van der Pol oscillator so that the state-space model for this system is given by z t+1,1 = z t,1 + 1 τ1 ∆z t,2 + σ z t+1,2 = z t,2 + 1 τ2 ∆(γ(1 -z t,1 ) 2 z t,2 -z t,1 ) + σ (23) y t | z t ∼ Poisson(y t | ∆ exp(Cz t + b)) where exp(•) is applied element wise, ∆ is the time bin size, and ∼ N (0, 1). In order to focus on learning the dynamical system, we fix ψ = {C, b} at the true values, and randomly initialize the parameters of the dynamics model so that we can evaluate how well eVKF performs filtering and learning the dynamics. We train each method for 3500 data points, freeze the dynamics model, then infer the filtering posterior for 500 subsequent time steps. In Table 2 we report all metrics in addition to the average time per step for both the Poisson and Gaussian likelihood cases. In Figure 1E , we see that eVKF quickly becomes the lowest RMSE filter and remains that way for all 4000 steps. The downside of using Gaussian approximations is most apparent when we look at the Chamfer distance, which is always worse within each method. Note, we do not calculate the KL measure when Gaussian approximations are used. To examine the computational cost, we report the actual run time per step. Note that OVS took a multi-fold amount of time per step. Continuous Bernoulli dynamics. The constant base measure exponential family opens up interesting possibilities of modeling dynamics beyond additive, independent, Gaussian state noise. Such dynamics could be bounded (i.e. Gamma dynamics) or exist over a compact space (i.e. Beta dynamics). In this example, we consider nonlinear dynamics that are conditionally continuous Bernoulli (CB) (Loaiza-Ganem & Cunningham, 2019 ) distributed, i.e. p θ (z t+1 | z t ) = i CB(z t+1,i | f θ (z t ) i ) p(y n,t | z t ) = N (y n,t | C n z t , r 2 n ) ( ) where f θ : [0, 1] L → [0, 1] L , and n = 1, . . . , N . We choose a factorized variational filtering distribution such that q(z t ) = i CB(z t,i | λ t,i ), where λ t,i is the i-th natural parameter at time t. In Fig. 2 we show that eVKF is able to learn an accurate representation of the dynamics underlying the observed data. Fig. 2B also demonstrates that a CB prior over the dynamics is able to generate trajectories much more representative of the true data compared to a Gaussian approximation. These results show CB dynamics could be a proper modeling choice if a priori the dynamics are known to be compact, and exhibit switching like behavior. In Table 3 we report the performance of eVKF and the other methods on synthetic data generated from the state-space model above when using both CB and Gaussian approximations. Notably, we see the Chamfer metric is lower within each method when using CB approximation, showing that even though the true filtering distribution might not exactly be a CB distribution, it is still a good choice.

5.2. ELECTROPHYSIOLOGICAL RECORDING DURING A REACHING TASK

To evaluate eVKF with real-world neural data, we considered electrophysiological recordings taken from monkey motor cortex during a reaching task (Churchland et al., 2012) . This dataset has typically been used to evaluate latent variable modeling of neural population activity (Pei et al., 2021) . In each trial of the experiment, a target position is presented to the monkey, after which it must wait a randomized amount of time until a "Go" cue, signifying that the monkey should reach toward the target. We first take 250 random trials from the experiment, and use latent states inferred by Gaussian process factor analysis (GPFA) (Yu et al., 2009) to pretrain eVKF's model of the dynamics. Then, we True eVKF 1 0 1 0 Gauss. CB dim1 dim2 predicted true mean B) A) Figure 2: Continuous Bernoulli dynamics. A) Velocity field for both E(z t | f θ (z t-1 )) and f θ (z t-1 ) from the synthetically created continuous Bernoulli dynamics, and those inferred by eVKF. We see that in mean there are limit cycle dynamics, but for the states to actually saturate at the boundary there have to be strong attractor dynamics in parameter space. B) Inferred filtering distributions when using Gaussian approximations compared to continuous Bernoulli approximations; Gaussian distributions are able to infer the latent state well -but they cannot generate similar trajectories, as we see from trajectories propagated forward through the learned dynamics (shaded in gray) The hand position given by the velocity that we linearly decode using eVKF's inferred firing rates. C) Same as previous, but for GPFA. We see that the R 2 value, and decoded hand positions using eVKF are competitive with GPFA. D) Single trial (thin lines), and condition average (bold lines) firing rates for select neurons and tasks, aligned to the movement onset (demarcated with green dots) use eVKF to perform filtering and update the dynamics model on a disjoint set of 250 trials. In order to determine if eVKF learns a useful latent representation, we examine if the velocity of the monkey's movement can be linearly decoded using the inferred filtering distribution. In Fig. 3B , we show the decoded hand position from the smoothed firing rates inferred by eVKF in parallel to the result of GPFA. eVKF is able to achieve competitive performance even though GPFA is a smoothing method. In Fig. 3C , we plot the single trial firing rates of some neurons over selected reaching conditions, showing that even for single trials, eVKF can recover firing rates decently.

6. CONCLUSION

We tackled the problem of inferring latent trajectories and learning the dynamical system generating them in real-time-for Poisson observation, processing took ∼ 10 ms per sample. We proposed a novel online recursive variational Bayesian joint filtering method, eVKF, which allows rich and flexible stochastic state transitions from any constant base measure exponential family for arbitrary observation distributions. Our two-step variational procedure is analogous to the Kalman filter, and achieves a tighter bound on the ELBO than the previous methods. We demonstrated that eVKF performs on par with competitive online variational methods of filtering and parameter learning. For future work, we will focus on extensions to the full exponential family of distributions, characterizing the variance lost in more generality, and improving performance as latent dimensionality is scaled up. Future work will also incorporate learning the parameters of the likelihood ψ into eVKF, rather than focusing only on the dynamics model parameters and filtering states. Theorem 1 (Variational prediction distribution). If p θ (z t | z t-1 ), q(z t-1 ), and q(z t ) are chosen to be in the same minimal and constant base measure exponential family distribution, E c , then q * (z t ) = argmin q∈Ec F(q) has a closed form solution given by q * (z t ) with natural parameters, λθ λθ = E q(zt-1) [λ θ (z t-1 )] (11) Proof: The upper bound we want to minimize is given by F = -H(q(z t )) -E q(zt) E q(zt-1) [log p θ (z t | z t-1 )] where H(q) is the entropy of q. For exponential family distributions, we recall that the negative entropy coincides with the conjugate dual of the log partition function, or -H(q µ ) = A * (µ) (Wainwright & Jordan, 2008) . Then, we have that, F = -E q(zt) E q(zt-1) λ θ (z t-1 ) t(z t ) -A(λ(z t-1 )) -log h(z t ) + A * ( μ) (27) = -E q(zt) E q(zt-1) λ θ (z t-1 ) t(z t ) -A(λ(z t-1 )) -log h + A * ( μ) (28) = -μ E q(zt-1) [λ θ (z t-1 )] + A * ( μ) + constants (29) = -μ λθ + A * ( μ) + constants (30) where in the first line we use the fact that E q(zt) [t(z t )] = μ. In the second line, we use the fact that p θ (z t | z t-1 ) has a constant base measure. In the third line we separate out terms that are constant with respect to λ. In the fourth line we use the definition λθ := E q(zt-1) [λ θ (z t-1 )]. Since our equation is in terms of the mean parameters of q, thanks to the minimality of q, we can just consider the optimal variational parameters in their mean parameterization. Then, by considering maximization of -F rather than minimization of F, we can write that the optimal variational parameters satisfy μ * = argmax μ μ λθ -A * ( μ) Taking derivatives of the right hand side and setting it equal to 0, we have that λθ -∇ μ * A * ( μ * ) = 0 (32) λθ -λ * = 0 (33) λ * = λθ (34) where in the second line we use the fact that λ = ∇ µ A * (µ). As stated in Theorem 1, we have that λ * = λθ as claimed.

B PROOF OF THEOREM 2

Theorem 2 (Tightness of L t ). If we set ∆(q) = L t (q) -M t (q) (14) then, we have that ∆(q) = E q(zt-1) [A(λ θ (z t-1 ))] -A( λθ ) ≥ 0. ( ) so that log p(y t ) ≥ L t (q) ≥ M t (q) (16) Proof: We write the two ELBOs as L t (q) = E q(zt) log p(y t | z t ) -D KL (q(z t )||q(z t )) M t (q) = E q(zt) log p(y t | z t ) -E q(zt) log q(z t ) -E q(zt-1) log p(z t | z t-1 ) This means the difference, ∆(q) := L t (q) -M t (q), can be written as ∆(q) = E q(zt) log q(z t ) -E q(zt-1) [log p(z t | z t-1 )] = E q(zt) log h(z t ) + λ θ t(z t ) -A( λθ ) -E q(zt-1) log h(z t ) + λ θ (z t-1 ) t(z t ) -A(λ θ (z t-1 )) = ( λθ -E q(zt-1) [λ θ (z t-1 )] λθ ) E q(zt) [t(z t )] + E q(zt-1) [A(λ θ (z t-1 ))] -A( λθ ) = E q(zt-1) [A(λ θ (z t-1 ))] -A( λθ ) Invoking Jensen's inequality, and the fact that the log-partition function is convex in its arguments, we can write that E q(zt-1) [A(λ θ (z t-1 ))] ≥ A(E q(zt-1) [λ θ (z t-1 )] λθ ) which means that ∆(q) = E q(zt-1) [A(λ θ (z t-1 ))] -A( λθ ) ≥ 0 as claimed.

C PROOF OF PROPOSITION 1

Proposition 1 (Optimal θ). If the mapping from z t-1 to the natural parameters of z t , given by λ θ (z t-1 ), is a universal function approximator with trainable parameters, θ, then setting θ * = argmin θ 1 2 ||λ * -λθ || 2 (19) is equivalent to finding θ * = argmax θ L t (λ * , θ). Proof: By the fact that L(λ * , θ) = E q(zt) log p(y t | z t ) -D KL (q(z )||q θ (z t )), we have that ∇ θ L(λ * , θ) = -∇ θ D KL (q(z t )||q θ (z t )) so that ∇ θ D KL (q(z t )||q(z t )) = ∇ θ E q(zt) t(z t ) λ * -λθ + A( λθ ) (39) = -∇ θ λθ µ * + [∇ θ λθ ]∇λ θ A( λθ ) (40) = -∇ θ λθ µ * + [∇ θ λθ ] μθ (41) = [∇ θ λθ ]( μθ -µ * ) whereas, for the alternative objective we have that ∇ θ 1 2 λ -λθ 2 = [∇ θ λθ ]( λθ -λ * ) Assume that λ θ (•) is a flexible enough function approximator so that [∇ θ λθ ] has full column rank. Then if the gradient of the KL term is 0, either [∇ θ λθ ] is 0, in which case Eq. ( 43) is 0, or ( λθλ * ) is 0, which by continuity of the mapping from natural to mean parameters implies that ( μθµ * ) is 0. Showing equivalence of stationary points of the two objectives.

D VARIANCE CORRECTION FOR NONLINEAR GAUSSIAN DYNAMICS

For the case of nonlinear Gaussian dynamics, we could consider directly linearizing the dynamics in order to forego solving a variational problem (e.g. linearizing the dynamics of Eq. 8 about the mean of z t-1 to evaluate the expectation directly). Concretely, consider nonlinear Gaussian dynamics specified via p(z t | z t-1 ) = N (z t | m θ (z t-1 ), Q); assuming we have a variational approximation eVKF OVS VJF SVMC Figure 4 : Learned phase portraits and inferred distribution of the latent states for the Van der Pol system with Poisson observations. For comparison, we plot the ground truth latent state in black. We see that eVKF infers much smoother latent states than the other methods. to the filtering distribution at time t -1 given by q(z t-1 ) = N (m t-1 , P t-1 ), then the prediction step could be approximated as p(z t | y 1:t-1 ) = E q(zt-1) [p θ (z t | z t-1 )] (44) = E q(zt-1) [ N (z t | m θ (z t-1 ), Q)] (45) ≈ E q(zt-1) [ N (z t | m θ (m t-1 ) + M t-1 (z t-1 -m t-1 ), Q)] (46) = N (z t | m θ (m t-1 ), M t-1 P t-1 M t-1 + Q) (47) := q(z t ) ) where M t-1 := ∇m θ (z t-1 )| mt-1 . This prediction distribution coincides exactly with the one returned by the extended Kalman filter (Särkkä, 2013) , as well as the one prescribed by eVKF. Similar procedures to facilitate tractable inference in nonlinear Gaussian models is covered extensively in Kamthe et al. (2022) .

E EXPERIMENTAL DETAILS

E.1 VAN DER POL Data was generated according to Eq. ( 23) with γ = 1.5, τ 1 = τ 2 = 0.1, σ = 0.1. For the Poisson likelihood example, we can take advantage of CVI as mentioned in the main text. For this model, the expected log-likelihood has an analytical solution, E q(zt) log p(y t | z t ) = n y nt C n m t -b n -∆ exp(C n m t + 1 2 C n P t C n ) To parameterize the dynamics, p θ (z t | z t-1 ), we use a single layer MLP with 32 hidden units and SiLU (Elfwing et al., 2018)  where Cn = C n C n . We use Adam and update our dynamics model every 100 time steps. For this example, the synthetic data is a length 500 sequence. The open source code for running OVS was not immediately compatible with non-Gaussian approximations. Top: portraits and filtered latent states when approximations are constrained to be Gaussian. Bottom: Same as top, but for approximations constrained to be CB.

E.3 EVKF ALGORITHM

Below we present the algorithm for using eVKF to perform inference. Instead of updating θ every data point, we could accumulate gradients for a fixed number of steps, so that the variance of gradient steps is reduced. Algorithm 1 eVKF Input: y t ∈ R N , θ (dynamics parameters) for each y t or until done do λt ← E q(zt-1) [λ θ (z t-1 )] predict λ t ← argmax λt E q(zt;λt) [log p(y t | z t )] -D KL q(z t ; λ)||q(z t ; λt ) update t ← ||λ t -E q(zt-1) [λ θ (z t-1 )] || 2 2 θ t ← θ t-1 -∇ θ t end for

E.4 EXAMPLE OF GAMMA DYNAMICS

We consider synthetic data where observations are conditionally Poisson and dynamic transitions are Gamma distributed, so that the state-space model description is p(z t+1 | z t ) = Gamma(z t+1 | b 0 f (z t ) 2 , b 0 f (z t )) (51) p(y t | z t ) = Poisson(y t | ∆ exp(Cz t + b)) Similar to a standard Gaussian dynamical system, under this specification, we have that E p(zt+1|zt) [z t+1 ] = f (z t ) like the gamma dynamical system presented in Schein et al. (2016) , but unlike that work both α and β are functions of z t so that the variance is constant. We choose a variational approximation that factors as a product of Gamma distributions so that q(z t ) = Gamma(z t,i | α i , β i ). Since we use the canonical link function, the expected log- Figure 6 : On the left: phase portrait and filtered latent states of a system with gamma distributed transitions using the prediction step variance correction. On the right: same as the left, but without using the prediction step correction. likelihood can be calculated in closed form since E q(zt) log p(y t | z t ) = n (C n m n + b n )y t,n -∆ exp(b n ) L l=1 1 - C n,l β t,i -αt,i which allows us to use CVI for inference. We take f (•) to be a 64 hidden unit MLP with SiLU nonlinearity and softplus output. To see how to correct for the underestimation of variance, notation will be less cluttered if our discussion is in terms of means/variances; so, let z t+1 ∼ q(z t+1 ) have mean mθ,t+1 , then the corrected variance, s2 θ,t+1 should be equal to the prior transition variance (i.e. 1/b 0 ) plus the correction term so that s2 θ,t+1 = 1/b 0 + (s t ∇ θ mθ,t+1 ) 2 where s t is the standard deviation of z t ∼ q(z t ). As shown in Figure 6 , without this correction, the quality of inference is noticeably worse.



minimality means that all sufficient statistics are linearly independent.



Figure1: Van der Pol oscillator with Poisson observations. A) The filtering distribution inferred by eVKF over time, shading indicates the 95% credible interval. B) Zoomed in view at the beginning observations. We plot the mean, and trajectories evolved from the filtered mean 5 steps ahead using a "snapshot" of the dynamics at that time, their ending positions are given by the ×'s. C) Same as before, but at the ending observations. eVKF has learned the dynamics, leading to better filtering capabilities. D) True Van der Pol velocity field compared to the dynamics inferred by eVKF. E) Moving average RMSE of the filtering mean to the true dynamics, averaged over 10 trials, error bars indicate two standard errors.

Figure3: A) True hand movements from fixation point to target. B) The hand position given by the velocity that we linearly decode using eVKF's inferred firing rates. C) Same as previous, but for GPFA. We see that the R 2 value, and decoded hand positions using eVKF are competitive with GPFA. D) Single trial (thin lines), and condition average (bold lines) firing rates for select neurons and tasks, aligned to the movement onset (demarcated with green dots)

Figure 5: Learned phase portraits and inferred distribution of the latent states for the continuous Bernoulli example.The open source code for running OVS was not immediately compatible with non-Gaussian approximations. Top: portraits and filtered latent states when approximations are constrained to be Gaussian. Bottom: Same as top, but for approximations constrained to be CB.

Marino et al. (2018)  andZhao & Park (2020)(VJF), in contrast to eVKF, perform a single step approximation each time instant, which leads to a provably looser bound on the ELBO as stated in Theorem 2.Zhao et al. (2022)(SVMC) use particle filtering to infer the filtering distribution and derives a surrogate ELBO for parameter learning, but because of weight degeneracy it is hard to scale this method to higher dimensional SSMs.Campbell et al. (2021)(OVS) use a backward factorization of the joint posterior.

Metrics of inference for Van der Pol dynamics. We report the log-likelihood of the ground truth under the inferred filtering distributions, the KL of one-step transitions, the log symmetric Chamfer distance of trajectories drawn from the learned prior to trajectories realized from the true system, and computation time per time step. SVMC uses 5000 particles.

Metrics of inference for continuous Bernoulli dynamics. We use both CB and Gaussian approximations for the methods that are applicable. eVKF achieves the highest log-likelihood of latent trajectories, lowest KL-divergence of the learned dynamics, and lowest Chamfer distance.

nonlinearity. During training we use Adam(Kingma & Ba, 2014), and update the dynamics every 150 time steps. In total we use 3500 time points for training the dynamics model for all methods. For measuring the time per step as in Table2the experiments were run on a computer with an Intel Xeon E5-2690 CPU at 2.60 GHz.E.2 CONTINUOUS BERNOULLIFor the continuous Bernoulli example, we can take advantage of CVI as mentioned in the main text. For this, we require derivatives of the expected log-likelihood, which for a Gaussian likelihood, p(y t,n | z

ACKNOWLEDGEMENTS

MD and IP were supported by an NSF CAREER Award (IIS-1845836) and NIH RF1DA056404. YZ was supported in part by the National Institute of Mental Health Intramural Research Program (ZIC-MH002968). We thank the anonymous reviewers for their helpful feedback and comments, and Josue Nassar for helpful suggestions for improving the manuscript.

