JOINT-PREDICTIVE REPRESENTATIONS FOR MULTI-AGENT REINFORCEMENT LEARNING

Abstract

The recent advances in reinforcement learning have demonstrated the effectiveness of vision-based self-supervised learning (SSL). However, the main efforts on this direction have been paid on single-agent setting, making multi-agent reinforcement learning (MARL) lags thus far. There are two significant obstacles that prevent applying off-the-shelf SSL approaches with MARL on a partially observable multi-agent system : (a) each agent only gets a partial observation, and (b) previous SSL approaches only take consistent temporal representations into account, while ignoring the characterization that captures the interaction and fusion among agents. In this paper, we propose Multi-Agent Joint-Predictive Representations (MAJOR), a novel framework to explore self-supervised learning on cooperative MARL. Specifically, we treat the latent representations of local observations of all agents as the sequence of masked contexts of the global state, and we then learn effective representations by predicting the future latent representations for each agent with the help of the agent-level information interactions in a joint transition model. We have conducted extensive experiments on wide-range MARL environments, including both vision-based and state-based scenarios, and show that our proposed MAJOR achieves superior asymptotic performance and sample efficiency against other state-of-the-art methods.

1. INTRODUCTION

Representation learning has played an important role in recent developments of reinforcement learning (RL) algorithms. Especially self-supervised representation learning (SSL) has attracted more and more attention due to its success in other research fields (He et al., 2019; Devlin et al., 2018; Liu et al., 2019; Lan et al., 2019) . Recently, numerous works (Srinivas et al., 2020; Zhu et al., 2022; Yarats et al., 2021; Schwarzer et al., 2021a; Yu et al., 2022) have borrowed insights from different areas and attempted to use SSL-based auxiliary tasks to learn more effective representations of RL and thus improve the empirical performance. Through augmentation for inputs, it can conduct multiple views for building SSL learning objectives, allowing the agent to improve data efficiency and generalization to obtain better task-related representations. Moreover, many proper auxiliary self-supervision priors proposed predictive representations with the help of an additional learned dynamic model, which is utilized to encourage the representations to be temporally predictive and consistent. However, when meeting partially observable multi-agent systems, it is challenging to apply such self-supervision priors to learn compact and informative feature representations in multi-agent reinforcement learning (MARL). The critical obstacle to learning effective representations in MARL is that agents in partially observable multi-agent systems only have access to their observations, which means that other agents' behavior influences each agent's observations. As a result, independently building representation priors for each agent may be failed due to imperfect information. Furthermore, in the MARL context, it is more important to focus on the representations that embody the interaction and fusion between agents in the environment but not temporal representations for each agent. In other words, it is necessary to learn the representations that can take the other agents into account. In this work, we propose a novel representation learning framework for MARL, named Multi-Agent Joint-Predictive Representations (MAJOR), which trains better representations for MARL by forc-ing representations to be temporally predictive and consistent for all agents at the same timestep. We posit that in one timestep, latent representations of local observations of all agents can be treated as the sequence of masked contexts of the global state so that we can predict the future latent representations for each agent with the representations of corresponding actions. Accordingly, we construct a joint transition model as the bridge to connect all agents and implement the interaction of their individual information. The joint transition model treats the encoded representations of individual observations and actions as a sequence and attempts to predict future representations in latent space. Meanwhile, we can also get another view of the subsequent timestep representations by feeding sampled observations into the encoder. In this way, we build the SSL objective by enforcing consistency across different perspectives of each observation. Besides, our proposed framework is a plug-and-play module for almost common-used MARL methods. Additionally, to maximize its power, we implement an instantiation of MAJOR on the basis of the recently proposed MARL algorithm, named Multi-Agent Transformer (MAT, (Wen et al., 2022)), which solves MARL issues via a sequential updating mechanism and an encoder-decoder architecture. In MAT, the encoder seeks to extract post-interaction representations from observations; the decoder then uses them to generate actions sequentially by a cross-attention mechanism. Moreover, our proposed MAJOR can employ both representations generated from the encoder and decoder, and the gradient derived from our representation learning objective can be back-propagated to both the encoder and decoder. To evaluate our proposed algorithm, we construct extensive experiments on several common-used cooperative MARL benchmarks, including vision-and state-based environments in discrete and continuous scenarios against current state-of-the-art baselines such as HAPPO, MAPPO, and MAT. Results demonstrate its state-of-the-art performance across all tested tasks.

2.1. DEC-POMDP

Cooperative MARL problems are often modeled by decentralized Partially Observable Markov Decision Processes (Dec-POMDPs, (Oliehoek & Amato, 2016)) (N , S, {A i } , T , R, Ω, O, γ). Here, N = 1, . . . , n is the set of agents, S is a set of states, A = × i A i is the set of joint actions, T is a set of conditional transition probabilities between states, T (s, a, s ′ ) = P (s ′ | s, a), R : S × A → R is the reward function, O = × i O i is a set of observations for agent i, Ω is a set of conditional observation probabilities Ω (s ′ , a, o) = P (o | s ′ , a), and γ ∈ [0, 1] is the discount factor. At each time step, each agent takes an action a i , and the state is updated based on the transition function (using the current state and the joint action). Each agent observes an observation based on the observation function Ω (s ′ , a, o) (using the next state and the joint action) and a reward is generated for the entire team based on the reward function R(s, a). The goal is to maximize the expected cumulative reward over a finite or infinite number of steps.

2.2. MULTI-AGENT TRANSFORMER

Multi-Agent Transformer (MAT, Wen et al. ( 2022)) effectively casts cooperative MARL into Sequential Modeling (SM) problems wherein the task is to map the observation sequence of agents to the optimal action sequence of agents. Its sequential update scheme is built on the Multi-Agent Advantage Decomposition Theorem (Kuba et al., 2021) and Heterogeneous-Agent Proximal Policy Optimization (HAPPO, Kuba et al. (2022) ). The lemma provides an intuition guiding the choice of incrementally improving actions, and HAPPO fully leverages the lemma to implement multi-agent trust-region learning with a monotonic improvement guarantee. Unfortunately, HAPPO requests the sequential update scheme in the permutation for agents' orders, meaning that HAPPO cannot be run in parallel. To address the drawback of HAPPO, MAT produces Transformer-based implementation for multi-agent trust-region learning. Concretely, MAT maintains an encoder-decoder structure where the encoder maps an input sequence of tokens to latent representations. Then the decoder generates a sequence of desired outputs in an auto-regressive manner wherein, at each step of inference, the Transformer takes all previously generated tokens as the input. In other words, MAT treats a team of agents as a sequence, thus

