DYNAMIC RELATIONAL INFERENCE IN MULTI-AGENT TRAJECTORIES

Abstract

Unsupervised learning of interactions from multi-agent trajectories has broad applications in physics, vision and robotics. However, existing neural relational inference works are limited to static relations. In this paper, we consider a more general setting of dynamic relational inference where interactions change over time. We propose DYnamic multi-Agent Relational Inference (DYARI) model, a deep generative model that can reason about dynamic relations. Using a simulated physics system, we study various dynamic relation scenarios, including periodic and additive dynamics. We perform comprehensive study on the trade-off between dynamic and inference period, the impact of training scheme, and model architecture on dynamic relational inference accuracy. We also showcase an application of our model to infer coordination and competition patterns from real-world multi-agent basketball trajectories. Particles, friends, and teams are multi-agent relations at different scales. Learning multi-agent interactions is essential to our understanding of the structures and dynamics underlying many systems. Practical examples include understanding social dynamics among pedestrians (



Figure 1 . Physical simulation of 2D particles coupled by invisible springs (left) according to a latent interaction graph (right). In this example, solid lines between two particle nodes denote connections via springs whereas dashed lines denote the absence of a coupling. In general, multiple, directed edge types -each with a different associated relation -are possible. able to reason about the different types of interactions that might arise, e.g. defending a player or setting a screen for a teammate. It might be feasible, though tedious, to manually annotate certain interactions given a task of interest. It is more promising to learn the underlying interactions, perhaps shared across many tasks, in an unsupervised fashion. Recently there has been a considerable amount of work 4687v2 [stat.ML] 6 Jun 2018 (Kipf et al., 2018) Relational inference aims to discover hidden interactions from data and has been studied for decades. Statistical relational learning are based on probabilistic graphical models such as probabilistic relational model (Kemp & Tenenbaum, 2008; Getoor et al., 2001; Koller et al., 2007; Shum et al., 2019) . However, these methods may require significant feature engineering and high computational costs. Recently, Battaglia et al. (2016) ; Santoro et al. (2017) propose to reason about relations using graph neural networks but still requires supervision. One exception is Neural Relational Inference (NRI) (Kipf et al., 2018) , a flexible deep generative model that can infer potential relations in an unsupervised fashion. As shown in Figure 1 , NRI simultaneously learns the dynamics from multi-agent trajectories and infers their relations. In particular, NRI builds upon variational auto-encoder (VAE) (Kingma & Welling, 2013) and introduces latent variables to represent the hidden relations. Despite its flexibility, a major limiting factor of NRI is that it assumes the relations among the agents are static. That is, two agents are either interacting or not interacting regardless of their states at different time steps, which is rather restrictive. In this paper, we study a more realistic setting: dynamic relational inference. For example, in game plays, players can coordinate and compete dynamically depending on the strategy. We propose a novel deep generative model, which we call DYnamic multi-Agent Relational Inference (DYARI). DYARI encodes trajectory interactions at different time steps. It utilizes deep temporal CNN models with pyramid pooling to extract rich representations from the interactions. DYARI infers the relations for each sub-sequence dynamically and jointly decode a sequence of relations. As relational inference is unsupervised, we use simulated dynamics physics systems as ground truth for validation. We find that the performance of the static NRI model deteriorates significantly with shorter output trajectories, making it unsuitable for dynamic relational inference. In contrast, DYARI is able to accurate infer the hidden relations with various dynamics scenarios. We also perform extensive ablative study to understand the effect of inference period, training schemes and model architecture. Finally, We showcase our DYARI model on real-world basketball trajectories. In summary, our contributions include: • We tackle the challenging problem of unsupervised learning of hidden dynamic relations given multi-agent trajectories. • We develop a novel deep generative model called DYARI to handle time-varying interactions and predict a sequence of hidden relations in an end-to-end fashion. • We demonstrate the effectiveness our method on both the simulated physics dynamics and real-world basketball game play datasets.

2. RELATED WORK

Deep sequence models Deep sequence models include both deterministic models (Alahi et al., 2016; Li et al., 2019; Mittal et al., 2020) and stochastic models (Chung et al., 2015; Fraccaro et al., 2016; Krishnan et al., 2017; Rangapuram et al., 2018; Chen et al., 2018; Huang et al., 2018; Yoon et al., 2019) . For GAN-like models, (Yoon et al., 2019) (2018) combine normalizing flows with autoregressive models. However, all existing models only model the temporal latent states for individual sequences rather than their interactions. Relational inference Graph neural networks (GNNs) seek to learn representations over relational data, see several recent surveys on GNNs and the references therein, e.g. (Wu et al., 2019; Goyal & Ferrara, 2018) . Unfortunately, most existing work assume the graph structure is observed and train with supervised learning. In contrast, relational inference aims to discover the hidden interactions and is unsupervised. Earlier work in relational reasoning (Koller et al., 2007) use probabilistic graphical models, but requires significant feature engineering. The seminal work of NRI (Kipf et al., 2018) use neural networks to reason in dynamic physical systems. Alet et al. (2019) reformulates NRI as meta-learning and proposes simulated annealing to search for graph structures. Relational inference is also posed as Granger causal inference for sequences (Louizos et al., 2017; Löwe et al., 2020) . Nevertheless, all existing work are limited to static relations while we focus on dynamic relations. Multi-agent learning Multi-agent trajectories arises frequently in reinforcement learning (RL) and imitation learning (IL) (Albrecht & Stone, 2018; Jaderberg et al., 2019) . Modeling agent interactions given dynamic observations from the environment remains a central topic. In the RL setting, for example, Sukhbaatar et al. (2016) models the control policy in a fully cooperative multi-agent setting and applies a GNN to represent the communications. Le et al. (2017) models the agents coordination as a latent variable for imitation learning. Song et al. (2018) generalizes GAIL (Ho & Ermon, 2016) to multi-agent through a shared generator. However, these coordination models only capture the global interactions implicitly without the explicit graph structure. Tacchetti et al. (2019) combines GNN with a forward dynamics model to model multi-agent coordination but also requires supervision. Grover et al. (2018) directly models the episodes of interaction data with GNs for learning multiagent policies. Our method instantiates the multi-agent imitation learning framework, but focuses on relational inference. Our approach is also applicable to dynamic modeling in model-based RL.

3. DYNAMIC MULTI-AGENT RELATIONAL INFERENCE

Given a collection of multi-agent trajectories, we aim to reason about their hidden relations over time. First we describe the underlying probabilistic inference problem.

3.1. PROBABILISTIC INFERENCE FORMULATION

For each agent i ∈ {1, • • • , N }, define its state (coordinates) as x t ∈ R D . A trajectory τ (i) = (x 1 , x 2 , • • • , x T ) is a sequence of states that are sampled from a policy. Given trajectories from N agents {τ (i) } N i=1 , dynamic relational inference aims to infer the pairwise interactions of N agents at every time step. Mathematically speaking, the joint distribution of the trajectories can be written as:  x (1) 1 x (1) 2 x (1) t x (2) 1 x (2) 2 x (2) t x (3) 1 x (3) 2 x (3) t z (12) 1 z (23) 1 z (13) 1 z (12) 2 z (12) t z (23) 2 z (23) t z (13) p(τ (1) , • • • , τ (N ) ) = T t=1 p(x t+1 |x t , • • • , x 1 ) where p(x t+1 |x t , • • • , x 1 ) represents the state transition dynamics. We use the bold form x t := (x (1) t , • • • , x (N ) t ) to indicate the concatenation of all agents observations and x <t := (x 1 , • • • , x t ). We introduce latent variables z (ij) t to denote the interactions between agent i and j at time t. To make the problem tractable, we restrict z (ij) t to be categorical, representing discrete interactions such as coordination or competition. We assume that the dynamics model can be decomposed into the individual dynamics, in conjunction with the pairwise interaction. This substantially reduces the dimensionality of the distribution and simplifies learning. Therefore, we can rewrite the transition dynamics as: p(x t+1 |x <t ) ≈ z N i=1 p(x (i) t+1 |x (i) <t , z (ij) t ) N i=1 N j=1,j =i p(z (ij) t |x (i) <t , x (j) <t )dz Here each p(x (i) t+1 |x (i) <t ) captures the state transition dynamics of a single agent. p(z (ij) t |x (i) <t , x <t ) represents the latent interactions between two agents. Figure 2 visualizes the graphical model representation for three agents over t number of time steps. The shaded nodes represent observed variables and the unshaded nodes are latent variables. Dynamic relational inference is to estimate distributions of the hidden variables {z (ij) t } at different time steps.

3.2. DYNAMIC MULTIAGENT RELATIONAL INFERENCE (DYARI)

We propose a deep generative model: Dynamic multi-Agent Relational Inference (DYARI). Given the trajectories (x 1 , • • • , x T ) of all agents, DYARI first concatenates the trajectories based on a fully connected graph. The concatenated trajectories are used as interaction features for the encoder. Then we sample the sequence of relations from the encoded hidden states. Finally, we generate the future trajectory predictions conditioned on the sampled relations. Figure 3 visualizes the overall architecture of our model which encodes and decodes multi-agent trajectories. The bottom cut-out diagram shows the architecture of our encoder.

Encoder.

A key ingredient of DYARI is an encoder that is inspired by PSPNet (Zhao et al., 2017) to learn rich representations of trajectories at different scales. In particular, we define a residual block as a two-layer CNN with residual connections (He et al., 2016) . Our encoder has four modules: feature extraction, pyramid pooling, an aggregation module, and an interpolation module. • Feature extraction: the feature extraction module consists of multiple residual blocks interleaved with pooling layers to extract rich temporal features. • Pyramid pooling: the pyramid pooling module learns multi-scale temporal representations from the extracted features. First, the output of the feature extraction module is downsampled by 2x < l a t e x i t s h a 1 _ b a s e 6 4 = " 0 Y j H w 4 c R j Q / 3 i J b X 7 3 6 q M n T 1 Y Z g = " > A A A B 8 X i c b V D L S g N B E O z 1 G e M r 6 t H L Y B D i J e x G Q Y 9 B L x 4 j m A c m a 5 i d T J I h s 7 P L T K 8 Q l v y F F w + K e P V v v P k 3 T p I 9 a G J B Q 1 H V T X d X E E t h 0 H W / n Z X V t f W N z d x W f n t n d 2 + / c H D Y M F G i G a + z S E a 6 F V D D p V C 8 j g I l b 8 W a 0 z C Q v B m M b q Z + 8 4 l r I y J 1 j + O Y + y E d K N E X j K K V H j p I k 8 e 0 5 J 1 N u o W i W 3 Z n I M v E y 0 g R M t S 6 h a 9 O L 2 J J y B U y S Y 1 p e 2 6 M f k o 1 C i b 5 J N 9 J D I 8 p G 9 E B b 1 u q a M i N n 8 4 u n p B T q / R I P 9 K 2 F J K Z + n s i p a E x 4 z C w n S H F o V n 0 p u J / X j v B / p W f C h U n y B W b L + o n k m B E p u + T n t C c o R x b Q p k W 9 l b C h l R T h j a k v A 3 B W 3 x 5 m T Q q Z e + 8 X L m 7 K F a v s z h y c A w n U A I P L q E K t 1 C D O j B Q 8 A y v 8 O Y Y 5 8 V 5 d z 7 m r S t O N n M E f + B 8 / g D a S 5 B g < / l a t e x i t > ⌧ (N ) < l a t e x i t s h a 1 _ b a s e 6 4 = " + / r 3 K k t m + 1 n j D 2 q s Q d v z I 1 0 s A u w = " > A A A B 8 X i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 B I t Q L y W p g h 6 L X j x J B f u B b S y b 7 a Z d u t m E 3 Y l Q Q v + F F w + K e P X f e P P f u G 1 z 0 N Y H A 4 / 3 Z p i Z 5 8 e C a 3 S c b y u 3 s r q 2 v p H f L G x t 7 + z u F f c P m j p K F G U N G o l I t X 2 i m e C S N Z C j Y O 1 Y M R L 6 g r X 8 0 f X U b z 0 x p X k k 7 3 E c M y 8 k A 8 k D T g k a 6 a G L J H l M y 7 e n k 1 6 x 5 F S c G e x l 4 m a k B B n q v e J X t x / R J G Q S q S B a d 1 w n R i 8 l C j k V b F L o J p r F h I 7 I g H U M l S R k 2 k t n F 0 / s E 6 P 0 7 S B S p i T a M / X 3 R E p C r c e h b z p D g k O 9 6 E 3 F / 7 x O g s G l l 3 I Z J 8 g k n S 8 K E m F j Z E / f t / t c M Y p i b A i h i p t b b T o k i l A 0 I R V M C O 7 i y 8 u k W a 2 4 Z 5 X q 3 X m p d p X F k Y c j O I Y y u H A B N b i B O j S A g o R n e I U 3 S 1 s v 1 r v 1 M W / N W d n M I f y B 9 f k D B o i Q f Q = = < / l a t e x i t > ! t < l a t e x i t s h a 1 _ b a s e 6 4 = " f c u U 7 4 a h y P J x c f D b f x g f d y / 9 C j A = " > A A A B + n i c b V B N S 8 N A E N 3 4 W e t X q k c v w S J 4 K k k V 9 F j 0 4 r G C / Y A 2 h M 1 2 k y 7 d 7 I b d i a X E / h Q v H h T x 6 i / x 5 r 9 x 2 + a g r Q 8 G H u / N M D M v T D n T 4 L r f 1 t r 6 x u b W d m m n v L u 3 f 3 B o V 4 7 a W m a K 0 B a R X K p u i D X l T N A W M O C 0 m y q K k 5 D T T j i 6 n f m d R 6 o 0 k + I B J i n 1 E x w L F j G C w U i B X e l z K W L F 4 i F g p e Q 4 g M C u u j V 3 D m e V e A W p o g L N w P 7 q D y T J E i q A c K x 1 z 3 N T 8 H O s g B F O p + V + p m m K y Q j H t G e o w A n V f j 4 / f e q c G W X g R F K Z E u D M 1 d 8 T O U 6 0 n i S h 6 U w w D P W y N x P / 8 3 o Z R N d + z k S a A R V k s S j K u A P S m e X g D J i i B P j E E E w U M 7 c 6 Z I g V J m D S K p s Q v O W X V 0 m 7 X v M u a v X 7 y 2 r j p o i j h E 7 Q K T p H H r p C D X S H m q i F C B q j Z / S K 3 q w n 6 8 V 6 t z 4 W r W t W M X O M / s D 6 / A E K u Z S B < / l a t e x i t > z (1,2) t < l a t e x i t s h a 1 _ b a s e 6 4 = " T j a + B p B P G P B 0 J M x V 1 H q 3 M x n 0 X 8 c = " > A A A B 8 n i c b V B N S w M x E M 3 W r 1 q / q h 6 9 B I t Q Q c p u F f R Y 9 O K x g v 2 A d i 3 Z N N u G Z p M l m R X q 0 p / h x Y M i X v 0 1 3 v w 3 p u 0 e t P X B w O O 9 G W b m B b H g B l z 3 2 8 m t r K 6 t b + Q 3 C 1 v b O 7 t 7 x f 2 D p l G J p q x B l V C 6 H R D D B J e s A R w E a 8 e a k S g Q r B W M b q Z + 6 5 F p w 5 W 8 h 3 H M / I g M J A 8 5 J W C l z t N D W v b O q q e T H v S K J b f i z o C X i Z e R E s p Q 7 x W / u n 1 F k 4 h J o I I Y 0 / H c G P y U a O B U s E m h m x g W E z o i A 9 a x V J K I G T + d n T z B J 1 b p 4 1 B p W x L w T P 0 9 k Z L I m H E U 2 M 6 I w N A s e l P x P 6 + T Q H j l p 1 z G C T B J 5 4 v C R G B Q e P o / 7 n P N K I i x J Y R q b m / F d E g 0 o W B T K t g Q v M W X l 0 m z W v H O K 9 W 7 i 1 L t O o s j j 4 7 Q M S o j D 1 2 i G r p F d d R A F C n 0 j F 7 R m w P O i / P u f M x b c 0 4 2 c 4 j + w P n 8 A Q d + k G 8 = < / l a t e x i t > z (i,j) t < l a t e x i t s h a 1 _ b a s e 6 4 = " P G W 7 3 w W z M Q + N e D V d N W i t B 5 H D / s U = " > A A A B 8 n i c b V D L S g N B E J z 1 G e M r 6 t H L Y B A i S N i N g h 6 D X j x G M A / Y r G F 2 M p u M m Z 1 Z Z n q F u O Q z v H h Q x K t f 4 8 2 / c f I 4 a G J B Q 1 H V T X d X m A h u w H W / n a X l l d W 1 9 d x G f n N r e 2 e 3 s L f f M C r V l N W p E k q 3 Q m K Y 4 J L V g Y N g r U Q z E o e C N c P B 9 d h v P j J t u J J 3 M E x Y E J O e 5 B G n B K z k P 9 1 n J X 7 6 c D L q Q K d Q d M v u B H i R e D N S R D P U O o W v d l f R N G Y S q C D G + J 6 b Q J A R D Z w K N s q 3 U 8 M S Q g e k x 3 x L J Y m Z C b L J y S N 8 b J U u j p S 2 J Q F P 1 N 8 T G Y m N G c a h 7 Y w J 9 M 2 8 N x b / 8 / w U o s s g 4 z J J g U k 6 X R S l A o P C 4 / 9 x l 2 t G Q Q w t I V R z e y u m f a I J B Z t S 3 o b g z b + 8 S B q V s n d W r t y e F 6 t X s z h y 6 B A d o R L y 0 A W q o h t U Q 3 V E k U L P 6 B W 9 O e C 8 O O / O x 7 R 1 y Z n N H K A / c D 5 / A L N u k N 8 = < / l a t e x i t > ! t < l a t e x i t s h a 1 _ b a s e 6 4 = " f c u U 7 4 a h y P J x c f D b f x g f d y / 9 C j A = " > A A A B + n i c b V B N S 8 N A E N 3 4 W e t X q k c v w S J 4 K k k V 9 F j 0 4 r G C / Y A 2 h M 1 2 k y 7 d 7 I b d i a X E / h Q v H h T x 6 i / x 5 r 9 x 2 + a g r Q 8 G H u / N M D M v T D n T 4 L r f 1 t r 6 x u b W d m m n v L u 3 f 3 B o V 4 7 a W m a K 0 B a R X K p u i D X l T N A W M O C 0 m y q K k 5 D T T j i 6 n f m d R 6 o 0 k + I B J i n 1 E x w L F j G C w U i B X e l z K W L F 4 i F g p e Q 4 g M C u u j V 3 D m e V e A W p o g L N w P 7 q D y T J E i q A c K x 1 z 3 N T 8 H O s g B F O p + V + p m m K y Q j H t G e o w A n V f j 4 / f e q c G W X g R F K Z E u D M 1 d 8 T O U 6 0 n i S h 6 U w w D P W y N x P / 8 3 o Z R N d + z k S a A R V k s S j K u A P S m e X g D J i i B P j E E E w U M 7 c 6 Z I g V J m D S K p s Q v O W X V 0 m 7 X v M u a v X 7 y 2 r j p o i j h E 7 Q K T p H H r p C D X S H m q i F C B q j Z / S K 3 q w n 6 8 V 6 t z 4 W r W t W M X O M / s D 6 / A E K u Z S B < / l a t e x i t > … ⌧ (1) < l a t e x i t s h a 1 _ b a s e 6 4 = " I 1 0 0 t q L n M W g K 0 and 5x through average pooling. Then, the downsampled features are passed through two residual blocks and finally upsampled by 2x and 5x to generate features which are of the same size as the input. The representations learned at 2x and 5x resolutions are concatenated with the input to generate composite multi-scale features. 1 A d + O J Y H z J I 3 H U = " > A A A B + X i c b V B N S 8 N A E N 3 4 W e t X 1 K O X x S L U S 0 m q o M e i F 4 8 V 7 A c 0 s W y 2 m 3 b p Z h N 2 J 4 U S 8 k + 8 e F D E q / / E m / / G b Z u D t j 4 Y e L w 3 w 8 y 8 I B F c g + N 8 W 2 v r G 5 t b 2 6 W d 8 u 7 e / s G h f X T c 1 n G q K G v R W M S q G x D N B J e s B R w E 6 y a K k S g Q r B O M 7 2 Z + Z 8 K U 5 r F 8 h G n C / I g M J Q 8 5 J W C k v m 1 7 I w K Z B y T N n 7 K q e 5 H 3 7 Y p T c + b A q 8 Q t S A U V a P b t L 2 8 Q 0 z R i E q g g W v d c J w E / I w o 4 F S w v e 6 l m C a F j M m Q 9 Q y W J m P a z + e U 5 P j f K A I e x M i U B z 9 X f E x m J t J 5 G g e m M C I z 0 s j c T / / N 6 K Y Q 3 f s Z l k g K T d L E o T A W G G M 9 i w A O u G A U x N Y R Q x c 2 t m I 6 I I h R M W G U T g r v 8 8 i p p 1 2 v u Z a 3 + c F V p 3 B Z x l N A p O k N V 5 K J r 1 E D 3 q I l a i K I J e k a v L I G O E V C I a 9 h c = " > A A A B + X i c b V B N S 8 N A E N 3 U r 1 q / o h 6 9 L B a h X k p S B T 0 W v X i S C v Y D m l g 2 2 2 2 7 d L M J u 5 N C C f k n X j w o 4 t V / 4 s 1 / 4 7 b N Q V s f D D z e m 2 F m X h A L r s F x v q 3 C 2 v r G 5 l Z x u 7 S z u 7 d / Y B 8 e t X S U K M q a N B K R 6 g R E M 8 E l a w I H w T q x Y i Q M B G s H 4 9 u Z 3 5 4 w p X k k H 2 E a M z 8 k Q 8 k H n B I w U s + 2 v R G B 1 A O S Z A i S N i N g h 6 D X j x G M A / Y r G F 2 M p u M m Z 1 Z Z n q F u O Q z v H h Q x K t f 4 8 2 / c f I 4 a G J B Q 1 H V T X d X m A h u w H W / n a X l l d W 1 9 d x G f n N r e 2 e 3 s L f f M C r V l N W p E k q 3 Q m K Y 4 J L V g Y N g r U Q z E o e C N c P B 9 d h v P j J t u J J 3 M E x Y E J O e 5 B G n B K z k P 9 1 n J X 7 6 c D L q Q K d Q d M v u B H i R e D N S R D P U O o W v d l f R N G Y S q C D G + J 6 b Q J A R D Z w K N s q 3 U 8 M S Q g e k x 3 x L J Y m Z C b L J y S N 8 b J U u j p S 2 J Q F P 1 N 8 T G Y m N G c C V Q Q r b u u E 4 O X E Q W c C j Y p 9 h L N Y k L H Z M i 6 h k o S M u 1 l s + M n + M Q o A x x E y p Q E P F N / T 2 Q k 1 D o N f d M Z E h j p R W 8 q / u d 1 E w i u v I z L O A E m 6 X x R k A g M E Z 4 m g Q d c M Q o i N Y R Q x c 2 t m I 6 I I h R M X k U T g r v 4 8 j J p V S v u e a V 6 f 1 G u 3 e R x F N A R O k a n y E W X q I b u U B 0 1 E U U p e k a v o Y V S i K a t T J Z R u h c Q w w S W r A w f B W r F m J A o F a 4 a j u 5 n f f G T a c C V r M I l Z E J G B 5 H 1 O C V i p 6 x b g B p c 7 m g + G Q L R W Y 1 z r u k W v 5 M 2 B V 4 m f k S L K U O 2 6 X 5 2 e o k n E J F B B j G n 7 X g x B S j R w K t g 0 3 0 k M i w k d k Q F r W y p J x E y Q z k + f 4 j O r 9 H B f a V s S 8 F z 9 P Z G S y J h J F N r O i M D Q L H s z 8 T + v n U D / O k i 5 j B N g k i 4 W 9 R O B Q e F Z D r j H N a M g J p Y Q q r m 9 F d M h 0 Y S C T S t v Q / C X X 1 4 l j X L J v y i V H y 6 L l d s s j h w 6 Q a f o H P n o C l X Q P a q i O q J o j J 7 R • Aggregation module: a 1-D convolution that aggregates the multi-scale features. • Interpolation module: it average-pools the aggregated features corresponding to the dynamic period. Then the outputs are upsampled through nearest neighbours interpolation to obtain the hidden presentations for the relations. Sampling. We utilize variational inference (Kingma & Welling, 2013) to sample the latent variables from hidden representations. Specifically, assume the interaction posterior z (ij) t to be categorical: (Jang et al., 2017) , we can reparameterize the categorical distribution as: q φ (z (ij) t |x <t ) ∼ Cat(p 1 , • • • , p k ) Using the Gumbel-Max trick z (ij) t = Softmax(h (ij) t + g (ij) t ). Here h (ij) t is the hidden states of the encoder and g (ij) t is a random Gumbel vector. Note that a defining feature of DYARI is that the latent variable z (ij) t is time-dependent, requiring fine-grained modeling. Our encoder ensures that the learned representations are expressive enough to capture such complex dynamics. Decoder. Given the sampled latent variables, the decoder generates the prediction auto-regressively following a Gaussian distribution: p(x (i) t+1 |x <t , z (ij) t ) = N (x (i) t+1 |µ (i) t+1 , σ 2 I) (3) µ i t+1 = f dec ( j =i k z (ij) t,k u k ; θ), u k = f k mlp (x (i) t , x (j) t ) Here the output x (i) t+1 is reparameterized by a Gaussian distribution with mean µ (i) t+1 and a fixed standard deviation σ 2 . The mean vector µ (i) t+1 of agent i is computed by aggregating the hidden states of all other agents. We use a separate MLP to encode the previous inputs into different type of edges in a k-dimensional one-hot vector z (ij) t . To generate long-term predictions using the model in Eqn. (4), we can also incorporate the predictions from the previous time step. The decoder architecture is the same as in NRI at a given time step, which consists of message passing GNN operations, followed by a GRU (Cho et al., 2014) decoder. Inference. At every time step t, we learn a different distribution for the hidden relation z . We assume a uniform prior for p θ (z t ) and use ELBO as the optimization objective: L ELBO = E[log p θ (x <T |z <T )] -βd KL [q φ (z <T |x <T )||p θ (z <T )]] (5) = - N i=1 T t=1 (µ (i) t -x (i) t ) 2 2σ 2 + β N i,j T t=1 H(q φ (z (ij) t |x t )) where the mean vector µ (i) t is parameterized by the decoder. H is the entropy function and β balances the two terms in ELBO (Higgins et al., 2016) .

4. EXPERIMENTS

We conduct extensive experiments on simulated physics dynamics and real-world basketball trajectories. The majority of our experiments are based on the physics simulation in the Spring environment. This is ideal for model verification and ablative study as we know the ground truth relations. Data Generation The Spring environment (Kipf et al., 2018) simulates the movements of a group of particles connected by a spring in a box. The hidden relation is whether there is a spring connecting the two particles. To simulate dynamic relations, we generate the trajectories by removing and adding back the springs following certain patterns. Figure 4 visualizes the trajectories resulting from such dynamic relations. Starting from the bottom, the two-particle trajectories appear as straight lines and bend in the middle due to the spring force, and return to straight lines after the removal of the spring. We define the number of time steps between the change of relations as the dynamic period. The primary challenges for dynamic relation inference arise along two dimensions: 1. The shorter the dynamic period, the more frequent the relation changes. Hence, it becomes more difficult to infer relations with shorter dynamic period. 2. If the dynamic period itself also changes, then the task becomes much harder because the model also needs to adapt to the unknown period in the changing relations. Note that the way relations change in the trajectories must follow certain pattern and not be completely random. Otherwise it would be impossible to learn anything meaningful. We experiment with two types of dynamic relations: periodic dynamics and additive dynamics. For periodic dynamics, we generate the trajectories by periodically removing and adding back the springs. We investigate the model performance by generating data with different frequencies of periodicity. For additive dynamics, we assume the dynamic period is increasing arithmetically. Each trajectory is of length 50 and the decoding length is 40, see details of the generated dataset in Appendix.

Baselines and Setup

We consider several baselines for comparison: (1) NRI (static): unsupervised NRI with an encoder trained using the entire trajectory and infer repeatedly over time. This corresponds to NRI (learned) in (Kipf et al., 2018) . ( 2) NRI (adaptive): NRI (static) with an encoder trained over sub-trajectories. The encoding length corresponds exactly to the dynamic period of the dataset. We use the NRI (static) decoder to predict the entire trajectory in an auto-regressive fashion. (3) Interaction Networks (IN) (Battaglia et al., 2016) : a supervised GNN model which uses the ground truth relations to predict future trajectories. We include this supervised learning model as the "gold standard" for our inference tasks. It is important to note here that (Graber & Schwing, 2020 ) also propose a model, dNRI, for this problem but the focus of their work is trajectory prediction whereas we focus on unsupervised relational inference. In our experiments with their model, we observed a relational inference accuracy of 0.505 on our periodic data with dynamic period 20. On the other hand, the same model gives an accuracy of 0.66 on the 3-particle synthetic data presented in their paper. Therefore, dNRI is unable to infer relations in an unsupervised setting. In practice, we do not know the dynamic period beforehand. Therefore, how often we infer the relations is a difficult choice: rare inference would miss the time steps where relations change while predicting too frequently introduces more latent variables and complicates the inference. To investigate this trade-off, we define inference period as the number of time steps between two predicted relations. Unless otherwise noted, the inference period in our experiments is the same as the dynamic period. All the models are trained to predict the sequence in an auto-regressive fashion: the prediction of the current time step is fed as the input to the next time step. We use Adam (Kingma & Ba, 2014) optimizer with learning rate 5e -4 and weight decay 1e -4 and train for 300 epochs.

4.1.1. PATHOLOGICAL CASES OF NEURAL RELATIONAL INFERENCE

It is known that latent variable models suffer from the problem of identifiability (Koopmans & Reiersol, 1950) , which means certain parameters, in principle, cannot be estimated consistently. NRI infers correlation-like relations between trajectories which highly depend on the length of the time lag. To test this hypothesis, we follow the exact same setting as Kipf et al. (2018) to infer the interaction graph. Instead of decoding 50 time steps, we vary the length of input and output sequence. Table 1 summarizes the inference accuracy with different sequence length in the encoder and decoder. We can see that the performance of NRI deteriorates drastically with shorter training sequences, simply increasing the capacity of the encoder (NRI++) does not help. One plausible explanation is that NRI is learning correlation-like interactions. Shorter decoding sequences carry less information about correlations, making it harder to learn. Meanwhile, we also observed that using auto-regressive can achieve better inference accuracy compared to teacher forcing. The pathological cases highlight the issue of NRI for dynamic relational inference. If the interactions change frequently every few time steps, repeatedly applying NRI to different time steps would suffer from short decoding sequences. Therefore, having a model that can jointly infer a sequence of relations is critical. Table 1 : Inference accuracy (%) of NRI trained with trajectory lengths. Note that the performance deteriorates significantly when the output length decreases. For NRI++, we added two more hidden layers to the MLP encoder of NRI. We compare the performance of different models for dynamic relational inference tasks. Periodic dynamics In the periodic scenario, the dynamic period is fixed. We generate four datasets with a dynamic period of 40, 20, 8, 4 to simulate relational dynamics with increasing frequency. Table. 2 columns "40, 20, 8, 4" show the trajectory prediction mean square error (MSE) and interaction inference accuracy comparison of different methods. We can see that all methods can achieve almost perfect predictions of the trajectories with very low MSE. However, there is a sharp difference in relational inference accuracy. NRI (static) is incapable of learning dynamic interactions. NRI(adaptive) can learn but has lower accuracy due to short decoding sequences. With a more expressive encoder and joint decoding, DYARI is able to reach higher accuracy. When the dynamic period is very small at 4, even DYARI struggles slightly, suggesting the fundamental difficulty with frequently changing dynamics. 2.2e-4 5.2e-3 2.7e-3 2.4e-3 3.6e-3 0.99 0.52 0.51 0.50 0.53 NRI (adaptive) 2.2e-4 2.7e-3 1.3e-3 5.9e-4 3.1e-3 0.99 0.81 0.80 0.69 0.81 DYARI 2.6e-5 4.1e-5 4.6e-6 3.6e-6 7.6e-6 0.99 0.92 0.91 0.74 0.87 IN 2.9e-5 2.3e-5 4.3e-5 4.7e-5 3.9e-5 0.99 0.99 0.99 0.98 0.99 Additive Dynamics In the additive scenario, we allow the dynamic period itself to increase arithmetically. We increase the dynamic period in steps of 4 starting from a dynamic period of 4 timesteps. In a sequence of 40 timesteps, this implies that the relations (spring connection) get flipped at timesteps 4, 12 and 24. We use four NRI(static) models, each trained separately with 4, 8, 12 and 16 encoding timesteps. We combine the ensemble model predictions into NRI(adaptive). Table . 2 "Add" columns show the performance comparison. Similar to the periodic scenario, DYARI outperforms the baselines in this challenging task as well. Note that NRI(adaptive) is a close competitor w.r.t inference accuracy, but it is a four model ensemble and takes a long time to train.

4.1.3. ABLATIVE STUDY

We perform ablative study to further validate our experiment design and understand the behavior of DYARI. In particular, we study the trade-off between dynamic and inference period, the effect of training scheme, as well as the ablative study of model architecture design. In Table 3 , we observe that dynamic relational inference reaches the highest accuracy when the inference period matches the dynamic period. If the inference period is longer than the dynamic period, the model can miss the changes in the relations and completely fails to perform inference. Meanwhile, if the inference period is shorter than the dynamic period, the model still can learn but suffers from low accuracy. This is potentially due to the extra uncertainty introduced by estimating more latent variables. Decoder Training Scheme Another fundamental challenge in sequence prediction is covariate shift (Bickel et al., 2009 ) -a mismatch between distribution in training and testing -due to sequential dependency. Common solutions to mitigate covariate shift include teacher forcing (Williams & Zipser, 1989 ) and scheduled sampling (Bengio et al., 2015) . However, all these work are focused the prediction of observed sequence while our sequence predictions are on the latent variables. It is not evident that covariate shift exists in this setting. We demonstrate the empirical evidence for the effect of different training schemes on the accuracy of relational inference. Quite surprisingly, we found that auto-regressive training is most effective for dynamic relations inference. Figure . 5 summarizes the difference in learning curve between using teacher forcing and auto-regressive for different dynamic periods. We also include a version of scheduled sampling (hybrid): in the first 30 time-steps, we train the model with teacher forcing and then switch to auto-regressive in the last 10 time-steps. We observe that while teacher forcing converges faster, it leads to lower accuracy. This observation is consistent across different dynamic periods. Therefore, auto-regressive training is preferred for dynamic relation inference.

4.2. REAL-WORLD BASKETBALL DATA EXPERIMENTS

To showcase the practical value of dynamic relational inference, we apply DYARI to a real-world basketball trajectory dataset. The goal of the experiment is to extract meaningful "hidden" relations in competitive basketball plays. The basketball dataset contains trajectories for 10 players in a game. As the ground-truth relations are unknown, we use the trajectory prediction MSE and negative ELBO as in-direct measures for the dynamic relational inference performance. We assume there are two types of hidden relations: coordination and competition. We defer the details of the dataset and training setup to the Appendix. We report performance comparisons for different inference periods. As shown in Table 4 , we observe lower MSE loss and negative ELBO with shorter inference period. Intuitively, the interactions in the real world may change constantly, thus shorter inference period can capture the dynamics better. DYARI outperforms the baselines in trajectory prediction MSE and negative ELBO loss. Notice that NRI(adaptive) is using encoder and decoder that are trained separately and this results in a high negative ELBO loss on the test set. Fig. 6 visualizes a sample trajectory of 10 basketball players with inferred relations from DYARI over different time steps. We separate coordination and competition interactions in different rows. In Fig. 6 , Kobe Bryant is moving along with three-point line and guarded by a defender. We can see clear attention drawn to the specific players throughout the play. See Appendix for other inferred relations.

5. CONCLUSION

We investigate unsupervised learning of dynamic relations in multi-agent trajectories. We propose a novel deep generative model: Dynamic multi-Agent Relational Inference (DYARI) to infer changing relations over time. We conduct extensive experiments using a simulated physics system to study the performance of DYARI in handling various dynamic relations. We perform ablative study to understand the effect of dynamic and inference period, training scheme and model design choice. Compared with static NRI and its variant, our DYARI model demonstrates significant improvement in simulated physics systems as well as in a real-world basketball trajectory dataset. 

C ADDITIONAL EXPERIMENTS

Stochastic dynamics In order to make the problem even harder and to unify all the previous settings, we generate a dataset where the edge types are flipped randomly with a probability p after each dynamic period of 4 timesteps. The static data generation corresponds to p = 0 and the periodic dynamics corresponds to p = 1. Table 5 shows the MSE and inference accuracy of NRI, DYARI and Interaction Networks on the stochastic dataset for flipping probabilities p = 0.8 and p = 0.9. The Effect of Average Pooling We perform an ablation study where we remove the average pooling corresponding to the inference period in the interpolation module to study the effect of this average pooling on the results. We find that without average pooling, inference accuracy decreases from 0.92 to 0.59 (as shown in Table 6 ) for inference period at 20. Here we drop the average pooling corresponding to the inference periods and directly interpolate to the sequence length. This shows that intermediate average pooling is critical for the relational inference performance.

Additional Inferred Relations in Basketball Trajectories

We set the number of relations in Basketball dataset as two. In Sec 4.2, we visualized one of the inferred relations. Table 8 visualizes the second relations. Notice that the first relation captures focus on the rightmost red player while here the relation captures focus on the leftmost red player.



Figure1: Neural Relational Inference for learning the interaction graph. Picture taken from(Kipf et al., 2018)

Figure 2: Probabilistic graphical model representation of dynamic multi-agent relational inference. Shaded nodes are observed and unshaded nodes are hidden variables.

6 M 3 K r B f r 3 f p Y t K 5 Z x c w J + g P r 8 w c z b Z N e < / l a t e x i t > ⌧ (N ) < l a t e x i t s h a 1 _ b a s e 6 4 = " J z r + g 1 h L i u p o d b T

E 9 p 5 f 4 8 6 9 l l p + r M g V e J m 5 M y y t H o 2 V 9 e P 6 J J yC R Q Q b T u u k 4 M f k o U c C p Y V v I S z W J C x 2 T I u o Z K E j L t p / P L M 3 x m l D 4 e R M q U B D x X f 0 + k J N R 6 G g a m M y Q w 0 s v e T P z P 6 y Y w u P Z T L u M E m K S L R Y N E Y I j w L A b c 5 4 p R E F N D C F X c 3 I r p i C h C w Y R V M i G 4 y y + v k l a t 6 l 5 U a w + X 5 f p N H k c R n a B T V E E u u k J 1 d I c a q I ko m q B n 9 I r e r N R 6 s d 6 t j 0 V r w c p n j t E f W J 8 / X 5 u T e w = = < / l a t e x i t > Trajectory Output z (i,j) t < l a t e x i t s h a 1 _ b a s e 6 4 = " P G W 7 3 w W z M Q + N e D V d N W i t B 5 H D / s U = " > A A A B 8 n i c b V D L S g N B E J z 1 G e M r 6 t H L Y B

a h 7 Y w J 9 M 2 8 N x b / 8 / w U o s s g 4 z J J g U k 6 X R S l A o P C 4 / 9 x l 2 t G Q Q w t I V R z e y u m f a I J B Z t S 3 o b g z b + 8 S B q V s n d W r t y e F 6 t X s z h y 6B A d o R L y 0 A W q o h t U Q 3 V E k U L P 6 B W 9 O e C 8 O O / O x 7 R 1 y Z n N H K A / c D 5 / A L N u k N 8 = < / l a t e x i t > t = 1 ! T 1 < l a t e x i t s h a 1 _ b a s e 6 4 = " 7 l O S g h f Z h C I j C M K 1 w n k v k C E w 2 H k = " > A A A B / H i c b V B N S 8 N A E N 3 4 W e t X t E c v i 0 X w Y k m q o B e h 6 M V j h X5 B G 8 p m u 2 m X b j Z h d 6 K E U P + K F w + K e P W H e P P f u G 1 z 0 N Y H A 4 / 3 Z p i Z 5 8 e C a 3 C c b 2 t l d W 1 9 Y 7 O w V d z e 2 d 3 b t w 8 O W z p K F G V N G o l I d X y i m e C S N Y G D Y J 1 Y M R L 6 g r X 9 8 e 3 U b z 8 w p X k k G 5 D G z A v J U P K A U w J G 6 t s l u M Z u T / H h C I h S 0 S N u n L l 9 u + x U n B n w M n F z U k Y 5 6 n 3 7 q z e I a B I y

6 M 1 6 s l 6 s d + t j 3 r p i 5 T M l 9 A f W 5 w 9 O q p P m < / l a t e x i t > t = 2 ! T < l a t e x i t s h a 1 _ b a s e 6 4 = " 2 t m E 9 bt I U d r v N H t i Y 8 3 R p B A q J G E = " > A A A B + n i c b V B N S w M x E M 3 6 W e v X V o 9 e g k X w V H a r o B e h 6 M V j h X 5 B u 5 R s m r a h 2 W R J Z i 1 l 7 U / x 4 k E Rr / 4 S b / 4 b 0 3 Y P 2 v p g 4 P H e D D P z w l h w A 5 7 3 7 a y t b 2 x u b e d 2 8 r t 7 + w e H b u G

Figure 3: Visualization of the DYARI model. It infers pairwise relations at different time steps given trajectories. The bottom diagram shows the details of the encoder and the decoder is the same as NRI.

Figure 4: Example trajectories of two particles. with two relation changes. The trajectories start from the end with lighter color and gradually become darker.

Figure 5: Learning curve of DYARI trained with teacher forcing (blue), auto-regressive (yellow), and hybrid (green) in decoder. Hybrid model is first trained with teacher forcing in the beginning 30 time steps and then auto-regressively in the later 10 time steps. We report the relational inference accuracy on the validation data for different dynamic periods.

Figure 6: Visualization of the inferred relations (dashed links) in the basketball players trajectories over time by DYARI with an inference period of 8. The blue dashed links in the top are the inferred interactions from the same team (coordination) and red dashed links in the bottom are from different teams (competition). Different columns represent different time steps.

Figure 8: Visualization of the basketball players trajectories with inference period = 8. The top row visualizes the inferred interactions from the same team (coordination) and the bottom row visualizes the inferred interactions from different teams (competition). Different columns represent different time steps.

Performance comparison for ours and the baselines in both the periodic(40,20,8,4)  and additive (Add) dynamic scenarios. MSE is for trajectory prediction and Accuracy quantifies the dynamic relational inference performance.

Inference accuracy for different combinations of dynamic and inference periods with DYARI.

). Different columns represent different time steps. Performance comparison for DYARI and baselines on the real-world basketball trajectory dataset with different inference periods 40, 20, 8 and 4.

Qualitative results for stochastic dynamics. Accuracy improves by increasing the model capacity. In the training, The inference period of the two DYARI match with the dynamic period.

Results with and without average pooling in the interpolation module of DYARI.

A MODEL IMPLEMENTATION DETAILS

In this section, we include some details about the model implementation, especially the encoder part. Our encoder is analogical to ResNet CNNs (He et al., 2016) used in the field of image recognition, where the task can be abstracted to be a classification problem on 1D dimension. Meanwhile, inspired by PSPNet used in visual scene semantic parsing (Zhao et al., 2017) , we add additional 2 global feature extractors to combine the whole-sequence (global) features and the sub-sequence (local) features.for DYARI, each residual block shown in Fig. 3 consists of 4 skip connections structure. 

B EXPERIMENTAL DETAILS

Particle dataset In general, we use the same pre-processing in NRI. Each raw simulated trajectory has length of 5000 and we sample with frequency of 100 so that each sample has length of 50 in our dataset. Correspondingly, the value of dynamic period/inference period matches the length of sample in our dataset. For instance, dynamic period = 10 means that the in the raw trajectory, the state of a node changes every 1000 time steps. In addition, The value of trajectories are all normalized to range of [0, 1] and the evaluation is done on the same range as well.Basketball dataset details The basketball dataset consists of trajectory from 30 teams. The raw trajectory is captured with frequency of 25 ms. For our experiment, we sample the trajectory with frequency of 50 ms for more evident player movements. We use a inference period that matches the length of sample. For instance, inference period = 10 means that our model produce prediction every 500 ms. The resulting dataset include 50,000 training samples, 10,000 validations samples and 10,000 test samples.We normalize the values of the trajectories to range [0,1] and train all the models in an auto-regressive fashion. We use the same training set up as in physics simulation experiments with a batch size of 64.

