LEARNING UNSUPERVISED FORWARD MODELS FROM OBJECT KEYPOINTS

Abstract

Object-centric representation is an essential abstraction for forward prediction. Most existing forward models learn this representation through extensive supervision (e.g., object class and bounding box) although such ground-truth information is not readily accessible in reality. To address this, we introduce KINet (Keypoint Interaction Network)-an end-to-end unsupervised framework to reason about object interactions based on a keypoint representation. Using visual observations, our model learns to associate objects with keypoint coordinates and discovers a graph representation of the system as a set of keypoint embeddings and their relations. It then learns an action-conditioned forward model using contrastive estimation to predict future keypoint states. By learning to perform physical reasoning in the keypoint space, our model automatically generalizes to scenarios with a different number of objects, novel backgrounds, and unseen object geometries. Experiments demonstrate the effectiveness of our model in accurately performing forward prediction and learning plannable object-centric representations which can also be used in downstream robotic manipulation tasks.

1. INTRODUCTION

Discovering a structured representation of the world allows humans to perform a wide repertoire of motor tasks such as interacting with objects. The core of this process is learning to predict the response of the environment to applying an action (Miall & Wolpert, 1996; Wolpert & Kawato, 1998) . The internal models, often referred to as the forward models, come up with an estimation of the future states of the world given its current state and the action. By cascading the predictions of a forward model it is also possible to plan a sequence of actions that would bring the world from an initial state to a desired goal state (Wolpert et al., 1998; 1995) . Recently, deep learning architectures have been proposed to perform forward prediction using an object-centric representation of the system (Ye et al., 2020; Chen et al., 2021b; Li et al., 2020a; Qi et al., 2020) . This representation is learned from the visual observation by factorizing the scene into the underlying object instances using ground-truth object states (e.g., object class, position, and bounding box). We identified two major limitations in the existing work: First, they either assume access to the ground-truth object states (Battaglia et al., 2016; Li et al., 2020a) or predict them using idealized techniques such as pre-trained object detection or instance segmentation models (Ye et al., 2020; Qi et al., 2020) . However, obtaining ground truth object states is not always feasible in practice. Relying on object detection and segmentation tools, on the other hand, makes the forward model fragile and dependent on the flawless performance of these tools which is often infeasible in real-world settings. Second, factorizing the scene into a fixed number of object instances limits the generalization of the model models to scenarios with a different number of objects. In this paper, we address both of these limitations by proposing to learn forward models using a keypoint representation. Keypoints represent a set of salient locations of moving entities. Our model KINet (Keypoint Interaction Network) learns an unsupervised forward model in three steps: (1) A keypoint extractor factorizes the scene into keypoints with no supervision other than raw visual observations. (2) A probabilistic graph representation of the system is learned where each node corresponds to a keypoint and edges are keypoints relations. Node features carry implicit object-centric features as well as explicit keypoint state information. (3) With probabilistic message passing, our model learns an action-conditional forward model to predict the future location of keypoints and reconstruct the future appearances of the scene. We evaluate KINet's forward prediction accuracy and demonstrate that, by learning forward prediction in a keypoint coordinate, our model effectively re-purposes this knowledge and generalizes it to complex unseen circumstances. Our key contributions are: (1) We introduce KINet, an end-to-end method for learning unsupervised action-conditional forward models from visual observations (2) We introduce probabilistic Interaction Networks for efficient message-passing to aggregate relevant information. (3) We introduce the GraphMPC for accurate action planning using graph similarity. (4) We demonstrate learning forward models in keypoint coordinates enables zero-shot generalization to complex unseen scenarios.

2. RELATED WORK

Unsupervised keypoint extraction. Keypoints have been used in computer vision areas such as pose tracking (Zhang et al., 2018; Yao et al., 2019) and video prediction (Minderer et al., 2019; Zhang et al., 2018; Xue et al., 2016; Manuelli et al., 2020) . Recent work explored keypoints for control tasks in reinforcement learning to project the visual observation to a lower-dimensional keypoint space (Kulkarni et al., 2019; Chen et al., 2021a; Jakab et al., 2018) . Forward models. The most fundamentally relevant work to ours is Interaction Networks (IN) (Battaglia et al., 2016; Sanchez-Gonzalez et al., 2018) and follow-up work using graph neural network for forward simulation (Pfaff et al., 2020; Li et al., 2019; Kipf et al., 2018; Mrowca et al., 2018) . These methods rely on the ground-truth states of objects to build explicit object representations. Several approaches extended IN by combining explicit object positional information with implicit visual features from images (Watters et al., 2017; Qi et al., 2020; Ye et al., 2020) . However, two main concerns remain unaddressed. First, visual features are often extracted from object bounding boxes using object detection or segmentation models (Janner et al., 2018; Qi et al., 2020; Kipf et al., 2018) that are either pretrained on the setup (Qi et al., 2020) or use the ground-truth object position (Ye et al., 2020) . Second, these approaches lack generalization as they are formulated on a fixed number of objects. Minderer et al. ( 2019) used keypoints for video prediction given a history of previous frames. However, their dynamics model is not formulated on external action and cannot be used for action planning applications. Action-conditional forward models. Battaglia et al. (2016) augments the action as an external effect augmented to the node embeddings. Ye et al. ( 2020) included action as an additional node in a fully connected graph with other nodes representing objects. For probabilistic forward models Henaff et al. (2019) suggests using a latent variable dropout mechanism to properly condition the model on the action (Gal & Ghahramani, 2016) . In a more relevant application to ours, Yan et al. (2020) highlighted the effectiveness of contrastive estimation (Oord et al., 2018) to learn proper actionable object-centric representations. Unsupervised Forwad Models. Kipf et al. (2019) uses an object-level contrastive loss to learn object-centric abstractions in multibody environments with deterministic structures and minimal visual features such as 2d shapes. In our work, we randomize the properties of the system and examine realistic objects. Veerapaneni et al. (2020) infers a set of entities in the image based on their depth and predicts future entity states. Entities are mixed using weight parameters of their distance from the camera. Our work, on the other hand, does not make any assumption on inferring depth and relies on keypoints. Kossen et al. (2019) uses images to infer a set of explicit states (e.g., position and velocity) for a fixed number of objects in a dynamic system to predict the future state of each object using graph networks. Although learning unsupervised state representation, this method is formulated on a fixed number of objects and only tested on environments with simple 2d geometries. Li et al. (2020b) infers the causal structure and makes future predictions on a fixed dynamic system using a pretrained keypoint extractor on topview images. In our work, we experiment with other camera angles and 3d objects with random properties such as geometry and texture.

3. KEYPOINT INTERACTION NETWORKS (KINET)

We assume access to observational data that consists of RGB image, action vector, and the resulting image after applying the action: D = {(I t , u t , I t+1 )}. Our goal is to learn a forward model that predicts the future states of the system with no supervision above the observational data. We describe our approach in two main steps (see Figure 1 ): learning to encode visual observations into keypoints and learning an action-conditioned forward model in the keypoint space.

