LEARNING ONLINE DATA ASSOCIATION

Abstract

When an agent interacts with a complex environment, it receives a stream of percepts in which it may detect entities, such as objects or people. To build up a coherent, low-variance estimate of the underlying state, it is necessary to fuse information from multiple detections over time. To do this fusion, the agent must decide which detections to associate with one another. We address this dataassociation problem in the setting of an online filter, in which each observation is processed by aggregating into an existing object hypothesis. Classic methods with strong probabilistic foundations exist, but they are computationally expensive and require models that can be difficult to acquire. In this work, we use the deeplearning tools of sparse attention and representation learning to learn a machine that processes a stream of detections and outputs a set of hypotheses about objects in the world. We evaluate this approach on simple clustering problems, problems with dynamics, and a complex image-based domain. We find that it generalizes well from short to long observation sequences and from a few to many hypotheses, outperforming other learning approaches and classical non-learning methods.

1. INTRODUCTION

Consider a robot operating in a household, making observations of multiple objects as it moves around over the course of days or weeks. The objects may be moved by the inhabitants, even when the robot is not observing them, and we expect the robot to be able to find any of the objects when requested. We will call this type of problem entity monitoring. It occurs in many applications, but we are particularly motivated by the robotics applications where the observations are very high dimensional, such as images. Such systems need to perform online data association, determining which individual objects generated each observation, and state estimation, aggregating the observations of each individual object to obtain a representation that is lower variance and more complete than any individual observation. This problem can be addressed by an online recursive filtering algorithm that receives a stream of object detections as input and generates, after each input observation, a set of hypotheses corresponding to the actual objects observed by the agent. When observations are closely spaced in time, the entity monitoring problem becomes one of tracking and it can be constrained by knowledge of the object dynamics. In many important domains, such as the household domain, temporally dense observations are not available, and so it is important to have systems that do not depend on continuous visual tracking. A classical solution to the entity monitoring problem, developed for the tracking case but extensible to other dynamic settings, is a data association filter (DAF) (the tutorial of Bar-Shalom et al. (2009) provides a good introduction). A Bayes-optimal solution to this problem can be formulated, but it requires representing a number of possible hypotheses that grows exponentially with the number of observations. A much more practical, though much less robust, approach is a maximum likelihood DAF (ML-DAF), which commits, on each step, to a maximum likelihood data association: the algorithm maintains a set of object hypotheses, one for each object (generally starting with the empty set) and for each observation it decides to either: (a) associate the observation with an existing object hypothesis and perform a Bayesian update on that hypothesis with the new data, (b) start a new object hypothesis based on this observation, or (c) discard the observation as noise. The engineering approach to constructing a ML-DAF requires many design choices, including the specification of a latent state space for object hypotheses, a generative model relating observations to objects, and thresholds or other decision rules for choosing, for a new observation, whether to associate it with an existing hypothesis, use it to start a new hypothesis, or discard it. In any particular application, the engineer must tune all of these models and parameters to build a DAF that performs well. This is a time-consuming process that must be repeated for each new application. A special case of entity monitoring is one in which the objects' state is static, and does not change over time. In this case, a classical solution is online (robust) clustering. Clustering algorithms perform data association (cluster assignment) an state estimation (computing a cluster center). In this paper we explore training neural networks to perform as DAFs for dynamic entity monitoring and as online clustering methods for static entity monitoring. Although it is possible to train an unstructured RNN to solve these problems, we believe that building in some aspects of the structure of the DAF will allow faster learning with less data and allow the system to address problems with a longer horizon. We begin by briefly surveying the related literature, particularly focused on learning-based approaches. We then describe a neural-network architecture that uses self-attention as a mechanism for data association, and demonstrate its effectiveness in several illustrative problems. We find that it outperforms a raw RNN as well as domain-agnostic online clustering algorithms, and competitively with batch clustering strategies that can see all available data at once and with state-of-the-art DAFs for tracking with hand-built dynamics and observation models. Finally, we illustrate its application to problems with images as observations in which both data association and the use of an appropriate latent space are critical.

2. RELATED WORK

Online clustering methods The typical setting for clustering problems is batch, where all the data is presented to the algorithm at once, and it computes either an assignment of data points to clusters or a set of cluster means, centers, or distributions. We are interested in the online setting, with observations arriving sequentially and a cumulative set of hypotheses output after each observation One of the most basic online clustering methods is vector quantization, articulated originally by Gray (1984) and understood as a stochastic gradient method by Kohonen (1995) . It initializes cluster centers at random and assigns each new observation to the closest cluster center, and updates that center to be closer to the observation. Methods with stronger theoretical guaranteees, and those that handle unknown numbers of clusters have also been developed. Charikar et al. (2004) formulate the problem of online clustering, and present several algorithms with provable properties. Liberty et al. (2016) explore online clustering in terms of the facility allocation problem, using a probabilistic threshold to allocate new clusters in data. Choromanska and Monteleoni (2012) formulate online clustering as a mixture of separate expert clustering algorithms. Dynamic domains In the setting when the underlying entities have dynamics, such as airplanes observed via radar, a large number of DAFs have been developed. The most basic filter, for the case of a single entity and no data association problem, is the Kalman filter (Welch and Bishop, 2006) . In the presence of data-association uncertainty the Kalman filter can be extended by considering assignments of observations to multiple existing hypotheses under the multiple hypothesis tracking (MHT) filter. A more practical approach that does not suffer from the combinatorial explosion of the MHT is the joint probabilistic data association (JPDA) filter, which keeps only one hypothesis but explicitly reasons about the most likely assignment of observations to hypotheses. Bar-Shalom et al. (2009) provides a detailed overview and comparison of these approaches, all of which require hand-tuned transition and observation models. Learning for clustering There is a great deal of work using deep-learning methods to find latent spaces for clustering complex objects, particularly images. Min et al. (2018) provide an excellent survey, including methods with auto-encoders, GANs, and VAEs. Relevant to our approach are amortized inference methods, including set transformers (Lee et al., 2018) and its specialization to deep amortized clustering (Lee et al., 2019) , in which a neural network is trained to map directly from data to be clustered into cluster assignments or centers. A related method is neural clustering processes (Pakman et al., 2019) , which includes an online version, and focuses on generating samples from a distribution on cluster assignments, including an unknown number of clusters. Visual data-association methods Data association has been explored in the context of visual object tracking (Luo et al., 2014; Xiang et al., 2015; Bewley et al., 2016; Brasó and Leal-Taixé, 2020; Ma et al., 2019; Sun et al., 2019; Frossard and Urtasun, 2018) . In these problems, there is typically a fixed visual field populated with many smoothly moving objects. This is an important special case of the general data-association. It enables some specialized techniques that take advantage of the fact that the observations of each object are typically smoothly varying in space-time, and incorporate additional visual appearance cues. In contrast, in our setting, there is no fixed spatial field for observations and they may be temporally widely spaced, as would be the case when a robot moves through the rooms of a house, encountering and re-encountering different objects as it does so. Our emphasis is on this long-term data-association and estimation, and our methods are not competitive with specialized techniques on fixed-visual-field tracking problems. Learning for data association There is relatively little work in the area of generalized data association, but Liu et al. (2019) provide a recent application of LSTMs (Hochreiter and Schmidhuber, 1997) to a rich version of the data association problem, in which batches of observations arrive simultaneously, with a constraint that each observation can be assigned to at most one object hypothesis. The sequential structure of the LSTM is used here not for recursive filtering, but to handle the variable numbers of observations and hypotheses. It is assumed that Euclidean distance is an appropriate metric and that the observation and state spaces are the same. Milan et al. (2017) combine a similar use of LSTM for data association with a recurrent network that learns to track multiple targets. It learns a dynamics model for the targets, including birth and death processes, but operates in simple state and observation spaces. Algorithmic priors for neural networks One final comparison is to other methods that integrate algorithmic structure with end-to-end neural network training. This approach has been applied to sequential decision making by Tamar et al. (2016) , particle filters by Jonschkowski et al. (2018) , and Kalman filters by Krishnan et al. (2015) , as well as to a complex multi-module robot control system by Karkus et al. (2019) . The results generally are much more robust than completely hand-built models and much more sample-efficient than completely unstructured deep-learning. We view our work as an instance of this general approach.

3. PROBLEM FORMULATION

The problem of learning to perform online data association requires careful formulation. When the DAF is executed online, it will receive a stream of input detections z 1 , . . . z T where z t ∈ R dz , and after each input z t , it will output two vectors, y t = [y tk ] k∈(1..K) and c t = [c tk ] k∈(1..K) , where y tk ∈ R dy , c tk ∈ (0, 1) and k c tk = 1. The y values in the output represent the predicted properties of the hypothesized objects and the c values represent a measure of confidence in the hypotheses, in terms of the proportion of data that each one has accounted for. The maximum number of hypothesis "slots" is limited in advance to K. In some applications, the z and y values will be in the same space with the same representation, but this is not necessary. We have training data representing N different data-association problems, D = {(z (i) t , m (i) t ) t∈(1..Li) } i∈(1..N ) , where each training example is an input/output sequence of length L i , each element of which consists of a pair of input z and m = {m j } j∈(1..J (i) t ) which is a set of nominal object hypotheses representing the true current state of objects that have actually been observed so far in the sequence. It will always be true that m (i) t ⊆ m (i) t+1 and J (i) t ≤ K. Our objective is to train a recurrent computational model to perform DAF effectively in problems that are drawn from the same distribution as those in the training set. To do so, we formulate a model (described in section 4) with parameters θ, which transduces the input sequence z 1 , . . . , z L into an output sequence (y 1 , c 1 ), . . . , (y L , c L ), and train it to minimize the following loss function: L(θ; D) = N i=1 Li t=1 L obj (y (i) t , m (i) t ) + L slot (y (i) t , c (i) t , m (i) t ) + L sparse (c (i) t ) . The L obj term is a chamfer loss (Barrow et al., 1977) , which looks for the predicted y that is closest to each actual m k and sums their distances, making sure the model has found a good, high-confidence representation for each true object:  L obj (y, m) = j min k 1 c k + y k -m j . l A E 0 T o f b f d X T / C f C 5 F a 0 = " > A A A B 6 H i c b V D L S g N B E O y N r x h f U Y 9 e B o P g K e x K Q L 0 F v X h M w D w g W c L s p D c Z M z u 7 z M w K M e Q L v H h Q x K u f 5 M 2 / c Z L s Q R M L G o q q b r q 7 g k R w b V z 3 2 8 m t r W 9 s b u W 3 C z u 7 e / s H x c O j p o 5 T x b D B Y h G r d k A 1 C i 6 x Y b g R 2 E 4 U 0 i g Q 2 A p G t z O / 9 Y h K 8 1 j e m 3 G C f k Q H k o e c U W O l + l O v W H L L 7 h x k l X g Z K U G G W q / 4 1 e 3 H L I 1 Q G i a o 1 h 3 P T Y w / o c p w J n B a 6 K Y a E 8 p G d I A d S y W N U P u T + a F T c m a V P g l j Z U s a M l d / T 0 x o p P U 4 C m x n R M 1 Q L 3 s z 8 T + v k 5 r w y p 9 w m a Q G J V s s C l N B T E x m X 5 M + V 8 i M G F t C m e L 2 V s K G V F F m b D Y F G 4 K 3 / P I q a V 6 U v U r 5 u l 4 p V W + y O P J w A q d w D h 5 c Q h X u o A Y N Y I D w D K / w 5 j w 4 L 8 6 7 8 7 F o z T n Z z D H 8 g f P 5 A + 0 h j Q o = < / l a t e x i t > y < l a t e x i t s h a 1 _ b a s e 6 4 = " t 5 l W B l x n h P c 8 v a t g O C K M O 4 k p b D U = " > A A A B 6 H i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 L B b B U 0 m k o N 6 K X j y 2 Y G u h D W W z n b R r N 5 u w u x F K 6 C / w 4 k E R r / 4 k b / 4 b t 2 0 O 2 v p g 4 P H e D D P z g k R w b V z 3 2 y m s r W 9 s b h W 3 S z u 7 e / s H 5 c O j t o 5 T x b D F Y h G r T k A 1 C i 6 x Z b g R 2 E k U 0 i g Q + B C M b 2 f + w x M q z W N 5 b y Y J + h E d S h 5 y R o 2 V m p N + u e J W 3 T n I K v F y U o E c j X 7 5 q z e I W R q h N E x Q r b u e m x g / o 8 p w J n B a 6 q U a E 8 r G d I h d S y W N U P v Z / N A p O b P K g I S x s i U N m a u / J z I a a T 2 J A t s Z U T P S y 9 5 M / M / r p i a 8 8 j M u k 9 S g Z I t F Y S q I i c n s a z L g C p k R E 0 s o U 9 z e S t i I K s q M z a Z k Q / C W X 1 4 l 7 Y u q V 6 t e N 2 u V + k 0 e R x F O 4 B T O w Y N L q M M d N K A F D B C e 4 R X e n E f n x X l 3 P h a t B S e f O Y Y / c D 5 / A O u d j Q k = < / l a t e x i t > c < l a t e x i t s h a 1 _ b a s e 6 4 = " b h i S a p Z q B 4 Y p G i P q Z w i j 2 a j / H S s = " > A A A B 6 H i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 L B b B U 0 m k o N 6 K X j y 2 Y G u h D W W z n b R r N 5 u w u x F K 6 C / w 4 k E R r / 4 k b / 4 b t 2 0 O 2 v p g 4 P H e D D P z g k R w b V z 3 2 y m s r W 9 s b h W 3 S z u 7 e / s H 5 c O j t o 5 T x b D F Y h G r T k A 1 C i 6 x Z b g R 2 E k U 0 i g Q + B C M b 2 f + w x M q z W N 5 b y Y J + h E d S h 5 y R o 2 V m q x f r r h V d w 6 y S r y c V C B H o 1 / + 6 g 1 i l k Y o D R N U 6 6 7 n J s b P q D K c C Z y W e q n G h L I x H W L X U k k j 1 H 4 2 P 3 R K z q w y I G G s b E l D 5 u r v i Y x G W k + i w H Z G 1 I z 0 s j c T / / O 6 q Q m v / I z L J D U o 2 W J R m A p i Y j L 7 m g y 4 Q m b E x B L K F L e 3 E j a i i j J j s y n Z E L z l l 1 d J + 6 L q 1 a r X z V q l f p P H U Y Q T O I V z 8 O A S 6 n A H D W g B A 4 R n e I U 3 5 9 F 5 c d 6 d j 0 V r w c l n j u E P n M 8 f y k W M 8 w = = < / l a t e x i t > a < l a t e x i t s h a 1 _ b a s e 6 4 = " b i h f Q S g o h U C a i Z L H 8 E / K / p u H k 0 I = " > A A A B 6 H i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 L B b B U 0 m k o N 6 K X j y 2 Y G u h D W W z n b R r N 5 u w u x F K 6 C / w 4 k E R r / 4 k b / 4 b t 2 0 O 2 v p g 4 P H e D D P z g k R w b V z 3 2 y m s r W 9 s b h W 3 S z u 7 e / s H 5 c O j t o 5 T x b D F Y h G r T k A 1 C i 6 x Z b g R 2 E k U 0 i g Q + B C M b 2 f + w x M q z W N 5 b y Y J + h E d S h 5 y R o 2 V m r R f r r h V d w 6 y S r y c V C B H o 1 / + 6 g 1 i l k Y o D R N U 6 6 7 n J s b P q D K c C Z y W e q n G h L I x H W L X U k k j 1 H 4 2 P 3 R K z q w y I G G s b E l D 5 u r v i Y x G W k + i w H Z G 1 I z 0 s j c T / / O 6 q Q m v / I z L J D U o 2 W J R m A p i Y j L 7 m g y 4 Q m b E x B L K F L e 3 E j a i i j J j s y n Z E L z l l 1 d J + 6 L q 1 a r X z V q l f p P H U Y Q T O I V z 8 O A S 6 n A H D W g B A 4 R n e I U 3 5 9 F 5 c d 6 d j 0 V r w c l n j u E P n M 8 f x z 2 M 8 Q = = < / l a t e x i t > r < l a t e x i t s h a 1 _ b a s e 6 4 = " y g y q l k F M 1 M l D 0 9 s J / 2 Q 9 I l U Z a e Q = " > A A A B 6 H i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 L B b B U 0 m k o N 6 K X j y 2 Y G u h D W W z n b R r N 5 u w u x F K 6 C / w 4 k E R r / 4 k b / 4 b t 2 0 O 2 v p g 4 P H e D D P z g k R w b V z 3 2 y m s r W 9 s b h W 3 S z u 7 e / s H 5 c O j t o 5 T x b D F Y h G r T k A 1 C i 6 x Z b g R 2 E k U 0 i g Q + B C M b 2 f + w x M q z W N 5 b y Y J + h E d S h 5 y R o 2 V m q p f r r h V d w 6 y S r y c V C B H o 1 / + 6 g 1 i l k Y o D R N U 6 6 7 n J s b P q D K c C Z y W e q n G h L I x H W L X U k k j 1 H 4 2 P 3 R K z q w y I G G s b E l D 5 u r v i Y x G W k + i w H Z G 1 I z 0 s j c T / / O 6 q Q m v / I z L J D U o 2 W J R m A p i Y j L 7 m g y 4 Q m b E x B L K F L e 3 E j a i i j J j s y n Z E L z l l 1 d J + 6 L q 1 a r X z V q l f p P H U Y Q T O I V z 8 O A S 6 n A H D W g B A 4 R n e I U 3 5 9 F 5 c d 6 d j 0 V r w c l n j u E P n M 8 f 4 Q G N A g = = < / l a t e x i t > u < l a t e x i t s h a 1 _ b a s e 6 4 = " 9 1 v J i 9 j 8 k o o W 8 V k i N Q m x z 0 W l v 7 E = " > A A A B 6 H i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 L B b B U 0 m k o N 6 K X j y 2 Y G u h D W W z n b R r N 5 u w u x F K 6 C / w 4 k E R r / 4 k b / 4 b t 2 0 O 2 v p g 4 P H e D D P z g k R w b V z 3 2 y m s r W 9 s b h W 3 S z u 7 e / s H 5 c O j t o 5 T x b D F Y h G r T k A 1 C i 6 x Z b g R 2 E k U 0 i g Q + B C M b 2 f + w x M q z W N 5 b y Y J + h E d S h 5 y R o 2 V m m m / X H G r 7 h x k l X g 5 q U C O R r / 8 1 R v E L I 1 Q G i a o 1 l 3 P T Y y f U W U 4 E z g t 9 V K N C W V j O s S u p Z J G q P 1 s f u i U n F l l Q M J Y 2 Z K G z N X f E x m N t J 5 E g e 2 M q B n p Z W 8 m / u d 1 U x N e + R m X S W p Q s s W i M B X E x G T 2 N R l w h c y I i S W U K W 5 v J W x E F W X G Z l O y I X j L L 6 + S 9 k X V q 1 W v m 7 V K / S a P o w g n c A r n 4 M E l 1 O E O G t A C B g j P 8 A p v z q P z 4 r w 7 H 4 v W g p P P H M M f O J 8 / 5 Y 2 N B Q = = < / l a t e x i t > e < l a t e x i t s h a 1 _ b a s e 6 4 = " 1 7 M r n l p 5 4 + N c A c K P 6 5 0 8 W + O 6 A 1 A = " > A A A B 6 H i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 L B b B U 0 m k o N 6 K X j y 2 Y G u h D W W z n b R r N 5 u w u x F K 6 C / w 4 k E R r / 4 k b / 4 b t 2 0 O 2 v p g 4 P H e D D P z g k R w b V z 3 2 y m s r W 9 s b h W 3 S z u 7 e / s H 5 c O j t o 5 T x b D F Y h G r T k A 1 C i 6 x Z b g R 2 E k U 0 i g Q + B C M b 2 f + w x M q z W N 5 b y Y J + h E d S h 5 y R o 2 V m t g v V 9 y q O w d Z J V 5 O K p C j 0 S 9 / 9 Q Y x S y O U h g m q d d d z E + N n V B n O B E 5 L v V R j Q t m Y D r F r q a Q R a j + b H z o l Z 1 Y Z k D B W t q Q h c / X 3 R E Y j r S d R Y D s j a k Z 6 2 Z u J / 3 n d 1 I R X f s Z l k h q U b L E o T A U x M Z l 9 T Q Z c I T N i Y g l l i t t b C R t R R Z m x 2 Z R s C N 7 y y 6 u k f V H 1 a t X r Z q 1 S v 8 n j K M I J n M I 5 e H A J d b i D B r S A A c I z v M K b 8 + i 8 O O / O x 6 K 1 4 O Q z x / A H z u c P z U 2 M 9 Q = = < / l a t e x i t > s < l a t e x i t s h a 1 _ b a s e 6 4 = " X x x t D C j b h 7 L Z w F i E j R J V x A Q L G m c = " > A A A B 6 H i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 L B b B U 0 m k o N 6 K X j y 2 Y G u h D W W z n b R r N 5 u w u x F K 6 C / w 4 k E R r / 4 k b / 4 b t 2 0 O 2 v p g 4 P H e D D P z g k R w b V z 3 2 y m s r W 9 s b h W 3 S z u 7 e / s H 5 c O j t o 5 T x b D F Y h G r T k A 1 C i 6 x Z b g R 2 E k U 0 i g Q + B C M b 2 f + w x M q z W N 5 b y Y J + h E d S h 5 y R o 2 V m r p f r r h V d w 6 y S r y c V C B H o 1 / + 6 g 1 i l k Y o D R N U 6 6 7 n J s b P q D K c C Z y W e q n G h L I x H W L X U k k j 1 H 4 2 P 3 R K z q w y I G G s b E l D 5 u r v i Y x G W k + i w H Z G 1 I z 0 s j c T / / O 6 q Q m v / I z L J D U o 2 W J R m A p i Y j L 7 m g y 4 Q m b E x B L K F L e 3 E j a i i j J j s y n Z E L z l l 1 d J + 6 L q 1 a r X z V q l f p P H U Y Q T O I V z 8 O A S 6 n A H D W g B A 4 R n e I U 3 5 9 F 5 c d 6 d j 0 V r w c l n j u E P n M 8 f 4 o W N A w = = < / l a t e x i t > w < l a t e x i t s h a 1 _ b a s e 6 4 = " k b + U K G U J j 5 O w 3 8 5 6 J U Z e / Q g w f H s = " > A A A B 6 H i c b V D L S g N B E O y N r x h f U Y 9 e B o P g K e x K Q L 0 F v X h M w D w g W c L s p D c Z M z u 7 z M w q I e Q L v H h Q x K u f 5 M 2 / c Z L s Q R M L G o q q b r q 7 g k R w b V z 3 2 8 m t r W 9 s b u W 3 C z u 7 e / s H x c O j p o 5 T x b D B Y h G r d k A 1 C i 6 x Y b g R 2 E 4 U 0 i g Q 2 A p G t z O / 9 Y h K 8 1 j e m 3 G C f k Q H k o e c U W O l + l O v W H L L 7 h x k l X g Z K U G G W q / 4 1 e 3 H L I 1 Q G i a o 1 h 3 P T Y w / o c p w J n B a 6 K Y a E 8 p G d I A d S y W N U P u T + a F T c m a V P g l j Z U s a M l d / T 0 x o p P U 4 C m x n R M 1 Q L 3 s z 8 T + v k 5 r w y p 9 w m a Q G J V s s C l N B T E x m X 5 M + V 8 i M G F t C m e L 2 V s K G V F F m b D Y F G 4 K 3 / P I q a V 6 U v U r 5 u l 4 p V W + y O P J w A q d w D h 5 c Q h X u o A Y N Y I D w D K / w 5 j w 4 L 8 6 7 8 7 F o z T n Z z D H 8 g f P 5 A + i V j Q c = < / l a t e x i t > n < l a t e x i t s h a 1 _ b a s e 6 4 = " E w l Y 8 l B b F j n T U Z Z C P z K n s l b s y V 4 = " > A A A B 6 H i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 L B b B U 0 m k o N 6 K X j y 2 Y G u h D W W z n b R r N 5 u w u x F K 6 C / w 4 k E R r / 4 k b / 4 b t 2 0 O 2 v p g 4 P H e D D P z g k R w b V z 3 2 y m s r W 9 s b h W 3 S z u 7 e / s H 5 c O j t o 5 T x b D F Y h G r T k A 1 C i 6 x Z b g R 2 E k U 0 i g Q + B C M b 2 f + w x M q z W N 5 b y Y J + h E d S h 5 y R o 2 V m r J f r r h V d w 6 y S r y c V C B H o 1 / + 6 g 1 i l k Y o D R N U 6 6 7 n J s b P q D K c C Z y W e q n G h L I x H W L X U k k j 1 H 4 2 P 3 R K z q w y I G G s b E l D 5 u r v i Y x G W k + i w H Z G 1 I z 0 s j c T / / O 6 q Q m v / I z L J D U o 2 W J R m A p i Y j L 7 m g y 4 Q m b E x B L K F L e 3 E j a i i j J j s y n Z E L z l l 1 d J + 6 L q 1 a r X z V q l f p P H U Y Q T O I V z 8 O A S 6 n A H D W g B A 4 R n e I U 3 5 9 F 5 c d 6 d j 0 V r w c l n j u E P n M 8 f 2 v G M / g = = < / l a t e x i t > transition < l a t e x i t s h a 1 _ b a s e 6 4 = " p r W 8 U M r 0 p j The L slot term is similar, but makes sure that each object the model has found is a true object, where we multiply by c k to not penalize for predicted objects in which we have low confidence: E g F 1 0 u G Y L g v H + w s h U = " > A A A B 6 X i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 L B b R U 0 l E 0 W P R i 8 c q 9 g P a U D b b S b t 0 s w m 7 G 6 G E / g M v H h T x 6 j / y 5 r 9 x 2 + a g r Q 8 G H u / N M D M v S A T X x n W / n c L K 6 t r 6 R n G z t L W 9 s 7 t X 3 j 9 o 6 j h V D B s s F r F q B 1 S j 4 B I b h h u B 7 U Q h j Q K B r W B 0 O / V b T 6 g 0 j + W j G S f o R 3 Q g e c g Z N V Z 6 0 K e 9 c s W t u j O Q Z e L l p A I 5 6 r 3 y V 7 c f s z R C a Z i g W n c 8 N z F + R p X h T O C k 1 E 0 1 J p S N 6 A A 7 l k o a o f a z 2 a U T c m K V P g l j Z U s a M l N / T 2 Q 0 0 n o c B b Y z o m a o F 7 2 p + J / X S U 1 4 7 W d c J q l B y e a L w l Q Q E 5 P p 2 6 T P F T I j x p Z Q p r i 9 l b A h V Z Q Z G 0 7 J h u A t v r x M m u d V 7 7 L q 3 l 9 U a j d 5 H E U 4 g m M 4 A w + u o A Z 3 U I c G M A j h G V 7 h z R L slot (y, c, m) = k min j c k y k -m j . The sparsity loss discourages the model from using multiple outputs to represent the same true object: L sparse (c) = -log c .

4. DAF-NETS

Inspired by the the basic form of classic DAF algorithms and the ability of modern neural-network techniques to learn complex models, we have designed the DAF-Net architecture for learning DAFs and a customized procedure for training it from data, inspired by several design considerations. First, because object hypotheses must be available after each individual input and because observations will generally be too large and the problem too difficult to solve from scratch each time, the network will have the structure of a recursive filter, with new memory values computed on each observation and then fed back for the next. Second, because the loss function is set based, that is, it doesn't matter what order the object hypotheses are delivered in, our memory structure should also be permutation invariant, and so the memory processing is in the style of an attention mechanism. Finally, because in some applications the observations z may be in a representation not well suited for hypotheses representation and aggregation, the memory operates on a latent representation that is related to observations and hypotheses via encoder and decoder modules. Figure 1 shows the architecture of the DAF-Net model. There are six modules with adaptable weights and memory that is stored in two recurrent quantities, s and n. The main memory is s, which consists of K elements, each in R ds ; the length-K vector n of positive values encodes how many observations so far have been assigned to each slot. When an input z arrives, it is immediately encoded into a vector e in R ds . The update network operates on the encoded input and the contents of each hypothesis slot, intuitively producing an update of the hypothesis in that slot under the assumption that the current z is an observation of the object represented by that slot; so for all slots k, u k = update(s k , n k , e) . The attention weights w represent the degree to which the current input "matches" the current value: w k = exp(attend(s k , n k , e)) n j=0 exp(attend(s j , n k , e)) . To force the network to commit to a sparse assignment of observations to object hypotheses while retaining the ability to effectively train with gradient descent, the suppress module sets all but the top M values in w to 0 and renormalizes, to obtain the vector a of M values that sum to 1. The a vectors are integrated to obtain n, which is normalized to obtain the final output confidence values c. Additionally, a scalar relevance value, r ∈ (0, 1), is computed from s and e; this value is used to modulate the degree to which slot values are updated, and gives the machine the ability to ignore or downweight an input. It is computed as r = NN 1 ( K avg k=1 NN 2 (e, s k , n k )) , where NN 1 is a fully connected network with the same input and output dimensions and NN 2 is a fully connected network with a sigmoid output unit. The attention output a and relevance r are now used to decide how to combine all possible slot-updates u with the old slot values s t using the following fixed formula for each slot k: s tk = (1 -ra k )s tk + ra k u k . Because most of the a k values have been set to 0, this results in a sparse update which will ideally concentrate on a single slot to which this observation is being "assigned." To compute the outputs, the s t slot values are decoded into the representation that is required for the outputs, y: y k = decode(s tk ) . Finally, to handle the setting in which object state evolves over time, we can further add a dynamics model, which computes the state s t+1 from the new slot values s t using an additional neural network: s t+1 k = NN 3 (s t ) k . These values are fed back, recurrently, as inputs to the overall system. Given a data set D, we train the DAF-Net model end-to-end to minimize loss function L, with a slight modification. We find that including the L sparse term from the beginning of training results in poor learning, but adopting a training scheme in which the L sparse is first omitted then reintroduced over training epochs, results in reliable training that is efficient in both time and data.

5. EMPIRICAL RESULTS

We evaluate DAF-Net on several entity monitoring tasks, including simple online clustering, monitoring objects with dynamics, and high-dimensional image pose prediction in which the observation space is not the same as the hypothesis space. Our experiments aim to substantiate the following claims: • DAF-Net outperforms non-learning clustering methods, even those that operate in batch mode rather than online, because those methods cannot learn from experience to take advantage of information about the distribution of observations and true object properties (tables 1, 2 and 5). • DAF-Net outperforms clustering methods that can learn from previous example problems when data is limited, because it provides useful structural bias for learning (table 1, 2 and 5). • DAF-Net generalizes to differences between training and testing in (a) the numbers of actual objects, (b) the numbers of hypothesis slots and (c) the number of observations (tables 1 and 3). • DAF-Net works when significant encoding and decoding are required (table 5 ). • DAF-Net is able to learn dynamics models and observation functions for the setting when the entities are moving over time (table 4), nearly matching the performance of strong data association filters with known ground-truth models. We compare with the following alternative methods: Batch, non-learning: K-means++ (Arthur and Vassilvitskii, 2007) and expectation maximization (EM) (Dempster et al., 1977) on a Gaussian mixture model (SciKit Learn implementation); Online, non-learning: vector quantization (Gray, 1984) ; Batch, learning: set transformer (Lee et al., 2018) ; Online, learning: LSTM (Hochreiter and Schmidhuber, 1997) and an online variant of the set transformer (Lee et al., 2018) for a thousand iterations. We use a total of 3 components, and train models with 30 observations. We report standard error in parentheses. Each cluster observation and center is drawn between -1 and 1, except for angular which is drawn between -π and pi with reported error as the L2 distance between predicted and ground truth mean. architectures are set to have about 50000 parameters. We provide additional details about architecture and training in the appendix. The set transformer is a standard architecture that has been evaluated on clustering problems in the past. All models except DAF-Net are given the ground truth number of components K, while DAF-Net uses 10 hypothesis slots. Results are reported in terms of loss j min k y k -m j (with the most confident K hypotheses selected for DAF-Net). Gaussian domains To check the basic operation of the model and understand the types of problems for which it performs well, we tested in simple clustering problems with the same input and output spaces, but different types of data distributions, each a mixture of three components. We train on 1000 problems drawn from each problem distribution distribution and test on 5000 from the same distribution. In every case, the means of the three components are drawn at random for each problem. 1. Normal: Each component is a 2D Gaussian with fixed identical variance across each individual dimension and across distributions. This is a basic "sanity check." 2. Elongated: Each component is a 2D Gaussian, where the variance along each dimension is drawn from a uniform distribution, but fixed across distributions. 3. Mixed: Each component is a 2D Gaussian, with fixed identical variance across each individual dimension, but with the variance of each distribution drawn from a uniform distribution. 4. Angular: Each component is a 2D Gaussian with identical variance across dimension and distribution, but points above π are wrapped around to -π and points below -π wrapped to π 5. Noise: Each component has 2 dimensions parameterized by Gaussian distributions, but with the values of the remaining 30 dimensions drawn from a uniform centered at 0. We compare our approach to each of the non-dynamic baselines for the five problem distributions in Table 1 ; a complete listing of results for all the distributions can be found in the Appendix. The results in this table show that on Normal, Mixed, and Elongated tasks, DAF-Net performs comparably to the offline clustering algorithms, even though it is running and being evaluated online. On the Angular and Noise tasks, DAF-Net is able to learn a useful metric for clustering and outperforms both offline Table 3 : Quantitative evaluation of DAF-Net on distributions with different numbers of true components and hypothesis slots at test time with 30 observations. In all cases, DAF-Net is trained with 3-component problems, 10 slots, and 30 observations. We compare with an offline set transformer trained with different numbers of problem components as well as with vector quantization. and online alternatives (with additional analysis in the appendix showing DAF-Net outperforms all other learning baselines with more training distributions in the Angular task). In Table 1 we evaluate the quality of predictions after 10, 30, 50, and 100 observations in the Normal distribution. We find that DAF-Net generalizes well to increased numbers of observations, with predictions becoming more accurate as the observation sequence length increases, despite the fact that it is trained only on observation sequences of length 30. This is in contrast with other online learning baselines, set transformer and LSTM, which both see increases in error after 50 or 100 observations. This pattern holds across all the test problem distributions (see Appendix). In Table 3 , we investigate the generalization ability of DAF-Net to both increases in the number of hypothesis slots and the underlying number of mixture components from which observations are drawn. We compare to the offline set transformer and to VQ, both of which know the correct number of components at test time. Recall that, to evaluate DAF-Net even when it has a large number of extra slots, we use its K most confident hypotheses. We find that DAF-Net generalizes well to increases in hypothesis slots, and exhibits improved performance with large number of underlying components, performing comparably to or better than the VQ algorithm. We note that none of the learning baselines can adapt to different numbers cluster components at test time, but find that DAF-Net outperforms the set transformer even when it is trained on the ground truth number of clusters in the test. We also ablated each component of our model and found that each of our proposed components enables both better performance and generalization. Detailed results of the ablations and a figure illustrating the clustering process are in the appendix. Dynamic Domains We next evaluate the ability of DAF-Net to perform data association in domains where objects are moving dynamically over time. This domain is typical of tracking problems considered by data association filters, and we compare with the de-facto standard method, Joint Probabilistic Data Association (JPDA), which uses hand-built ground-truth models. We consider a setup consisting of 3 different moving objects in 2D. Their velocity is perturbed at each step by an additive component drawn from a Gaussian distribution and observations of their positions (but no observations of velocities) are made with Gaussian error. To perform well in this task, a model must discover that it needs to estimate the latent velocity of each object, as well as learn the underlying dynamics and observation models. We compare our approach to the Set Transformer and LSTM methods, as well as to JPDA with ground-truth models. The basic clustering methods have no ability to handle dynamic systems so we omit them from the comparison. The learning methods (DAF-Net, Set Transformer, and LSTM) are all trained on observation sequences of length 30. We test performance of all four methods on sequences of multiple lengths. Quantitative performance, measured in terms of prediction error on true object locations, is reported in Table 4 . We can see that the online Set Transformer cannot learn a reasonable model at all. The LSTM performs reasonably well for short (length 10) sequences but quickly degrades relative to DAF-Net and JPDA as sequence length increases. We note that DAF-Net performs comparably to but just slightly worse than JPDA. This is very strong performance because DAF-Net is generic and can be adapted to new domains given training data without the need to hand-design the models used by JPDA. Image-based domains We now evaluate the ability of DAF-Net to perform data association in domains with substantially more complex observation spaces, where the outputs are not simple averages of the inputs. This requires the network to synthesize a latent representation for slots in which the simple additive update performs effectively. We investigate this with two image-based domains. In each domain, we have a set of similar objects (digits or airplanes). A problem is constructed by selecting K objects from the domain, and the desired y values are images of those objects in a canonical viewpoint. The input observation sequence is generated by randomly selecting one of those K objects, and then generating an image of it from a random viewpoint as the observation z. Our two domains are: (1) MNIST: Each object is a random digit image in MNIST, with observations corresponding to that same image rotated, and the desired outputs being the un-rotated images; (2) Airplane: Each object is a random object from the Airplane class in ShapeNet (Chang et al., 2015) , with observations corresponding to airplane renderings (using Blender) at different viewpoints and the desired outputs the objects rendered in a canonical viewpoint. For MNIST, we use the 50000 digit images in the training set to construct the training problems, and the 10000 images in the non-overlaping test set to construct the test problems. For the Airplane dataset, we use 1895 airplanes objects to construct the training problems, and 211 different airplanes objects to construct the test problems. Each object is rendered with 300 viewpoints. Of our baseline methods, only batch K-means (in pixel space) can be directly applied to this problem with even reasonable results. We also include versions of LSTM and of batch K-means that operate on a latent representation that is learned first using an auto-encoder. In Table 5 , we find that our approach significantly outperforms other comparable baselines in both accuracy and generalization. We visualize qualitative predictions from our model in Figure 3 .

6. DISCUSSION

This work has demonstrated that using algorithmic bias inspired by a classical solution to the problem of filtering to estimate the state of multiple objects simultaneously, coupled with modern machinelearning techniques, we can arrive at solutions that learn to perform and generalize well from a comparatively small amount of training data.

A.1 APPENDIX A.2 DISCOVERY OF OBJECTS

In contrast to other algorithms, DAF-Net learns to predict both a set of object properties y k of objects and a set of confidences c k for each object. This corresponds to the task of both predicting the number of objects in set of observations, as well as associated object properties. We evaluate the ability to regress object number in DAF-Net in scenarios where the number of objects is different than that of training. We evaluate on the Normal distribution with a variable number of component distributions, and measure inferred component through a threshold confidence. DAF-Net is trained on a dataset with 3 underlying components. We find in Figure A1 that DAF-Net is able to infer the presence of more component distributions (as they vary from 3 to 6), with improved performance when cluster centers are sharply separated (right figure of Figure A1 ).

Noisy Cluster

Seperated Cluster

Components Inferred Components

Figure A1 : Plots of inferred number of components using a confidence threshold in DAF-Net compared to the ground truth number of clusters (DAF-Net is trained on only 3 clusters). We consider two scenarios, a noisy scenario where cluster centers are randomly drawn from -1 to 1 (left) and a scenario where all added cluster components are well seperated from each other (right). DAF-Net is able to infer more clusters in both scenarios, with better performance when cluster centers are more distinct from each other.

A.3 QUALITATIVE VISUALIZATIONS

We provide an illustration of our results on the Normal clustering task in Figure A2 . We plot the decoded values of hypothesis slots in red, with size scaled according to confidence, and ground-truth cluster locations in black. DAF-Net is able to selectively refine slot clusters to be close to ground truth cluster locations even with much longer observation sequences than it was trained on. We find that each component learned by DAF-Net is interpretable. We visualize attention weights of each hypothesis slot in Figure A3 and find that each hypothesis slot learns to attend to a local region next to the value it decodes to. We further visualize a plot of relevance weights in Figure A4 across increasing number of observations over different levels of noise in each distribution. We find that as observations. We report standard error in parentheses. Each cluster observation and center is drawn between -1 and 1, with reported error as the L2 distance between predicted and ground truth means.

A.7 DISTRIBUTIONS DETAILS

We provide detailed quantitative values for each distribution below. Gaussian centers are drawn uniformly from -1 to 1. We provide overall architecture details for LSTM in Figure A5a , for the set-transformer in Figure A5b and DAF-Net in Figure A5c . For image experiments, we provide the architecture of the encoder in Figure A6a and decoder in Figure A6b . Both LSTM, DAF-Net, and autoencoding baselines use the same image encoder and decoder. In DAF-Net memory, the function update(s k , n k , e) is implemented by applying a 2 layer MLP with hidden units h which concatenates the vectors s k , n k , e as input and outputs a new state u k of dimension h. The value n k is encoded using the function 1 1+n k , to normalize the range of input to be between 0 and 1. The function attend(s k , n k , e) is implemented in an analogous way to update, using a 2 layer MLP that outputs a single real value for each hypothesis slot. For the function relevance(s k , n k , e), we apply NN 1 per hypothesis slot, which is implemented as a 2 layer MLP with hidden units h that outputs a intermediate state of dimension h. (s k , n k , e) are fed into NN 1 in an analogous manner to update. NN 2 is applied to average of the intermediate representations of each hypothesis slot and is also implemented as a 2 layer MLP with hidden unit size h, followed by a sigmoid activation. We use the ReLU activation for all MLPs. A.10 BASELINE DETAILS All baseline models are trained using prediction slots equal to the ground truth of components. To modify the set transformer to act in an online manner, we follow the approach in (Santoro et al., 2018) and we apply the Set Transformer sequentially on the concatenation of an input observation with hypothesis slots. Hypothesis slots are updated based off new values of the slots after applying self-attention (Set Transformer Encoder). We use the Chamfer loss to train baseline models, with confidence set to 1.

A.11 ABLATION

We investigate ablations of our model in Table 9 . We ablate the components of sparsity loss, learned memory update, suppression of attention weights and relevance weights. We find that each component of our model contributes to improved performance.



t e x i t s h a 1 _ b a s e 6 4 = " W u d 7 T k 1

Figure 1: Architecture of the DAF-Net. Grey boxes represent fixed computations; white boxes represent neural networks with adjustable parameters; those with internal vertical bars represent a replication of the same computation on slot values in parallel. Red lines indicate information derived from an input observation, green lines indicate information derived some hypothesis slots values, and blue lines indicate information derived from counts on each hypothesis slot.

Figure 2: Visualizations of Dynamic and Gaussian domains. Observations are transparent while while ground truth states are bolded

Figure 3: Results on two image-based association tasks (left: MNIST, right: airplanes). At the top of each is an example training problem, illustrated by the true objects and an observation sequence. Each of the next rows shows an example test problem, with the ground truth objects and decoded slot values. The three highest-confidence hypotheses for each problem are highlighted in red, and correspond nicely to the ground-truth objects.

Figure A2: Illustration of the clustering process. Decoded value of hypothesis (with size corresponding to confidence) shown in red, with ground truth clusters in black. Observations are shown in blue.

Figure A6: Architectures of encoder and decoder models on image experiments.

Comparison of performance after training on one thousand Normal distributions for a thousand iterations. We use 3 components, and train models with 30 observations. We report standard error in parentheses. Each cluster observation and center is drawn between -1 and 1, with reported error as the L2 distance between predicted and ground truth mean.

Comparison of performance on clustering after 30 iteration when training on 1000 different distributions

Comparison of performance on position estimation of 3 dynamically moving objects. All learning models are trained with 1000 sequences of 30 observations. We report standard error in parentheses. JPDA uses the ground-truth observation and dynamics models.

Comparison of performance of online clustering on MNIST and on rendered Airplane dataset. For DAF-Net, LSTM and K-means (Learned) we use a convolutional encoder/decoder trained on the data; for K-means (Pixel) there is no encoding. We use a total of 3 components and train models with 30 observations. Models are trained on 20000 problems on both datasets.

Comparison of performance under different settings after training on different distribution for a thousand iterations. We use a total of 3 components, and train models with 30 observations. We report standard error in parentheses.

Comparison of performance on Normal distribution. We use 30 components, and train models with 50

annex

we see more observations, the relevance weight of new observations decreases over time, indicating that DAF-Net learns to pay the most attention towards the first set of observations it sees. In addition, we find that in distributions with higher variance, the relevance weight decreases more slowly, as later observations are now more informative in determining cluster centers. 

A.4 QUANTITATIVE RESULTS

We report full performance of each different model across different distributions in Table 6 . We find that DAF-Net is able to obtain better performance with increased number of observations across different distributions. In addition DAF-Net out-performs neural network baselines when evaluated on 30 observations across distributions except for rotation. For rotation we find that when training with 10,000 different distribution, DAF-Net exhibits better performance of 0.555 compared to Set Transformer Online performance of 0.647 and LSTM performance of 0.727.

A.5 SPARSITY LOSS

In this section, we show that L sparse (c) encourage confidences c to be sparse. Recall that' where c is the L2 norm which is convex. Recall that c, the confidence vector, defines a polyhedron, since it is the set of points that are non-negative, and whose element sum up to one. The maximum of a convex function over a polyhedra must occur at the vertices, which correspond to an assignment of 1 to a single c i and 0s to every other value of c. Next we consider the minimum of c given that it's elements sum up to one. This is equivalent to finding the stationary points of the LegragianBy taking the gradient of the above expression, we find that the stationary value corresponds to each c i being equal. Since the function is convex, this corresponds to the minimum of c . Thus L sparse (c) is maximized when each individual confidence is equal.

A.6 PERFORMANCE USING MORE CLUSTERS

We measure the performance DAF-Net in the presence of a large number of clusters and slots. We consider the Normal distribution setting, where input observations are generated by a total of 30 difference clusters. We train DAF-Net with 50 observations, and measure performance at inferring cluster centers with either 50 or 100 observations. We report performance in Table 7 and find that DAF-Net approach obtains good performance in this setting, out-performing both online and offline baselines. digits move over time. For DAF-Net, LSTM and K-means (Learned) we use a convolutional encoder/decoder trained on the data; for K-means (Pixel) there is no encoding. We use a total of 3 components and train models with 30 observations. We report MSE error with respect to ground truth unrotated images.1. Normal: Each 2D Gaussian has standard deviation 0.2. 2. Mixed: Each distribution is a 2D Gaussian, with fixed identical variance across each individual dimension, but with the standard deviation of each distribution drawn from a uniform distribution from (0.04, 0.4). 3. Elongated: Each distribution is a 2D Gaussian, where the standard deviation along each dimension is drawn from a uniform distribution from (0.04, 0.4), but fixed across distributions. 4. Angular: Each distribution is a 2D Gaussian with identical standard deviation across dimension and distribution, but points above π are wrapped around to -π and points below -π wrapped to π. Gaussian means are selected between (-π, -2π/3) and between (2π/3, π). The standard deviation of distributions is 0.3 * π. 5. Noise: Each distribution has 2 dimensions parameterized by Gaussian distributions with standard deviation 0.5, but with the values of the remaining 30 dimensions drawn from a uniform distribution from (-1, 1).

A.8 DYNAMIC IMAGES

We further compare DAF-Net with other baselines in the setting where the rendered images move over time. We follow the same setup described in image-based domain, but now consider the MNIST setup with each digit centered at a random position in the image (with parts of a digit that are outside of the image wrapped around to the other side of the image). At each timestep, the center of each moves with a constant velocity, with the goal to predict the un-rotated image at the current center of digit. We report results in Table 8 and find that our approach performs well in this setting also.A.9 MODEL/BASELINE ARCHITECTURES 

