SIMPLIFYING MODEL-BASED RL: LEARNING REPRESENTATIONS, LATENT-SPACE MODELS, AND POLICIES WITH ONE OBJECTIVE

Abstract

While reinforcement learning (RL) methods that learn an internal model of the environment have the potential to be more sample efficient than their model-free counterparts, learning to model raw observations from high dimensional sensors can be challenging. Prior work has addressed this challenge by learning lowdimensional representation of observations through auxiliary objectives, such as reconstruction or value prediction. However, the alignment between these auxiliary objectives and the RL objective is often unclear. In this work, we propose a single objective which jointly optimizes a latent-space model and policy to achieve high returns while remaining self-consistent. This objective is a lower bound on expected returns. Unlike prior bounds for model-based RL on policy exploration or model guarantees, our bound is directly on the overall RL objective. We demonstrate that the resulting algorithm matches or improves the sample-efficiency of the best prior model-based and model-free RL methods. While sample efficient methods typically are computationally demanding, our method attains the performance of SAC in about 50% less wall-clock time 1 .

1. INTRODUCTION

While RL algorithms that learn an internal model of the world can learn more quickly than their model-free counterparts (Hafner et al., 2018; Janner et al., 2019) , figuring out exactly what these models should predict has remained an open problem: the real world and even realistic simulators are too complex to model accurately. Although model errors may be rare under the training distribution, a learned RL agent will often seek out the states where an otherwise accurate model makes mistakes (Jafferjee et al., 2020) . Simply training the model with maximum likelihood will not, in general, produce a model that is good for model-based RL (MBRL). The discrepancy between the policy objective and the model objective is called the objective mismatch problem (Lambert et al., 2020) , and remains an active area of research. The objective mismatch problem is especially important in settings with high-dimensional observations, which are challenging to predict with high fidelity. In this paper, we present a simple yet principled answer to this question by devising a single objective that jointly optimizes the three components of the model-based algorithm: the representation, the model, and the policy. As shown in Fig. 1 , this is in contrast to prior methods, which use three separate objectives. We build upon prior work that views model-based RL as a latent-variable problem: the objective is to maximize the returns (the likelihood), which is an expectation over trajectories (the unobserved latent variable) (Botvinick & Toussaint, 2012; Attias, 2003; Eysenbach et al., 2021a) . This is different from prior work that maximizes the likelihood of observed data, independent of the reward function (Hafner et al., 2019; Lee et al., 2020) . This perspective suggests that model-based RL algorithms should resemble inference algorithms, sampling trajectories (the latent variable) and then maximizing returns (the likelihood) on those trajectories. However, sampling trajectories is challenging when observations are high-dimensional. The key to our work is to infer both the trajectories (observations, actions) and the representations of the observations. Crucially, we show how to maximize the expected returns under this inferred distribution by sampling only the representations, without the need to sample high-dimensional observations. The main contribution of this paper is Aligned Latent Models (ALM), an MBRL algorithm that jointly optimizes the observation representations, a model that predicts those representations, and a policy that acts based on those representations. To the best of our knowledge, this objective is the first lower bound for a model-based RL method with a latent-space model. Across a range of continuous control tasks, we demonstrate that ALM achieves higher sample efficiency than prior model-based and model-free RL methods, including on tasks that stymie prior MBRL methods. Because ALM does not require ensembles (Chua et al., 2018; Janner et al., 2019) or decision-time planning (Deisenroth & Rasmussen, 2011; Sikchi et al., 2020; Morgan et al., 2021) , our open-source implementation performs updates 10× and 6× faster than MBPO (Janner et al., 2019) and REDQ (Chen et al., 2021) respectively, and achieves near-optimal returns in about 50% less time than SAC.

2. RELATED WORK

Prior model-based RL methods use models in many ways, using it to search for optimal action sequences (Garcia et al., 1989; Springenberg et al., 2020; Hafner et al., 2018; Chua et al., 2018; Hafner et al., 2019; Xie et al., 2020) , to generate synthetic data (Sutton, 1991; Luo et al., 2018; Hafner et al., 2019; Janner et al., 2019; Shen et al., 2020) , to better estimate the value function (Deisenroth & Rasmussen, 2011; Chua et al., 2018; Buckman et al., 2018; Feinberg et al., 2018) , or some combination thereof (Schrittwieser et al., 2020; Hamrick et al., 2020; Hansen et al., 2022) . Similar to prior work on stochastic value gradients (Heess et al., 2015; Hafner et al., 2019; Clavera et al., 2020; Amos et al., 2020) , our approach uses model rollouts to estimate the value function for a policy gradient. Prior work find that taking gradients through a learned dynamics model can be unstable (Metz et al., 2021; Parmas et al., 2019) . However, similar to dreamer Hafner et al. (2019; 2020) , we found that BPTT can work successfully if done with a latent-space model with appropriately regularized representations. Unlike these prior works, the precise form of the regularization emerges from our principled objective. Because learning a model of high-dimensional observations is challenging, many prior model-based methods first learn a compact representation using a representation learning objective (e.g., image reconstruction (Kaiser et al., 2019; Oh et al., 2015; Buesing et al., 2018; Ha & Schmidhuber, 2018; Hafner et al., 2018; 2019; 2020) , value and action prediction (Oh et al., 2017; Schrittwieser et al., 2020; Grimm et al., 2020) , planning performance (Tamar et al., 2016; Racanière et al., 2017; Okada et al., 2017) , or self-supervised learning (Deng et al., 2021; Nguyen et al., 2021; Okada & Taniguchi, 2020) ). These methods then learn the dynamics of these representations (not of the raw observations), and use the model for RL. The success of these methods depends on the representation Arumugam & Roy (2022): the representations should be compact (i.e., easy to predict) while retaining task-relevant information. However, prior work does not optimize for this criterion, but instead optimizes the representation using some auxiliary objective. The standard RL objective is to to maximize the expected returns, but models are typically learned via a different objective (maximum likelihood) and representations are learned via a third objective (e.g., image reconstruction). To solve this objective mismatch (Lambert et al., 2020; Joseph et al., 2013; Grimm et al., 2020) , prior work study decision aware loss functions which optimize the model to minimize the difference between true and imagined next step values (Farahmand et al., 2017; Farahmand, 2018; D'Oro et al., 2020; Abachi et al., 2020; Voelcker et al., 2022) or directly optimize



Prior model-based methods have coped with the difficulty to model high-dimensional observations by learning the dynamics of a compact representation of observations, rather than the dynamics of the raw observations. Depending on their learning objective, these representations might still be hard to predict or might not contain task relevant information. Besides, the accuracy of prediction depends not just on the model's parameters, but also on the states visited by the policy. Hence, another way of reducing prediction errors is to optimize the policy to avoid transitions where the model is inaccurate, while achieving high returns. In the end, we want to train the model, representations, and policy to be self-consistent: the policy should only visit states where the model is accurate, the representation should encode information that is task-relevant and predictable. Can we design a model-based RL algorithm that automatically learns compact yet sufficient representations for model-based reasoning?1 Project website with code: https://alignedlatentmodels.github.io/



Figure 1: (left) Most model-based RL methods learn the representations, latent-space model, and policy using three different objectives. (Right) We derive a single objective for all three components, which is a lower bound on expected returns. Based on this objective, we develop a practical deep RL algorithm.

