SPATIALLY STRUCTURED RECURRENT MODULES

Abstract

Capturing the structure of a data-generating process by means of appropriate inductive biases can help in learning models that generalize well and are robust to changes in the input distribution. While methods that harness spatial and temporal structures find broad application, recent work has demonstrated the potential of models that leverage sparse and modular structure using an ensemble of sparingly interacting modules. In this work, we take a step towards dynamic models that are capable of simultaneously exploiting both modular and spatiotemporal structures. To this end, we model the dynamical system as a collection of autonomous but sparsely interacting sub-systems that interact according to a learned topology which is informed by the spatial structure of the underlying system. This gives rise to a class of models that are well suited for capturing the dynamics of systems that only offer local views into their state, along with corresponding spatial locations of those views. On the tasks of video prediction from cropped frames and multi-agent world modeling from partial observations in the challenging Starcraft2 domain, we find our models to be more robust to the number of available views and capable of better generalization to novel tasks without additional training than strong baselines that perform equally well or better on the training distribution.

1. INTRODUCTION

Many spatiotemporal complex systems can be abstracted as a collection of autonomous but sparsely interacting sub-systems, where sub-systems tend to interact if they are in each others' vicinity. As an illustrative example, consider a grid of traffic intersections. Traffic flows from a given intersection to the adjacent ones, and the actions taken by some "agent", say an autonomous vehicle, may at first only affect its immediate surroundings. Now suppose we want to forecast the future state of the traffic grid (say for the purpose of avoiding traffic jams). There is a spectrum of possible strategies for modeling the system at hand. On one extreme lies the most general strategy which considers the entirety of all intersections simultaneously to predict the next state of the grid (Figure 1c ). The resulting model class can in principle account for interactions between any two intersections, irrespective of their spatial distance. However, the number of interactions such models must consider does not scale well with the size of the grid, and this strategy might be rendered infeasible for large grids with hundreds of intersections. On the other end of the spectrum is a strategy which abstracts the dynamics of each intersection as an autonomous sub-system, with each sub-system interacting only with its immediate neighbors (Figure 1a ). The interactions may manifest as messages that one sub-system passes to another and possibly contain information about how many vehicles are headed towards which direction, resulting in a collection of message passing entities (i.e. sub-systems) that collectively model the entire grid. By adopting this strategy, one assumes that the immediate future of any given intersection is affected only by the present states of the neighboring intersections, and not some intersection at the opposite end of the grid. The resulting class of models scales well with the size of the grid, but is possibly unable to model certain long-range interactions that could be leveraged to efficiently distribute traffic flow. The spectrum above parameterizes the extent to which the spatial structure of the underlying system informs the design of the model. The former extreme ignores spatial structure altogether, resulting Figure 1 : A schematic representation of the spectrum of modeling strategies. Solid arrows with speech bubbles denote (dynamic) messages being passed between sub-systems (dotted arrows denote the lack thereof). Gist: on one end of the spectrum, (Figure 1a ), we have the strategy of abstracting each intersection as a sub-system that interact with neighboring sub-systems. On the other end of the spectrum (Figure 1c ) we have the strategy of modeling the entire grid with one monolithic system. The middle ground (Figure 1b ) we explore involves letting the model develop a notion of locality by (say) abstracting entire avenues with a single sub-system. in a class of models that can be expressive but whose sample and computational complexity do not scale well with the size of the system. The latter extreme results in a class of models that can scale well, but its adequacy (in terms of expressivity) is contingent on a predefined notion of locality (in the example above: the immediate four-neighborhood of an intersection). In this work, we aim to explore a middle-ground between the two extremes: namely, by proposing a class of models that learns a notion of locality instead of relying on a predefined one (Figure 1b ). Reconsidering the traffic grid example: the proposed strategy results in a model that may learn to abstract (say) entire avenues with a single sub-system, if it is useful towards solving the prediction task. This yields a scheme where a single sub-system might account for events that are spatially distant (such as those in the opposite ends of an avenue), while events that are spatially closer together (like those on two adjacent avenues of the same street, where streets run perpendicular to avenues) might be assigned to different sub-systems. To implement this scheme, we build on a framework wherein the sub-systems are modelled as independent recurrent neural network (RNN) modules that interact sparsely via a bottleneck of attention (a variant of which is explored in Goyal et al. (2019) ) while extending it along two salient dimensions. First, we learn an interaction topology between the sub-systems, instead of assuming that all sub-systems interact with all others in an all-to-all topology. We achieve this by learning to embed each sub-system in a space endowed with a metric, and attenuate the interaction between two given sub-systems according to their distance in this space (i.e., sub-systems too far away from each other in this space are not allowed to interact). Second, we relax a common assumption that the entire system is perceived simultaneously; instead, we only assume access to local (partial) observations alongside with the associated spatial locations, resulting in a setting that partially resembles that of Eslami et al. (2018) . Expressed in the language of the example above: we do not expect a bird's eye view of the traffic grid, but only (say) LIDAR observations from autonomous vehicles at known GPS coordinates, or video streams from traffic cameras at known locations. The spatial location associated with an observation plays a crucial role in the proposed architecture in that we map it to the embedding space of sub-systems and address the corresponding observation only to sub-systems whose embeddings lie in close vicinity. Likewise, to predict future observations at a queried spatial location, we again map said location to the embedding space and poll the states of sub-systems situated nearby. The result is a model that can learn which spatial locations are to be associated with each other and be accounted for by the same sub-system. As an added plus, the parameterization we obtain is not only agnostic to the number of available observations and query locations, but also to the number of sub-systems. To evaluate the proposed model, we choose a problem setting where (a) the task is composed of different sub-systems or processes that locally interact both spatially and temporally, and (b) the environment offers local views into its state paired with their corresponding spatial locations. The challenge here lies in building and maintaining a consistent representation of the global state of the system given only a set of partial observations. To succeed, a model must learn to efficiently capture the available observations and place them in an appropriate spatial context. The first problem we consider is that of video prediction from crops, analogous to that faced by visual systems of many animals: given a set of small crops of the video frames centered around stochastically sampled pixels (corresponding to where the fovea is focused), the task is to predict the content of a crop around any queried pixel position at a future time. The second problem is that of multi-agent world modeling from partial observations in spatial domains, such as the challenging Starcraft2 domain (Samvelyan et al., 2019; Vinyals et al., 2017) . The task here is to model the dynamics of the global state of the environment given local observations made by cooperating agents and their corresponding actions. Finally, we also include visualizations on a multi-agent grid-world environment designed for simulating railroad traffic (Eichenberger et al., 2019) . Importantly and unlike prior work (Sun et al., 2019) , our parameterization is agnostic to the number of agents in the environment, which can be flexibly adjusted on the fly as new agents become available or existing agents retire. This is beneficial for generalization in settings where the number of agents during training and testing are different.

Contributions. (a)

We propose a new class of models, which we call Spatially Structured Recurrent Modules or S2RMs, which perform attention-driven modular computations according to a learned spatial topology. (b) We evaluate S2RMs (along with several strong baselines) on a selection of challenging problems and find that S2RMs are robust to the number of available observations and can generalize to novel tasks.

2. PROBLEM STATEMENT

In this section, we build on the intuition from the previous section to formally specify the problem we aim to approach with the methods described in the later sections. To that end, let X be a metric space, O some set of possible observations, and O X a set of mappings X → O. Now, consider the evolution function of a discrete-time dynamical system: φ : Z × O X → O X satisfying: (1) φ(0, o) = o where o ∈ O X and φ(t 2 , φ(t 1 , o)) = φ(t 1 + t 2 , o) for t 1 , t 2 ∈ Z Informally, o can be interpreted as the world state of the system; together with a spatial location x ∈ X , it gives the local observation O = o(x) ∈ O. Given an initial world state o, the mapping φ(t, o) yields the world state o t at some (future) time t, thereby characterizing the dynamics of the system (which might be stochastic). The problem we consider is the following: Problem: At every time step t = 0, ..., T , we are given a set of positions {x a t } A a=1 and the corresponding observations {O a t } A a=1 , where O a t := o t (x a ). The task is to infer the world state o t at some future time-step t > T in order to predict O q t = o t (x q ) at some arbitrary query position x q . In the traffic grid example of Section 1, one could imagine a as indexing traffic cameras or autonomous vehicles (i.e., observers), x a t as the GPS coordinates of observer a, and O a t as the corresponding sensor feed (e.g. LIDAR observations or video streams from vehicles or traffic cameras).

3. MODELLING ASSUMPTIONS

Given the problem in Section 2, we now make certain modelling assumptions. These assumptions will ultimately inform the inductive biases we select for the model (proposed in Section 4); nevertheless, we remark beforehand that as with any inductive bias, their applicability is subject to the properties of the system being modeled and the objectives being optimized (OOD generalization, etc). Recurrent Dynamics Modeling. While there exist multiple ways of modeling dynamical systems, we shall focus on recurrent neural networks (RNNs). Typically, RNN-based dynamics models are expressed as functions of the form: h t+1 = F (O t , h t ) O t = D(h t ) (2) where O t is the observation at time t ∈ Z, and h t+1 is the hidden state of the model. F can be thought of as the parameterized forward-evolution function the hidden state h conditioned on the observation O, whereas D is a decoder that maps the hidden state to observations. Decomposition into Locally Interacting Sub-systems. We make the assumption that the dynamical system φ can be decomposed to constituent sub-systems (φ 1 , φ 2 , ..., φ M ) that dynamically and sparsely interact with each other while respecting some interaction topology. By interaction topology, we mean that each module φ i can be identified with an embedding p i in a topological space S equipped with a similarity kernel Z and that the sub-system φ i may preferentially interact with another sub-system φ j if their respective embeddings are close in S with respect to Z, i.e. if 

Input Attention

Output Attention Z(p i , p j ) is large. Intuitively, one may think of Z as inducing a notion of locality between subsystems, according to which φ j lies in the local vicinity of φ i . F 1 F 2 F 3 F 4 F 5 F 6

Set of Interacting RNNs

Locality of Observations. The notion of locality between sub-systems induced by Z is distinct from that induced by the metric of space X of locations in the environment (cf. Section 2), and one important modelling decision is how these two should interact. We propose to embed the position x ∈ X associated with an observation O to the metric space of sub-systems S via a continuousfoot_4 and injective mapping P : X → S. This allows us to match the observation O to all sub-systems φ m that are in the vicinity of P (x) ∈ S, i.e., where Z(P (x), p m ) is sufficiently large. Each subsystem φ m therefore accounts for observations made at a set of positions X m ⊂ X , which we call its enclave.

4. SPATIALLY STRUCTURED RECURRENT MODULES (S2RMS)

Informedfoot_5 by the model assumptions detailed in the previous section, we now describe the proposed model (Figure 2 ) comprising the following components: Model Inputs. Recall from Section 2 that we have for every time step t = 0, ..., T a set of tuples of positions and observations {(x a t , O a t )} A a=1 where x a t ∈ X and O a t ∈ O for all t and a. To simplify, we assume that X ⊂ R n , and denote by x i the i-th component of the vector x ∈ X . Encoder. The encoder E is a parameterized function mapping observations O to a corresponding vector representation e = E(O). Here, E processes all observations in parallel across t and a to yield representations e a t . Positional Embedding. The positional embedding P is a fixed mapping from X to S. We choose S to be the unit sphere in d-dimensions, d being a multiple of 2n, and the positional encoder as the following function: P (x) = s / s ∈ S where (s i+m , s i+1+m ) = (sin ( xm /10000 i ), cos ( xm /10000 i )) with m = 0, ..., n -1 and i = 0, 2, ..., d /n -1. The above function finds common use (Vaswani et al., 2017; Mildenhall et al., 2020; Zhong et al., 2020) and can be motivated from the perspective of Reproducing Kernel Hilbert Spaces (Rahimi & Recht, 2007) (see Appendix C.3 for a discussion). We henceforth refer to P (x) as s and P (x a t ) as s a t . Set of Interacting RNNs. To model the dynamics of the world state, we use a set of M independent RNN modules, which we denote as {F m } M m=1 . To each F m , we associate an embedding vector p m ∈ S, where all {p m } M m=1 are learnable parameters. The RNNs F m interact with each other via an inter-cell attention, and with the input representations e a t via input attention. Precisely, at a given time step t, each F m expects an input u m t , an aggregated hidden state hm t and optionally, a memory state c m t to yield the hidden and memory states at the next time step: (h m t+1 , c m t+1 ) = F m (u m t , hm t , c m t ) where the input u m t results from the input attention and hm t from the inter-cell attention (see below). Kernel Modulated Dot-Product Attention. A central component of the proposed architecture is the kernel modulated dot-product attention (KMDPA), which we now define. First, we let Z : S × S → [0, 1] be the following kernel: Z(p, s) = exp [-2 (1 -p • s)], if p • s ≥ τ 0, otherwise where ∈ (0, ∞) is the kernel bandwidth, and τ ∈ [-1, 1) is the truncation parameter (additional details in Appendix C.1). Now, KMDPA maps two sets A and B to a third set Â, where: A = {(a i , y i )} I i=1 ; B = {(b j , z j )} J j=1 ; Â = {(a k , ŷk )} I k=1 = KMDPA(A, B) Here, a i , b j ∈ S, and y i , z j are vectors of not necessarily the same dimension. In order to evaluate Â, we first compute the interaction weights W ij between any two pairs of entities (a i , y i ) and (b j , z j ), which depends on a local term W (L) ij and a non-local term Wij . We have: Wij = softmax j Θ (Query) (y i ) • Θ (Key) (z j ) ; W (L) ij = Z(a i , b j ); W ij = Wij W (L) ij (7) where Θ (Query) and Θ (Key) are learnable linear mappings that project y i and z j to the same space. The penultimate step computes the following two quantities: ỹi = j W (L) ij y j ; ȳi = j Wij Θ (Value) (z j ) where Θ (Value) is another learnable linear function mapping from the vector space of z to that of y. Finally, we have: ŷi = G(ỹ i , ȳi ) • ỹi + 1 -G(ỹ i , ȳi ) • ȳi (9) where G is a gating layer with sigmoid non-linearity implementing a soft selection mechanism between the linear combination of values (ȳ i ) and the inputs weighted by the local weights (ỹ i ). In what follows, we will refer to the set A as query set, B as key set and Â as output set. Input Attention. The input attention mechanism is a KMDPA, mapping between sets of observation tuples {(s a t , e a t )} A a=1 (key set) and the current RNN-states {(p m , h m t )} M m=1 (query set) to that of RNN inputs {(p n , u n t )} M n=1 (output set). Now on the one hand, we observe that the input u m t to RNN F m can contain information about an observations e a t only if the embedded location of the said observation s a t is close enough to the embedding of the RNN p m in S (i.e. if Z(s a t , p m ) > 0), thereby implementing the assumption of locality of observations. On the other hand, the non-local term allows a module F m to reject (or accept) an observation based on its content, which can be beneficial if two modules attend to overlapping regions in the environment but specialize to different aspects of the dynamics. Please refer to Appendix C.1 for a precise description of the mechanism, in particular the (optional) use of multiple dot-product attention heads. Inter-cell Attention. The intercell attention mechanism is another KMDPA, mapping two copies of the current RNN-states {(p m , h m t )} M m=1 (one as query and another as key set) to the set of aggregated hidden states {(p n , hn t )} M n=1 (output set). This enables local interaction between the RNNs F m , in that the local term ensures that RNN F m interacts with RNN F n only if their respective embeddings p m and p n are close enough in S (i.e. if Z(p m , p n ) > 0), thereby implementing the assumption of local interactions between sub-systems. The non-local term allows two modules to interact with each other based on their hidden states, i.e. it provides the mechanism for a module to (not) interact with another other based on their respective states, even if their embeddings are similar enough in S. Appendix C.2 contains a precise description of the attention mechanism. Output Attention. The output attention mechanism together with the decoder (described below) serve as an apparatus to evaluate the world state modeled implicitly by the set of RNNs ({F m } M m=1 ) at time t + 1 (for one-step forward models). Given a query location x q ∈ X and its corresponding embedding s q , the output attention mechanism polls the RNNs F m whose embeddings p m are similar enough to s q , as measured by the kernel Z. Denoting by h mj the j-th component of h m t+1 and by d q j the j-th component of the vector d q t+1 associated with the query location x q , we have: d q j = m Z(s q , p m ) h mj Decoder. The decoder D is a parameterized function that predicts the observation Ôq t+1 ∈ O at x q given the representation d q t+1 from the output attention. This concludes the description of the generic architecture, which allows for flexibility in the choice of the RNN architecture (i.e., the internal architecture of F m ). In practice, we find Gated Recurrent Units (GRUs) (Cho et al., 2014) to work well, and call the resulting model Spatially Structured GRU or S2GRU. Moreover, Relational Memory Cores (RMCs) (Santoro et al., 2018 ) also profit from our architecture (with a modification detailed in Appendix E.3), and we call the resulting model S2RMC.

5. RELATED WORK

Problem Setting. Recall that the problem setting we consider is one where the environment offers local (partial) views into its global state paired with the corresponding spatial locations. With Generative Query Networks (GQNs), Eslami et al. ( 2018) investigate a similar setting where the 2D images of 3D scenes are paired with the corresponding viewpoint (camera position, yaw, pitch and roll). Given that GQNs are feedforward models, they do not consider the dynamics of the underyling scene and as such cannot be expected to be consistent over time (Kumar et al., 2018) . Singh et al. (2019) and Kumar et al. (2018) propose variants that are temporally consistent, but unlike us, they do not focus on the problem of predicting the future state of the system. Modularity. Modularity has been a recurring topic in the context of meta-learning (Alet et al., 2018; Bengio et al., 2019; Ke et al., 2019) , sequence modeling (Ghahramani & Jordan, 1996; Henaff et al., 2016; Li et al., 2018; Goyal et al., 2019; Mei et al., 2020; Mittal et al., 2020) and beyond (Jacobs et al., 1991; Shazeer et al., 2017; Parascandolo et al., 2017) . In the context of RNNs, Li et al. (2018) explore a setting where the recurrent units operate entirely independently of each other. Closer to our work, Goyal et al. (2019) explores the setting where autonomous RNN modules interact with each other via the bottleneck of sparse attention. However, instead of leveraging the spatial structure of the environment, they induce sparsity using a scheme inspired by the k-winners-take-all principle (Majani et al., 1988) where only the k modules that attend the most to the input are activated and propagate their state forward, whereas the remaining modules that do not receive an input follow default dynamics in that their hidden states are not updated. This can be contrasted with S2RMs, where the modules that do not receive inputs may still evolve their states forward in time, reflecting that the environment may evolve even when no observations are available. Attention Mechanisms and Information Flow. Attention mechanisms have been used to attenuate the flow of information between components of the network, e.g. (Graves et al., 2014; 2016; Santoro et al., 2018; Ke et al., 2018; Veličković et al., 2017; Battaglia et al., 2018) . There is a growing interest in efficient attention mechanisms for use in transformers (Vaswani et al., 2017) , and like KMDPA, some recently proposed methods rely on learned sparsity (Kitaev et al., 2020; Tay et al., 2020) . However, these induce sparsity by dynamically clustering or sorting based on content, while we make explicit use of the spatial information accompanying observations to learn a spatially-grounded sparsity pattern. Moreover, mechanisms for spatial attention have also been studied (Jaderberg et al., 2015; Wang et al., 2017; Zhang et al., 2018; Parmar et al., 2018) , but they typically operate on image pixels. Our setting is different in that we do not assume that the world-state (from which we sample local observations) can be represented as an image.

6. EXPERIMENTS

In this section, we present a selection of experiments to quantitatively evaluate S2RMs and gauge their performance against strong baselines on two data domains, namely video prediction from crops on the well-known bouncing-balls domain and multi-agent world modelling from partial observations in the challenging Starcraft2 domain. We also include qualitative visualizations on a grid-world task in Appendix A. Additional tables, results and supporting plots can be found in Appendix F. Baselines. To draw fair comparisons between various RNN architectures, we require an architectural scaffolding that is agnostic to the number of observations A, is invariant to the ordering of {(x a t , O a t )} A a with respect to a and features a querying mechanism to extract a predicted observation O q t at a given query location x q in a future time-step t > t. Fortunately, it is possible to obtain a performant class of models fulfilling our requirements by extending prior work on Generative Query Networks or GQNs (Eslami et al., 2018) . The resulting model has three components: an encoder, a RNN, and a decoder, which we describe in detail in Appendix D. In our experiments, we fix the encoder and decoder to be essentially identical to those in S2RMs, but vary the architecture of the RNN, where we experiment with LSTMs (Hochreiter & Schmidhuber, 1997) , RMCs (Santoro et al., 2018) and RIMs (Goyal et al., 2019) . As a sanity check, we also show results with a Time Travelling Oracle (TTO), which at time-step t has access to the (partially observed) state at t + 1. Its purpose is to verify that the architectural scaffolding around the baseline RNNs (defined in Appendix D) does not constrain their performance and that the comparison to S2RMs is indeed fair. Video Prediction from Crops. We consider the problem of predicting the future frames of simulated videos of balls bouncing in a closed box, given only crops from the past video frames which are centered at known pixel positions. Using the notation introduced in Section 2: at every time step t, we sample A = 10 pixel positions {x a t } 10 a=1 from the t-th full video frame o t of size 48 × 48. Around the sampled central pixel positions x a t , we extract 11 × 11 crops, which we use as the local observations O a t . The task now is to predict 11 × 11 crops O q t corresponding to query central-pixel-positions x q t at a future time-step t > t. Observe that at any given time-step t, the model has access to at most 52% of the global video frame assuming that the crops never overlap (which is rather unlikely). Having trained on the training dataset with 3 bouncing balls, we evaluate the forward-prediction performance on all test datasets with 1 to 6 bouncing balls. Given that we treat the prediction problem as a pixel-wise binary classification problem, we report the balanced accuracy (i.e. arithmetic mean of recall and specificity) or F1-scores (i.e. harmonic mean of precision and recall) to account for class-imbalance. In Figure 4 , we see that S2GRUs out-perform Figure 5 : Visualization of the spatial locations each module is responsible for modeling (i.e. the enclaves X m , defined in Section 3). The central ball does not bounce, i.e. it is stationary in all sequences. Gist: the modules focus attention on challenging regions, e.g. the corners of the arena and the surface of the fixed ball. Figure 8 : Performance metrics (larger the better) as a function of the probability that an agent will not supply information to the world model but still query it. Gist: while all models lose performance as fewer agents share observations, we find S2RMs to be most robust. all non-oracle baselines on the one-step forward prediction task and strike a good balance with regard to in-distribution and OOD performance. In Figure 3 , we qualitatively show reconstructions from 25 step rollouts on the out of distribution dataset with 5 balls to demonstrate that S2GRUs can perform OOD rollouts over long horizons without losing track of balls. Figure 6 shows the result of an ablation study where we disable the local and non-local terms in each of the two KMDPAs while keeping everything else same. We see that both local and nonlocal terms contribute to the overall performance; moreover, the input attention relies on the non-local term and the performance severely affected by its absence, whereas the inter-cell attention is dependent on the local-term to yield good performance. This suggests that the modules indeed rely on the content of the input observations as they select their inputs, and learning an interaction topology between modules is a strong contributor to the final performance. In Figure 5 , we show for each module its corresponding enclave, which is the spatial region that it is responsible for modelling, i.e. for pixels at position x, we plot {Z(P (x), p m )} 10 m=1 (cf. Section 4). We find that the modules learn to share the responsibility of modelling the entire spatial domain. Finally, in Figure 7 we see the effect of removing (randomly sampled) modules at test time, i.e. without additional retraining. The performance degrades gracefully as fewer modules are available, suggesting that the individual modules can function while other modules are missing. We include details and additional results in Appendix F.1. where O a t = φ(t, o)(x a t ). Under certain restrictions, this problem can be mapped to that of multiagent world modeling from partial and local observations, allowing us to evaluate the proposed model in a rich and challenging setting. In particular, we consider environments that are (a) spatial, i.e. all agents a have a well-defined and known location x a t (at time t), (b) the agents' actions u a t are local, in that their effects propagate away (from the agent) only at a finite speed, (c) the observations are local and centered around agents, in the sense that the agent only observes the events in its local vicinity, i.e., O a t . Observe that we do not fix the number of agents in the environment, and allow for agents to dynamically enter or exit the environment. Now, the task is: given observations O a t from a team of (cooperating) agents at position x a t and their corresponding actions u a t , predict the observation O q t that would be made by an agent at time t = t + 1 if it were at position x q .

Multi-Agent

UT-F1 FM-F1 The observations O a t and actions u a t are both multi-channel images represented in polar coordinates centered around the agent position x t a . The field of view (FOV) of each agent is therefore a circle of fixed radius centered around it. The channels of the image correspond to (a) a binary indicator marking whether a position in FOV is occupied by a living friendly agent (friendly marker), (b) a categorical indicator marking the type of living units at a given position in FOV (unit-type marker), and (c) four channels marking the health, energy, weapon-cooldown and shields (HECS markers) of all agents in FOV. With a heuristic, we gather a total of 9K trajectories ({x a t , O a t , u a t } A a=1 ) 100 t=1 spread over three training scenarios, corresponding to 1c3s5zfoot_6 , 3s5z and 2s5z in Samvelyan et al. (2019) . We also sample 1K trajectories (each) from two OOD scenarios 1s2z and 5s3z. Details in Appendix B.1. Having trained all models on scenarios 1c3s5z, 3s5z and 2s5z, we test their robustness to dropped agents (Figure 8 ) and their performance on OOD scenarios (Table 1 ). We only include baselines that achieve similar or better validation scores than S2RMs. Figure 8 shows that S2RMCs remain robust when fewer agents supply their observations to the world model, whereas Table 1 shows that S2GRUs outperforms the baselines in the OOD scenario 1s2z but is matched by RMCs in 5s3z (see Appendix F.2 for details). The strong performance of RMCs suggests that the task benefits from the inductive bias of relational memory. One hypothesis as to why is that the pace of the considered environments requires fast communication between agents, which can be achieved by a shared memory where all agents may read from and write to. Further, we observe that while the oracle (TTO) can generalize well out of distribution, Figure 8 shows that it is less robust to the number of available observations. This is explained by the fact that unlike recurrent models, TTO does not leverage the temporal dynamics to fill in the missing information due to fewer available observations. This pattern also holds for the bouncing balls task, cf. Figures 21e and 20e in Appendix F.1.

CONCLUSIONS, LIMITATIONS AND FUTURE WORK

We proposed Spatially Structured Recurrent Modules, a new class of models constructed to jointly leverage both spatial and modular structure in data, and explored its potential in the challenging problem setting of predicting the forward dynamics from partial observations at known spatial locations. In the tasks of video prediction from crops and multi-agent world modeling in the Starcraft2 domain, we found that it compares favorably against strong baselines in terms of out-ofdistribution generalization and robustness to the number of available observations. Future work may attempt to extend the idea to parallel-in-time methods like universal transformers (Dehghani et al., 2018) and thereby address the computational bottleneck of recurrent processing, which is a current limitation. Another interesting avenue of research could be to explore how latent random variables can be used in tandem with the spatial structure to obtain a variational version of S2RMs. Finally, efficient implementations using block-sparse methods (Gray et al., 2017) might hold the key to unlock applications to significantly larger scale spatiotemporal forecasting problems encountered in domains like climate change research (Rolnick et al., 2019) .

A QUALITATIVE VISUALIZATIONS A GRID-WORLD NAVIGATION TASK

In this section, we show qualitative results on a grid-world task defined in Eichenberger et al. (2019) , which formulates the problem of navigation on a railway network in a multi-agent reinforcement learning framework. The environment comprises a network of railroads, on which agents (trains) may move in order to reach their destination. In our experiments, the entire railway network is defined on a 60 × 60 grid-world and we let each agent only observe a partial and local view of the environment, which is a 5 × 5 crop centered around itself. We gather 10000 multi-agent trajectories with 10 agents and maximum length 128, from which we use 8000 for training and reserve 2000 for validation. We train S2GRU with 10 modules for 100 epochs and early stop when the validation loss is at its minimum. With the trained model, we visualize the following two things. First: for each module F m , we visualize the spatial locations it may attend to. To this end, we consider all 60 × 60 = 3600 pixel locations in the grid-world, say x ij where i, j ∈ {1, ..., 60}. For each such x ij , we evaluate the quantity: X m ij = Z(p m , s ij ) where s ij = P (x ij ) (see Eqn 3), p m is the embedding of module F m and X m is a 60 × 60 image indexed by i and j, which we call the enclave of module F m . Note that this is identical to what we visualize in Figure 5 . Next: we identify each module with its enclave, and visualize the graph of interactions between them. In Figure 9 , we plot as nodes the enclaves X m . Further, we draw an edge between enclaves X m and X n iff Z(p m , p n ) > 0. We make the following two observations. First, the images in Figure 9 show that each module learns to account for a spatial region in the environment, as we imagined in Figure 1b in Section 1. Second, we find that the modules interact sparsely with each other -while some modules learn to interact with up-to five other modules, other modules learn to operate independently.

B DETAILED TASK DESCRIPTIONS B.1 STARCRAFT2

The Starcraft2 Environment we use is a modified version of the SMAC-Env proposed in Samvelyan et al. ( 2019) and built on PySC2 wrapper around Blizzard SC2 API (Vinyals et al., 2017) . Starcraft2 is a real-time-strategy (RTS) game where players are tasked with manufacturing and controlling armies of units (airborne or land-based) to defeat the opponent's army (where the opponent can be an AI or another human). The players must choose their alien racefoot_7 before starting the game; available options are Protoss, Terran and Zerg. All unit types (of all races) have their strengths and weaknesses against other unit types, be it in terms of maximum health, shields (Protoss), energy (Terran), DPS (damage per second, related to weapon cooldown), splash damage, or manufacturing costs (measured in minerals and vespene gas, which must be mined). The key engineering contribution of Samvelyan et al. ( 2019) is to repurpose the RTS game as a multi-agent environment, where the individual units in the army become individual agentsfoot_8 . The result is a rich and challenging environment where heterogeneous teams of agents must defeat each other in melee and ranged combat. The composition of teams vary between scenarios, of which Samvelyan et al. ( 2019) provide a selection. Further, new scenarios can be easily created with the SC2MapEditor, which allows for practically endlessly many possibilities. We build on Samvelyan et al. ( 2019) by modifying their environment to better expose the transfer and out-of-distribution aspects of the domain by (a) standardizing the state and action space across a large class of scenarios and (b) standardizing the unit stats to better reflect the game-defined notion of hit-points. 

B.1.1 STANDARDIZED STATE SPACE FOR ALL SCENARIOS

In the environment provided by Samvelyan et al. ( 2019), the dimensionality of the vector state space varies with the number of friendly and enemy agents, which in turn varies with the scenario. While this is not an issue in the typical use case of training MARL agents in a fixed scenario, it is not convenient for designing models that seamlessly handle multiple scenarios. In the following, we propose an alternate state representation that preserves the spatial structure and is consistent across multiple scenarios. Instead of representing the state of an agent a with a vector of variable dimension, we represent it with a multi-channel polar image I a of shape C × I × J, where C is the number of channels and (I, J) is the image size. Given the radial and angular resolutions ρ and ϕ (respectively), the pixel coordinate i = 0, ..., I -1, j = 0, ..., J -1 corresponds to coordinates (i • ρ, j • ϕ) with respect to a polar coordinate system centered on the agent a, where the positive x-axis (j = 0) points towards the east. Further, the field of view (FOV) of an agent is characterized by a circle of radius I • ρ centered on the agent at 2D game-coordinates x a = (x a 1 , x a 2 ), to which the Starcraft2 API (Vinyals et al., 2017) provides raw access. The polar image I a therefore provides an agent-centric view of the environment, where pixel coordinates i, j in I a can be mapped to global game coordinates x = (x 1 , x 2 ) in FOV via: x 1 = i • ρ cos [j • ϕ] + x a 1 (12) x 2 = i • ρ sin [j • ϕ] + x a 2 (13) In what follows, we denote this transformation with T a , as in T a (i, j) = (x 1 , x 2 ). Now, the channels in the polar image can encode various aspects of the observation; in our case: friendly markers (one channel), unit-type markers (nine channels, one-hot), health-energy-cooldown-shields (HECS, four channels) and terrain height (one channel). As an example, let us consider the friendly markers, which is a binary indicator marking units that are friendly. If we have an agent at game position (x 1 , x 2 ) that is friendly to agent a, then we would expect the pixel coordinate (i, j) = T -1 a (x 1 , x 2 ) of the corresponding channel in the polar image I a to be 1, but 0 otherwise. Likewise, the value of I at the channels corresponding to HECS at pixel position i, j gives the HECS of the corresponding unitfoot_9 at T a (i, j). This representation has the following advantages: (a) it does not depend on the number of units in the field of view, (b) it exposes the spatial structure in the arrangement of units which can naturally processed by convolutional neural networks (e.g. with circular convolutions). Nevertheless, it has the disadvantage that the positions are quantized to pixels, but the euclidean distance between the locations represented by pixels (i, j) and (i, j + 1) increases with increasing i. Consequently, this representation may not remain suitable for larger FOVs. Further, this representation is also appropriate for the action space. Given an agent, we represent the one-hot categorical actions of all friendly agents in FOV as a multi-channel polar image. In this representation, the pixel position i, j gives the action taken by an agent at at position T a (i, j). Unfriendly agents get assigned an "unknown action", whereas positions not occupied by a living agent are assigned a "no-op" action.

B.1.2 STANDARDIZED UNIT STATS

At any given point in time, an active unit in Starcraft2 has certain stats, e.g. its health, energy (Terran), shields (Protoss) and weapon-cooldown (for armed units). A large and expensive unit-type like the Colossus has more max-health (hit-points) than smaller units like Stalkers and Marinesfoot_10 . Likewise, unit-types differ in the rate at which they deal damage (measured in damage-per-second or DPS, excluding splash damage), which in turn depends on the cooldown duration of the active weapon. Now, the environment provided by Samvelyan et al. (2019) normalizes the stats by their respective maximum value, resulting in values between 0 and 1. However, given that different units may have different normalization, the stats are rendered incomparable between unit types (without additionally accounting the unit-type). We address this by standardizing stats (instead of normalizing) by dividing them by a fixed value. In this scheme, the stats are scaled uniformly across all unit-types, enabling models to directly rely on them instead of having to account for the respective unit-types.

B.2 VIDEO PREDICTION FROM CROPS ON THE BOUNCING BALLS TASK

The bouncing balls task is a well-known test-bed for evaluating the performance of video prediction models (Fraccaro et al., 2017; Watters et al., 2017; Miladinović et al., 2019; Kossen et al., 2019; Cenzato et al., 2019) . We modify the problem by introducing partial observability -concretely, instead of providing the model with the full image frames, we only provide it with crops at randomly sampled locations. As mentioned in Section 6, at every time step t we sample A = 10 pixel positions {x a t } 10 a=1 from the t-th full video frame o t of size 48 × 48. Around the sampled central pixel positions x a t , we extract 11 × 11 crops, which we use as the local observations O a t . The task now is to predict 11 × 11 crops O q t corresponding to query central pixel positions x q t at a future time-step t > t. Observe that at any given time-step t, the model has access to at most 52% of the global video frame assuming that the crops never overlap (which is rather unlikely). We train all models on a training dataset of 20K video sequences with 100 frames of 3 balls bouncing in an arena of size 48 × 48. We also include an additional fixed ball in the center to make the task more challenging. We use another 1K video sequences of the same length and the same number of balls as a held-out validation set. In addition, we also have 5 out-of-distribution (OOD) test sets with various number of bouncing balls (ranging from 1 to 6) and each containing 1K sequences of length 100. In Figure 4 , for each number of balls (i.e. point on the x-axis), we plot the respective metrics which are aggregated over 10 randomly selected 11 × 11 crops of a total of 100000 frames spread over 1000 trajectories with 100 frames each.

C PRECISE DESCRIPTION OF ATTENTION MECHANISMS C.1 INPUT ATTENTION

Recall from Section 4 that the input attention mechanism is a mapping between sets: namely, from that of observation encodings {e a t } A a=1 to that of RNN inputs {u m t } M m=1 . In what follows, we use the einsum notationfoot_11 to succintly describe the exact mechanism. But before that, we repeat the definition of the truncated spherical Gaussian kernel (Fasshauer, 2011) to quantify the similarity between two points p, s ∈ S: Z(p, s) = exp [-2 (1 -p • s)], if p • s ≥ τ 0, otherwise where ∈ R + and τ ∈ [-1, 1) are hyper-parameters (kernel bandwidth and truncation parameter, respectively), and 0 ≤ Z ≤ 1 since p and s are unit vectors. We observe that both τ and controls the sparsity of the kernel: τ determines the size of the neighborhood of p, i.e. the size of the set B(p) ⊂ S of all s ∈ B(p) such that p • s ≥ τ and accordingly Z(p, s) > 0, whereas the bandwidth controls how the attention decays inside B(p). Intuitively, τ determines a lower bound to the amount of sparsity that the kernel induces (irrespective of the bandwidth ), whereas for fixed τ , sparsity can be increased by increasing . We find τ ∈ [-1, 0.6] and ∈ [0.9, 2] to work well; setting τ and to much larger values destabilizes the training due to excessive sparsity, whereas setting to much smaller values results in Z being flat inside B(p) and therefore poor propagation of gradients. Now, we use k to index the attention heads, d to index the dimension of the key and query vectors, and denote with e ai the i-th component of e a t and with h mj the j-th component of h m t . Given learnable parameters Θ (K) , Θ (Q) , Θ (V ) , we obtain: Q akd = e ai Θ (Q) ikd K mkd = h mj Θ (K) jkd (15) V akv = e ai Θ (V ) ikv Wmak = Q akd K mkd (16) Wmak = sm a ( Wmak ) W (L) ma = Z(p m , s a ) W mak = W (L) ma Wmak ũm(kv) = W mak V akv where: sm a denotes softmax along the a-dimension, W (L) is what we will call the local weights, we omit the time subscript in s a for notational clarity, and ũm(kv) is the (kv)-th component of a vector ũm . Finally, we obtain the components u mi of RNN inputs u m t via a gating operation: u mi = G (inp) m • b mi + (1 -G (inp) m ) • ũmi where the gating weight G (inp) m ∈ (0, 1) is obtained by passing ũmi and b mi = W (L) ma e ai through a two-layer MLP with sigmoidal output (in parallel across m). Now, observe that by weighting the MHDPA attention outputs ( W in Equation 18) by the kernel Z (via W (L) ), we construct a scheme where the interaction between input O a t and RNN F m is allowed only if the embedding s a t of the corresponding position x a t has a large enough cosine similarity (≥ τ ) to the embedding p m of F m . This partially implements the assumption of Locality of Observation detailed in Section 3.

C.2 INTER-CELL ATTENTION

Recall from Section 4 that the inter-cell attention maps the hidden states of each RNN {h m t } M m=1 to the set of aggregated hidden states { hm t } M m=1 , thereby enabling interaction between the RNNs F m . While its mechanism is identical to that of the input attention, we formulate it below for completeness. To proceed, we denote with h li the i-th component of h l t (in addition to the notation introduced before Equation 15), and take Φ (Q) , Φ (K) and Φ (V ) to be learnable parameters. We have: Q mkd = h mj Φ (Q) jkd K lkd = h li Φ (K) ikd (20) V lkv = h li Φ (V ) ikv Wmlk = Q mkd K lkd (21) Wmlk = sm l ( Wmlk ) W (L) ml = Z(p m , p l ) W mlk = Wmlk W (L) ml hm(kv) = W mlk V lkv where hm(kv) is the (kv)-th component of a vector hm . Finally, the j-th component hmj of the aggregated hidden state hm t in Equation 4 is given by a gating operation: hmj = G (ic) m • c mj + (1 -G (ic) m ) • hmj where the gating weight G (ic) m ∈ (0, 1) is obtained by passing hmj and c mj = W (L) ml h lj through a two-layer MLP with sigmoid output (in parallel across m). The weighting by Z (in Equation 23, left) ensures that the interaction is constrained to be only between RNNs whose embeddings in S are similar enough, thereby implementing the assumption of Local Interactions between Sub-systems in Section 3.

C.3 POSITIONAL ENCODING

In Section 4, recall that we used the following positional embedding P : P (x) = s / s ∈ S where (s i+m , s i+1+m ) = (sin ( xm /10000 i ), cos ( xm /10000 i )) In this section, we explore how the choice of a positional embedding function P determines a function space of spatial functions (defined on X ) that the local-attention can represent. To this end, consider the distance in S of a module with embedding p to an observation made at location x as a function of x, given by w (L) (x) = p • P (x) Here, the local weight of interaction between the module at p and an observation made at x is given by: Z(p, P (x)) = exp -2 (1 -w (L) (x)) , if w (L) (x) ≥ τ 0, otherwise In particular, observe that in order for two locations x and y to be connected by the module, we require from w (L) that it be flexible enough such that w (L) (x) ≥ τ and w (L) (y) ≥ τ for a chosen τ . This flexibility stems from the fact that we implicitly express w (L) as a linear combination of sinusoidal basis functions with learned weights: w (L) (x) = J j=0 [p 2j cos (ω j • x) + p 2j+1 sin (ω j • x)] Here, p 2j and p 2j+1 are learnable parameters (as components of learnable vector p of dimension 2J), and ω j are frequency vectors. Now, if the dimension of the embedding vector p were to tend to infinity, we may have a growing number of frequencies ω j to gradually recover the full Fourier basis of L 2 (X ) (assuming X is Euclidean for simplicity). In the limit, w (L) (x) can be an arbitrary function lying on a unit sphere in L 2 (X ) (i.e. |w (L) | 2 = 1; recall that p is normed to unity). In other words, in a large dimensional embedding space, the system is afforded a large amount flexibility to learn any spatial structure or topology on X by connecting pairs of points x and y in X via a module. For computational tractability, however, we require p to be finite-dimensional, implying that ω j must be sampled. By using P as defined in Equation 25, we essentially sample ω j as coordinate (one-hot) vectors with log ω j sampled on a uniform grid. This sampling step is not only computationally favorable, but also justified in the theory of RKHS -Rahimi & Recht (2007) use Bochner's theorem to show that any proper distribution p(ω) (from which ω can be sampled) leads to a feature map, the inner product of which in expectation over p corresponds to a positive-definite kernel. The convergence to such a kernel is exponential in the number of samples (equivalent to the dimension of the embedding). Further, we note that while the sampling constrains the function space in which w (L) can lie in, we find (empirically) that this can in fact have a regularizing effect. Nevertheless, this raises the question whether other choices of a basis function are viable. We speculate that a polynomial basis (e.g. feature maps of a degree d polynomial kernel) might also be viable, but leave extensive exploration to future work.

D BASELINE ARCHITECTURE

As mentioned in Section 6: in order to ensure fair comparison between the baselines and our method, we describe a baseline architecture constructed to satisfy a few critically important desiderata that are naturally satisfied by S2RMs. Namely, (a) it must be parameterically agnostic to the number of available observations and (b) it must be invariant to the permutation of the observations. For this, we extend the framework Generative Query Networks (Eslami et al., 2018) by predicting the forward dynamics of an aggregated representation. While we invest effort in ensuring that the resulting class of models can perform at least as well as S2RMs on in-distribution (validation) data, we do not consider it a novel contribution of this work. Encoder. At a given timestep t, the encoder E jointly maps the embedding s a t ∈ S of the position x a t ∈ X and the corresponding observations O a t to encodings e a t , which are then summed over a to obtain an aggregated representation: r t = A a=1 E(O a t , s a t ) The additive aggregation scheme we use is well known from prior work (Santoro et al., 2017; Eslami et al., 2018; Garnelo et al., 2018) and makes the model agnostic to A and to permutations of (x a t , O a t ) over a. The encoder E is a seven-layer CNN with residual layers, and the positional embedding s a t is injected after the second convolutional layer via concatenation with the feature tensor. The exact architectures can be found in Appendices E.1 and E.2. RNN. The aggregated representation r t is used as an input to a RNN model F as following: h t+1 , c t+1 = F (r t , h t , c t ) where h t and c t are hidden and memory states of the RNN F respectively. We experiment with various RNN models, including LSTMs (Hochreiter & Schmidhuber, 1997) , RMCs (Santoro et al., 2018) and Recurrent Independent Mechanisms (RIMs) (Goyal et al., 2019) . As a sanity check, we also show results with a Time Travelling Oracle (TTO), which has access to r t+1 (but at time step t), and produces h t+1 = F T T O (r t+1 ) with a two layer MLP F T T O . TTO therefore does not model the dynamics, but merely verifies that the additive aggregation scheme (Equation 29) and the querying mechanism (Equation 31) are sufficient for the task at hand. Decoder. Given the embedding s q of the query position x q , the decoder D predicts the corresponding observation Ôq t+1 : Ôq t+1 = D(h t+1 , s q ) (31) We parameterize D with a deconvolutional network with residual layers, and inject the positional embedding of the query s q after a single convolutional layer by concatenating with the layer features (see Appendices E.1 and E.2).

E HYPERPARAMETERS AND ARCHITECTURES E.1 ENCODER AND DECODER FOR BOUNCING BALLS

The architectures of image encoder and decoder was fixed for all models after initial experimentation. We converged to the following architectures.

E.1.1 S2RMS

The encoder (decoder) is a (de)convolutional network with residual connections (Figure 12 ).

E.1.2 BASELINES

Like in the case of S2RMs, the encoder (decoder) is a (de)convolutional network with residual connections (Figure 11 ), but with the positional embeddings injected after the second convolutional layer. This is loosely inspired by the encoders used in Eslami et al. (2018) .

E.2 ENCODER AND DECODER FOR STARCRAFT2 E.2.1 S2RMS

Recall from Appendix B.1 that the states are polar images. We therefore use polar convolutions, which entails zero-padding the input image along the first (radial) dimension but circular padding along the second (angular) dimension. The encoder and decoder architectures can be found in Figure 14 .

E.2.2 BASELINES

Like for S2RMs, we use polar convolutions while injecting the positional embeddings further downstream in the network. The corresponding encoder and decoder architectures are illustrated in Figure 13 .

E.3 SPATIALLY STRUCTURED RELATIONAL MEMORY CORES (S2RMCS)

Embedding Relational Memory Cores (Santoro et al., 2018) naïvely in the S2RM architecture did not result in a working model. We therefore had to adapt it by first projecting the memory matrix (M in Santoro et al. (2018) ) of the m-th RMC to a message h m t . This message is then processed by the intercell attention to obtain hm t , which is finally concatenated with the memory matrix and current input before applying the attention mechanism (i.e. in Equation 2 of Santoro et al. (2018) , we replace [M ; x] with M ; x, hm t ).

E.4 HYPERPARAMETERS E.4.1 BOUNCING BALL MODELS

The hyperparameters we used can be found in Table 2 . Further, note that in Equation 5, we pass the gradients through the constant region of the kernel as if the kernel had not been truncated.

E.4.2 STARCRAFT2 MODELS

The hyperparameters we used can be found in Table 3 . Note that we only report models that attained a validation loss similar to or better than S2RMs.

E.4.3 TRAINING

All models were trained using Adam Kingma & Ba (2014) with an initial learning rate 0.0003foot_12 . We use Pytorch's (Paszke et al., 2019) ReduceLROnPlateau learning rate scheduler to decay the learning rate by a factor of 2 if the validation loss does not improve by at least 0.01% over the span of 5 epochs. We initially train all models for 100 epochs, select the best of three successful runs, fine-tune it for another 100 epochs, and finally select the checkpoint with the lowest validation loss (i.e. we early stop). We train all models with batch-size 8 (Starcraft2) or 32 (Bouncing Balls) on a single V100-32GB GPU (each). 

E.4.4 OBJECTIVE FUNCTIONS

In the Starcraft2 task, predicting the next state entails predicting images of binary friendly markers, categorical unit type markers and real valued HECS markers. Accordingly, the loss function is a sum of a binary cross-entropy term (on friendly markers), a categorical cross-entropy term (on unit-type markers) and a mean squared error term (on HECS markers). In the Bouncing Balls task, the model output is a binary image. Accordingly, we use a pixel-wise binary cross-entropy loss.

F ADDITIONAL RESULTS

F.1 BOUNCING BALLS

F.1.1 ROLLOUTS

To obtain the rollouts in Figure 3 , we adopt the following strategy. For the first 20 prompt-steps, we present all models with exactly the same 11 × 11 crops around randomly sampled pixel positions for 20 time-steps. For the next 25 steps, all models are queried at random pixel positions 10 , and the resulting predictions (on crops) are thresholded at 0.5 and fed back in to the model for the next step (at known pixel positions from the previous step). Also at every time-step, the models are queried for their predictions on 16 pixel locations placed on a 4 × 4 grid. The resulting predictions are stitched together and shown in Figures 15, 16, 17, 18, 3 and 19 . In this section, we evaluate the robustness of all models to dropped crops on in-distribution and OOD data. We measure the performance metrics on one-step forward prediction task on all datasets (with 1-6 balls), albeit by dropping a given fraction of the available input observations. Figure 20 and 21 visualize the performance of all evaluated models. We find that S2GRU maintains performance on OOD data even with fewer views (or crops) than it was trained on. Interestingly, we find that the time-travelling oracle (TTO), while robust OOD, is adversely affected by the number of available views. This could be because unlike the other models, it cannot leverage the temporal information to compensate for the missing observations. F.2 STARCRAFT2

F.2.1 TABULAR RESULTS

The results used to plot Figure 8 can be found tabulated in Tables 4, 5 , 6 and 7. Model LSTM RMC S2GRU S2RMC TTO % of Active Agents

20%

-0.014035 -0.013569 -0.011491 -0.011921 -0.014174 30% -0.013355 -0.012747 -0.010631 -0.011101 -0.013539 40% -0.012567 -0.011808 -0.009906 -0.010367 -0.012916 50% -0.012220 -0.011305 -0.009637 -0.009887 -0.012481 60% -0.010888 -0.009799 -0.008751 -0.009034 -0.010929 70% -0.009738 -0.008469 -0.008068 -0.008359 -0.009184 80% -0.009081 -0.008027 -0.007873 -0.008162 -0.008466 90% -0.007970 -0.007180 -0.007347 -0.007615 -0.007038 100% -0.007638 -0.006823 -0.007103 -0.007362 -0.006401 



Max-Planck Institute for Intelligent Systems Tübingen, Mila, Québec, Bethgelab, Eberhard Karls Universität Tübingen, Université de Montreal. Correspondence to: <nasim.rahaman@tuebingen.mpg.de>. The continuity of P ties the two notions of locality by requiring that an infinitesimal change in x corresponds to one in S. Injectivity ensures that no two points in X are mapped to the same point in S. In doing so, we use the assumptions merely as guiding principles; we do not claim that we infer the true decomposition of the ground-truth system, even if all assumptions are satisfied. Here, the code 1c3s5z refers to a scenario where each team comprises 1 colossus (1c), 3 stalkers (3s), and 5 zealots (5z). Please note that this is a game-specific notion. Note that this is rather unconventional, since each player usually controls entire armies and must switch between macro-and micro-management of units or unit-groups. If health drops to zero, the unit is considered dead and the representation does not differentiate between dead and absent units. These stats may change with game-versions, and are catalogued here: https://liquipedia.net/ starcraft2/Units_(StarCraft). Indices not appearing on both sides of an equation are summed over; this is implemented as einsum in most DL frameworks. https://twitter.com/karpathy/status/801621764144971776?s=20 These random pixel positions are the same for all models, but change between time-steps



(a) Fully localized sub-systems. (b) Middle ground. (c) Single, monolithic system.

Figure 2: Schematic illustration of the proposed architecture. An observation is addressed to modules with embeddings situated in vicinity of its embedded location. Likewise, modules with embeddings in the vicinity of an embedded query location are polled to produce a prediction.

Figure 3: Rollouts (OOD) with 5 bouncing balls, from top to bottom: ground-truth, S2GRU, RIMs, RMC, LSTM. Note that all models were trained on sequences with 3 bouncing balls, and the global state is reconstructed by stitching together 16 patches of size 11 × 11 produced by the models (queried on a 4 × 4 grid). Gist: S2GRU succeeds at keeping track of all bouncing balls over long rollout horizons (25 frames).

Figure 4: Performance metrics on OOD one-step forward prediction task. Gist: S2GRU outperforms all RNN baselines OOD.

Figure 6: Ablation over the local and the non-local terms in the input and inter-cell attention mechanisms (KMDPAs). For a set number of bouncing balls, each sub-plot shows how the balanced accuracy changes with the fraction of views (crops) available to the model. Gist: Both local and non-local terms in KMDPA contribute to the overall performance. The non-local term is more important for the input attention, whereas the local term is more important for the inter-cell attention.

Figure7: The effect of removing random modules at test time. Gist: Performance degrades gracefully as modules are removed, suggesting that modules can function even when their counterparts are removed, and that there is limited co-adaptation between them.

Figure 9: Joint visualization of spatial enclaves and the interaction graph between modules in the grid-world environment of Eichenberger et al. (2019), as detailed in Appendix A. The images show which spatial locations a module attends to via the local attention (spatial enclaves), whereas the presence of an edge indicates that the corresponding modules may interact via inter-cell attention. Gist: The modules indeed learn a notion of spatial locality, while interacting sparsely with each other.

(a) 1s2z (1 Stalker and 2 Zealots per team). (b) 5s3z (5 Stalkers and 3 Zealots per team). (c) 2s3z (2 Stalkers and 3 Zealots per team). (d) 3s5z (3 Stalkers and 5 Zealots per team). (e) 1c3s5z (1 Colossus, 3 Stalkers and 5 Zealots per team).

Figure 10: Human readable illustrations of the Starcraft2 (SMAC) scenarios we consider in this work. Figures 10a and 10b show the OOD scenarios, whereas Figures 10c, 10d and 10e show the training scenarios (provided by Samvelyan et al. (2019)).

Figure 15: Rollouts (OOD) with 1 bouncing ball, from top to bottom: ground-truth, S2GRU, RIMs, RMC, LSTM. Note that all models were trained on sequences with 3 bouncing balls, and the global state was reconstructed by stitching together 11 × 11 patches from the models (queried on a 4 × 4 grid).

Figure 16: Rollouts (OOD) with 2 bouncing balls, from top to bottom: ground-truth, S2GRU, RIMs, RMC, LSTM. Note that all models were trained on sequences with 3 bouncing balls, and the global state was reconstructed by stitching together 11 × 11 patches from the models (queried on a 4 × 4 grid).

Figure 17: Rollouts (ID) with 3 bouncing balls, from top to bottom: ground-truth, S2GRU, RIMs, RMC, LSTM. Note that all models were trained on sequences with 3 bouncing balls, and the global state was reconstructed by stitching together 11 × 11 patches from the models (queried on a 4 × 4 grid).

Figure 18: Rollouts (OOD) with 4 bouncing balls, from top to bottom: ground-truth, S2GRU, RIMs, RMC, LSTM. Note that all models were trained on sequences with 3 bouncing balls, and the global state was reconstructed by stitching together 11 × 11 patches from the models (queried on a 4 × 4 grid).

Figure 19: Rollouts (OOD) with 6 bouncing balls, from top to bottom: ground-truth, S2GRU, RIMs, RMC, LSTM. Note that all models were trained on sequences with 3 bouncing balls, and the global state was reconstructed by stitching together 11 × 11 patches from the models (queried on a 4 × 4 grid).

Figure 20: Balanced accuracy (arithmetic mean of recall and specificity) achieved by all evaluated models for one-step forward prediction task with various number of balls and fractions of available views. All models were trained on video sequences with 3 balls and a constant number of crops / views (10 views, corresponding to the right-most columns labelled 1.0). The color map is consistent across all plots.

World Modeling on Starcraft2. In Section 2, we formulated the problem of modeling what we called the world state o of a dynamical system φ given local observations {(x a

Hyperparameters used for various models on the Bouncing Ball task. Hyperparameters not listed here were left at their respective default values.

The color map is consistent across all plots. Friendly marker F1 scores on the validation set of the training distribution. Larger numbers are better.

Unit-type marker (macro averaged) F1 scores on the validation set of the training distribution. Larger numbers are better.

HECS Negative MSE on the validation set of the training distribution. Larger numbers are better.

Log Likelihood (negative loss) on the validation set of the training distribution. Larger numbers are better.

ACKNOWLEDGEMENTS

The authors would like to thank Georgios Arvanitidis, Luigi Gresele, Michael Cobos for their feedback on the paper, and Murray Shanahan for the discussions. The authors also acknowledge the important role played by their colleagues at the Empirical Inference Department of MPI-IS Tübingen and Mila throughout the duration of this work.

