LATENT HIERARCHICAL IMITATION LEARNING FOR STOCHASTIC ENVIRONMENTS

Abstract

Many applications of imitation learning require the agent to avoid mode collapse and mirror the full distribution of observed behaviours. Existing methods that address this distributional realism typically rely on hierarchical policies conditioned on sampled types that model agent-internal features like persona, goal, or strategy. However, these methods are often inappropriate for stochastic environments, where internal and external factors of influence on the observed agent trajectories have to be disentangled, and only internal factors should be encoded in the agent type to be robust to changing environment conditions. We formalize this challenge as distribution shift in the conditional distribution of agent types under environmental stochasticity, in addition to the familiar covariate shift in state visitations. We propose Robust Type Conditioning (RTC), which eliminate this shifts with adversarial training under randomly sampled types. Experiments on two domains, including the large-scale Waymo Open Motion Dataset, show improved distributional realism while maintaining or improving task performance compared to state-of-the-art baselines.

1. INTRODUCTION

Learning to imitate behaviour is crucial when reward design is infeasible (Amodei et al., 2016; Hadfield-Menell et al., 2017; Fu et al., 2018; Everitt et al., 2021) , for overcoming hard exploration problems (Rajeswaran et al., 2017; Zhu et al., 2018) , and for realistic modelling of dynamical systems with multiple interacting agents (Farmer and Foley, 2009) . Such systems, including games, driving simulations, and agent-based economic models, often have known state transition functions, but require accurate agents to be realistic. For example, for driving simulations, which are crucial for accelerating the development of autonomous vehicles (Suo et al., 2021; Igl et al., 2022) , faithful reactions of all road users are paramount. Furthermore, it is not enough to mimic a single mode in the data; instead, agents must reproduce the full distribution of behaviours to avoid sim2real gaps in modelled systems (Grover et al., 2018; Liang et al., 2020) , under-explored solutions in complex tasks (Vinyals et al., 2019) and suboptimal policies in games requiring mixed strategies (Nash Jr, 1950) . Current imitation learning (IL) methods fall short of achieving such distributional realism by matching all modes in the data. The required stochastic policy cannot be recovered from a fixed reward function and adversarial methods, while aiming to match the distribution in principle, are known to be prone to mode collapse in practice (Wang et al., 2017; Lucic et al., 2018; Creswell et al., 2018) . Furthermore, progress on distributional realism is hindered by a lack of suitable IL benchmarks, with most relying on unimodal data and only evaluating task performance as measured by rewards, but not mode coverage. By contrast, many applications require distributional realism in addition to good task performance. For example, accurately evaluating the safety of autonomous vehicles in simulation relies on distributionally realistic agents. Consequently, our goal is to improve distributional realism while maintaining strong task performance. To mitigate mode collapse in complex environments, previous work uses hierarchical policies in an autoencoder framework (Wang et al., 2017; Suo et al., 2021; Igl et al., 2022) . During training, an encoder infers latent variables from observed trajectories and the agent, conditioned on those latent variables, strives to imitate the original trajectory. At test time, a prior distribution proposes distributionally realistic latent values, without requiring access to privileged future information. We refer to this latent vector as an agent's inferred type since it expresses intrinsic characteristics of the agent that yield the multimodal behaviour. Depending on the environment, the type could, for example, represent the agent's persona, belief, goal, or strategy. However, these hierarchical methods rely on either manually designed type representations (Igl et al., 2022) or the strong assumption that all stochasticity in the environment can be controlled by the agent (Wang et al., 2017; Suo et al., 2021) . Unfortunately, this assumption is violated in most realistic scenarios. For example, in the case of driving simulations, trajectories depend not only on the agent's type, expressing its driving style and intent, but also on external factors such as the behaviour of other road users. Crucially, despite being inferred from future trajectories during training, agent types must be independent of these external factors to avoid leaking information about future events outside the agent's control, which in turn can impair generalization at test time under changed, and ex-ante unknown, environmental conditions. In other words, the challenge in learning hierarchical policies using IL in stochastic environments is to disentangle the internal and external factors of influence on the trajectories and only encode the former into the type. Consider the example of an expert approaching an intersection at the same time as another car. The expert passes if the other car brakes and yields to it otherwise. To reconstruct the scene with ease, a naively trained latent model could not only encode the agent's intended direction (an internal decision) but also whether to yield, which depends on the other car (an external factor). This is catastrophic at test time when the latent, and hence the yielding decision, is sampled independently of the other car's behaviour. In contrast, if only the expert's intent were encoded in the latent, the policy would learn to react appropriately to external factors. In this paper, we identify these subtle challenges arising under stochastic environments and formulate them as a new form of distribution shift for hierarchical policies. Unlike the familiar covariate shift in the state distribution (Ross et al., 2011) , this conditional type shift occur in the distribution of the inferred latent type. It greatly reduces performance by yielding causally confused agents that rely on the latent type for information about external factors, instead of inferring them from the latest environment observation. We propose Robust Type Conditioning (RTC) to eliminate this distribution shift and avoid causally confused agents through a coupled adversarial training objective under randomly sampled types. We do not require access to an expert, counterfactuals, or manually specified type labels for trajectories. Experimentally, we show the need for improved distributional realism due to mode collapse in state-of-the-art imitation learning techniques such as GAIL (Ho and Ermon, 2016) . Furthermore, we show that naively trained hierarchical models with inferred types improve distributional realism, but exhibit poor task performance in stochastic environments. By contrast, RTC can maintain good task performance in stochastic environments while improving distributional realism and mode coverage. We evaluate RTC on the illustrative Double Goal Problem as well as the large scale Waymo Open Motion Dataset (Ettinger et al., 2021) of real driving behaviour.

2. BACKGROUND

We are given a dataset D = {τ i } N i=1 of N trajectories τ i = s (i) 0 , a (i) 0 , . . . s (i) T , drawn from p(τ ) of one or more experts interacting with a stochastic environment p(s t+1 |s t , a t ) where s t ∈ S are states and a t ∈ A are actions. Our goal is to learn a policy π θ (a t |s t ) to match p(τ ) when replacing the unknown expert and generating rollouts τ ∼ p(τ ) = p(s 0 ) T -1 t=0 π θ (â t |ŝ t )p(ŝ t+1 |ŝ t , ât ) from the inital states s 0 ∼ p(s 0 ). We simplify notation and write τ ∼ π θ (τ ) and τ ∼ D(τ ) to indicate rollouts generated by the policy or drawn from the data respectively. Expectations E τ ∼D and E τ ∼π θ are taken over all pairs (s t , a t ) ∈ τ and (ŝ t , ât ) ∈ τ . Previous work (e.g., Ross et al., 2011; Ho and Ermon, 2016) shows that a core challenge of learning from demonstration is reducing or eliminating the covariate shift in the state-visitation frequencies p(s) caused by accumulating errors when using π θ . Unfortunately, Behavioural Cloning (BC), a simple supervised training objective optimising max θ E τ ∼D [log π θ (a t |s t )] is not robust to it. To overcome covariate shift, generative adversarial imitation learning (GAIL) (Ho and Ermon, 2016) optimises π θ to fool a learned discriminator D φ (â t , ŝt ) that is trained to distinguish between trajectories in D and those generated by π θ : min θ max φ E τ ∼π θ log(D φ (â t , ŝt )) + E τ ∼D log(1 -D φ (a t , s t )) . (1) The policy can be optimised using reinforcement learning, by treating the log-discriminator scores as costs, r t = -log D φ (â t , ŝt ). Alternatively, if the policy can be reparameterized (Kingma and Welling, 2013) and the environment is differentiable, the sum of log discriminator scores can be optimised directly without relying on high-variance score function estimators by backpropagating through the transition dynamics, L adv (τ ) = E τ ∼π θ [ t -log D φ (â t , ŝt )]. We refer to this as Model-based GAIL (MGAIL), though in contrast to Baram et al. (2016) , we assume a known differentiable environment instead of a learned model. In this work, we are concerned with multimodal distributions p(τ ) and how mode collapse can be avoided when learning π θ . To this end, we assume the dataset is sampled from p(τ ) = p(s 0 ) p(g)p(ξ) T t=0 p(a t |s t , g)p(s t+1 |s t , a t , ξ)dξdg, where g is the agent type, expressing agent characteristics such as persona, goal, or, strategy, and ξ is a random variable capturing the stochasticity in the environment, i.e. p(s t+1 |s t , a t , ξ) is a delta distribution δ f (st,at,ξ) (s t+1 ) for some transition function f . Learned agents matching p(τ ), i.e., with p(τ ) ≈ p(τ ), are distributionally realistic, whereas realism describes single trajectories when τ lies in the support of p τ (τ ). As we show in section 6, current non-hierarchical adversarial methods (Ho and Ermon, 2016) exhibit mode collapse and are not distributionally realistic. To combat mode collapse, hierarchical methods (e.g., Wang et al., 2017; Lynch et al., 2020; Suo et al., 2021; Igl et al., 2022) often rely on an encoder to infer latent agent types ĝe from trajectories during training, ĝe ∼ e θ (ĝ e |τ ), and optimise the control policy π θ (â t |ŝ t , ĝe ) to generate trajectories τe similar to τ : τe ∼ p(τ e |ĝ e ) = p(s 0 ) T -1 t=0 π θ (â t |ŝ t , ĝe )p(ŝ t+1 |â t , ŝt ), with ĝe ∼ e θ (ĝ e |τ ). If ground truth trajectories are not accessible during testing, a prior p θ (ĝ p ) can be used to sample distributionally realistic types ĝp . We indicate by subscript ĝp or ĝe whether the inferred type and trajectory are drawn from the prior distribution p θ (ĝ p ) or encoder e θ (ĝ e |τ ). Subscripts are omitted for states and actions to simplify notation. Inferred types and predicted trajectories without subscripts indicate that either sampling distribution could be used. For discussing information theoretic quantities, we will use capital letters S, A, Â, Ĝ and Ξ to denote the random variables for values s, a, â, ĝ and ξ.

3. CONDITIONAL TYPE SHIFT IN STOCHASTIC ENVIRONMENTS

In this section we outline a challenge that arises for hierarchical policies in stochastic environments. A shift in the conditional type distribution can arise because latent types are drawn from the encoder e θ (ĝ e |τ ) during training but from the prior p θ (ĝ p ) during testing. While the prior is trained to match the marginal distribution of the encoder p(ĝ e ) = E g,ξ [e θ (ĝ e |τ )], this is not the case for the conditional distribution p(ĝ e |ξ) = E g [e θ (ĝ e |τ )]. We show that this conditional type shift can result in policies ignoring environmental information. In section 6 we experimentally confirm that this translates to reduced task performance. We use the simplified model in fig. 1 to describe the consequences of the conditional type shift. This model has two sources of randomness in the data D: the environmental noise ξ and the multimodal type g of the expert we are mimicking. The crucial difference between g and ξ is that ξ represents external factors that the agent cannot control but to which it has to react, while g encodes agent-internal decisions that can be taken independently of ξ. In this simplified model, the state s is a deterministic function of only ξ as we disregard cross-temporal dependencies. During training, when real trajectories τ = (s, a) are available, the inferred type ĝe is drawn from the encoder e θ (ĝ e |τ ). During testing, without access to τ , a prior p θ (ĝ p ) is used. Actions â are drawn from the learned control policy π θ (â|s, ĝ) and optimisation is performed to minimise a reconstruction loss L rec (a, â). We express the core result of the policy 'ignoring environmental information' using mutual information I(S, A), entropy H(A, Ĝe ) and conditional entropy H(A| Ĝe ). Proofs are in appendix A. The core result is that the mutual information between states and actions is lower for prior policies than in the data, implying that such policies ignore action-relevant information in the states. This happens if H(A| Ĝe ) < I(S, A), i.e. if the type ĝ captures too much information about â during training. The condition H(A, Ĝe ) = H( Â, Ĝp ) assures that the entropy in the system and mutual information between variables remains comparable between training and testing. In the extreme case that type fully determines the action, i.e. H(A| Ĝe ) = 0, the policy ignores the state entirely, i.e. I(S, Â) = 0. Because the encoder and policy are trained jointly, the failure case H(A| Ĝe ) < I(S, A) requires the encoder to capture too much information about s in ĝe and the policy relying too much on ĝe to predict a, which constitutes a form of causal confusion. Without the encoder providing excessive information about s, the policy could not learn to over-rely on ĝe . Conversely, even with excessive information about s in ĝe , the policy could still ignore it and avoid H(A| Ĝe ) < I(S, A). As an example, for the data given in fig. 1c , this is a solution satisfying I(S, A) > H( Â| Ĝ) = 0: ĝe (s i , a j ) = 0 if j = 0 1 if j = 1 implies π θ (â|s i , ĝ) = a 0 if ĝ = 0 a 1 if ĝ = 1 and p θ (ĝ p ) = B(0.5). (2) with Bernoulli distribution B. The latent type fully determines the action and the policy ignores the state. This allows perfect reconstruction during training, but fails at test time when p θ (ĝ p )π θ (â|s, ĝp ) would randomly sample a 0 and a 1 with equal probability. In the example of a car approaching an intersection, this corresponds to entering the intersection independently of whether another car is approaching quickly, leading to increased collision rates. Also note that the conditional type distribution is not independent of the state, i.e. p(ĝ e |s 0 ) = B(1 -) = p(ĝ e |s 1 ) = B( ), leading to a conditional type shift to p θ (ĝ p ) = B(0.5) at test time. Other training solutions exist for which I(S, A) = H( Â| Ĝ) and I(S, Â) = I(S, A). For example: Here the latent type ĝ only captures the agent-internal randomness, the conditional type distribution matches the prior, i.e. p(ĝ e |s 1 ) = p(ĝ e |s 2 ) = p θ (ĝ p ) = B(1 -), and the test policy correctly reproduces the data. ĝe (s i , a j ) = 0 if i = j 1 if i = j , π θ (â|s i , ĝ) = a i if ĝ = 0 a j if ĝ = 1 for i = j, p θ (ĝ p ) = B(1 -). (3) For temporally extended data, the states s t will depend not only on ξ, but also on g or ĝ, complicating theoretical treatment. Nevertheless, seeing ξ as all future stochasticity in the environment, the same threat of conditional type shift arises. In the next section, we introduce two training interventions. The first discourages causally confused policies, the other discourages the encoder from capturing excessive information about ξ. In section 6 we show that using both interventions jointly allows to use hierarchical policies for improved distributional realism while avoiding the sub-optimal solution for which I(S, A) > I(S, Â).

4. ROBUST TYPE CONDITIONING

We present Robust Type Conditioning (RTC), a method for improving distributional realism in imitation learning while maintaining high task performance. RTC follows the auto-encoder framework discussed in sections 2 and 3 but avoids conditional type shifts and causally confused policies that ignore environmental information in stochastic environments. This is achieved through two augmentations. First, during training, latent types are not only sampled from the encoder, but also the prior. Because we do not have ground truth trajectories for these prior sampled types, an adversarial loss is used in place of the reconstruction loss. Second, we regularise the mutual information I(S, Ĝ) using a variational information bottleneck (Alemi et al., 2016) to avoid excessive information in ĝ. RTC combines four losses: the reconstruction loss L rec , the information bottleneck loss L ib , the adversarial loss L adv , and the prior loss L prior (see fig. 2 ): L RTC = E D(τ )e θ (ĝ e |τ )π θ (τe|ĝ e ) L rec (τ, τe ) + βL ib +λ adv L adv (τ e ) + L prior (τ ) + E D(τ )pθ (ĝ p )π θ (τp|ĝ p ) λ adv L adv (τ p ) + L prior (τ ) . p θ (ĝ p ) is a learned prior and π θ (τ |ĝ) is shorthand for generating trajectories τ by rolling out the learned control policy π θ (â|ŝ, ĝ) in the environment. Parameters θ are held fixed and λ adv and β are scalar weights. We now introduce the individual terms. First, L rec is a reconstruction loss between τ and τe , encouraging the hierarchical policy to be distributionally realistic and encode useful information about the trajectory in the inferred type ĝe . The loss L rec can take different forms. For example, in section 6.1 we use the BC loss L rec (τ ) = -log π θ (a t |s t , ĝe ) while in section 6.2 we minimise the L 2 distance between agent positions in s t and ŝt . Note that state-based losses like the L 2 reconstruction loss require access to a training environment able to resimulate the conditions of the original trajectory τ as we assume that τ is still approximately optimal. The loss L prior (τ ) = E ĝe ∼eθ (ĝ e |τ ) [log p θ (ĝ e )] optimises the prior to propose distributionally realistic types by matching the marginal encoder distribution. The key algorithmic contribution of RTC is to also optimise the policy under types sampled from the prior (second line in eq. ( 4)), not only the encoder (first line in eq. ( 4)). For these prior-sampled types, the reconstruction loss cannot be used as the correct ground truth trajectories are unavailable. Instead, we use the adversarial loss L adv (τ ) = t -log D φ (â t , ŝt ), where D φ (â t , ŝt ) is a learned discriminator (see section 2). This reduces conditional type shift because the prior distribution is already used during training. It also reduces causal confusion: Because some types ĝ are now sampled independently of the trajectory τ and hence ξ, their information about ξ is now less reliable and the policy is incentivised to rely on s as much as possible. One can view sampling from the prior as a causal intervention do(ĝ) in which ĝ is changed independently of the environmental factor ξ. De Haan et al. (2019) show that causal confusion can be avoided by applying such interventions and optimising the policy to correctly predict the counterfactual expert trajectory distribution, in our case p expert (τ |ξ, do(ĝ)). Unfortunately, we do not have access to this counterfactual trajectory. Instead, we rely on the generalisation of π θ to get us 'close' to such a counterfactual trajectory for types do(ĝ) and then refine the policy locally using the adversarial objective. We experimentally found that optimising the policy under prior types is sufficient to improve task performance and avoid causal confusion in hierarchical policies. However, it also eliminated improvements in distributional realism gained through the use of hierarchies, likely because the policy simply learned to ignore the latent type altogether. As solution, we employ a informational bottleneck on types ĝe ∼ e θ (ĝ e |τ ). This filters information about ξ while still encoding information about g in ĝe , thereby making the information in ĝe more reliable and useful to the policy. We experimentally show that this, when combined with priortype sampling during training, achieves improved distributional realism while maintaining excellent task performance. Without prior-type sampling, task-performance degrades considerably, indicating that the bottleneck is insufficient for filtering out information about ξ entirely. The information bottleneck preferentially filters information about ξ, because the control policy π θ also has direct access to it through the visited states s. By contrast, information about g can only be accessed by the policy through ĝe and is hence preferably encoded in the bottleneck when information bandwidth costs are applied. We found that both continuous type representations with L ib = KL[p(ĝ e ) N (ĝ e ; 0, I)] and discrete type representations using straight-through gradient estimation work well in practice (see section 6.2). To accommodate optimisation under inferred types drawn from both the encoder e θ and the prior p θ , we split each minibatch B = {τ (b) } N b b of N b trajectories sampled from D into two parts. For the fraction f of trajectories in B the rollouts τe are generated from types sampled from the encoder ĝe ∼ e θ (ĝ e |τ ) and all four losses are optimised (first line in eq. ( 4)). For the remaining fraction (1 -f ) of trajectories types are sampled from the prior p θ (ĝ p ) and only L adv and L prior are optimised (second line in eq. ( 4)). Optimisation of L adv and L rec can either be performed directly, similar to MGAIL (Baram et al., 2016) , by using a differentiable environment and reparameterised policies and encoder (Kingma and Welling, 2013) or by treating them as rewards and using RL methods such as TRPO (Schulman et al., 2015; Ho and Ermon, 2016) or PPO (Schulman et al., 2017) . The losses L prior and L ib can always be optimised directly.

5. RELATED WORK

Several previous works combine adversarial training with autoencoder architectures in the image domain. Makhzani et al. (2016) use an adversarial loss on the latent variable in place of the KL-regularization used in VAEs. However, this eliminates the information bandwidth regularization for continuous latents which we show to be important for hierarchical imitation learning. Larsen et al. (2016) aim to learn a similarity metric for visual inputs using latent representations of the discriminator. This is valuable for imitation learning from raw images (Rahmatizadeh et al., 2018) , but is not required for our experimental domains. Lastly, Chrysos et al. (2018) , similar to our work, use an additional autoencoding loss to better capture the data distribution in the latent space. However, they consider denoising images instead of imitation learning under stochasticity. Hierarchical policies have been extensively studied in RL (e.g., Sutton et al., 1999; Bacon et al., 2017; Vezhnevets et al., 2017; Nachum et al., 2019; Igl et al., 2020) and IL. In RL, they improve exploration, sample efficiency and fast adaptation. By contrast, in IL, hierarchies are used to capture multimodal distributions, improve data efficiency (Krishnan et al., 2017; Le et al., 2018) , and enable goal conditioning (Shiarlis et al., 2018) . Similar to our work, Wang et al. (2017) and Lynch et al. (2020) learn to encode trajectories into latent types that influence a control policy. Crucially, both only consider deterministic environments and hence avoid the distribution shifts and unwanted information leakage we address. They extend prior work in which the type, or context, is provided in the dataset (Merel et al., 2017) , which is also assumed in (Fei et al., 2020) . Tamar et al. (2018) use a sampling method to infer latent types. Khandelwal et al. (2020) and Igl et al. (2022) use manually designed encoders specific to road users by expressing future goals as sequences of lane segments. This avoids information leakage but cannot express all characteristics of human drivers, such as persona, and cannot transfer to other tasks. Futures states in deterministic environments (Ding et al., 2019) , language (Pashevich et al., 2021) , and predefined strategy statistics (Vinyals et al., 2019) have also been used as types. Information theoretic regularization offers an alternative to learning hierarchical policies using the autoencoder framework (Li et al., 2017; Hausman et al., 2017) . However, these methods are less expressive since their prior distribution cannot be learned and only aim to cluster modes already captured by the agent but not penalize dropping modes in the data. This provides a useful inductive bias but often struggles in complex environments with high diversity, requiring manual feature engineering (Eysenbach et al., 2019; Pathak et al., 2019) . Lastly, TrafficSim (Suo et al., 2021) uses IL to model driving agents and controls all stochasticity in the scene but uses independent prior distributions for separate agents. Hence, while no conditional distribution shift in p(ĝ|ξ) can occur (as ξ is constant), distribution shifts in p(ĝ (i) |ĝ (j) ), and hence the joint marginal p(ĝ (1:N ) ) can occur for latent types ĝ(i) , ĝ(j) of agents i = j with i, j ∈ {1 . . . N } and ĝ(1:N) = [ĝ (1) . . . ĝ(N) ]: when drawn from the encoder, goals ĝ(i) and ĝ(j) are coordinated through conditioning on the joint agent future, while they are independent when drawn from the prior. They use a biased "common sense" collision avoidance loss, motivated by covariate shift in visited states. Our work suggests that marginal type shift might also explain the benefits gained. In contrast, our adversarial objective is unbiased. See appendix B for more related work on agent modelling in multi-agent settings, behavioural prediction and causal confusion.

6. EXPERIMENTS

We show in two stochastic environments with multimodal expert behaviour that i) existing adversarial methods suffer from insufficient distributional realism, ii) existing hierarchical methods cannot achieve good task performance and distributional realism and iii) RTC improves distributional realism while maintaining excellent task performance. We discuss differences in realism, coverage and distributional realism in fig. 5 . We compare the following models: MGAIL uses a learned discriminator and backpropagates gradients through the differentiable environment. It also optimises a BC loss as we found this to improve performance. Symphony (Igl et al., 2022) , building on MGAIL, utilises future lane segments as manually specified types (see appendix D.3). Our implementation of Symphony outperforms the results from (Igl et al., 2022) due to the additional use of a value function. InfoMGAIL (Li et al., 2017) augments MGAIL to elicit distinct trajectories for different types. This introduces an inductive bias but does not directly penalise mode collapse. Our methods, RTC-C and RTC-D, use a continuous or discrete type respectively. We also perform the ablation Hierarchy-NoPT (No Prior Training) which only uses the first line in eq. ( 4), i.e. f = 1. Hierarchy-NoPT is similar to existing hierarchical methods, such as the proprietory TrafficSim (Suo et al., 2021) , in that it learns the prior but does not use it during training, only inference. It thereby does not account for distribution shifts in the latent types, as discussed in section 3.

6.1. DOUBLE GOAL PROBLEM

In the double goal problem, the expert starts from the origin and creates a multimodal trajectory distribution by randomly choosing and approaching one of two possible, slowly moving goals located on the 2D plane. Stochasticity is introduced through randomized initial goal locations and movement directions. Nevertheless, the lower and upper goal {g l , g u } remain identifiable by their location as y l < 0 for g l and y u > 0 for g u (see fig. 3 ). While both goals are equally easy to reach, the expert has a preference P (G = g l ) = 0.75. Sufficiently complex expert trajectories prevent BC from achieving optimal performance, requiring more advanced approaches. The expert follows a curved path and randomly resamples the selected goal for the first ten steps to avoid a simple decision boundary along the x-axis in which experts in the lower half-plane always target goal g l . RTC uses the BC loss as reconstruction loss L rec (τ ) = -log π θ (a t |s t , ĝe ) and continuous types. All policies use a bimodal Gaussian mixture model as action distribution. Performance is measured as the number of steps for which the agent is within δ = 0.1 distance of one of the goals. We take h s = sign(y T ) of the final agent position [x T , y T ] to indicate the approached goal and measure distributional realism as the divergence between the empirical distributions, JSD (p agent (h s ) p expert (h s )). Details can be found in appendix D.1. Figure 3 shows that MGAIL improves task performance compared to BC. Our method, RTC, improves it further, possibly because given a type, the required action distribution is unimodal. Importantly, RTC substantially improves distributional realism, achieving lower JSD values. To analyse this result, we show p agent (h s = -1), the frequency of targeting the lower goal. Not only is RTC's average value of p RTC (h s = -1) closer to the true value of 0.75, it is also more stable across seeds, resulting in a lower JSD. The bias introduced by InfoMGAIL reduces task performance without improving distributional realism. As expected, the ablation Hierarchy-NoPT achieves excellent distributional realism through the learned hierarchy but suffers reduced task performance due to unaccounted distribution shifts. Lastly, the rightmost plot of fig. 3 shows that the information bottleneck is necessary.

6.2. WAYMO OPEN MOTION DATASET (WOMD)

To evaluate RTC on a complex environment we use the Waymo Open Motion Dataset (Ettinger et al., 2021) consisting of 487K segments of real world driving behaviour. Distributionally realistic agents are critical for driving simulations, for example for estimating safety metrics. Diverse intents and driving styles cause the data to be highly multimodal. Stochasticity is induced through the unpredictable behaviour of other cars, cyclists and pedestrians. We use L rec (τ, τ ) = T t L Huber (s t , ŝt ) where L Huber is the average Huber loss of the four vehicle bounding box corners. More details can be found in appendix D.2.  E τ ∼D,{τi} K i ∼π θ min τi 1 T T t=1 δ(s t , ŝi,t ) , where δ is the Euclidean distance between agent positions and we find the minimum over K = 16 rollouts (hierarchical methods use K independently sampled types). Lower minADE implies better mode coverage, but does not directly measure the relative frequency of modes, e.g., low probability modes may be overrepresented. To measure distribution matching in driving intent, we use the Curvature JSD (Igl et al., 2022) : in lane branching regions, such as intersections, it maps trajectories to the nearest lane and extracts its curvature as feature h cur . To compute JSD (p agent (h cur ) p expert (h cur )), the value of h cur is discretize into 100 equisized bins. To measure the driving style distribution, we extract the progress feature h style = δ(ŝ 0 , ŝT ) and use the same discretization to compute the JSD. Results are provided in table 1. Both versions of RTC improve task performance (collisions and off-road events) and distributional realism metrics (minADE and divergences) compared to the flat MGAIL baseline and previous hierarchical approaches (Symphony, InfoMGAIL, Hierarchy-NoPT). Both type representations, RTC-C and RTC-D, perform similarly, showing robustness of RTC to different implementations. The advantage of RTC in achieving both good task performance and distributional realism becomes clearest by comparing it to Hierarchy-NoPT and RTC-NoIB. While Hierarchy-NoPT achieves some improvements in distributional realism, is has nearly an order of magnitude more collisions. This is a consequence of the challenges discussed in section 3, which RTC is able to avoid. On the other hand, RTC-NoIB, also avoids these challenges and achieves excellent task performance by using prior-sampled types during training. However, as discussed in section 4, it does not improve on distributional realism compared to flat baselines, indicating that the learned policy simply ignores the latent type. Combining prior-type sampling and the information bottleneck achieves better distributional realism and task performance than all baselines.

7. CONCLUSIONS, LIMITATIONS, AND FUTURE WORK

This paper identified new challenges in learning hierarchical policies from demonstration to capture multimodal trajectory distributions in stochastic environments. We expressed them as conditional type shifts and causal confusion in the hierarchical policy. We proposed Robust Type Conditioning (RTC) to eliminate these distribution shifts and showed improved distributional realism while maintaining or improving task performance on two stochastic environments, including the Waymo Open Motion Dataset (Ettinger et al., 2021) . Future work will address conditional distributional realism by not only matching the marginal distribution p(τ ), but the conditional distribution p(τ |ξ) under a specific realization of the environment. For example, drivers might change their intent based on the current traffic situation or players might adapt their strategy as the game unfolds. Achieving such conditional distributional realism will also require new models and metrics.

A PROOFS

We restate the theorem and corollary for convenience. Theorem 1. We assume the model p θ (â|s, a) = e θ (ĝ e |a, s)π θ (â|s, ĝe )dĝ e is achieving optimal reconstruction loss L rec = 0 on P D (s, a). The test policy is p θ (â|s) = p θ (ĝ p )π θ (â|s, ĝp )dĝ p with the marginal encoder p θ (ĝ) = E P D [e θ (ĝ|a, s)] as prior distribution. We can say for the training distribution P (s, a, ĝe ) = P D (s, a)e θ (ĝ e |s, a) and and testing distribution P (s, â, ĝp ) = P D (s)p θ (ĝ p )π θ (â|s, ĝp ): If H(A| Ĝe ) < I(S, A) and H(A, Ĝe ) = H( Â, Ĝp ), then I(S, Â) < I(S, A). Corollary 1. If H(A| Ĝe ) = 0, the assumption H(A, Ĝe ) = H( Â, Ĝp ) becomes unnecessary in theorem 1 and we have I(S, Â) = 0 < I(S, A).

A.1 PRELIMINARIES

We denote by H(X) the entropy, by H(X|Y ) the conditional entropy, by I(X, Y ) the mutual information and by I(X, Y |Z) the conditional mutual information between random variables. Furthermore, the proof is relying on the interaction information I(X, Y, Z), an extension of mutual information to three variables. Importantly, the interaction information can be positive or negative. A positive interaction information indicates that one variable explains some of the correlation between the other two while a negative interaction information indicates that one variable enhances their correlation. Our model e θ (ĝ e |a, s)π θ (â|s, ĝe ) is trained on the dataset P D (s, a). To achieve minimal reconstruction loss, the model is required to predict â = a with certainty, implying H( Â|S, Ĝe ) = H( Â|S, Ĝp ) = 0. At test time, latents are drawn from the prior p θ (ĝ p ) which we assume matches the marginal distribution of the encoder, i.e. E P D [e θ (ĝ e |a, s)], perfectly. We use the following equalities: • I(X, Y, Z) = I(X, Y ) -I(X, Y |Z) (and permutations as I(X, Y, Z) is symmetric) • I(X, Y |Z) = H(X|Z) -H(X|Y, Z) (and permutations) • H(X|Y, Z) ≤ H(X|Y ) • H(X, Y ) = H(X) + H(Y |X) (and permutations) • I(X, Y ) > 0 A.2 PROOF OF THEOREM During training on the dataset P D (s, a) the interaction information is positive because H(A| Ĝe ) < I(S, A): I(A, Ĝe , S) = I(S, A) -I(S, A| Ĝe ) = I(S, A) -H(A| Ĝe ) + H(A|S, Ĝe ) =0 > 0. (5) On the other hand, during testing, we have I( Ĝp , S) = 0 because now Ĝp is drawn independently of S from p θ ( Ĝp ). Consequently, the interaction information becomes negative: I( Â, Ĝp , S) = I( Ĝp , S) =0 -I( Ĝp , S| Â) = H( Ĝp | Â, S) -H( Ĝp | Â) ≤ 0 data (Tang and Salakhutdinov, 2019; Casas et al., 2020; Ivanovic and Pavone, 2019; Salzmann et al., 2020; Yuan et al., 2021; Hu et al., 2021) , they are vulnerable to the same marginal and conditional type shifts we consider. While none of these works take these challenges into account, they often use small discrete latent spaces (e.g., Tang and Salakhutdinov, 2019; Ivanovic and Pavone, 2019; Salzmann et al., 2020) , mitigating the severity of the distribution shifts and future information leakage by limiting the information bandwidth of latent types. Furthermore, prediction quality metrics such as displacement-based metrics or log-likelihood are less sensitive to yield lower performance due to covariate shift, which primarily impacts interactions with the environment, such as collisions. As discussed in section 3, the conditional type shift is exacerbated by causally confused policies relying on the latent type for information about environmental noise. Unlike in most literature on causal confusion (De Haan et al., 2019) , our nuisance variables are hence not part of the current state, but the learned latent state. Distribution shift is induced not through earlier actions but through sampling from the prior instead of the encoder. Prior work on causal confusion typically relies on problem specific regularization (e.g. Wen et al., 2020; Park et al., 2021) or has access to an expert or task rewards (e.g. De Haan et al., 2019; Ortega et al., 2021) . Instead, our work relies on generalisation over latent types to generate counterfactual trajectories. This generalisation is enabled by the information bottleneck and results are refined by the adversarial loss.

C LIMITATIONS AND SOCIETAL IMPACT

While RTC notably improves distributional realism (see section 6), it does not achieve it perfectly, especially in the long tail of the data distribution. This has implications for its use, for example in economic simulations to evaluate policy proposals or in driving simulations to evaluate autonomous vehicles, where this limitation has to be taken into account and the simulation results should not be trusted unconditionally. As RTC is application agnostic, the societal impact depends on where it is used. Here, we focus on agent-based simulations as we anticipate this to create the highest impact. Examples include better policy decisions through economic simulations, safer autonomous vehicles through driving simulations, better AI in games or improved safety precautions for large crowds of people. For other use-cases, e.g., in armed conflicts, the societal impact will depend on the intention of the simulation. Furthermore, we stress once more, that for many use-cases, precautions have to be taken to account for remaining errors in the learned agents. Lastly, depending on the use case, algorithmic bias has to be taken into account if mode-collapse might be prevented more effectively by RTC for certain strata in the population.

D ADDITIONAL EXPERIMENTAL DETAILS D.1 DETAILS ON DOUBLE GOAL PROBLEM

The agent observation s t = [s t , g l,t , g u,t , a t-1 ] ∈ R 8 with s t = [x t , y t ], g i,t = [x i,t , y i,t ], i ∈ {u, l} contains the 2D position of the agent, s t , as well as two marked locations g l,t , g u,t of the lower and upper goal. Because the current agent position cannot uniquely identify the currently selected goal, the observation also contains the last agent action a t-1 with the simple transition function s t+1 = s t + a t . The goal locations are randomly sampled at the beginning of each episode. The lower (upper) goal is always located in the lower (upper) half of the x, y plane. Their horizontal and vertical distances from the initial agent position are uniformly sampled within rectangular bounds x For the first 10 timesteps, the agent randomly resamples the target goal with P (G = g l ) = 0.75 to avoid a simple decision boundary along the x-axis in which experts in the lower half-plane always target goal g l . The expert action is a t = 0.1 ∆ t d t d t where d t = 0.1 0 0 0.05 ∆ t and ∆ t = (g t -s t ). The expert approaches the goal faster along the x-axis, hence creating a curved path. To avoid over-shooting, the step-size reduces by ∆ t as the agent proceeds towards the goal. All networks use simple MLPs with two latent layers and a latent dimension of 256. To capture their shape, the discriminator acts on entire trajectories, aggregating across time using max-pooling over a 32 dimensional per-timestep embedding. All policies are parameterised as Gaussian mixture models with two modes. RTC uses a continuous bottleneck of size 2 with additional regularization term L β ib (τ ) = βKL [e θ (ĝ e |τ ) N (0, I)] to regulate the information bandwidth of ĝe . We use the BC loss for L rec (τ, τe ) = -log π θ (a t |s t , ĝe ). Batch size is 1024 for training and 10K for evaluation. Results shown in fig. 3 are evaluated every 100 steps and exponentially smoothed with a decay rate of 0.9. The learning rate is 0.01 for BC (lower learning rates performed worse) and 0.004 for both MGAIL and RTC, which were tuned independently for values lr ∈ [0.02, 0.01, 0.004, 0.002, 0.001]. f = 0.5 was used to split between B encoder and B prior (no tuning was performed). Lastly, without further tuning, λ adv = 1 was used. Training time is about 7h without hardware acceleration.

D.2 DETAILS ON WAYMO OPEN MOTION DATASET

The Waymo Open Motion Dataset (Ettinger et al., 2021) (published under Apache License 2.0) consists of segments of length 9s sampled at 10Hz. The available training and validation splits in the dataset consist of 487K and 49K segments each, which are used for training and testing the agent respectively. Due to memory constraints, we filter for segments with less than 256 agents and 10K points describing the lane geometryresulting in 428K train and 39K test segments. 250 segments from the training split are used for validation to select the training checkpoint for evaluation. In each segment, we learn to control two agents at a frequency of 3.33Hz, repeating actions three times. Similar to (Igl et al., 2022) , the actions of other agents are replayed from the logged data. The collision metric measures the number of segments, in percent, for which at least one pair of bounding boxes overlaps for at least one timestep. The off-road metric similarly detects for how much time the agent's bounding box overlaps with off-road areas. + a t . Dynamic features and roadgraph users are replayed from the logged data, similar to (Igl et al., 2022) . All positions and headings are first normalised to be relative to the observing ego-agent. MLPs are used to encode each object and point individually and per-type max-pooling is used to aggregate over a variable number of inputs. The resulting three embeddings (one for the ego-agent, one for other road-users and one for the scene), each of size 64, are concatenated and passed either to the policy, discriminator or value function, whose encoders are not shared and which consist of MLPs with two latent layers of size 64 for discriminator and value function and 128 for the policy. The inference encoder e θ for RTC only observes future agent positions s (a) 1:T which are each concatenated with an eight-dimensional learned positional embedding and individually encoded to dimension 128 and max-pooled along the time dimension. A Gaussian mixture model with 8 modes was used for all policies, although we find empirically that typically only up to three are used after training. We train for 200K gradient steps and select the model checkpoint for evaluation with the lowest sum of collision rates and off-road time on the validation set. To stabilize training for all methods, we discount gradients through time with γ = 0.9 and bootstrap from a learned value function every 10 steps. We anneal f from 1 to f min = 0.5 over the course of training. Initially, high values of f encourage meaningful information in ĝe while lower values address covariate and type shifts and improve performance. A learning rate of 0.0001, which was tuned for MGAIL, was used for all evaluated methods. Each batch contained 24 segments and training was performed on a single V100 (per seed) and required about 4-5 days. We used λ adv = 4.0 and β = 0.01 for L ib (τ ) for continuous type representations of size 2. Discrete type representations used three one-hot vectors of size 16, trained using Straight Through gradient estimation. We found performance to be marginally better for three vectors, compared to one, without noticeable performance increases for additional or larger vectors. Smaller vectors with only four values only performed slightly worse. The Huber loss L rec uses δ = 30.

D.3 DETAILS ON SYMPHONY BASELINE

Symphony implements the hierarchical policy proposed in Igl et al. (2022) (called 'MGAIL+H' in their results) . Agent types represent high-level driving intent and are expressed as a sequence of road-segments to be followed. They are encoded into a latent vector by expressing them as a fixed-length sequence of points {[x i , y i ]} Ns i=1 . Each point is concatenated with a positional embedding, then encoded individually, and subsequently max-pooled along the time-dimension. The pooled embedding is provided as additional inputs to both the discriminator and the policy. During training, lane sequences are extracted from the given data trajectory τ . The prior p θ (ĝ), which is used during testing when no ground truth trajectories τ are available, predicts a categorical distribution over all possible sequences of lane-segments which the agent could follow in a scene. To allow for a variable number of such sequences, the logits are predicted individually per sequence.

D.4 DETAILS ON INFOGAIL BASELINE

Like RTC, InfoMGAIL (Li et al., 2017) is a general method for learning a hierarchical agent from demonstrations. What makes it a suitable baseline is that, like in RTC, the higher level policy captures a distribution over alternative trajectories that can be taken. It does so in an unsupervised fashion by introducing an In practice, the dimensionality of X is often too high, requiring us to measure distributional realism only in selected features h(X). Consequently, distributional realism in h(X) does not necessarily imply good realism, i.e. task performance. additional reward that incentivizes the policy to produce state-action pairs from which an additionally trained discriminator (also called 'posterior') can infer the type on which the policy was conditioned. In other words, it rewards the policy for producing distinct trajectories for different types, where the type is drawn from a fixed prior when generating rollouts. A crucial difference between InfoMGAIL and RTC is that InfoMGAIL's goal is to disentangle trajectories, but it does not target distributional realism directly. In particular, because the prior from which the types are sampled is fixed, it might not even be able to properly capture the true distribution of trajectories. This is especially true for the uniform discrete prior used in the original InfoMGAIL paper, which assumes a uniform distribution over trajectory modes. Furthermore, the additional posterior reward introduces bias, potentially harming task performance. Lastly, because mode collapse is not directly penalized in the additional loss (only 'non-distinctiveness' of trajectories), it might not improve distributional realism at all. In our experiments, we augment InfoMGAIL in several ways: • We not only try discrete latents, but also continuous ones. For continuous priors we use the same GMM as posterior as we use as prior in RTC. • We additionally provide the posterior with the initial state as input. Unlike in the examples used in the InfoMGAIL paper, we believe that for more complex WOMD data, the current state is insufficient to determine modes. • To make it comparable in our setup, we optimise it using Info(M)GAIL, i.e. the posterior score of the true type is added as differentiable loss term, not as reward for TRPO. The network architecture of the posterior is the same one as we used for our MGAIL discriminator. • We greatly increase the number of latent dimensions. In (Li et al., 2017) , 2 and 3 dimension were used for the two experiments. We tried d ∈ [3, 10, 30, 100] . We also tried λ 1 ∈ [0.01, 0.03, 0.1, 0.3, 1.0] as regularization strength for the additional loss term. • Lastly, we are also adding a BC term to InfoMGAIL as we found this stabilizes training greatly. • In contrast to the original implementation, we are not using pre-training and do not make use of additional shaping rewards.



(a) Encoder: ĝe ∼ e θ (ĝ e |τ ) (b) Prior: ĝp ∼ p θ (ĝ p ). (c) Example dataset.

Figure 1: Simplified, non-temporal setup with environmental noise ξ and multi-modality induced by the unobserved agent type g. We denote τ = (s, a). The inferred type ĝ is sampled from e θ (ĝ e |τ ) during training (left) and p θ (ĝ p ) otherwise (middle). The control policy is π θ (â|s, ĝ). Circles are random variables and squares deterministic functions. The loss L(a, â) penalises differences between a and â. Right: Example data, B denotes Bernoulli distributions. Theorem 1. We assume the model p θ (â|s, a) = e θ (ĝ e |a, s)π θ (â|s, ĝe )dĝ e is achieving optimal reconstruction loss L rec = 0 on P D (s, a). The test policy is p θ (â|s) = p θ (ĝ p )π θ (â|s, ĝp )dĝ p with the marginal encoder p θ (ĝ) = E P D [e θ (ĝ|a, s)] as prior distribution. We can say for the training distribution P (s, a, ĝe ) = P D (s, a)e θ (ĝ e |s, a) and and testing distribution P (s, â, ĝp ) = P D (s)p θ (ĝ p )π θ (â|s, ĝp ): If H(A| Ĝe ) < I(S, A) and H(A, Ĝe ) = H( Â, Ĝp ), then I(S, Â) < I(S, A). Corollary 1. If H(A| Ĝe ) = 0, the assumption H(A, Ĝe ) = H( Â, Ĝp ) becomes unnecessary in theorem 1 and we have I(S, Â) = 0 < I(S, A).

Figure 2: Robust Type Conditioning (RTC): The control policy π θ (ât|ŝt, ĝ) is trained under inferred types ĝ sampled from both the encoder e θ (ĝ e |τ ) and the prior p θ (ĝ p ). The reconstruction loss Lrec(τ, τe) avoids mode collapse. The adversarial loss Ladv(τp) under prior types prevents causally confused policies and ensures good task performance. Lprior optimises the prior to sample distributionally realistic types and the information bottleneck loss Lib reduces covariate shift.

Figure 3: Top: Visualization of ten randomly sampled goal pairs and associated trajectories. Bottom: Training curves, exponentially smoothed and averaged over 20 seeds. Shading shows the standard deviation. We show task performance as 'Test Return' and distributional realism as 'JSD' between the goal distribution of expert and agent (lower is better). 'Frequency Lower Goal' shows the data from which the JSD is computed. The inset shows the distribution at the last training step. Boxes show quartiles, whiskers extreme values, diamonds outliers, and stars the mean.

Each episode has a fixed horizon of T = 30 steps over the course of which each goal moves by s

Figure 4: Waymo Open Motion Dataset: Performance on the validation set during training. Distributional realism metrics are not shown as their evaluation is high variance on the small validation set.

Figure 5: Differences between realism, coverage and distributional realism. The data distribution P (XD) ∈ ∆(X ) is shown in green, blue denotes a learned distribution P θ (XL) ∈ ∆(X ). (a) Data from learned distribution is realistic, i.e. supp(XL) ⊆ supp(XD), but not distributionally realistic. (b) The learned distribution achieves coverage but not distributional realism: the frequencies of modes are not matched. (c) The learned distribution is distributionally realistic.In practice, the dimensionality of X is often too high, requiring us to measure distributional realism only in selected features h(X). Consequently, distributional realism in h(X) does not necessarily imply good realism, i.e. task performance.

Averages and standard deviation over 10 training runs on WOMD.We use the percentage of segments with collisions and time spent off-road as proxy metrics for task performance. Mode coverage is measured by the minimum average displacement error, minADE =

The state s t = [s , s(RU ) ] contains the agent's position and heading s

annex

We also have, similarly to eq. ( 5 By assumption, we haveFurthermore, because the marginals of Ĝe and Ĝp are matched, we have H( Ĝe ) = H( Ĝp ) and hence the required H(A| Ĝe ) = H( Â| Ĝp ) for eq. ( 8) to hold.Can we remove H(A, Ĝe ) = H( Â, Ĝp ) as an assumption? Unfortunately only if H(A| Ĝe ) = 0, in which case it is automatically true (see next subsection). Otherwise this assumption is needed to make sure that the entropy in the system remains comparable between training and testing.If H(A, Ĝe ) = H( Â, Ĝp ), the main result I(S, Â) < I(S, A) could still hold, but one could also construct environments and encoders in which it does not. The reason is that H(A| Ĝ) depends on the distribution p(s|ĝ) which changes between training, where it is p(s|ĝ e ), and testing, where it is P D (s) due to the independent drawing of ĝp . This can be used to construct environments and encoders that change H(A, Ĝ) and H(A, Ĝ) arbitrarily between training and testing, hence making comparing the mutual information I(S, Â) and I(S, A) meaningless.

A.3 PROOF OF COROLLARY

We have p(a|ĝ) = s p(s|ĝ)π θ (a|s, ĝ)ds. We also know that H(A|S, Ĝe ) = H( Â|S, Ĝp ) = 0 and hence π θ (â|s, ĝ) ∈ {0, 1}. Furthermore, if H(A| Ĝe ) = 0, the action is fully determined by Ĝe , i.e. π θ (â|s, ĝ) = π θ (â|ĝ) ∈ {0, 1}. Hence, because s p(s|ĝ e )ds = s P D (s)ds = 1, the switch from p(s|ĝ e ) to p(s|ĝ p ) = P D (s) does not impact p(a|ĝ), so we have H(A| Ĝe ) = H( Â| Ĝp ) = 0.The result that I(S, Â) = 0 follows directly from eq. ( 7) and I(S, Â) > 0.

B ADDITIONAL RELATED WORK

Unlike our work, agent modelling (Grover et al., 2018; Papoudakis and Albrecht, 2020) often assumes knowledge of agent identities in multi-agent systems and aims at learning a useful representation for each identity. In contrast, we neither know the true type g of the imitated agent, nor the identity of external stochastic noise source ξ. Furthermore, applications of opponent modelling in RL settings (e.g., Papoudakis and Albrecht, 2020; He et al., 2016; Raileanu et al., 2018; Hernandez-Leal et al., 2019; Xie et al., 2020) are generally unconcerned about distributional realism and do not consider distribution shifts.Behaviour Prediction (BP) also forecasts future trajectories. Unless future steps are predicted independently of the evolution of the scene (e.g., not auto-regressively) (Chai et al., 2019; Cui et al., 2019; Phan-Minh et al., 2020; Liu et al., 2021) , these methods also suffer from covariate shift in the state visitations (Bengio et al., 2015; Lamb et al., 2016) . Furthermore, if hierarchical methods are used to capture the multimodality in the D.5 COVERAGE AND DISTRIBUTIONAL REALISM METRICS While coverage is easy to achieve on the Double Goal Problem, we measure it on the Waymo Open Motion Dataset (section 6.2) usingusing a fixed number of K = 16 rollouts per segment. Intuitively, the more modes are covered by a given agent, the closer one of all K rollouts should be to a given trajectory from the dataset, resulting in a lower minADE.We want to measure distributional realism as the divergence between the expert distribution p expert (τ ) and the predicted distribution p agent (τ ). However, since the space of possible trajectories is far too large to directly measure JSD (p agent (τ ) p expert (τ )), we extract scalar features h from trajectories and measure the divergence on those features. For the Double Goal Problem, we would like to capture whether the agent is approaching g l or g u , for which we extract h s = sign(y T ), i.e. whether the agent is in the lower or upper half of the plane at the last timestep. In the driving domain, we measure progress as the total distance travelled over the 9s segment, i.e., h style = δ(ŝ 0 , ŝT ). Measuring this distance as a straight line avoids measurement noise through swerving or jittering of the agent. Lastly, to measure high-level intent, i.e. whether the agent prefers going left, right or straight at branching points such as intersections, we follow Igl et al. (2022) and extract as feature h cur , i.e., the curvature of the lane segments being followed right after possible branching points in the road.

