SET PREDICTION WITHOUT IMPOSING STRUCTURE AS CONDITIONAL DENSITY ESTIMATION

Abstract

Set prediction is about learning to predict a collection of unordered variables with unknown interrelations. Training such models with set losses imposes the structure of a metric space over sets. We focus on stochastic and underdefined cases, where an incorrectly chosen loss function leads to implausible predictions. Example tasks include conditional point-cloud reconstruction and predicting future states of molecules. In this paper, we propose an alternative to training via set losses by viewing learning as conditional density estimation. Our learning framework fits deep energy-based models and approximates the intractable likelihood with gradient-guided sampling. Furthermore, we propose a stochastically augmented prediction algorithm that enables multiple predictions, reflecting the possible variations in the target set. We empirically demonstrate on a variety of datasets the capability to learn multi-modal densities and produce different plausible predictions. Our approach is competitive with previous set prediction models on standard benchmarks. More importantly, it extends the family of addressable tasks beyond those that have unambiguous predictions.

1. INTRODUCTION

This paper strives for set prediction. Making multiple predictions with intricate interactions is essential in a variety of applications. Examples include predicting the set of attributes given an image (Rezatofighi et al., 2020) , detecting all pedestrians in video footage (Wang et al., 2018) or predicting the future state for a group of molecules (Noé et al., 2020) . Because of their unordered nature, sets constitute a challenge for both the choice of machine learning model and training objective. Models that violate permutation invariance suffer from lower performance, due to the additional difficulty of needing to learn it. Similarly, loss functions should be indifferent to permutations in both the ground-truth and predictions. Additional ambiguity in the target set exacerbates the problem of defining a suitable set loss. We propose Deep Energy-based Set Prediction (DESP) to address the permutation symmetries in both the model and loss function, with a focus on situations where multiple plausible predictions exist. DESP respects the permutation symmetry, by training a permutation invariant energy-based model with a likelihood-based objective. In the literature, assignment-based set distances are applied as loss functions (Zhang et al., 2019; Kosiorek et al., 2020) . Examples include the Chamfer loss (Fan et al., 2017) and the Hungarian loss (Kuhn, 1955) . Both compare individual elements in the predicted set to their assigned groundtruth counterpart and vice-versa. While they guarantee permutation invariance, they also introduce a structure over sets, in the form of a metric space. Choosing the wrong set distance can result in implausible predictions, due to interpolations in the set space for underdefined problems. For example, Fan et al. (2017) observe different set distances to lead to trade-offs between fine-grained shape reconstruction and compactness, for 3d reconstruction from RGB images. As an additional shortcoming, optimizing for a set loss during training poses a limitation on the family of learnable data distributions. More specifically, conditional multi-modal distributions over sets cannot be learned by minimizing an assignment-based set loss during training. To overcome the challenges of imposed structure and multi-modal distributions, we propose to view set prediction as a conditional density estimation problem, where P (Y |x) denotes the distribution for the target set Y given observed features x. In this work we focus on distributions taking the form of deep energy-based models (Ngiam et al., 2011; Zhai et al., 2016; Belanger & McCallum, 2016) : P θ (Y |x) = 1 Z(x; θ) exp (-E θ (x, Y )), with Z as the partition function and E θ the energy function with parameters θ. The expressiveness of neural networks (Cybenko, 1989) allows for learning multi-modal densities P θ (Y |x). This sets the approach apart from forward-processing models, that either require conditional independence assumptions (Rezatofighi et al., 2017) , or an order on the predictions, when applying the chain rule (Vinyals et al., 2016) . Energy-based prediction is regarded as a non-linear combinatorial optimization problem (LeCun et al., 2006) : Ŷ = arg min Y E θ (x, Y ), which is typically approximated by gradient descent for deep energy-based models (Belanger & Mc-Callum, 2016; Belanger et al., 2017) . We replace the deterministic gradient descent with a stochastically augmented prediction algorithm, to account for multiple plausible predictions. We show that our stochastic version outperforms standard gradient descent for set prediction tasks. Our main contribution is DESP, a training and prediction framework for set prediction, that removes the limitations imposed by assignment-based set losses. Sampling plays a key role in DESP. For training, sampling approximates the intractable model gradients, while during prediction, sampling introduces stochasticity. We show the generality of our framework by adapting recently proposed permutation invariant neural networks as set prediction deep energy-based models. We demonstrate that our approach (i) learns multi-modal distributions over sets (ii) makes multiple plausible predictions (iii) generalizes over different deep energy-based model architectures and (iv) is competitive even in non-stochastic settings, without requiring problem specific loss-engineering.

2.1. TRAINING

Our goal is to train a deep energy based model for set prediction, such that all plausible sets are captured by the model. Regression models with a target in the R d space, that are trained with a root mean-square error (RMSE) loss, implicitly assume a Gaussian distribution over the target. Analog to the RMSE, assignment-based set losses assume a uni-modal distribution over the set space. Training with the negative log-likelihood (NLL) circumvents the issues of assignment-based set losses. Notably, NLL does not necessitate explicit element-wise comparisons, but treats the set holistically. We reformulate the NLL for the training data distribution P D as: E (x,Y )∼P D [-log(P θ (Y |x))] = E (x,Y )∼P D [E θ (x, Y )] + E x∼P D [log(Z(x; θ))] . The gradient of the left summand is approximated by sampling a mini-batch of n tuples {(x i , Y + i )} i=0..n from the training set. The gradient of the right summand is approximated by solely sampling input features {x i } i=0..m . Directly evaluating ∂ ∂θ log(Z(x; θ)) is intractable; instead we approximate the gradient by sampling {Y - j } j=0..k from the model distribution: ∂ ∂θ log(Z(x; θ)) = -E Y ∼P θ ∂ ∂θ E θ (x, Y ) ≈ - k j=0 ∂ ∂θ E θ (x, Y - j ). The resulting approximate NLL objective is equivalent to contrasting the energy value for real and synthesized targets, with the former being minimized and the latter maximized. The objective is reminiscent of the discriminator's loss in generative adversarial networks (Goodfellow et al., 2014) , where a real sample is contrasted to a sample synthesized by the generator network. In practice, setting k=1 suffices. The Langevin MCMC algorithm allows for efficient sampling from high dimensional spaces (Geman & Geman, 1984; Neal et al., 2011) . Access to the derivative of the unnormalized density function

