LEARNING UNSUPERVISED FORWARD MODELS FROM OBJECT KEYPOINTS

Abstract

Object-centric representation is an essential abstraction for forward prediction. Most existing forward models learn this representation through extensive supervision (e.g., object class and bounding box) although such ground-truth information is not readily accessible in reality. To address this, we introduce KINet (Keypoint Interaction Network)-an end-to-end unsupervised framework to reason about object interactions based on a keypoint representation. Using visual observations, our model learns to associate objects with keypoint coordinates and discovers a graph representation of the system as a set of keypoint embeddings and their relations. It then learns an action-conditioned forward model using contrastive estimation to predict future keypoint states. By learning to perform physical reasoning in the keypoint space, our model automatically generalizes to scenarios with a different number of objects, novel backgrounds, and unseen object geometries. Experiments demonstrate the effectiveness of our model in accurately performing forward prediction and learning plannable object-centric representations which can also be used in downstream robotic manipulation tasks.

1. INTRODUCTION

Discovering a structured representation of the world allows humans to perform a wide repertoire of motor tasks such as interacting with objects. The core of this process is learning to predict the response of the environment to applying an action (Miall & Wolpert, 1996; Wolpert & Kawato, 1998) . The internal models, often referred to as the forward models, come up with an estimation of the future states of the world given its current state and the action. By cascading the predictions of a forward model it is also possible to plan a sequence of actions that would bring the world from an initial state to a desired goal state (Wolpert et al., 1998; 1995) . Recently, deep learning architectures have been proposed to perform forward prediction using an object-centric representation of the system (Ye et al., 2020; Chen et al., 2021b; Li et al., 2020a; Qi et al., 2020) . This representation is learned from the visual observation by factorizing the scene into the underlying object instances using ground-truth object states (e.g., object class, position, and bounding box). We identified two major limitations in the existing work: First, they either assume access to the ground-truth object states (Battaglia et al., 2016; Li et al., 2020a) or predict them using idealized techniques such as pre-trained object detection or instance segmentation models (Ye et al., 2020; Qi et al., 2020) . However, obtaining ground truth object states is not always feasible in practice. Relying on object detection and segmentation tools, on the other hand, makes the forward model fragile and dependent on the flawless performance of these tools which is often infeasible in real-world settings. Second, factorizing the scene into a fixed number of object instances limits the generalization of the model models to scenarios with a different number of objects. In this paper, we address both of these limitations by proposing to learn forward models using a keypoint representation. Keypoints represent a set of salient locations of moving entities. Our model KINet (Keypoint Interaction Network) learns an unsupervised forward model in three steps: (1) A keypoint extractor factorizes the scene into keypoints with no supervision other than raw visual observations. (2) A probabilistic graph representation of the system is learned where each node corresponds to a keypoint and edges are keypoints relations. Node features carry implicit object-centric features as well as explicit keypoint state information. (3) With probabilistic message passing, our model learns an action-conditional forward model to predict the future location of keypoints and reconstruct the future appearances of the scene. We evaluate KINet's forward prediction accuracy and demonstrate that, by learning forward prediction in a keypoint coordinate, our model effectively re-purposes this knowledge and generalizes it to complex unseen circumstances. Our key contributions are: (1) We introduce KINet, an end-to-end method for learning unsupervised action-conditional forward models from visual observations (2) We introduce probabilistic Interaction Networks for efficient message-passing to aggregate relevant information. (3) We introduce the GraphMPC for accurate action planning using graph similarity. (4) We demonstrate learning forward models in keypoint coordinates enables zero-shot generalization to complex unseen scenarios.

2. RELATED WORK

Unsupervised keypoint extraction. Keypoints have been used in computer vision areas such as pose tracking (Zhang et al., 2018; Yao et al., 2019) and video prediction (Minderer et al., 2019; Zhang et al., 2018; Xue et al., 2016; Manuelli et al., 2020) . Recent work explored keypoints for control tasks in reinforcement learning to project the visual observation to a lower-dimensional keypoint space (Kulkarni et al., 2019; Chen et al., 2021a; Jakab et al., 2018) . Forward models. The most fundamentally relevant work to ours is Interaction Networks (IN) (Battaglia et al., 2016; Sanchez-Gonzalez et al., 2018) and follow-up work using graph neural network for forward simulation (Pfaff et al., 2020; Li et al., 2019; Kipf et al., 2018; Mrowca et al., 2018) . These methods rely on the ground-truth states of objects to build explicit object representations. Several approaches extended IN by combining explicit object positional information with implicit visual features from images (Watters et al., 2017; Qi et al., 2020; Ye et al., 2020) . However, two main concerns remain unaddressed. First, visual features are often extracted from object bounding boxes using object detection or segmentation models (Janner et al., 2018; Qi et al., 2020; Kipf et al., 2018) that are either pretrained on the setup (Qi et al., 2020) or use the ground-truth object position (Ye et al., 2020) . Second, these approaches lack generalization as they are formulated on a fixed number of objects. Minderer et al. (2019) used keypoints for video prediction given a history of previous frames. However, their dynamics model is not formulated on external action and cannot be used for action planning applications. Action-conditional forward models. Battaglia et al. (2016) augments the action as an external effect augmented to the node embeddings. Ye et al. (2020) included action as an additional node in a fully connected graph with other nodes representing objects. For probabilistic forward models Henaff et al. (2019) suggests using a latent variable dropout mechanism to properly condition the model on the action (Gal & Ghahramani, 2016) . In a more relevant application to ours, Yan et al. (2020) highlighted the effectiveness of contrastive estimation (Oord et al., 2018) to learn proper actionable object-centric representations. Unsupervised Forwad Models. Kipf et al. (2019) uses an object-level contrastive loss to learn object-centric abstractions in multibody environments with deterministic structures and minimal visual features such as 2d shapes. In our work, we randomize the properties of the system and examine realistic objects. Veerapaneni et al. (2020) infers a set of entities in the image based on their depth and predicts future entity states. Entities are mixed using weight parameters of their distance from the camera. Our work, on the other hand, does not make any assumption on inferring depth and relies on keypoints. Kossen et al. (2019) uses images to infer a set of explicit states (e.g., position and velocity) for a fixed number of objects in a dynamic system to predict the future state of each object using graph networks. Although learning unsupervised state representation, this method is formulated on a fixed number of objects and only tested on environments with simple 2d geometries. Li et al. (2020b) infers the causal structure and makes future predictions on a fixed dynamic system using a pretrained keypoint extractor on topview images. In our work, we experiment with other camera angles and 3d objects with random properties such as geometry and texture.

3. KEYPOINT INTERACTION NETWORKS (KINET)

We assume access to observational data that consists of RGB image, action vector, and the resulting image after applying the action: D = {(I t , u t , I t+1 )}. Our goal is to learn a forward model that predicts the future states of the system with no supervision above the observational data. We describe our approach in two main steps (see Figure 1 ): learning to encode visual observations into keypoints and learning an action-conditioned forward model in the keypoint space. 

3.1. UNSUPERVISED KEYPOINT DETECTION

The keypoint detector (f kp , Fig. 1 ) is a mapping from visual observations to a lower-dimensional set of K keypoint coordinates {x k t } k=1...K = f kp (I t ). The keypoint coordinates are learned by capturing the spatial appearance of the objects in an unsupervised manner. Specifically, the keypoint detector receives a pair of initial and current images (I 0 , I t ) and uses a convolutional encoder to compute a K-dimensional feature map for each image Φ(I 0 ), Φ(I t ) ∈ R H ′ ×W ′ ×K . The expected number of keypoints is set by the dimension K. Next, the feature maps are marginalized into a 2D keypoint coordinate {x k 0 , x k t } k=1...K ∈ R 2 . We use a convolutional image reconstruction model (f rec , Fig. 1 ) with skip connections to inpaint the current image frame using the initial image and the predicted keypoint coordinates Ît = f rec (I 0 , {x k 0 , x k t } k=1...K ). With this formulation, f kp and f rec create a bottleneck to encode the visual observation in a temporally consistent lower-dimensional set of keypoint coordinates (Kulkarni et al., 2019) .

3.2. GRAPH REPRESENTATION OF SYSTEM

After factorizing the system into K keypoints, we build a graph G t = (V t , E t , Z) (undirected, no self-loop) where keypoints and their pairwise relations are the graph nodes and edges. Keypoint positional and visual information are encoded into embedding of nodes {n k t } k=1...K ∈ V t and edges {e ij t } ∈ E t . We also use an adjacency matrix to specify the connectivity as Z ∈ R K×K where [Z] ij ∈ [0, 1] specifies the probability of the edge {e ij } ∈ E. At timestep t, node embeddings encode keypoint visual and positional information {n k t } = Φ k t , x k t . Edge embeddings contain relative positional information of each node pair {e ij t } = x i t -x j t , ∥x i t -x j t ∥ 2 2 .

3.3. PROBABILISTIC INTERACTION NETWORKS

To build a forward model, we extend the recent approaches and propose a probabilistic variation of the Interaction Networks (IN) (Battaglia et al., 2016; Sanchez-Gonzalez et al., 2018) . The core of probabilistic IN is generating edge-level latent variables z ij ∈ R d that represents the edge probabil- ity [Z] ij = z ij . A posterior network p θ infers the elements of the adjacency matrix given the graph representation of the scene. In particular, we model the posterior as p θ (z ij |G t ) = σ(f enc ([n i t , n j t ] )) where σ(.) is the sigmoid function. The probabilistic IN forward model Ĝt+1 = f fwd (G t , u t , Z) predicts the graph representation at the next timestep by taking as input the current graph representation and action (f fwd , Fig. 1 ). The message-passing operation in the forward model can be described as, {ê ij } ← f e (n i , n j , e ij ), {n k } ← f n (n k , i∈N (k) Z ik êik , u t ) where the edge-specific function f e first updates edge embeddings, then the node-specific function f n updates each node embedding nk by probabilistically aggregating its neighboring nodes N (k) information (i.e., the edge probabilities from the inferred adjacency matrix Z is used as weights in the neighbor aggregation). The action vector u t is also an input to the neighbor aggregation step. Note that the functions f enc , f e , and f n are MLPs. Recent models for forward prediction rely on fully connected graphs for message passing (Qi et al., 2020; Ye et al., 2020; Li et al., 2018) . Our model, however, learns to probabilistically sample the neighbor information. Intuitively, this adaptive sampling allows the network to efficiently aggregate the most relevant neighboring information. This is specifically essential in our model since keypoints could provide redundant information if they are in very close proximity.

3.4. FORWARD PREDICTION

The state decoder (f state , Fig. 1 ) transforms the predicted node embeddings of the updated graph Ĝt+1 to a first-order difference which is integrated once to predict the position of the keypoints in the next timestep {x k t+1 } = {x k t } + f state ({n k t+1 }). To reconstruct the image at the next timestep, we borrow the reconstruction model f rec from the keypoint detector Ît+1 = f rec (I 0 , {x k 0 , xk t+1 }).

3.5. LEARNING KINET

Reconstruction loss. The keypoint detector is trained using the distance between the ground truth image and the reconstructed image at each timestep L rec = ∥ Ît -I t ∥ 2 2 . As suggested by Minderer et al. (2019) , errors from the keypoint detector were not backpropagated to other modules of the model. This is a necessary step to ensure the model does not conflate errors from image modules and reasoning modules. Forward loss. The model is also optimized to predict the next state of the keypoints. A forward loss penalizes the distance between the estimated future keypoint locations using first-order state decoder and the keypoint extractor predictions: L fwd = K ∥x k t+1 -f kp (I t+1 ) k ∥ 2 2 . Inference loss. Our model is also trained to minimize the KL-divergence between the posterior and prior distributions: L infer = D KL p ϕ (Z|G) p(Z) . We use independent Gaussian prior p(Z) = i N (z i ) and use reparameterization trick for training (Kingma & Welling, 2013) . Contrastive loss. We use the contrastive estimation method to further enhance the learned graph representations. We add a contrastive loss (Oord et al., 2018; Yan et al., 2020) and reframe it for graph embeddings as: L ctr = -E D [log(S( Ĝt+1 , G + t+1 )/ S( Ĝt+1 , G - t+1 ))]. The predicted graph representations Ĝt+1 are maximally similar to their corresponding positive sample pair G + t+1 := G t+1 and maximally distant from the negative sample pairs G - t+1 := G τ ∀τ ̸ = t + 1. We use a simple node embedding similarity as the graph matching algorithm S(G 1 , G 2 ) = K {n k 1 }.{n k 2 }. The motivation behind adding a contrastive loss is aligning the graph representation of similar object configurations while pushing apart those of dissimilar configurations in the embedding space to enhance the learned graphs. Finally, the combined loss is: L = λ rec L rec + λ fwd L fwd + λ infer L infer + λ ctr L ctr .

3.6. GRAPHMPC PLANNING WITH KINET

We use a learned KINet model and select actions based on Model Predictive Control (MPC) algorithm (Finn & Levine, 2017) in the graph embedding space (GraphMPC). With KINet forward prediction, we compute the predicted graph representation of the next timestep for multiple sampled actions. We then select the optimal action that produces the most similar graph representation to the goal graph representation G goal . We describe our GraphMPC algorithm with a time horizon of T as: u * t = arg max{S G goal , f fwd (G t , {u t:T }) }; t ∈ [0, T ]. Unlike performing conventional MPC only with respect to positional states, GraphMPC allows for accurately bringing the system to a goal state both explicitly (i.e, position) and implicitly (i.e, pose, orientation, and visual appearance).

4. EXPERIMENTAL SETUPS

Our experiments seek to address the following: (1) How accurate is our forward model? (2) Can we use the model for action planning? and (3) Does the model generalize to unseen circumstances?

4.1. DATASET

We apply our approach to learn a forward model for multi-object manipulation tasks. The task involves rearranging multiple objects in the scene and bringing them to the desired goal state using pushing actions. In the RealBlocks dataset, we use the Sawyer robot pushing dataset in Ye et al. (2020) to exemplify how our proposed framework applies to real settings. RealBlocks data includes 5K random pushing actions on 7 blocks. For each action, an RGB image pair with the action vector in the image coordinates is captured. We use MuJoCo 2.0 (Todorov et al., 2012) to generate two sets of simulated scenarios. In BlockObjs dataset, we simulate 10K episodes of random pushing on multiple objects (1-5 objects) where a simple robot end-effector applies randomized pushing for 90 timesteps per episode (Figure A.1.2) . Each object has randomized geometry and color. We sample the object geometry from a predefined continuous range denoted as geom train for training and geom gen for generalization to unseen geometries. Unseen geometries are designed to have elongated shapes to create complex out-ofdistribution cases (Figure A.1.1) . In YCBObjs dataset, we simulate 10K episodes (60 timesteps) of randomly pushing a subset of YCB objects placed on a wooden table (1-5 objects) (Calli et al., 2015) that includes objects of daily life with diverse properties such as shape, size, and texture (Figure A.1.3) . We collect the 4-dimensional action vector (pushing start and end location) and RGB images before and after each action is applied. Images are obtained using an overhead (Top View) and an angled camera (Angled View) (see Appendix A.1 for more details).

4.2. BASELINES

We compare our approach with existing methods for learning object-centric forward models: Forward Model (Forw): We train a convolutional encoder to extract visual features of the scene image (Img) and learn a forward model in the feature space. 2017), we train a convolutional encoder to extract visual features of fixed-size bounding boxes centered on groundtruth object locations (GT state + Img). We use the extracted visual features as object representations in the Interaction Network. This approach requires prior knowledge of the number of objects. Causal Discovery from Videos (V-CDN): Li et al. (2020b) infers the causal structure of a fixed physical system from visual observations and makes future predictions for that system. A pretrained perception module extracts keypoints that are used in an inference module to predict a graph representation for the visual observation which is then used in an IN-based dynamics module to predict the future location of the keypoints.

4.3. TRAINING AND EVALUATION SETTING.

All models are trained on a subset of the simulated data (BlockObjs and YCBObjs) with 3 objects (8K episodes: 80% training, 10% validation, and 10% testing sets). To evaluate generalization to a different number of objects, we use other subsets of data with 1, 2, 4, and 5 objects (∼ 400 episodes for each case). We train our model separately on images obtained from the overhead camera (Top View) and the angled camera (Angled View). For RealBlocks data, we only provide qualitative results as the ground-truth location of the objects is unknown. We set the expected number of keypoints K = 6 for BlockObjs, K = 9 for YCBObjs, and K = 14 for RealBlocks.

5. RESULTS

This section is organized to answer a series of questions to thoroughly evaluate our model and justify the choices we made in formulating our approach.

5.1. DOES THE MODEL ACCURATELY LEARN A FORWARD MODEL?

First, we evaluate the forward prediction accuracy. Figure 2 showcases qualitative results of forward predictions on RealBlocks. Our model factorizes the observation into a keypoint representation and accurately estimates the future appearance of the scene conditioned on external action. The qualitative results highlight that our model learns the effect of the action on objects as well as objectto-object interactions (see Appendix A.4 for more examples). Figure 2 : Qualitative results on RealBlocks dataset (Ye et al., 2020) . Given an image I t , our model factorizes the scene into keypoints x k t and conditioned on the action u t (green arrows) estimates the next keypoint coordinates xk t+1 and appearance Ît+1 of the scene. We quantify the effectiveness of our model in comparison with Forw, ForwInv, IN, and VisIN baselines (Table 1 ) on BlockObjs data. We separately train and examine each model on View and Angled View images. The prediction error is computed as the average distance between the predicted and ground-truth positional states. VisIN performs the best among baseline models as it builds object representations with explicit ground-truth object positions and their visual features. Our model, on the other hand, achieves a comparable performance to VisIN while it does not rely on any supervision beyond the scene images. Forw and ForwInv baselines have similar supervision to ours but are significantly less accurate. This emphasizes the capability of our approach in learning a rich graph representation of the scene and an accurate forward model while relaxing prevailing assumptions of the prior work on the structure of the environment and availability of ground-truth state information (see Appendix A.2 for the learned graph representations).

5.2. CAN WE USE THE MODEL IN CONTROL TASKS?

We design a robotic manipulation task of rearranging a set of objects to the desired goal state using MPC with pushing actions. For all models, we run 1K episodes with randomized object geometries and initial poses and a random goal configuration of objects. The planning horizon is set to T = 40 timesteps in each episode. For our model, we perform GraphMPC based on graph embedding similarity as described in Section 3.6. For all baseline models, we perform MPC directly on the distance to the goal. Figure 5 .2-a shows MPC results of BlockObjs based on Top View observations. Our approach is consistently faster than the baseline models in reaching the goal configuration.

5.3. DOES THE MODEL GENERALIZE TO UNSEEN CIRCUMSTANCES?

One of our main motivations to learn a forward model in keypoint space is to eliminate the dependency of the model formulation on the number of objects in the system which allows for generalization to an unseen number of objects, object geometries, background textures, etc. BlockObjs. We train KINet on 3 randomized blocks with (geom train ) and test for zero-shot generalization to an unseen number of objects (1, 2, 4, 5), unseen object geometries (geom gen ), and unseen background texture. Figure 4 shows qualitative generalization results. We separately train and examine for generalization on Top View and Angled View observations with a planning horizon of T = 80. Since our model learns to perform forward modeling in the keypoint space, it reassigns the keypoints to unseen objects and makes forward predictions. Table 2 summarizes the generalization performance. As expected, by increasing the number of objects the average distance to the goal position increases. Also, objects with out-of-distribution geometries have more distance to the goal position. Our model's generalization performance is significantly superior to Forw baseline which is a simple action-conditioned video prediction model that does not depend on the number of objects and uses the same supervision as our model. Further, we test the performance of the model on unseen background textures (i.e., the table texture). Since the keypoint extraction relies on visual features of salient objects, our model is able to perform the control tasks by ignoring the background and assigning keypoints to the moving objects. YCBObjs. We evaluate the generalization of our model on a set of realistic YCB objects (Calli et al., 2015) with challenging diverse properties such as color, texture, mass, and geometry. We train our model on a random subset of 3 YCB objects and test for generalization to an unseen number of objects (1,2,4,5) with a planning horizon of T = 40. As shown in Figure 5 , our method generalizes well to an unseen number of objects and performs the control task accurately. Importantly, assigning multiple keypoints to each object allows our framework to implicitly capture the orientation of each object, as well as their position without any supervision on the object pose (e.g., compare the power drill pose in Fig, 5 ). We compare the performance of our framework with V-CDN (Li et al., 2020b) baseline which is also a keypoint-based model to learn the structure of physical systems and perform future predictions and potentially generalize to an unseen number of objects. Although V-CDN is formulated to extract the causal structure of a fixed system through visual observations, we pretrain its perception module on our randomized YCBObjs dataset for a direct comparison to our model. Also, to perform a control task, we condition the V-CDN on the action by adding an encoding of the action vector to each keypoint embedding. Figure 6 compares the MPC results (normalized to the number of objects) of our method and V-CDN both trained on a subset of 3 YCB objects. The MPC performance of both models is comparable for rearranging 3 objects (green lines). However, V-CDN is most accurate for the number of objects it has been trained on and significantly less accurate when generalizing to an unseen number of objects. We attribute this to two reasons: (1) Although using keypoints, V-CDN is formulated to infer explicit causal structure and attributes for the graph representation of a fixed system setup which does not necessarily carry over to unseen circumstances. (2) Unlike our method, V-CDN does not take into account the visual features in the model and only uses the keypoint positions.

5.4. ANALYSIS AND ABLATION

We justify the major choices we made to formulate the model with ablation studies. We examine two elements in our approach: the probabilistic graph representation, and the contrastive loss. We train two variants of KINet: (1) KINet -deterministic with a fully connected graph instead of probabilistic (2) KINet -no ctr loss without the contrastive loss. The best forward prediction for both Top View and Angled View images is achieved when the model is probabilistic and trained with a contrastive loss (Table 1 ). The contrastive loss is an essential element in our approach to ensure the learned forward model is accurately action conditional. Also, with a probabilistic graph representation, our model achieves better generalization compared to the deterministic variant. This performance gap is more evident when generalizing to unseen geometries (Table 2 ). 

6. CONCLUSION

In this paper, we proposed a method for learning action-conditioned forward models based only on image observations. We showed that our approach effectively makes forward predictions through keypoint factorization. Also, we demonstrated that a keypoint-based forward model, unlike prior work, does not make assumptions about the number of objects which allows for automatic generalization to a variety of unseen circumstances. Importantly, our model learns a forward model without explicit supervision on ground-truth object state information. One limitation in our formulation is fixing the number of expected keypoints (see Appendix A.3). However, we showed this gives more generalizability compared to fixing the number of objects. We also observed inconsistency in the predicted keypoints for real-robot data scenarios where most of the objects were pushed out of the image frame (see Appendix A.6 ). An interesting future direction is to focus on the keypoint extraction module to further enhance forward models for real settings. Finally, we hope our general approach inspires future research on physical reasoning in settings where ground-truth information is hard to obtain.



Figure 1: KINet performs forward modeling in three major steps: extracting keypoints, inferring a probabilistic graph representation, and estimating the future appearance conditioned on the action. Some arrows are simplified for clarity.

Forward-Inverse Model (ForwInv): We train a convolutional encoder to extract visual features of the scene image (Img) and jointly learn forward and inverse models Agrawal et al. (2016). Interaction Network (IN): We follow Battaglia et al. (2016); Sanchez-Gonzalez et al. (2018) to build an Interaction Network based on the ground truth location of the objects. Each object representation contains the ground-truth position and velocity of the objects (GT state). This approach is only applicable to scenarios where the number of objects in the scene is known and fixed. Visual Interaction Network (VisIN): To replicate Ye et al. (2020); Watters et al. (

Figure 3: MPC results on BlockObjs measured as the distance to goal configuration. (a) Comparison with baselines. (b) MPC for KINet trained on a fixed white background (fixed train ), generalization to random backgrounds (rand gen ), and trained on random backgrounds (rand train ).

Figure 5.2-b compares the MPC results for the KINet trained on a fixed white background (fixed train ), zero-shot generalization of KINet trained on the fixed background to randomized backgrounds (rand gen ), and KINet trained directly on randomized backgrounds (rand train ). As expected, although the MPC converges, the final distance to the goal configuration is larger for rand gen . This final error is statistically at the same level of accuracy for fixed train and rand train . For qualitative examples, see Figure 4 (third row) and Figure A.5.2.

Figure 4: Qualitative results of generalization on BlockObjs to unseen number of objects, geometries (geom gen ) and background textures (domain gen ). The green arrows are the optimal actions.

Figure 5: Qualitative results of generalization on YCBObjs to unseen number of objects (1,2,4,5) for Top and Angled View observations.

Figure 6: YCBObjs MPC results for generalization to a varying number of objects.

Forward Prediction performance on BlockObjs measured as single-step predictions error.

Generalization results measured as the average distance to the goal position.

