LEARNING PROXIMAL OPERATORS TO DISCOVER MULTIPLE OPTIMA

Abstract

Finding multiple solutions of non-convex optimization problems is a ubiquitous yet challenging task. Most past algorithms either apply single-solution optimization methods from multiple random initial guesses or search in the vicinity of found solutions using ad hoc heuristics. We present an end-to-end method to learn the proximal operator of a family of training problems so that multiple local minima can be quickly obtained from initial guesses by iterating the learned operator, emulating the proximal-point algorithm that has fast convergence. The learned proximal operator can be further generalized to recover multiple optima for unseen problems at test time, enabling applications such as object detection. The key ingredient in our formulation is a proximal regularization term, which elevates the convexity of our training loss: by applying recent theoretical results, we show that for weakly-convex objectives with Lipschitz gradients, training of the proximal operator converges globally with a practical degree of over-parameterization. We further present an exhaustive benchmark for multi-solution optimization to demonstrate the effectiveness of our method.

1. INTRODUCTION

Searching for multiple optima of an optimization problem is a ubiquitous yet under-explored task. In applications like low-rank recovery (Ge et al., 2017) , topology optimization (Papadopoulos et al., 2021) , object detection (Lin et al., 2014) , and symmetry detection (Shi et al., 2020) , it is desirable to recover multiple near-optimal solutions, either because there are many equally-performant global optima or due to the fact that the optimization objective does not capture user preferences precisely. Even for single-solution non-convex optimization, typical methods look for multiple local optima from random initial guesses before picking the best local optimum. Additionally, it is often desirable to obtain solutions to a family of optimization problems with parameters not known in advance, for instance, the weight of a regularization term, without having to restart from scratch. Formally, we define a multi-solution optimization (MSO) problem to be the minimization min x∈X f τ (x), where τ ∈ T encodes parameters of the problem, X is the search space of the variable x, and f τ : R d → R is the objective function depending on τ . The goal of MSO is to identify multiple solutions for each τ ∈ T , i.e., the set {x * ∈ X : f τ (x * ) = min x∈X f τ (x)}, which can contain more than one element or even infinitely many elements. In this work, we assume that X ⊂ R d is bounded and that d is small, and that T is, in a loose sense, a continuous space, such that the objective f τ changes continuously as τ varies. To make gradient-based methods viable, we further assume that each f τ is differentiable almost everywhere. As finding all global minima in the general case is extremely challenging, realistically our goal is to find a diverse set of local minima. As a concrete example, for object detection, T could parameterize the space of images and X could be the 4-dimensional space of bounding boxes (ignoring class labels). Then, f τ (x) could be the minimum distance between the bounding box x ∈ X and any ground truth box for image τ ∈ T . Minimizing f τ (x) would yield all object bounding boxes for image τ . Object detection can then be cast as solving this MSO on a training set of images and extrapolating to unseen images (Section 5.5). Object detection is a singular example of MSO where the ground truth annotation is widely available. In such cases, supervised learning can solve MSO by predicting a fixed number of solutions together with confidence scores using a set-based loss such as the Hausdorff distance. Unfortunately, such annotation is not available for most optimization problems in the wild where we only have access to the objective functions -this is the setting that our method aims to tackle. Our work is inspired by the proximal-point algorithm (PPA), which applies the proximal operator of the objective function to an initial point iteratively to refine it to a local minimum. PPA is known to converge faster than gradient descent even when the proximal operator is approximated, both theoretically (Rockafellar, 1976; 2021) and empirically (e.g., Figure 2 of Hoheisel et al. ( 2020)). If the proximal operator of the objective function is available, then MSO can be solved efficiently by running PPA from a variety of initial points. However, obtaining a good approximation of the proximal operator for generic functions is difficult, and typically we have to solve a separate optimization problem for each evaluation of the proximal operator (Davis & Grimmer, 2019) . In this work, we approximate the proximal operator using a neural network that is trained using a straightforward loss term including only the objective and a proximal term that penalizes deviation from the input point. Crucially, our training does not require accessing the ground truth proximal operator. Additionally, neural parameterization allows us to learn the proximal operator for all {f τ } τ ∈T by treating τ as an input to the network along with an application-specific encoder. Once trained, the learned proximal operator allows us to effortlessly run PPA from any initial point to arrive at a nearby local minimum; from a generative modeling point of view, the learned proximal operator implicitly encodes the solutions of an MSO problem as the pushforward of a prior distribution by iterated application of the operator. Such a formulation bypasses the need to predict a fixed number of solutions and can represent infinitely many solutions. The proximal term in our loss promotes the convexity of the formulation: applying recent results (Kawaguchi & Huang, 2019), we show that for weakly-convex objectives with Lipschitz gradients-in particular, objectives with bounded second derivatives-with practical degrees of over-parameterization, training converges globally and the ground truth proximal operator is recovered (Theorem 3.1 below). Such a global convergence result is not known for any previous learning-to-optimize method (Chen et al., 2021) . Literature on MSO is scarce, so we build a benchmark with a wide variety of applications including level set sampling, non-convex sparse recovery, max-cut, 3D symmetry detection, and object detection in images. When evaluated on this benchmark, our learned proximal operator reliably produces high-quality results compared to reasonable alternatives, while converging in a few iterations.

2. RELATED WORKS

Learning to optimize. Learning-to-optimize (L2O) methods utlilize past optimization experience to optimize future problems more effectively; see (Chen et al., 2021) for a survey. Model-free L2O uses recurrent neural networks to discover new optimizers suitable for similar problems (Andrychowicz et al., 2016; Li & Malik, 2016; Chen et al., 2017; Cao et al., 2019) ; while shown to be practical, these methods have almost no theoretical guarantee for the training to converge (Chen et al., 2021) . In comparison, we learn a problem-dependent proximal operator so that at test time we do not need access to objective functions or their gradients, which can be costly to evaluate (e.g. symmetry detection in Section 5.4) or unavailable (e.g. object detection in Section 5.5). Model-based L2O substitutes components of a specialized optimization framework or schematically unrolls an optimization procedure with neural networks. Related to proximal methods, Gregor & LeCun (2010) emulate a few iterations of proximal gradient descent using neural networks for sparse recovery with an ℓ 1 regularizer, extended to non-convex regularizers by Yang et al. (2020) ; a similar technique is applied to susceptibility-tensor imaging in Fang et al. (2022) . Gilton et al. (2021) propose a deep equilibrium model with proximal gradient descent for inverse problems in imaging that circumvents expensive backpropagation of unrolling iterations. Meinhardt et al. (2017) use a fixed denoising neural network as a surrogate proximal operator for inverse imaging problems. All these works use

