LEARNING UNSUPERVISED FORWARD MODELS FROM OBJECT KEYPOINTS

Abstract

Object-centric representation is an essential abstraction for forward prediction. Most existing forward models learn this representation through extensive supervision (e.g., object class and bounding box) although such ground-truth information is not readily accessible in reality. To address this, we introduce KINet (Keypoint Interaction Network)-an end-to-end unsupervised framework to reason about object interactions based on a keypoint representation. Using visual observations, our model learns to associate objects with keypoint coordinates and discovers a graph representation of the system as a set of keypoint embeddings and their relations. It then learns an action-conditioned forward model using contrastive estimation to predict future keypoint states. By learning to perform physical reasoning in the keypoint space, our model automatically generalizes to scenarios with a different number of objects, novel backgrounds, and unseen object geometries. Experiments demonstrate the effectiveness of our model in accurately performing forward prediction and learning plannable object-centric representations which can also be used in downstream robotic manipulation tasks.

1. INTRODUCTION

Discovering a structured representation of the world allows humans to perform a wide repertoire of motor tasks such as interacting with objects. The core of this process is learning to predict the response of the environment to applying an action (Miall & Wolpert, 1996; Wolpert & Kawato, 1998) . The internal models, often referred to as the forward models, come up with an estimation of the future states of the world given its current state and the action. By cascading the predictions of a forward model it is also possible to plan a sequence of actions that would bring the world from an initial state to a desired goal state (Wolpert et al., 1998; 1995) . Recently, deep learning architectures have been proposed to perform forward prediction using an object-centric representation of the system (Ye et al., 2020; Chen et al., 2021b; Li et al., 2020a; Qi et al., 2020) . This representation is learned from the visual observation by factorizing the scene into the underlying object instances using ground-truth object states (e.g., object class, position, and bounding box). We identified two major limitations in the existing work: First, they either assume access to the ground-truth object states (Battaglia et al., 2016; Li et al., 2020a) or predict them using idealized techniques such as pre-trained object detection or instance segmentation models (Ye et al., 2020; Qi et al., 2020) . However, obtaining ground truth object states is not always feasible in practice. Relying on object detection and segmentation tools, on the other hand, makes the forward model fragile and dependent on the flawless performance of these tools which is often infeasible in real-world settings. Second, factorizing the scene into a fixed number of object instances limits the generalization of the model models to scenarios with a different number of objects. In this paper, we address both of these limitations by proposing to learn forward models using a keypoint representation. Keypoints represent a set of salient locations of moving entities. Our model KINet (Keypoint Interaction Network) learns an unsupervised forward model in three steps: (1) A keypoint extractor factorizes the scene into keypoints with no supervision other than raw visual observations. (2) A probabilistic graph representation of the system is learned where each node corresponds to a keypoint and edges are keypoints relations. Node features carry implicit object-centric features as well as explicit keypoint state information. (3) With probabilistic message passing, our model learns an action-conditional forward model to predict the future location of keypoints and 1

