LEARNING TO LEARN WITH GENERATIVE MODELS OF NEURAL NETWORK CHECKPOINTS Anonymous

Abstract

We explore a data-driven approach for learning to optimize neural networks. We construct a dataset of neural network checkpoints and train a generative model on the parameters. In particular, our model is a conditional diffusion transformer that, given an initial input parameter vector and a prompted loss, error, or return, predicts the distribution over parameter updates that achieve the desired metric. At test time, it can optimize neural networks with unseen parameters for downstream tasks in just one update. We find that our approach successfully generates parameters for a wide range of loss prompts. Moreover, it can sample multimodal parameter solutions and has favorable scaling properties. We apply our method to different neural network architectures and tasks in supervised and reinforcement learning.

1. INTRODUCTION

Gradient-based optimization is the fuel of modern deep learning. Techniques of this class, such as SGD (Robbins & Monro, 1951) and Adam (Kingma & Ba, 2015) , are easy to implement, scale reasonably well and converge to surprisingly good solutions-even in high-dimensional, non-convex neural network loss landscapes. Over the past decade, they have enabled impressive results in computer vision (Krizhevsky et al., 2012; Girshick et al., 2014) , natural language processing (Vaswani et al., 2017; Radford et al., 2018) and audio generation (Van Den Oord et al., 2016) . While these manual optimization techniques have led to large advances, they suffer from an important limitation: they are unable to improve from past experience. For example, SGD will not converge any faster when used to optimize the same neural network architecture from the same initialization the 100th time versus the first time. Learned optimizers capable of leveraging their past experiences have the potential to overcome this limitation and may accelerate future progress in deep learning. Of course, the concept of learning improved optimizers is not new and dates back to the 1980s, if not earlier, following early work from Schmidhuber (1987) and Bengio et al. (1991) . In recent years, significant effort has been spent on designing algorithms that learn via nested meta-optimization, where the inner loop optimizes the task-level objective and the outer loop learns the optimizer (Andrychowicz et al., 2016; Li & Malik, 2016; Finn et al., 2017) . In some instances, these approaches outperform manual optimizers. However, they are challenging to train in practice due to a reliance on unrolled optimization and reinforcement learning. Taking a modern deep learning perspective suggests a simple, scalable and data-driven approach to this problem. Over the past decade, our community has trained a massive number of checkpoints. These checkpoints contain a wealth of information: diverse parameter configurations and rich metrics such as test losses, classification errors and RL returns that describe the quality of the checkpoint. Instead of leveraging large-scale datasets of images or text, we propose learning from large-scale datasets of checkpoints recorded over the course of many training runs. To this end, we create a dataset of neural network checkpoints (Figure 1 , left). Our dataset consists of 23 million checkpoints from over a hundred thousand training runs. We collect data from supervised learning tasks (MNIST, CIFAR-10) as well as reinforcement learning tasks (Cartpole), and across different neural network architectures (MLPs, CNNs). In addition to parameters, we record relevant task-level metrics in each checkpoint, such as test losses and classification errors. Please see our project website in supplementary materials for additional results and visualizations. 1

