WHEN DATA GEOMETRY MEETS DEEP FUNCTION: GENERALIZING OFFLINE REINFORCEMENT LEARNING

Abstract

In offline reinforcement learning (RL), one detrimental issue to policy learning is the error accumulation of deep Q function in out-of-distribution (OOD) areas. Unfortunately, existing offline RL methods are often over-conservative, inevitably hurting generalization performance outside data distribution. In our study, one interesting observation is that deep Q functions approximate well inside the convex hull of training data. Inspired by this, we propose a new method, DOGE (Distance-sensitive Offline RL with better GEneralization). DOGE marries dataset geometry with deep function approximators in offline RL, and enables exploitation in generalizable OOD areas rather than strictly constraining policy within data distribution. Specifically, DOGE trains a state-conditioned distance function that can be readily plugged into standard actor-critic methods as a policy constraint. Simple yet elegant, our algorithm enjoys better generalization compared to state-of-the-art methods on D4RL benchmarks. Theoretical analysis demonstrates the superiority of our approach to existing methods that are solely based on data distribution or support constraints.

1. INTRODUCTION

Offline reinforcement learning (RL) provides a new possibility to learn optimized policies from large, pre-collected datasets without any environment interaction (Levine et al., 2020) . This holds great promise to solve many real-world problems when online interaction is costly or dangerous yet historical data is easily accessible (Zhan et al., 2022) . However, the optimization nature of RL, as well as the need for counterfactual reasoning on unseen data under offline setting, have caused great technical challenges for designing effective offline RL algorithms. Evaluating value function outside data coverage areas can produce falsely optimistic values; without corrective information from online interaction, such estimation errors can accumulate quickly and misguide policy learning process (Van Hasselt et al., 2018; Fujimoto et al., 2018; Kumar et al., 2019) . Recent model-free offline RL methods investigate this error accumulation challenge in several ways: 1) Policy Constraint: directly constraining learned policy to stay inside distribution, or with the support of dataset (Kumar et al., 2019) ; 2) Value Regularization: regularizing value function to assign low values at out-of-distribution (OOD) actions (Kumar et al., 2020b) ; 3) In-sample Learning: learning value function within data samples (Kostrikov et al., 2021b) or simply treating it as the value function of behavioral policy (Brandfonbrener et al., 2021) . All three schools of methods share similar traits of being conservative and omitting evaluation on OOD data, which brings benefits of minimizing model exploitation error, but at the expense of poor generalization of learned policy in OOD regions. Thus, a gaping gap still exists when such methods are applied to real-world tasks, where most datasets only partially cover state-action space with suboptimal policies. Meanwhile, online deep reinforcement learning (DRL) that leverages powerful deep neural network (DNN) with optimistic exploration on unseen samples can yield high-performing policies with promising generalization performance (Mnih et al., 2015; Silver et al., 2017; Degrave et al., 2022 ; , 2018) . This staring contrast propels us to re-think the question: Are we being too conservative? It is well known that DNN has unparalleled approximation and generalization abilities, compared with other function approximators. These attractive abilities have not only led to huge success in computer vision and natural language processing (He et al., 2016; Vaswani et al., 2017) , but also amplified the power of RL. Ideally, in order to obtain the best policy, an algorithm should enable offline policy learning on unseen state-action pairs that function approximators (e.g., Q function, policy network) can generalize well, and add penalization only on non-generalizable areas. However, existing offline RL methods heed too much conservatism on data-related regularizations, while largely overlooking the generalization ability of deep function approximators. Intuitively, let us consider the well-known AntMaze task in the D4RL benchmark (Fu et al., 2020) , where an ant navigates from the start to the destination in a large maze. We observe that existing offline RL methods fail miserably when we remove only small areas of data on the critical pathways to the destination. As shown in Figure 1 , the two missing areas reside in close proximity to the trajectory data. Simply "stitching" up existing trajectories as approximation is not sufficient to form a near-optimal policy at missing regions. Exploiting the generalizability of deep function appoximators, however, can potentially compensate for the missing information. In our study, we observe that the value function approximated by DNN can interpolate well but struggles to extrapolate (see Section 2.2). Such an "interpolate well" phenomenon is also observed in previous studies on the generalization of DNN (Haley & Soloway, 1992; Barnard & Wessels, 1992; Arora et al., 2019a; Xu et al., 2020; Florence et al., 2022) . This finding motivates us to reconsider the generalization of function approximators in offline RL in the context of dataset geometry. Along this line, we discover that a closer distance between a training sample to the offline dataset often leads to a smaller value variation range of the learned neural network, which effectively yields more accurate inference of the value function inside the convex hull (formed by the dataset). By contrast, outside the convex hull, especially in those areas far from the training data, the value variation range usually renders too large to guarantee a small approximation error. Inspired by this, we design a new algorithm DOGE (Distance-sensitive Offline RL with better GEneralization) from the perspective of generalization performance of deep Q function. We first propose a state-conditioned distance function to characterize the geometry of offline datasets, whose output serves as a proxy to the network generalization ability. The resulting algorithm learns a state-conditioned distance function as a policy constraint on standard actor-critic RL framework. Theoretical analysis demonstrates the superior performance bound of our method compared to previous policy constraint methods that are based on data distribution or support constraints. Evaluations on D4RL benchmarks validate that our algorithm enjoys better performance and generalization abilities than state-of-the-art offline RL methods. 2 DATA GEOMETRY VS. DEEP Q FUNCTIONS

2.1. NOTATIONS

We consider the standard continuous action space Markov decision process (MDP) setting, which can be represented by a tuple (S, A, P, r, γ), where S and A are the state and action space, P(s ′ |s, a) is the transition dynamics, r(s, a) is a reward function, and γ ∈ [0, 1) is a discount factor. The objective of the RL problem is to find a policy π(a|s) that maximizes the expected cumulative discounted return, which can be represented by a Q function Q π θ (s, a) = E[ ∞ t=0 γ t r(s t , a t )|s 0 = s, a 0 = a, a t ∼ π(•|s t ), s t+1 ∼ P(•|s t , a t )]. The Q function is typically approximated by function



Figure 1: Left: Visualization of AntMaze dataset. Data transitions of two small areas on the critical pathways to the destination have been removed (red box). Right: Performance of three SOTA offline RL methods.Packer et al., 2018). This staring contrast propels us to re-think the question: Are we being too conservative? It is well known that DNN has unparalleled approximation and generalization abilities, compared with other function approximators. These attractive abilities have not only led to huge success in computer vision and natural language processing(He et al., 2016; Vaswani et al., 2017), but also amplified the power of RL. Ideally, in order to obtain the best policy, an algorithm should enable offline policy learning on unseen state-action pairs that function approximators (e.g., Q function, policy network) can generalize well, and add penalization only on non-generalizable areas.

availability

//github.com/Facebear-ljx/

