LEARNING TO LEARN WITH GENERATIVE MODELS OF NEURAL NETWORK CHECKPOINTS Anonymous

Abstract

We explore a data-driven approach for learning to optimize neural networks. We construct a dataset of neural network checkpoints and train a generative model on the parameters. In particular, our model is a conditional diffusion transformer that, given an initial input parameter vector and a prompted loss, error, or return, predicts the distribution over parameter updates that achieve the desired metric. At test time, it can optimize neural networks with unseen parameters for downstream tasks in just one update. We find that our approach successfully generates parameters for a wide range of loss prompts. Moreover, it can sample multimodal parameter solutions and has favorable scaling properties. We apply our method to different neural network architectures and tasks in supervised and reinforcement learning.

1. INTRODUCTION

Gradient-based optimization is the fuel of modern deep learning. Techniques of this class, such as SGD (Robbins & Monro, 1951) and Adam (Kingma & Ba, 2015) , are easy to implement, scale reasonably well and converge to surprisingly good solutions-even in high-dimensional, non-convex neural network loss landscapes. Over the past decade, they have enabled impressive results in computer vision (Krizhevsky et al., 2012; Girshick et al., 2014) , natural language processing (Vaswani et al., 2017; Radford et al., 2018) and audio generation (Van Den Oord et al., 2016) . While these manual optimization techniques have led to large advances, they suffer from an important limitation: they are unable to improve from past experience. For example, SGD will not converge any faster when used to optimize the same neural network architecture from the same initialization the 100th time versus the first time. Learned optimizers capable of leveraging their past experiences have the potential to overcome this limitation and may accelerate future progress in deep learning. Of course, the concept of learning improved optimizers is not new and dates back to the 1980s, if not earlier, following early work from Schmidhuber (1987) and Bengio et al. (1991) . In recent years, significant effort has been spent on designing algorithms that learn via nested meta-optimization, where the inner loop optimizes the task-level objective and the outer loop learns the optimizer (Andrychowicz et al., 2016; Li & Malik, 2016; Finn et al., 2017) . In some instances, these approaches outperform manual optimizers. However, they are challenging to train in practice due to a reliance on unrolled optimization and reinforcement learning. Taking a modern deep learning perspective suggests a simple, scalable and data-driven approach to this problem. Over the past decade, our community has trained a massive number of checkpoints. These checkpoints contain a wealth of information: diverse parameter configurations and rich metrics such as test losses, classification errors and RL returns that describe the quality of the checkpoint. Instead of leveraging large-scale datasets of images or text, we propose learning from large-scale datasets of checkpoints recorded over the course of many training runs. To this end, we create a dataset of neural network checkpoints (Figure 1 , left). Our dataset consists of 23 million checkpoints from over a hundred thousand training runs. We collect data from supervised learning tasks (MNIST, CIFAR-10) as well as reinforcement learning tasks (Cartpole), and across different neural network architectures (MLPs, CNNs). In addition to parameters, we record relevant task-level metrics in each checkpoint, such as test losses and classification errors. Please see our project website in supplementary materials for additional results and visualizations.



G and .pt refer to generative models and checkpoint extensions, respectively.

annex

G.pt, a generative model of checkpoints. G.pt takes a parameter vector and a loss/error/return prompt as input and predicts the distribution over updated parameters that achieve the prompt.Given this data, we explore generative pre-training directly in parameter space (Figure 1 , right). Specifically, we train transformer-based diffusion models of neural network parameters. Given an initial input parameter vector and a target loss, error or return, these models are trained to predict the distribution over updated parameter vectors for a single network architecture that achieve the target metric. Our method is trained using standard generative modeling techniques instead of unrolled optimization and reinforcement learning algorithms. We call our model G.pt 1 .We show that our approach has a number of favorable properties. First, it is able to rapidly train neural networks from unseen initializations with just one parameter update (Figure 3 ). Second, it can generate parameters that achieve a wide range of prompted losses, errors and returns (Figure 5 ). Third, it is able to generalize to out-of-distribution weight initialization algorithms (Figure 6 ). Fourth, as a generative model, it is able to sample diverse solutions (Figure 8 ). Finally, it can optimize non-differentiable objectives, such as RL returns or classification errors.

2. GENERATIVE PRE-TRAINING FROM NEURAL NETWORK CHECKPOINTS

We pre-train a generative model G.pt on neural network checkpoints. At test time, we use it to generate parameters for neural networks that solve a downstream task.

2.1. A DATASET OF NEURAL NETWORK CHECKPOINTS

In order to train G.pt, we build a dataset of neural network checkpoints. Each checkpoint contains neural network parameters and relevant task-level metrics like train losses, test errors or returns. We use standard optimizers like Adam and SGD with momentum to generate the parameters, and we randomly save a subset of checkpoints from each training run. Our methodology for generating each individual training run is explained in detail in Algorithm 1. See Section 3 for additional details.Augmenting datasets of neural networks. To offset the computational cost of collecting checkpoints, we use data augmentation in neural network parameter space. Given a checkpoint (θ, ℓ), we construct augmented tuples (T (θ), ℓ), where T (•) is the parameter-level augmentation. In order for these augmented tuples to be valid, we need f T (θ) (x) = f θ (x) for all parameter vectors θ and all inputs to the neural network x. One type of augmentation that meets this criteria is permutation augmentation. Consider an MLP. If we apply some permutation to the outgoing weights (and biases) of the input layer and to the incoming weights of the next layer, the output of the neural network will be preserved (Roeder et al., 2021; Schürholt et al., 2021) . Different permutations can be sampled for each layer up to the output layer. This technique is generic and can be applied to MLPs and CNNs alike. We apply the same permutation to both the input and target parameters during pre-training.

