MUTUAL INFORMATION REGULARIZED OFFLINE REINFORCEMENT LEARNING

Abstract

Offline reinforcement learning (RL) aims at learning an effective policy from offline datasets without active interactions with the environment. The major challenge of offline RL is the distribution shift that appears when out-of-distribution actions are queried, which makes the policy improvement direction biased by extrapolation errors. Most existing methods address this problem by penalizing the policy for deviating from the behavior policy during policy improvement or making conservative updates for value functions during policy evaluation. In this work, we propose a novel MISA framework to approach offline RL from the perspective of Mutual Information between States and Actions in the dataset by directly constraining the policy improvement direction. Intuitively, mutual information measures the mutual dependence of actions and states, which reflects how a behavior agent reacts to certain environment states during data collection. To effectively utilize this information to facilitate policy learning, MISA constructs lower bounds of mutual information parameterized by the policy and Q-values. We show that optimizing this lower bound is equivalent to maximizing the likelihood of a one-step improved policy on the offline dataset. In this way, we constrain the policy improvement direction to lie in the data manifold. The resulting algorithm simultaneously augments the policy evaluation and improvement by adding a mutual information regularization. MISA is a general offline RL framework that unifies conservative Q-learning (CQL) and behavior regularization methods (e.g., TD3+BC) as special cases. Our experiments show that MISA performs significantly better than existing methods and achieves new state-of-the-art on various tasks of the D4RL benchmark.

1. INTRODUCTION

Reinforcement learning (RL) has made remarkable achievements for solving sequential decisionmaking problems, ranging from game playing (Mnih et al., 2013; Silver et al., 2017; Berner et al., 2019) to robot control (Levine et al., 2016; Kahn et al., 2018; Savva et al., 2019) . However, its success heavily relies on 1) an environment to interact with for data collection and 2) an online algorithm to improve the agent based only on its own trial-and-error experiences. These make RL algorithms incapable in real-world safety-sensitive scenarios where interactions with the environment are dangerous or prohibitively expensive, such as in autonomous driving and robot manipulation with human autonomy (Levine et al., 2020; Kumar et al., 2020) . Therefore, offline RL is proposed to study the problem of learning decision-making agents from experiences that are previously collected from other agents when interacting with the environment is costly or not allowed. Though much demanded, extending RL algorithms to offline datasets is challenged by the distributional shift between the data-collecting policy and the learning policy. Specifically, a typical RL algorithm alternates between evaluating the Q values of a policy and improving the policy to have better cumulative return under the current value estimation. When it comes to the offline setting, policy improvement often involves querying out-of-distribution (OOD) state-action pairs that have never appeared in the dataset, for which the Q values are over-estimated due to extrapolation error of neural networks. As a result, the policy improvement direction is erroneously affected, eventually leading to catastrophic explosion of value estimations as well as policy collapse after error accumulation. Existing methods (Kumar et al., 2020; Wang et al., 2020; Fujimoto & Gu, 2021; Yu et al., 2021) tackle this problem by either forcing the learned policy to stay close to the behavior policy (Fujimoto et al., 2019; Wu et al., 2019; Fujimoto & Gu, 2021) or generating low value estimations for OOD actions (Nachum et al., 2017; Kumar et al., 2020; Yu et al., 2021) . Though these methods are effective at alleviating the distributional shift problem of the learning policy, the improved policy is unconstrained and might still deviate from the data distribution. A natural question thus arises: can we directly constrain the policy improvement direction to lie in the data manifold? In this paper, we step back and consider the offline dataset from a new perspective, i.e., the Mutual Information between States and Actions (MISA). By viewing state and action as two random variables, the mutual information represents the reduction of uncertainty of actions given certain states, a.k.a., information gain in information theory (Nowozin, 2012) . Therefore, mutual information is an appealing metric to sufficiently acquire knowledge from a dataset and characterize a behavior policy. We for the first time introduce it into offline RL as an regularization that directly constrains the policy improvement direction. Specifically, to allow practical optimizations of state-action mutual information estimation, we introduce the MISA lower bound of state-action pairs, which connects mutual information with RL by treating a parameterized policy as a variational distribution and the Q-values as the energy functions. We show that this lower bound can be interpreted as the likelihood of a non-parametric policy on the offline dataset, which actually represents the one-step improvement of the current policy based on the current value estimation. Maximizing MISA lower bound is equivalent to directly regularizing the policy improvement within the dataset manifold. However, the constructed lower bound involves integration over a self-normalized energy-based distribution, whose gradient estimation is intractable. To alleviate this dilemma, Markov Chain Monte Carlo (MCMC) estimation is adopted to produce an unbiased gradient estimation for MISA lower bound. Theoretically, MISA is a general framework for offline RL that unifies several existing offline RL paradigms including behavior regularization and conservative learning. As examples, we show that TD3+BC (Fujimoto & Gu, 2021) and CQL (Kumar et al., 2020) are degenerated cases of MISA. In our experiments, we demonstrate that MISA achieves significantly better performance on various environments of the D4RL (Fu et al., 2020) benchmark than the state-of-the-art methods. Additional ablation studies, visualizations, and limitations are discussed to better understand the proposed method. Our code will be released upon publication.

2. RELATED WORKS

Offline Reinforcement Learning The most critical challenge for extending an off-policy RL algorithm to an offline setup is the distribution shift between the behavior policy, i.e., the policy for data collection, and the learning policy. To tackle this challenge, most of the offline RL algorithms consider a conservative learning framework. They either regularize the learning policy to stay close to the behavior policy (Fujimoto et al., 2019; Wu et al., 2019; Fujimoto & Gu, 2021; Siegel et al., 2020; Wang et al., 2020) , or force Q values to be low for OOD state-action pairs (Nachum et al., 2017; Kumar et al., 2020; Yu et al., 2021) . For example, TD3+BC (Fujimoto & Gu, 2021) adds an additional behavior cloning (BC) signal along with the TD3 (Fujimoto et al., 2018) , which encourages the policy to stay in the data manifold; CQL (Kumar et al., 2020) , from the Q-value perspective, penalizes the OOD state-action pairs for generating high Q-value estimations and learns a lower bound of the true value function. However, their policy improvement direction is unconstrained and might deviate from the data distribution. On the other hand, SARSA-style updates (Sutton & Barto, 2018) are considered to only query in-distribution state-action pairs (Peng et al., 2019; Kostrikov et al., 2022) . Nevertheless, without explicitly querying Bellman's optimality equation, they limit the policy from producing unseen actions. Our proposed MISA follows the conservative framework and directly regularizes the policy improvement direction to lie within the data manifold with mutual information, which more fully exploits the dataset information while learning a conservative policy. Mutual Information Estimation. Mutual information is a fundamental quantity in information theory, statistics, and machine learning. However, direct computation of mutual information is intractable as it involves computing a log partition function of a high dimensional variable. Thus, how to estimate the mutual information I(x, z) between random variables X and Z, accurately and efficiently, is a critical issue. One straightforward lower bound for mutual information estimation is Barber-Agakov bound (Barber & Agakov, 2004) , which introduces an additional variational distribution q(z | x) to approximate the unknown posterior p(z | x). Instead of using an explicit "decoder" q(z | x), we can use unnormalized distributions for the variational family q(z | x) (Donsker & Varadhan, 1975; Belghazi et al., 2018; Oord et al., 2018) , i.e., approximate the distribution as q(z | x) = p(z)e f (x,z) E p(z) [e f (x,z) ] , where f (x, z) is an arbitrary critic function. As an example, InfoNCE (Oord et al., 2018) has been widely used in representation learning literature (Oord et al., 2018; He et al., 2020; Chen et al., 2020) . To further improve the mutual information estimation, a combination of normalized and unnormalized variational distribution family can be considered (Brekelmans et al., 2022; Poole et al., 2019) . Our MISA connects mutual information estimation with RL by parameterizing a tractable lower bound with a policy network as a variational distribution and the Q values as critics. In this way, MISA explicitly regularizes the policy improvement direction to lie in the data manifold and produces strong empirical performance.

3. PRELIMINARIES

Reinforcement Learning We consider a Markov Decision Process (MDP) denoted as a tuple M = (S, A, p 0 (s), p(s ′ | s, a), r(s, a), γ), where S is the state space, A is the action space, p 0 (s) is the initial state distribution, p(s ′ | s, a) is the transition function, r(s, a) is the reward function, and γ is the discount factor. The target of a learning agent is to find a policy π * (a | s) that maximizes the accumulative reward by interacting with the environment π * = arg max π E π ∞ t=0 γ t r(s t , a t ) | s 0 ∼ p 0 (s), a t ∼ π(a | s t ) . (1) Q-learning is a set of off-policy RL algorithms that utilize the optimal Bellman's optimality operator B * Q(s, a) = r(s, a) + γE s ′ ∼p(s ′ |s,a) [max a ′ Q(s ′ , a ′ )] to learn a Q function. Differently, Bellman's expectation operator B π Q(s, a) = r(s, a) + γE s ′ ∼p(s ′ |s,a);a ′ ∼π(•|s ′ ) [Q(s ′ , a ′ )] gives an actor-critic framework that alternates between policy evaluation and policy improvement. Consider a value network Q ϕ (s, a) parameterized by ϕ and a policy network π θ (a|s) parameterzied by θ. Let µ π (s) denote the stationary distribution induced with policy π, which is also called occupancy measure (Schulman et al., 2015) . Given the current policy, the policy evaluation aims to learn a Q network that can accurately predict its values minimizing E µπ θ (s)π θ (a|s) [(Q ϕ (s, a) -B π θ Q ϕ (s, a)) 2 ]. Policy improvement focuses on learning the optimal policy that maximizes E µπ(s)π(a|s) [Q ϕ (s, a)]. In practical implementations, the Bellman operator is usually replaced with its sample-based version B, and the expectation over µ π (s)π(a|s) is approximated by an online replay buffer or an offline dataset D. Nevertheless, as it is unavoidable to query the OOD actions when performing the maximization over actions, an inaccurate over-estimation of Q value will be selected and the error will accumulate during the Bellman's update. Conservative RL methods, in turn, aim to perform "conservative" updates of the value / policy function during optimization by constraining the updates on only the in-distribution samples, which eventually minimizes the negative impact of OOD actions. KL Divergence Given two probability distributions p(x) and q(x) on the same probability space, the KL divergence (i.e., relative entropy) from q to p is given by D KL (p||q) = E p(x) log p(x) q(x) ≥ 0. The minimum value is achieved when the two densities are identical. We consider two dual representations that result in tractable estimators for the KL divergence. Lemma 3.1 (f -divergence representation (Nowozin et al., 2016) ). The KL divergence admits the following lower bound: D KL (p||q) ≥ sup T ∈F E p(x) [T (x)] -E q(x) [e T (x)-1 ], where the supremum is taken over a function family F satisifying the intergrability constraints. Lemma 3.2 (Donsker-Varadhan representation (Nguyen et al., 2010) ). The KL divergence has the lower bound: D KL (p||q) ≥ sup T ∈F E p(x) [T (x)] -log(E q(x) [e T (x) ]), where the supremum is taken over a function family F satisifying the intergrability constraints. The above two bounds are tight for sufficiently large families F.

4. MUTUAL INFORMATION REGULARIZED OFFLINE RL

In this paper, we propose to think the offline RL problem from the perspective of mutual information and develop a novel framework (MISA) by estimating the Mutual Information between States and Actions of a given offline dataset. We show that MISA is a general framework which unifies multiple existing offline RL algorithms as special cases, including standard behavior cloning, TD3+BC (Fujimoto & Gu, 2021) , and CQL (Kumar et al., 2020) .

4.1. MUTUAL INFORMATION REGULARIZATION

Consider the state S and action A as two random variables. Let p (S,A) (s, a) denote the joint distribution of state-action pairs, and p S (s), p A (a) be the marginal distributions. The subscripts are omitted in the following for simplicity. The mutual information between S and A is defined with: I(S; A) = E p(s,a) log p(s, a) p(s)p(a) = E p(s,a) log p(a | s) p(a) = H(A) -H(A | S), where H is Shannon entropy, and H(A|S) is conditional entropy of A given S. The higher mutual information between S and A means the lower uncertainty in A given state S. This coincides with the observation that the actions selected by a well-performing agent are usually coupled with certain states. Therefore, given a joint distribution of state-action pairs induced from a (sub-optimal) behavior agent, it is natural to learn a policy that can recover the dependence between states and actions produced by the behavior agent. By regularizing the agent with I(S; A) estimation, we encourage the agent to 1) perform policy update within the dataset distribution and 2) avoid being over-conservative and make sufficient use of the dataset information. Let π β (a|s) represent a behavior policy and p β (s, a) be the joint distribution of state-action pairs induced by π β . Calculating the mutual information is often intractable as accessing to p β (s, a) is infeasible. Fortunately, in the problem of offline reinforcement learning, a dataset D = {(s t , a t , r t , s t+1 )} of transitions is given by drawing samples independently from p β (s, a). This dataset can thus be seen as a sample-based empirical joint distribution p D (s, a) for p β . Let I(θ, ϕ) denote a mutual information lower bound that relies on parameterized functions with parameters θ and ϕfoot_0 , which are usually the policy network and Q network in the context of RL. We defer the derivation of such bounds in Sec. 4.2. Based on the above motivation, we aim at learning a policy that can approximate the mutual information of the dataset while being optimized to get the best possible cumulative return. We focus on the actor-critic framework, and formulate the offline RL problem with mutual information reguralization as follows: min ϕ E s,a,s ′ ∼D 1 2 (Q ϕ (s, a) -B π θ Q ϕ (s, a)) 2 -α 1 ÎD (θ, ϕ), (Policy Evaluation) (5) max θ E s∼D,a∼π θ (a|s) [Q ϕ (s, a)] + α 2 ÎD (θ, ϕ), (Policy Improvement) (6) where α 1 and α 2 are the coefficients to balance RL objective and mutual information objective, and ÎD (θ, ϕ) denotes the sample-based version of I(θ, ϕ) estimated from dataset D.

4.2. STATE-ACTION MUTUAL INFORMATION ESTIMATION

In this section, we develop practical solutions to approximate the mutual information I(S; A) from samples of the joint distribution. We use the learning policy π θ (a|s) as a variational variable and Eqn. 4 can be rewritten as: . We have I(S; A) ≥ I BA as the KL divergence is always non-negative. This is exactly the Barber-Agakov (BA) lower bound developed by (Barber & Agakov, 2004) . I(S; A) = E p(s, To obtain tighter bounds, we turn to KL dual representations of D KL (p(s, a)||p(s)π θ (a|s)) in Eqn. 7. To this end, we choose F to be a set of parameterized functions T ϕ : S × A → R, ϕ ∈ Φ, which can be seen as an energy function. With the f -divergence dual representation, we derive MISA-f as I MISA-f ≜ E p(s,a) log π θ (a|s) p(a) + E p(s,a) [T ϕ (s, a)] -E p(s)π θ (a|s) e T ϕ (s,a)-1 . ( ) The I MISA-f bound is tight when p(a|s) ∝ π θ (a|s)e T ϕ (s,a)-1 . Similarly, using the DV representation in Theorem 3.2, we can have another bound I MISA-DV ≤ I(S; A), as shown below: I MISA-DV ≜ E p(s,a) log π θ (a|s) p(a) + E p(s,a) [T ϕ (s, a)] -log E p(s)π θ (a|s) e T ϕ (s,a) , which is tight when p(a|s s,a) , where Z = E p(s)π θ (a|s) e T ϕ (s,a) . We observe that the KL term in Eqn. 7 can be rewritten as: ) = 1 Z p(s)π θ (a|s)e T ϕ ( D KL (p(s, a)||p(s)π θ (a|s)) = E p(s) E p(a|s) log p(a|s) π θ (a|s) = E p(s) [D KL (p(a|s)||π θ (a|s))] . Applying the DV representation of D KL (p(a|s)||π θ (a|s)), we can have a new lower bound I MISA : I MISA ≜ E p(s,a) log π θ (a|s) p(a) + E p(s,a) [T ϕ (s, a)] -E p(s) log E π θ (a|s) e T ϕ (s,a) . ( ) The bound is tight when p(a|s s,a) , where s,a) ]. Theorem 4.1. Given the joint distribution of state s and action a, the lower bounds of mutual information I(S; A) defined in Eqn. 8-10 have the following relations: ) = 1 Z(s) π θ (a|s)e T ϕ ( Z(s) = E π θ (a|s) [e T ϕ ( I(S; A) ≥ I MISA ≥ I MISA-DV ≥ I MISA-f . ( ) The proof is deferred to the appendix due to space limit. Algorithm 1 Mutual Information Regularized Offline RL Input: Initialize Q network Q ϕ , policy network π θ , dataset D, hyperparameters α 1 and α 2 . for t ∈ {1, . . . , MAX STEP} do Train the Q network by gradient descent with objective J Q (ϕ) in Eqn. 12: ϕ := ϕ -η Q ∇ ϕ J Q (ϕ) Improve policy network by gradient ascent with object J π (θ) in Eqn. 13: θ := θ + η π ∇ θ E s∼D,a∼π θ (a|s) [Q ϕ (s, a)] + α 2 ∇ θ I MISA end Output: The well-trained π θ .

4.3. INTEGRATION WITH OFFLINE REINFORCEMENT LEARNING

We now describe how our MISA lower bound is integrated into the above framework (Eqn. 12-13) to give a practical offline RL algorithm. We propose to use a Q network Q ϕ (s, a) as the energy function T ϕ (s, a) , and use p D (s, a) as the joint distribution in Eqn. 10. Then we have the following objective to learn a Q-network during policy evaluation: J Q (ϕ) = J B Q (ϕ) -γ 1 E s,a∼D [Q ϕ (s, a)] -E s∼D log E π θ (a|s) e Q ϕ (s,a) , ( ) where J B Q (ϕ) = E s,a,s ′ ∼D 1 2 (Q ϕ (s, a) -B π θ Q ϕ (s, a)) 2 represents the TD error. For policy imporvement, note that the entropy term H(a) in Eqn. 10 can be omitted as it is a constant given dataset D. Thus, we have the below objective to maximize: s,a) . (13) The formulations for other regularizers (e.g., I MISA-DV and I MISA-f ) can be derived similarly. A detailed description of the MISA algorithm for offline RL can be found in Algo. 1. J π (θ) = E s∼D,a∼π θ (a|s) [Q ϕ (s, a)]+γ 2 E s,a∼D [log π θ (a|s)] -E s∼D log E π θ (a|s) e Q ϕ ( Intuitive Explanation on the Mutual Information Regularizer. By rearranging the terms in Eqn. 10, MISA can be written as: I MISA = E s,a∼D log π θ (a | s)e Q ϕ (s,a) E π θ (a ′ |s) e Q ϕ (s,a ′ ) , where the log term can be seen as the log probability of a one-step improved policy. More specifically, for policy improvement with KL divergence reguralization: max π E s∼D,a∼π [Q ϕ (s, a)] + D KL (π||π θ ) , the optimal solution is given by π * θ,ϕ ∝ π θ (a|s)e Q ϕ (s,a) (Abdolmaleki et al., 2018; Peng et al., 2019) . Therefore, I MISA is rewritten with π * θ,ϕ as I MISA = E s,a∼D [log π * θ,ϕ (a|s)], and maximizing it means maximizing the log likelihood of the dataset using the improved policy. In other words, instead of directly fitting the policy on the dataset, which is short-sighted, this objective considers the optimization direction of the policy improvement step. Given the current policy and policy evaluation results, it first computes the analytic improved policy, and then forces the dataset likelihood to be maximized using the improved policy. In this way, even if an out-of-distribution state-action pair get an overestimated q value, I MISA is going to suppress this value and make sure in-distribution data have relatively higher value estimation. Unbiased Gradient Estimation For policy improvement with Eqn. 13, differentiating through a sampling distribution π θ (a | s) is required for s,a) ) . For a Gaussian policy π θ (a | s) = N (µ θ , σ θ ), one could consider the reparameterization trick (Kingma & Welling, 2014) and convert the objective as E s∼D log E ϵ∼N (0,I) e Q ϕ (s,µ θ +ϵ * σ θ ) . However, this introduces high variance in offline reinforcement learning setups because we condition the policy improvement directly on the Q values of the out-of-distribution actions, which eventually gives a noisy policy. Hence, we aim to minimize the influence of Q values for policy improvement. E s∼D log E π θ (a|s) e Q ϕ ( Differentiating Eqn. 10 with respect to policy parameters θ, we have ∂I MISA ∂θ = E s,a∼D log π θ (a | s) ∂θ -E s∼D,a∼p θ,ϕ (a|s) log π θ (a | s) ∂θ where p θ,ϕ (a | s) = π θ (a|s)e Q ϕ (s,a) ) E π θ (a|s)[ e Q ϕ (s,a) )] is a self-normalized distribution. See appendix A.2 for a derivation. By optimizing Eqn. 15, we obtain an unbiased gradient estimation of the MISA objective with respect to the policy parameters, while minimizing the negative effects of the Q values of OOD actions. To sample from p θ,ϕ (a | s), one can consider Markov-Chain Monte-Carlo (MCMC) methods, e.g., Hamiltonian Monte Carlo (Betancourt, 2017) .

4.4. CONNECTIONS TO EXISTING OFFLINE RL METHODS

We show that some existing offline RL methods can be viewed as special cases of MISA framework. Behavior Cloning and BC Regularized RL We first show that behavior cloning is a form of mutual information regularizer. As shown by Eqn. 7, I BA ≜ E s,a∼D log π θ (a|s) p(a) gives a lower bound of mutual information. Since H(a) is a consistent given datasets, maximizing I BA is equivalent to maximizing E s,a∼D [log π θ (a|s)], which is exactly the objective for behavior cloning. As for TD3+BC (Fujimoto & Gu, 2021) , the policy evaluation is unchanged, while the policy improvement objective is augmented by an MSE regularization term, i.e., E s∼D [Q(s, π θ (s))] -γE s,a∼D (π θ (s) -a) 2 , where λ is a hyperparameter. Maximizing the negative MSE term is equivalent to maximizing E s,a∼D [log p π θ (a|s)], where p π θ = Ce -1 2 (π θ (s)-a) 2 is a Gaussian distribution, and C is a constant. This is a special case of Eqn. 13 when we remove the last log-mean-exp term. Conservation Q Learning CQL (Kumar et al., 2020) was proposed to alleviate the overestimation issue of Q learning by making conservative updates to the Q values during policy evaluation. The policy improvement is kept unchanged compared to standard Q learning. We focus on the entropy-regularized policy evaluation of CQL as below: (Kostrikov et al., 2022) . min ϕ J B Q (ϕ) -γ 1 E s∼D E a∼π D (a|s) [Q ϕ (s, a)] -log a e Q ϕ (s,a) , where we highlight the main difference between it and our MISA policy evaluation (Eqn. 12) in blue. Let π U (a|s) denote a uniform distribution of actions and |A| is the number of actions. The log-sum-exp term can be written as log E a∼πU(a|s) [Q ϕ (s, a)] + log |A|. Substituting it into Eqn. 16 and discarding the constant log |A|, we recover the the formulation in Eqn. 12. Therefore, CQL is actually doing mutual information regularization during policy evaluation. The key difference is that it is not using the current policy network as the variational distribution. Instead, a manually designed distribution is used in CQL. However, a uniform policy is usually suboptimal in environments with continuous actions. CQL thus constructs a mixed variational policy by drawing samples drawn from the current policy network, a uniform distribution and the dataset. In our formulation, the variational distribution will be optimized to give a better mutual information estimation. This might explain why MISA is able to give better performance than CQL.

5. EXPERIMENTS

We perform extensive experiments on various tasks of the D4RL benchmark (Fu et al., 2020) to demonstrate effectiveness of the proposed method MISA. To provide better understandings of MISA, we provide additional ablation studies, visualizations, and discussions on the limitations.

5.1. OFFLINE REINFORCEMENT LEARNING ON D4RL BENCHMARKS

Experiment Setups. For all D4RL environments, we follow the network architectures of CQL (Kumar et al., 2020) and IQL (Kostrikov et al., 2022) , where a neural network of 2 encoding layers of size 256 is used, followed by an output layer. We use ELU activation function (Clevert et al., 2015) and SAC (Haarnoja et al., 2018) as the base RL algorithm. When approximating E π θ (a|s) e T ψ (s,a) , we use 50 Monte-Carlo samples. In addition, for unbiased gradient estimation with MCMC samples, we use a burn-in steps of 5. For all tasks, we average the mean returns over 10 evaluation trajectories and 5 random seeds. Detailed setups and hyper-parameters are in the appendix. Table 2 : Ablation studies on gym-locomotion-v2. k denotes the number of Monte-Carlo samples for estimating E π θ (a|s) e T ψ (s,a) , BI represents the burnin-steps for MCMC simulation, and BA denotes the use of Barber-Agakov Bound. In addition, MISA-x denotes different variants of MISA. Gym Locomotion Tasks. We first evaluate MISA on the standard MuJoCo-style continuous control tasks, reported as gym-locomotion-v2 in Table 1 . We observe that MISA improves the performance of baselines by a large margin. Specifically, MISA is less sensitive to the characteristics of data distributions. The medium datasets include trajectories collected by an SAC agent trained to reach 1/3 of the performance of an expert; the medium-replay datasets contain all data samples of the replay buffer during the training of the medium SAC agent, which covers the noisy exploration process of the medium agent. We can observe that prior methods are generally sensitive to the noisy sub-optimal data in medium and medium-replay environments, while MISA outperforms them by a large margin. In particular, MISA achieves near-expert performance on walker2d-mediumreplay with only sub-optimal trajectories. This indicates that by regularizing the policy and Qvalues within the mutual information of the dataset, we can fully exploit the data and perform safe and accurate policy improvement during RL. Moreover, on medium-expert environments, where the datasets are mixtures of medium agents and experts, MISA successfully captures the multi-modality of the datasets and allows further improvements of the policy over baselines. Adroit Tasks. According to (Fujimoto & Gu, 2021) , adroit tasks require strong policy regularization to overcome the extrapolation error, because the datasets are either generated by human (adroithuman-v0), which would show a narrow policy distribution, or a mixture of human demonstrations and a behavior cloning policy (adroit-cloned-v0). We observe that MISA provides a stronger regularization and significantly outperforms the baselines on adroit domains. Kitchen Tasks. An episode of Kitchen environment consists of multiple sub-tasks that can be mixed in an arbitrary order. We observe that MISA outperforms baselines on both kitchen-complete-v0 and kitchen-mixed-v0, while achieving slightly worse performance on kitchen-partial-v0. Specifically, on kitchen-mixed, the result is consistent with our assumption that by better regularizing the policy, MISA guarantees a safer and in-distribution policy improvement step in offline RL. Antmaze Tasks. On the challenging AntMaze domain with sparse delayed reward, we observe that MISA generally outperforms CQL and achieves the best performance on umaze environments. However, MISA performs worse than IQL on the challenging large environments. Multi-step value update is often necessary for learning a robust value estimation in these scenarios (Kostrikov et al., 2022) while MISA adopts a single-step SAC for the base RL algorithm.

5.2. ABLATION STUDIES

To better understand MISA, we conduct extensive ablation studies on each component (Table 2 ). MISA requires careful Monte-Carlo approximation. Firstly, we vary the number (k) of Monte-Carlo samples for approximating E π θ (a|s) e T ψ (s,a) and reduce the burn-in steps of MCMC sampling process. Both operations would introduce additional Monte-Carlo approximation errors to MISA. Comparing k = 5, k = 20, and MISA (k = 50), the performance increases monotonically; comparing MISA (BI=5) with BI=1, we can observe a sharp performance drop. We then conclude that MISA requires careful Monte-Carlo approximation for good performance. Accurately estimating the mutual information I(S; A) is critical for offline RL. In Table 2 , BA stands for the Barber-Adakov Bound, while no BA stands for removing the Barber-Agakov term in Eqn. 10, which gives an inaccurate estimation (neither an upper bound nor a lower bound) to I(S; A). We observe a clear performance drop when comparing them with MISA. In addition, as discussed in Sect. 4.2, considering the tightness of the various lower bound, we have BA ≤ MISA-f ≤ MISA-DV ≤ MISA. Empirically, in Table 2 , we observe that the overall performance of these four variants is consistent with the tightness of the bounds. This suggests that accurately estimating the I(S; A) is crucial to offline RL and tighter bounds often give better performance. Unbiased gradient estimations improves performance of MISA. Lastly, we study the importance of unbiased estimation discussed in Sect. 4.3. MISA-biased ignores the bias correction term in Eqn. 15. Although MISA-biased outperforms the baselines, it still performs worse than MISA. This suggests that by correcting the gradient estimation with additional MCMC samples, MISA achieves a better regularized policy learning in offline RL. In Fig. 1 , we visualize the embeddings before the output layer of Q-value networks, given different mutual information bounds (BA and MISA). We select a subset from walker2dmedium-v2 dataset to study the division of low reward (blue) and high reward (red) (s, a) pairs. We color each point by the reward r(s, a). As discussed in Sect. 4.2, BA gives a lowest bound for mutual information estimation and MISA produces the tightest bound. In Fig. 1 , we observe a consistent result. The embeddings of BA converge to a set of regular curves and fail to cluster the high r(s, a), because Q-values have converged to indistinguishably high values (3 × 10 12 ) for all (s, a) pairs. In contrast, MISA successfully learns to cluster the (s, a) pairs with a high reward into a cluster. From this perspective, we claim that regularizing the mutual information encourages learning a robust representation in offline RL scenarios. Although MISA achieves great performance on several benchmarks as reported in Table 1 , we have made the assumption that high mutual information comes from the stationary policy of a well-behaving agent. This will prevent MISA from being applied to tasks with extremely low-quality data, e.g., a random policy whose I(S; A) is near zero. We validate this limitation by running the locomotion-random-v2 datasets of D4RL benchmark. The results are presented in Table 3 . We observe that on datasets generated by random policies, MISA achieves worse performance than both CQL and IQL.

6. CONCLUSIONS

We present the MISA framework for offline reinforcement learning by directly regularizing policy improvement and policy evaluation with the mutual information between state-action pairs of the dataset. MISA connects mutual information estimation with RL by constructing tractable lower bounds, treating the learning policy as a variational distribution and Q values as energy functions. The resulting tractable lower bound resembles a non-parametric energy-based distribution, which can be interpreted as the likelihood of a one-step improved policy given current value estimation. In this way, MISA can constrain the policy improvement within the dataset manifold. In our experiments, MISA significantly outperforms the state-of-the-art methods on D4RL benchmark. However, MISA assumes a high correspondence between states and actions, which might fail on uncorrelated data generated by a random policy. We leave it for future study.

A PROOFS AND DERIVATIONS

A.1 PROOF FOR THEOREM 4.1 We first show I MISA , I MISA-DV and I MISA-f are lower bounds for mutual information I(S, A). s,a) , where Z(s) = E π θ (a|s) [e T ϕ (s,a) ], I MISA can be written as: Let µ θ,ϕ (a|s) ≜ 1 Z(s) π θ (a|s)e T ϕ ( I MISA ≜ E p(s,a) log π θ (a|s) p(a) + E p(s,a) [T ϕ (s, a)] -E p(s) log E π θ (a|s) e T ϕ (s,a) = E p(s,a) log p(a|s) p(a) -E p(s,a) [log p(a|s)] + E p(s,a) [log π θ (a|s)] + E p(s,a) [T ϕ (s, a)] -E p(s) [log Z(s)] = I(S, A) -E p(s) [D KL (p(a|s)||µ θ,ϕ (a|s))] ≤ I(S, A). The above inequality holds as the KL divergence is always non-negative. Similarly, let µ θ,ϕ (s, a) s,a) , where Z(s) = E p(s)π θ (a|s) [e T ϕ (s,a) ], I MISA-DV can be written as: ≜ 1 Z p(s)π θ (a|s)e T ϕ ( I MISA-DV ≜ E p(s,a) log π θ (a|s) p(a) + E p(s,a) [T ϕ (s, a)] -log E p(s)π θ (a|s) e T ϕ (s,a) = E p(s,a) log p(a|s) p(a) -E p(s,a) [log p(a|s)] + E p(s,a) [log π θ (a|s)] + E p(s,a) [T ϕ (s, a)] -log Z = I(S, A) -D KL (p(s, a)||µ θ,ϕ (s, a)) ≤ I(S, A). The above inequality holds as the KL divergence is always non-negative. Consider the generalized KL-divergence (Cichocki & Amari, 2010; Brekelmans et al., 2022) between two un-normalized distributions p(x) and q(x) defined by D GKL (p(x)||q(x)) = p(x) log p(x) q(x) -p(x) + q(x)dx, which is always non-negative and reduces to KL divergence when p and q are normalized. Let μθ,ϕ (a|s) ≜ π θ (a|s)e T ϕ (s,a)-1 denote an un-normalized policy. We can rewrite I MISA-f as  I MISA-f ≜ E p(s,a) log π θ (a|s) p(a) + E p(s,a) [T ϕ (s, a)] -E p(s)π θ (a|s) e T ϕ (s,a)-1 = E p(s,a) log p(a|s) p(a) -E p(s,a) [log p(a|s)] + E p(s,a) [log π θ (a|s)] + E p(s,a) [T ϕ (s, a) -1] + 1 -E p(s)π θ

B ADDITIONAL EXPERIMENTAL DETAILS

For gym-locomotion-v2, kitchen-v0, and adroit-v0 environments, we average the results over 10 evaluation episodes and 5 random seeds. Following (Kostrikov et al., 2022) , we evaluate the antmaze-v0 environments for 100 episodes instead. To stabilize the training of our agents in antmaze-v0 environments, we follow (Kumar et al., 2020) and normalize the reward by r ′ = (r -0.5) * 4. In addition, for a fair comparison with baseline methods, we use the same network structure as used in CQL (Kumar et al., 2020) , where a network with embedding layers of sizes (256, 256, 256) is used for antmaze-v0 environments, and embedding layers of sizes (256, 256) is used for other tasks. ELU activation is used after each layer (Clevert et al., 2015) . We use a learning rate of 1 × 10 -4 for both the policy network and Q-value network with a cosine learning rate scheduler. To sample from the non-parametric distribution p θ,ϕ (a | s) = π θ (a|s)e Q ϕ (s,a) E π θ (a|s)[ e Q ϕ (s,a) ] , we use Hamiltonian Monte Carlo algorithm. As MCMC sampling is slow, we trade-off its accuracy with efficiency by choosing moderately small iteration configurations. Specifically, we set the MCMC burn-in steps to 5, number of leapfrog steps to 2, and MCMC step size to 1.



Note some lower bounds might only have one parameterized function.



a) log π θ (a|s)p(a|s) p(a)π θ (a|s) = E p(s,a) log π θ (a|s) p(a) + D KL (p(s, a)||p(s)π θ (a|s)) , (7) where p(s)π θ (a|s) is an induced joint distribution. Let I BA ≜ E p(s,a) log π θ (a|s) p(a)

Figure 1: tSNE of the Q-value network embeddings of walker2d-medium-v2 dataset, where red color denote high reward and blue color denote low reward.

a|s) e T ϕ (s,a)-1 = I(S, A) -E p(s) [D GKL (p(a|s)||μ θ,ϕ (a|s))] ≤ I(S, A).

we have proven that I MISA , I MISA-DV and I MISA-f mutual information lower bounds. Then we are going to prove their relations by starting fromt he relation between I MISA and I MISA-DV .I MISA -I MISA-DV = D KL (p(s, a)||µ θ,ϕ (s, a)) -E p(s) [D KL (p(a|s)||µ θ,ϕ (a|s))] = E p(s) E p(a|s) log p(s, a) p(a|s) -log µ θ,ϕ (s, a) µ θ,ϕ (a|s) = E p(s) E p(a|s) log p(s) -log 1 Z p(s)Z(s) = E p(s) log p(s) -log 1 Z p(s)Z(s) = D KL p(s)|| 1 Z p(s)Z(s) ≥ 0,(21)where1 Z p(s)Z(s) is a self-normalized distribution as Z = E p(s)[Z(s)]. Therefore, we haveI MISA ≥ I MISA-DV .Similarly, the relation between I MISA-DV and I MISA-f is given by:I MISA-DV -I MISA-f = E p(s) [D GKL (p(a|s)||μ θ,ϕ (a|s))] -D KL (p(s, a)||µ θ,ϕ (s, a)) = E p(s) E p(a|s) log p(a|s) p(s, a) -log μθ,ϕ (a|s) µ θ,ϕ (s, a) -1 + E p(s) E π θ (a|s) e T ϕ (s,a)-1 = E p(s) E p(a|s) -log p(s) -log μθ,ϕ (a|s) µ θ,ϕ (s, a) -1 + E p(s) E π θ (a|s) e T ϕ (s,a)-1 = E p(s) E p(a|s) log µ θ,ϕ (s, a) p(s)μ θ,ϕ (a|s) -E µ θ,ϕ (s,a) [1] + E p(s) E π θ (a|s) e T ϕ (s,a)-1 = E p(s,a) log e Z -E µ θ,ϕ (s,a) [1] + E p(s) E π θ (a|s) e T ϕ (s,a)-1 = E µ θ,ϕ (s,a) log e Z -E µ θ,ϕ (s,a) [1] + E p(s) E π θ (a|s) e T ϕ (s,a)-1 = E µ θ,ϕ (s,a) log µ θ,ϕ (s, a) p(s)μ θ,ϕ (a|s) -E µ θ,ϕ (s,a) [1] + E p(s) E π θ (a|s) e T ϕ (s,a)-1 = D GKL (µ θ,ϕ (s, a)||p(s)μ θ,ϕ (a|s)) ≥ 0,(22)where p(s)μ θ,ϕ (a|s) is an unnormalized joint distribution. Therefore, we have I(S, A) ≥ I MISA ≥ I MISA-DV ≥ I MISA-f .A.2 DERIVATION OF MISA GRADIENTSWe detail how the unbiased gradient is derived in Sec.4.3.∂I MISA ∂θ = E s,a∼D log π θ (a | s) ∂θ -E s∼D ∂ log E π θ (a|s) [e Q ϕ (s,a) ] ∂θ= E s,a∼D log π θ (a | s) ∂θ -E s∼D E π θ (a|s) e Q ϕ (s,a) E π θ (a|s) e Q ϕ (s,a) log π θ (a | s) ∂θ (23) = E s,a∼D log π θ (a | s) ∂θ -E s∼D,a∼p θ,ϕ (a|s) log π θ (a | s) ∂θ (24) for Eqn. 23, we use the log-derivative trick.

Average normalized score on the D4RL benchmark. Results of baselines are taken directly from

Results on locomotion-random-v2

ETHICS STATEMENT

As discussed in the paper, MISA framework provides a simple yet effective approach to policy learning from offline datasets. Although the results presented in this paper only consider simulated environments, given the generality of MISA, it could be potentially effective on learning real-robot policies in more complex environments. We should be cautious about the misuse of the method proposed. Depending on the specific application scenarios, it might be harmful to domestic privacy and safety.

REPRODUCIBILITY STATEMENT

In this paper, all experiments are averaged over 5 random seeds for the stability and reliability of the results. Our code is attached in the supplementary materials. We promise to clean and release the code upon publication.For practical implementations, we follow the CQL-Lagrange (Kumar et al., 2020) implementation by constraining the Q-value update by a "budget" variable τ and rewrite Eqn. 12 asEqn. 25 implies that if the expected value of Q-value difference is less than the threshold τ , γ 1 will adjust to close to 0; if the Q-value difference is higher than the threshold τ , γ 1 will be larger and penalize Q-values harder. We set τ = 10 for antmaze-v0 environments and τ = 3 for adroit-v0 and kitchen-v0 environments. For gym-locomotion-v2 tasks, we disable this function and direction optimize Eqn. 12, because these tasks have a relatively short horizon and dense reward, and further constraining the Q values is less necessary. Our code is implemented in JAX (Bradbury et al., 2018) with Flax (Heek et al., 2020) neural networks library.

