SET PREDICTION WITHOUT IMPOSING STRUCTURE AS CONDITIONAL DENSITY ESTIMATION

Abstract

Set prediction is about learning to predict a collection of unordered variables with unknown interrelations. Training such models with set losses imposes the structure of a metric space over sets. We focus on stochastic and underdefined cases, where an incorrectly chosen loss function leads to implausible predictions. Example tasks include conditional point-cloud reconstruction and predicting future states of molecules. In this paper, we propose an alternative to training via set losses by viewing learning as conditional density estimation. Our learning framework fits deep energy-based models and approximates the intractable likelihood with gradient-guided sampling. Furthermore, we propose a stochastically augmented prediction algorithm that enables multiple predictions, reflecting the possible variations in the target set. We empirically demonstrate on a variety of datasets the capability to learn multi-modal densities and produce different plausible predictions. Our approach is competitive with previous set prediction models on standard benchmarks. More importantly, it extends the family of addressable tasks beyond those that have unambiguous predictions.

1. INTRODUCTION

This paper strives for set prediction. Making multiple predictions with intricate interactions is essential in a variety of applications. Examples include predicting the set of attributes given an image (Rezatofighi et al., 2020) , detecting all pedestrians in video footage (Wang et al., 2018) or predicting the future state for a group of molecules (Noé et al., 2020) . Because of their unordered nature, sets constitute a challenge for both the choice of machine learning model and training objective. Models that violate permutation invariance suffer from lower performance, due to the additional difficulty of needing to learn it. Similarly, loss functions should be indifferent to permutations in both the ground-truth and predictions. Additional ambiguity in the target set exacerbates the problem of defining a suitable set loss. We propose Deep Energy-based Set Prediction (DESP) to address the permutation symmetries in both the model and loss function, with a focus on situations where multiple plausible predictions exist. DESP respects the permutation symmetry, by training a permutation invariant energy-based model with a likelihood-based objective. In the literature, assignment-based set distances are applied as loss functions (Zhang et al., 2019; Kosiorek et al., 2020) . Examples include the Chamfer loss (Fan et al., 2017) and the Hungarian loss (Kuhn, 1955) . Both compare individual elements in the predicted set to their assigned groundtruth counterpart and vice-versa. While they guarantee permutation invariance, they also introduce a structure over sets, in the form of a metric space. Choosing the wrong set distance can result in implausible predictions, due to interpolations in the set space for underdefined problems. For example, Fan et al. (2017) observe different set distances to lead to trade-offs between fine-grained shape reconstruction and compactness, for 3d reconstruction from RGB images. As an additional shortcoming, optimizing for a set loss during training poses a limitation on the family of learnable data distributions. More specifically, conditional multi-modal distributions over sets cannot be learned by minimizing an assignment-based set loss during training. To overcome the challenges of imposed structure and multi-modal distributions, we propose to view set prediction as a conditional density

